Customizing Thema

Extend Thema with custom filters, graph builders, and large-scale workflows.

Custom Filter Functions

Filters return 1 (keep) or 0 (discard) for each graph.

Filter Structure

def custom_filter(star_object) -> int:
    if star_object.starGraph is None:
        return 0
    # Your logic here
    return 1 if condition else 0

Example: Component and Coverage Filter

import networkx as nx

def multi_criteria_filter(star_object, min_components=3, min_coverage=50):
    """Keep graphs with at least min_components and min_coverage unique items."""
    if star_object.starGraph is None:
        return 0

    graph = star_object.starGraph.graph
    n_components = nx.number_connected_components(graph)

    unique_items = set()
    for node in graph.nodes():
        unique_items.update(graph.nodes[node]["membership"])

    return 1 if (n_components >= min_components and len(unique_items) >= min_coverage) else 0

Usage:

from functools import partial
from thema.multiverse import Galaxy

galaxy = Galaxy(YAML_PATH="params.yaml")
galaxy.fit()

my_filter = partial(multi_criteria_filter, min_components=4, min_coverage=100)
selection = galaxy.collapse(filter_fn=my_filter, nReps=3)

Example: Density Filter

def density_filter(star_object, min_density=0.1):
    """Keep graphs with edge density above threshold."""
    if star_object.starGraph is None:
        return 0

    graph = star_object.starGraph.graph
    n_nodes = graph.number_of_nodes()
    n_edges = graph.number_of_edges()

    if n_nodes < 2:
        return 0

    max_edges = n_nodes * (n_nodes - 1) / 2
    density = n_edges / max_edges if max_edges > 0 else 0

    return 1 if density >= min_density else 0

Custom Star Implementations

Create custom graph construction methods by subclassing Star.

Minimal Star Template

from thema.multiverse.universe.star import Star
from thema.multiverse.universe.utils.starGraph import starGraph
import networkx as nx

class CustomStar(Star):
    def __init__(self, data_path, clean_path, projection_path, **kwargs):
        super().__init__(data_path, clean_path, projection_path)
        self.custom_param = kwargs.get("custom_param", default_value)

    def fit(self):
        # Access data
        projection = self.projection  # np.ndarray from Oort
        clean_data = self.clean       # pd.DataFrame from Planet

        # Build graph
        G = nx.Graph()
        # ... your graph construction logic ...

        # Wrap in starGraph
        self.starGraph = starGraph(G)

Example: k-NN Graph

from sklearn.neighbors import kneighbors_graph
import numpy as np

class kNNStar(Star):
    def __init__(self, data_path, clean_path, projection_path, k=5):
        super().__init__(data_path, clean_path, projection_path)
        self.k = k

    def fit(self):
        projection = self.projection

        # Build k-NN graph
        adjacency = kneighbors_graph(
            projection,
            n_neighbors=self.k,
            mode="connectivity"
        )

        G = nx.from_scipy_sparse_array(adjacency)

        # Add node attributes (membership = single item per node)
        for i, node in enumerate(G.nodes()):
            G.nodes[node]["membership"] = [i]

        self.starGraph = starGraph(G)

Registering Custom Stars

To use custom stars in the pipeline:

  1. Add your star class to thema/multiverse/universe/projectiles/

  2. Register in thema/config.py via tag_to_class mapping

  3. Reference by tag in YAML or params dict

Scaling to Large Datasets

Performance Considerations

Parameter Grid Size

Combinatorial explosion occurs quickly. A grid with 5 values per parameter across 4 parameters yields 625 combinations.

Memory Management
  • Filter aggressively before collapse()

  • Process component counts separately to reduce memory footprint

  • Use distance_threshold instead of nReps for adaptive selection

Computational Cost
  • forman_curvature: Fast (∼1s per graph pair)

  • balanced_forman_curvature: Moderate (∼3s per graph pair)

  • ollivier_ricci_curvature: Slow (∼10s+ per graph pair)

Optimization Strategies

  1. Start Small: Validate pipeline with 10-20 graphs before expanding

  2. Progressive Refinement: Coarse grid -> filter -> fine grid on promising regions

  3. Parallel Stages: Planet, Oort, and Galaxy parallelize internally

  4. Disk I/O: Use SSD for output directories; avoid network drives

Example: Large-Scale Workflow

from thema.multiverse import Planet, Oort, Galaxy
from thema.multiverse.universe.utils.starFilters import (
    minimum_unique_items_filter,
    component_count_range_filter
)
from pathlib import Path

# Stage 1: Preprocessing (few seeds, multiple samples for robustness)
planet = Planet(
    data="/data/large_dataset.pkl",
    outDir="./outputs/large_run/clean",
    seeds=[42, 13, 99],
    numSamples=3,
    scaler="standard",
    encoding="one_hot",
    imputeColumns="auto",
    imputeMethods="auto"
)
planet.fit()

# Stage 2: Embeddings (sparse grid initially)
oort = Oort(
    data="/data/large_dataset.pkl",
    cleanDir="./outputs/large_run/clean",
    outDir="./outputs/large_run/projections",
    params={
        "tsne": {
            "perplexity": [30, 66],
            "dimensions": [2],
            "seed": [42]
        },
        "pca": {
            "dimensions": [2],
            "seed": [42]
        }
    }
)
oort.fit()

# Stage 3: Graphs (moderate grid)
galaxy = Galaxy(
    data="/data/large_dataset.pkl",
    cleanDir="./outputs/large_run/clean",
    projDir="./outputs/large_run/projections",
    outDir="./outputs/large_run/models",
    params={
        "jmap": {
            "nCubes": [5, 10, 20],
            "percOverlap": [0.5, 0.7],
            "minIntersection": [-1],
            "clusterer": [
                ["HDBSCAN", {"min_cluster_size": 5}],
                ["HDBSCAN", {"min_cluster_size": 10}]
            ]
        }
    }
)
galaxy.fit()

# Stage 4: Filter and select (use fast curvature metric)
total_items = len(planet.data)
coverage_filter = minimum_unique_items_filter(int(0.9 * total_items))
comp_filter = component_count_range_filter(3, 8)

def combined(star):
    return coverage_filter(star) and comp_filter(star)

selection = galaxy.collapse(
    filter_fn=combined,
    curvature="forman_curvature",  # Fast option
    distance_threshold=200,
    selector="max_nodes"
)

print(f"Selected {len(selection)} representatives from {len(list(Path('./outputs/large_run/models').glob('*.pkl')))} graphs")

Debugging Tips

Check intermediate outputs

Verify Moon, Comet, and Star files exist and load correctly

Inspect graph properties

Load Star objects and check nodes, edges, components before filtering

Validate filters

Test custom filters on a single graph before applying to all

Monitor memory

Use top or htop during collapse() to catch memory issues

Log parameter combinations

Track which parameter sets produce usable graphs