# File: multiverse/universe.py
# Lasted Updated: 10/21/25
# Updated By: SG
import glob
import importlib
import itertools
import logging
import os
import pickle
from collections import Counter
import time
from typing import cast
import numpy as np
import networkx as nx
from omegaconf import OmegaConf
from sklearn.cluster import AgglomerativeClustering
from sklearn.manifold import MDS
from .utils import starFilters, starSelectors
from ... import config
from ...utils import (
create_file_name,
function_scheduler,
get_current_logging_config,
)
from . import geodesics
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
[docs]
class Galaxy:
"""
A space of stars.
The largest space of data representations, a galaxy can be searched to
find particular stars and systems most suitable for a particular explorer.
Galaxy generates a space of star objects from the distribution of
inner and outer systems.
Members
------
data: str
Path to the original raw data file.
cleanDir: str
Path to a populated directory containing Moons.
projDir: str
Path to a populated directory containing Comets
outDir: str
Path to an out directory to store star objects.
selection: dict
Dictionary containing selected representative stars. Set by collapse function.
YAML_PATH: str
Path to yaml configuration file.
Functions
---------
get_data_path() -> str
returns path to the raw data file
fit() -> None
fits a space of Stars and saves to outDir
collapse() -> list
clusters and selects representatives of star models
get_galaxy_coordinates() -> np.ndarray
computes a 2D coordinate system of stars in the galaxy using Multidimensional Scaling (MDS)
save() -> None
Saves instance to pickle file.
Example
--------
>>> cleanDir = <PATH TO MOON OBJECT FILES>
>>> data = <PATH TO RAW DATA FILE>
>>> projDir = <PATH TO COMET OBJECT FILES>
>>> outDir = <PATH TO OUT DIRECTORY OF PROJECTIONS>
>>> params = {
... "jmap": { "nCubes":[2,5,8],
... "percOverlap": [0.2, 0.4],
... "minIntersection":[-1],
... "clusterer": [["HDBSCAN", {"minDist":0.1}]]
... }
... }
>>> galaxy = Galaxy(params=params,
... data=data,
... cleanDir = cleanDir,
... projDir = projDir,
... outDir = outDir)
>>> galaxy.fit()
>>> # First, compute distances and cluster the stars
>>> selected_stars = galaxy.collapse()
>>> print(f"Selected {len(selected_stars)} representative stars")
>>>
>>> # Generate and visualize the galaxy coordinates with custom plotting
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>>
>>> # Manual plotting of the galaxy coordinates (NOTE: `Thema` does not have built-in visualization dependencies)
>>> coordinates = galaxy.get_galaxy_coordinates()
>>> plt.figure(figsize=(8, 6))
>>> plt.scatter(coordinates[:, 0], coordinates[:, 1], alpha=0.7)
>>> plt.title('2D Coordinate Map of Star Models')
>>> plt.xlabel('X Coordinate')
>>> plt.ylabel('Y Coordinate')
>>> plt.show()
```
"""
def __init__(
self,
params=None,
data=None,
cleanDir=None,
projDir=None,
outDir=None,
metric="stellar_curvature_distance",
selector="max_nodes",
nReps=3,
filter_fn=None,
YAML_PATH=None,
verbose=False,
):
"""
Constructs a Galaxy Instance
Parameters
----------
NOTE: all parameters can be provided via the YAML_PATH attr.
Please see docs/yaml_configuration.md.
data : str, optional
Path to input data
cleanDir: str, optional
Path to directory containg saved Moon Objects
projDir: str, optional
Path to directort containing saved Comet Objects
outDir : str, optional
The directory path where the stars will be saved.
params: dict, optional
A parameter dictionary specifying stars and corresponding parameter lists
**Behavior**
{"star0_name" : { "star0_parameter0":[list of star0_parameter0 values],
"star0_parameter1": [list of star0_parameter1 values]},
"star1_name": {"star1_parameter0": [list of star1_parameter0 values]} }
filter_fn: str, callable, or None, optional
Filter function to apply to stars before distance calculations.
Can be a string name of a function in starFilters, a callable, or None for no filtering.
YAML_PATH : str, optional
The path to a YAML file containing configuration settings. Default is None.
verbose: bool
Set to true to see warnings + print messages
"""
if YAML_PATH is not None:
assert os.path.isfile(YAML_PATH), "yaml parameter file could not be found."
try:
with open(YAML_PATH, "r") as f:
yamlParams = OmegaConf.load(f)
except Exception as e:
print(e)
data = yamlParams.data
cleanDir = os.path.join(yamlParams.outDir, yamlParams.runName + "/clean/")
projDir = os.path.join(
yamlParams.outDir, yamlParams.runName + "/projections/"
)
outDir = os.path.join(yamlParams.outDir, yamlParams.runName + "/models/")
metric = yamlParams.Galaxy.metric
selector = yamlParams.Galaxy.selector
nReps = yamlParams.Galaxy.nReps
filter_fn = yamlParams.Galaxy.get("filter", None)
if type(yamlParams.Galaxy.stars) == str:
stars = [yamlParams.Galaxy.stars]
else:
stars = yamlParams.Galaxy.stars
self.params = {}
for star in stars:
self.params[star] = yamlParams.Galaxy[star]
elif params is not None:
self.params = params
else:
raise ValueError("please provide a parameter dictionary")
self.data = data
self.cleanDir = cleanDir
self.projDir = projDir
self.outDir = outDir
self.YAML_PATH = YAML_PATH
self.metric = metric
self.selector = selector
self.nReps = nReps
# Store YAML params for filter setup later (avoid pickling issues)
self._yaml_filter = filter_fn
self._yamlParams = yamlParams if YAML_PATH is not None else None
self.keys = None
self.distances = None
self.verbose = verbose
self.selection = {}
assert self.data is not None, "Missing path to raw data file"
assert self.cleanDir is not None, "Missing 'cleanDir' parameter'"
assert self.projDir is not None, "Missing 'projDir' parameter"
assert self.outDir is not None, "Missing 'outDir' parameter"
assert os.path.isdir(self.cleanDir), "Invalid clean data directory."
assert (
len(os.listdir(self.cleanDir)) > 0
), "No clean data found. Please make sure you generated clean data."
assert os.path.isdir(self.projDir), "Invalid projection directory."
assert (
len(os.listdir(self.projDir)) > 0
), "No projections found. Please make sure you have generated them correctly."
if not os.path.isdir(self.outDir):
try:
os.makedirs(self.outDir)
except Exception as e:
print(e)
self.data = cast(str, self.data)
self.cleanDir = cast(str, self.cleanDir)
self.projDir = cast(str, self.projDir)
self.outDir = cast(str, self.outDir)
def _setup_filter(self, yamlParams):
logger.info("Checking yaml for filter configuration.")
if yamlParams and yamlParams.Galaxy.get("filter"):
filter_type = yamlParams.Galaxy.get("filter")
if filter_type in config.filter_configs:
filter_config = config.filter_configs[filter_type]
logger.info(f"Loading supported filter function: `{filter_type}`")
params = {
**filter_config["params"],
**yamlParams.Galaxy.get("filter_params", {}),
}
logger.info(f"Using filter parameters: {params}")
func = getattr(starFilters, filter_config["function"])(**params)
# Tag the callable with a human-friendly name for logging
try:
setattr(func, "_display_name", str(filter_type))
except Exception:
pass
return func
# Default to no-op filter with a stable display name
nf = starFilters.nofilterfunction
try:
setattr(nf, "_display_name", "nofilterfunction")
except Exception:
pass
return nf
def _log_graph_distribution(self, files_to_use):
out_dir = cast(str, self.outDir)
file_paths = [
os.path.join(out_dir, f) for f in os.listdir(out_dir) if f.endswith(".pkl")
]
component_counts = []
for file_path in file_paths:
try:
with open(file_path, "rb") as f:
star_obj = pickle.load(f)
if star_obj.starGraph and star_obj.starGraph.graph:
component_counts.append(
nx.number_connected_components(star_obj.starGraph.graph)
)
except:
continue
if component_counts:
counts = Counter(component_counts)
logger.debug("Component distribution:")
for n, count in sorted(counts.items()):
bar = "█" * count
logger.debug(f" {n:>2} components: {bar} ({count})")
[docs]
def fit(self):
"""
Configure and generate space of Stars.
Uses the `function_scheduler` to spawn multiple star
instances and fit them in parallel.
Returns
------
None
Saves star objects to outDir and prints a count of failed saves.
"""
# Get current logging config to pass to child processes
logging_config = get_current_logging_config()
subprocesses = []
for starName, starParamsDict in self.params.items():
star_configName = config.tag_to_class[starName]
cfg = getattr(config, star_configName)
module = importlib.import_module(cfg.module)
star = module.initialize()
# Load matching files
clean_dir = cast(str, self.cleanDir)
cleanfile_pattern = os.path.join(clean_dir, "*.pkl")
valid_cleanFiles = glob.glob(cleanfile_pattern)
proj_dir = cast(str, self.projDir)
projfile_pattern = os.path.join(proj_dir, "*.pkl")
valid_projFiles = glob.glob(projfile_pattern)
for j, projFile in enumerate(valid_projFiles):
projFilePath = os.path.join(proj_dir, projFile)
with open(projFilePath, "rb") as f:
cleanFile = pickle.load(f).get_clean_path()
param_attr_names = [
attr
for attr in sorted(cfg.__annotations__)
if attr not in ["name", "module"]
]
param_combinations = itertools.product(
*[starParamsDict[attr] for attr in param_attr_names]
)
for k, combination in enumerate(param_combinations):
starParameters = dict(zip(param_attr_names, combination))
subprocesses.append(
(
self._instantiate_star,
self.data,
cleanFile,
projFilePath,
star,
starParameters,
starName,
f"{k}_{j}",
logging_config,
)
)
# Run with function scheduler
results = function_scheduler(
subprocesses,
max_workers=4,
resilient=True,
verbose=self.verbose,
)
failed_saves = sum(1 for r in results if r is False)
if failed_saves > 0:
logger.warning(f"{failed_saves}/{len(results)} star saves failed")
def _instantiate_star(
self,
data_path,
cleanFile,
projFile,
star,
starParameters,
starName,
id,
logging_config,
):
"""Helper function for the fit() method. Creates a Star instances and fits it.
Parameters
----------
data_path: str
Path to input data
cleanFile: str
Path to a moon instance.
projFile: str
Path to comet instance.
star: class
A class oject defined in stars/
starParameters: dict
Parameter configuration for specified star.
starName: str
Name of star class
id : int
Identifier
logging_config : dict or None
Logging configuration from parent process
Returns
-------
bool
True if saved successfully, False otherwise
See Also
--------
`Star` class and stars directory for more info on an individual fit.
"""
# Configure logging in this child process
from ...utils import configure_child_process_logging
configure_child_process_logging(logging_config)
try:
my_star = star(
data_path=data_path,
clean_path=cleanFile,
projection_path=projFile,
**starParameters,
)
my_star.fit()
output_file = create_file_name(starName, starParameters, id)
out_dir = cast(str, self.outDir)
output_file = os.path.join(out_dir, output_file)
return my_star.save(output_file)
except Exception as e:
logger.error(
f"Star {starName} #{id} failed - params: {starParameters}, error: {str(e)}"
)
return False
[docs]
def collapse(
self,
metric=None,
nReps=None,
selector=None,
filter_fn=None,
files: list | None = None,
distance_threshold: float | None = None,
**kwargs,
):
"""
Collapses the space of Stars into representative Stars.
Either nReps (number of clusters) or distance_threshold (AgglomerativeClustering) can be used.
Parameters
----------
metric : str, optional
Metric function name for comparing graphs. Defaults to self.metric.
nReps : int, optional
Number of clusters for AgglomerativeClustering. Ignored if distance_threshold is set.
selector : str, optional
Selection function name to choose representative stars. Defaults to self.selector.
filter_fn : callable, str, or None
Filter function to select a subset of graphs. Defaults to no filter.
files : list[str] or None
Optional list of file paths to process. Defaults to self.outDir.
distance_threshold : float, optional
AgglomerativeClustering distance threshold. Used if nReps is None.
**kwargs :
Additional arguments passed to the metric function.
Returns
-------
dict
Mapping from cluster labels to selected stars and cluster sizes.
"""
logger.info("Configuring Galaxy Collapse…")
metric = metric or self.metric
selector = selector or self.selector
# Set up filter when needed
if callable(filter_fn):
logger.info(
f"Using provided filter function: {getattr(filter_fn, '__name__', str(type(filter_fn)))}"
)
elif filter_fn is None:
filter_fn = self._setup_filter(self._yamlParams)
elif isinstance(filter_fn, str):
logger.info(
f"Function name provided, attempting to load from supported star filters: {filter_fn}"
)
filter_callable = getattr(
starFilters, filter_fn, starFilters.nofilterfunction
)
# Tag display name for logging
try:
setattr(filter_callable, "_display_name", str(filter_fn))
except Exception:
pass
filter_fn = filter_callable
logger.info(
f"Loaded filter function: {getattr(filter_fn, '__name__', str(type(filter_fn)))}"
)
else:
filter_fn = starFilters.nofilterfunction
try:
setattr(filter_fn, "_display_name", "nofilterfunction")
except Exception:
pass
logger.info(f"Defaulting to : {filter_fn.__name__}")
if not callable(filter_fn):
raise ValueError(
f"filter_fn must be None, callable, or string, got {type(filter_fn)}"
)
metric_fn = getattr(geodesics, metric, geodesics.stellar_curvature_distance)
selector_fn = getattr(starSelectors, selector, starSelectors.max_nodes)
# Filter/metric/selector names for readability
filter_fn_name = getattr(
filter_fn,
"_display_name",
getattr(filter_fn, "__name__", str(type(filter_fn))),
)
logger.info(
f"Filter: {filter_fn_name} | Metric: {metric} | Selector: {selector}"
)
# Determine files to process
files_to_use = files if files is not None else self.outDir
# Build a robust view of file list for logging (without changing behavior)
file_list: list[str]
out_dir = cast(str, self.outDir)
if files is None:
file_list = [
os.path.join(out_dir, f)
for f in os.listdir(out_dir)
if f.endswith(".pkl")
]
else:
if isinstance(files, (list, tuple)):
file_list = list(files)
elif isinstance(files, str) and os.path.isdir(files):
dir_str = cast(str, files)
file_list = [
os.path.join(dir_str, f)
for f in os.listdir(dir_str)
if f.endswith(".pkl")
]
else:
# Fallback: treat as a single path
file_list = [str(files)]
total_files = len(file_list)
target_desc = (
f"directory '{self.outDir}'"
if files is None
else f"{total_files} provided file(s)"
)
logger.info(f"Scanning {total_files} candidate graph(s) from {target_desc}.")
# Show graph distribution before filtering if DEBUG enabled
if logger.isEnabledFor(logging.DEBUG):
self._log_graph_distribution(files_to_use)
# Determine concrete type to pass to metric function: either directory (str) or list[str]
out_dir: str = cast(str, self.outDir)
if files is None:
metric_files: str | list[str] = out_dir
else:
if isinstance(files, (list, tuple)):
metric_files = [str(f) for f in files]
elif isinstance(files, str) and os.path.isdir(files):
metric_files = files
else:
metric_files = [str(files)]
# Compute distances with timing
t0 = time.perf_counter()
self.keys, self.distances = metric_fn(
files=metric_files, filterfunction=filter_fn, **kwargs
)
t1 = time.perf_counter()
filtered_count = len(self.keys)
logger.info(
f"Filter results: {filtered_count}/{total_files} graph(s) passed the filter in {t1 - t0:.2f}s"
)
# Distance matrix quick stats (off-diagonal)
try:
n = self.distances.shape[0]
if n == self.distances.shape[1] and n == filtered_count and n > 1:
mask = ~np.eye(n, dtype=bool)
dvals = self.distances[mask]
finite = np.isfinite(dvals)
if not np.all(finite):
bad = np.size(dvals) - np.count_nonzero(finite)
logger.warning(
f"Distance matrix contains {bad} non-finite value(s) (NaN/inf)."
)
if np.any(finite):
dvals_f = dvals[finite]
logger.debug(
"Distance stats (off-diagonal, finite): min=%.4f | mean=%.4f | max=%.4f | count=%d",
float(np.min(dvals_f)),
float(np.mean(dvals_f)),
float(np.max(dvals_f)),
int(dvals_f.size),
)
except Exception:
# Keep logging resilient
pass
# Check if we have enough graphs for clustering
if filtered_count < 2:
raise ValueError(
f"Only {filtered_count} graph(s) passed the filter. "
"Clustering requires at least 2 graphs. "
"Consider relaxing your filter criteria."
)
# Use nReps or distance_threshold for AgglomerativeClustering
# Handle clustering configuration clarity
if nReps is None and distance_threshold is None:
nReps = self.nReps
if nReps is not None and distance_threshold is not None:
logger.warning(
"Both nReps and distance_threshold provided; using distance_threshold and ignoring nReps."
)
nReps = None
# Check if nReps is valid for the number of filtered graphs
if nReps and nReps > filtered_count:
raise ValueError(
f"Cannot create {nReps} clusters from {filtered_count} graphs. "
f"Set nReps to {filtered_count} or fewer, or relax your filter."
)
model = AgglomerativeClustering(
metric="precomputed",
linkage="average",
compute_distances=True,
n_clusters=nReps,
distance_threshold=distance_threshold,
)
mode_desc = (
f"n_clusters={nReps}"
if nReps is not None
else f"distance_threshold={distance_threshold}"
)
logger.info(
f"Clustering {filtered_count} graph(s) with AgglomerativeClustering ({mode_desc})…"
)
t2 = time.perf_counter()
model.fit(self.distances)
t3 = time.perf_counter()
labels = model.labels_
subgroups = {label: self.keys[labels == label] for label in set(labels)}
# Log cluster size distribution
cluster_sizes = {
int(lbl): int(len(members)) for lbl, members in subgroups.items()
}
size_list = sorted(cluster_sizes.values(), reverse=True)
logger.info(
f"Formed {len(subgroups)} cluster(s) in {t3 - t2:.2f}s | sizes: {size_list}"
)
self.selection = {}
for label, subgroup in subgroups.items():
selected_star = selector_fn(subgroup)
self.selection[label] = {
"star": selected_star,
"cluster_size": len(subgroup),
}
# Keep detailed selection at DEBUG to avoid log spam
try:
star_name = os.path.basename(str(selected_star))
except Exception:
star_name = str(selected_star)
logger.debug(
"Cluster %s: selected representative '%s' from %d member(s)",
str(label),
star_name,
len(subgroup),
)
total_time = (t1 - t0) + (t3 - t2)
logger.info(
f"Galaxy Collapse complete: {len(self.selection)} representative model(s) selected "
f"({metric}, {mode_desc}). Total compute time ~{total_time:.2f}s"
)
logger.info(
"Access results: this Galaxy's 'selection' maps cluster -> {'star','cluster_size'}. "
"If using a Thema instance, check its 'selected_model_files' for the chosen file paths."
)
return self.selection
[docs]
def get_galaxy_coordinates(self) -> np.ndarray:
"""
Computes a 2D coordinate system for stars in the galaxy, allowing visualization
of their relative positions. This function uses Multidimensional Scaling (MDS)
to project the high-dimensional distance matrix into a 2D space, preserving the
relative distances between stars as much as possible.
Note: This method requires that distances have been computed first, usually by calling the
collapse() method or directly computing distances with a metric function.
Returns
-------
np.ndarray
A 2D array of shape (n_stars, 2) containing the X,Y coordinates of each star in the galaxy.
Each row represents the 2D coordinates of one star.
Examples
--------
>>> # After fitting the galaxy and computing distances
>>> import matplotlib.pyplot as plt
>>> coordinates = galaxy.get_galaxy_coordinates()
>>>
>>> # Basic scatter plot
>>> plt.figure(figsize=(10, 8))
>>> plt.scatter(coordinates[:, 0], coordinates[:, 1], alpha=0.7)
>>> plt.title('Star Map of the Galaxy')
>>> plt.xlabel('X Coordinate')
>>> plt.ylabel('Y Coordinate')
>>> plt.show()
>>>
>>> # Advanced plot with cluster coloring
>>> if galaxy.selection: # If collapse() has been called
>>> plt.figure(figsize=(12, 10))
>>> # Plot all stars
>>> plt.scatter(coordinates[:, 0], coordinates[:, 1], c='lightgray', alpha=0.5)
>>> # Highlight representative stars
>>> for cluster_id, info in galaxy.selection.items():
>>> # Find the index of the representative star in the keys array
>>> rep_idx = np.where(galaxy.keys == info['star'])[0][0]
>>> plt.scatter(coordinates[rep_idx, 0], coordinates[rep_idx, 1],
>>> s=100, c='red', edgecolor='black', label=f'Cluster {cluster_id}')
>>> plt.legend()
>>> plt.title('Star Map with Representative Stars')
>>> plt.show()
"""
if self.distances is None:
raise ValueError("Distance matrix is not computed.")
mds = MDS(n_components=2, dissimilarity="precomputed")
coordinates = mds.fit_transform(self.distances)
return coordinates
[docs]
def save(self, file_path):
"""
Save the current object instance to a file using pickle serialization.
Parameters
----------
file_path : str
The path to the file where the object will be saved.
"""
try:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
pickle.dump(self, f)
print(f"Saved object to {file_path}")
except Exception as e:
print(f"Failed to save object: {e}")
[docs]
def getParams(self):
"""
Returns the parameters of the Galaxy instance.
Returns
-------
dict
A dictionary containing the parameters of the Galaxy instance.
"""
params = {
"params": self.params,
"data": self.data,
"cleanDir": self.cleanDir,
"projDir": self.projDir,
"outDir": self.outDir,
"metric": self.metric,
"selector": self.selector,
"nReps": self.nReps,
"YAML_PATH": self.YAML_PATH,
"verbose": self.verbose,
}
return params
[docs]
def writeParams_toYaml(self, YAML_PATH=None):
"""
Write the parameters of the Galaxy instance to a YAML file.
Parameters
----------
YAML_PATH: str, optional
The path to the YAML file. If not provided, the YAML_PATH
attribute of the instance will be used.
Returns
-------
None
"""
# Resolve yaml path to a non-None string for type checking
if YAML_PATH is None:
if self.YAML_PATH is None:
raise ValueError("Please provide a valid filepath to YAML")
yaml_path = cast(str, self.YAML_PATH)
else:
yaml_path = str(YAML_PATH)
if not os.path.isfile(yaml_path):
raise TypeError("File path does not point to a YAML file")
with open(yaml_path, "r") as f:
params = OmegaConf.load(f)
params.Galaxy = self.getParams()["params"]
params.Galaxy.stars = list(self.getParams()["params"].keys())
with open(yaml_path, "w") as f:
OmegaConf.save(params, f)
print("YAML file successfully updated")
[docs]
def summarize_graphClustering(self):
"""
Summarizes the graph clustering results.
Returns
-------
dict
A dictionary of the clusters and their corresponding graph members.
The keys are the cluster names and the values are lists of graph
file names.
"""
pass
# Ensure Galaxy instances are pickle-friendly for multiprocessing
def __getstate__(self):
state = self.__dict__.copy()
for k, v in list(state.items()):
if callable(v):
state[k] = None
return state
def __setstate__(self, state):
self.__dict__.update(state)