Source code for thema.multiverse.universe.geodesics

# File: multiverse/universe/geodesics.py
# Lasted Updated: 10/21/25
# Updated By: SG

import os
import pickle
from typing import Callable

import numpy as np
import networkx as nx
from scott import Comparator

from .utils.starFilters import nofilterfunction


[docs] def stellar_curvature_distance( files: str | list, filterfunction: Callable | None = None, curvature="forman_curvature", vectorization="landscape", ): """ Compute a pairwise distance matrix between graphs using curvature filtrations. Parameters ---------- files : str or list[str] Either a path to a directory containing starGraph files or a list of individual file paths. filterfunction : Callable, optional A custom filter function to select a subset of cosmic graphs. Defaults to None. curvature : str, optional The curvature measure to use. Defaults to "forman_curvature". Supported values (increasing in complexity and computational intensity): - "forman_curvature" : A combinatorial measure based purely on local graph structure. Fast to compute and suitable for large graphs or exploratory analysis. - "balanced_forman_curvature" : A refinement of Forman curvature that balances edge contributions, improving sensitivity to degree heterogeneity while remaining efficient. - "resistance_curvature" : Derived from effective resistance distances between nodes. Captures global connectivity patterns but is more computationally demanding. - "ollivier_ricci_curvature" : A transport-based curvature measure that reflects the geometry of probabilistic mass movement between node neighborhoods. Provides the most geometric insight but is the slowest to compute. For further details, see: https://github.com/aidos-lab/curvature-filtrations/blob/main/notebooks/bagpipeline.ipynb vectorization : str, optional Vectorization method for computing distances. Defaults to "landscape". Returns ------- keys : np.ndarray Array of keys identifying the models being compared. distance_matrix : np.ndarray Pairwise distance matrix between the persistence landscapes of the starGraphs. """ starGraphs = _load_starGraphs(files, graph_filter=filterfunction) keys = list(starGraphs.keys()) starGraph_list = list(starGraphs.values()) graphs = [sg.graph for sg in starGraph_list] mapped_graphs, _ = _map_string_nodes_to_integers(graphs) C = Comparator(measure=curvature, weight="weight") n = len(mapped_graphs) distance_matrix = np.zeros((n, n)) for i in range(n): for j in range(i + 1, n): d_ij = C.fit_transform( [mapped_graphs[i]], [mapped_graphs[j]], metric=vectorization, ) distance_matrix[i, j] = d_ij distance_matrix[j, i] = d_ij return np.array(keys), distance_matrix
def _load_starGraphs(dir: str | list, graph_filter: Callable | None = None) -> dict: """ Load starGraphs from a directory or a list of pickle files. Only returns starGraphs that satisfy the `graph_filter`. Parameters ---------- dir : str or list Directory containing .pkl graphs, or a list of .pkl file paths. graph_filter : Callable, optional Function that returns True for graphs to include. Defaults to nofilterfunction. Returns ------- dict Mapping of file path to starGraph object. """ if graph_filter is None: graph_filter = nofilterfunction # Handle list vs directory if isinstance(dir, list): files = [str(f) for f in dir] else: assert os.path.isdir(dir), "Invalid graph Directory" assert len(os.listdir(dir)) > 0, "Graph directory appears to be empty!" files = [os.path.join(dir, f) for f in os.listdir(dir) if f.endswith(".pkl")] if not files: raise ValueError("No .pkl files found to load.") starGraphs = {} for graph_file in files: with open(graph_file, "rb") as f: graph_object = pickle.load(f) if graph_filter(graph_object): if graph_object.starGraph is not None: starGraphs[graph_file] = graph_object.starGraph if not starGraphs: raise ValueError( "No valid starGraphs produced. Your filter function may be too stringent." ) return starGraphs def _map_string_nodes_to_integers(graphs): """ Map string node IDs to integers for GUDHI compatibility. GUDHI's SimplexTree requires integer node IDs, but jmapStar creates graphs with string node IDs ('a', 'b', 'c', etc.). This function creates a consistent mapping across all graphs. Parameters ---------- graphs : list List of networkx graphs that may have string node IDs Returns ------- tuple (mapped_graphs, node_mapping) where mapped_graphs have integer node IDs and node_mapping is the string->int mapping dict """ all_nodes = set() for graph in graphs: all_nodes.update(graph.nodes()) node_mapping = {node: i for i, node in enumerate(sorted(all_nodes))} mapped_graphs = [] for graph in graphs: if any(not isinstance(node, int) for node in graph.nodes()): mapped_graph = nx.relabel_nodes(graph, node_mapping) mapped_graphs.append(mapped_graph) else: mapped_graphs.append(graph.copy()) return mapped_graphs, node_mapping