tramdag.TramDagModel
Copyright 2025 Zurich University of Applied Sciences (ZHAW) Pascal Buehler, Beate Sick, Oliver Duerr
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
1""" 2Copyright 2025 Zurich University of Applied Sciences (ZHAW) 3Pascal Buehler, Beate Sick, Oliver Duerr 4 5Licensed under the Apache License, Version 2.0 (the "License"); 6you may not use this file except in compliance with the License. 7You may obtain a copy of the License at 8 9 http://www.apache.org/licenses/LICENSE-2.0 10 11Unless required by applicable law or agreed to in writing, software 12distributed under the License is distributed on an "AS IS" BASIS, 13WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14See the License for the specific language governing permissions and 15limitations under the License. 16""" 17 18import os 19os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 20 21import pandas as pd 22import numpy as np 23import json 24import matplotlib.pyplot as plt 25 26import torch 27from torch.optim import Adam 28from joblib import Parallel, delayed 29 30from statsmodels.graphics.gofplots import qqplot_2samples 31from scipy.stats import logistic, probplot 32 33from .utils.model_helpers import train_val_loop,evaluate_tramdag_model, get_fully_specified_tram_model , model_train_val_paths ,ordered_parents 34from .utils.sampling import create_latent_df_for_full_dag, sample_full_dag, is_outcome_modelled_ordinal,is_outcome_modelled_continous, is_outcome_modelled_ordinal, show_hdag_continous,show_hdag_ordinal 35from .utils.continous import transform_intercepts_continous 36from .utils.ordinal import transform_intercepts_ordinal 37 38from .models.tram_models import SimpleIntercept 39 40from .TramDagConfig import TramDagConfig 41from .TramDagDataset import TramDagDataset 42 43# testserver 44#%pip install -i https://test.pypi.org/simple --extra-index-url https://pypi.org/simple tramdag 45 46 47# # Remove previous builds 48# rm -rf build dist *.egg-info 49 50# # Build new package 51# python -m build 52 53# # Upload to TestPyPI 54# python -m twine upload --repository testpypi dist/* 55 56# documentaiton 57# 1. download new version of tramdag in env 58# pip install pdoc 59# 2. generate docs 60# pdoc tramdag -o docs 61 62 63class TramDagModel: 64 """ 65 Probabilistic DAG model built from node-wise TRAMs (transformation models). 66 67 This class manages: 68 - Configuration and per-node model construction. 69 - Data scaling (min–max). 70 - Training (sequential or per-node parallel on CPU). 71 - Diagnostics (loss history, intercepts, linear shifts, latents). 72 - Sampling from the joint DAG and loading stored samples. 73 - High-level summaries and plotting utilities. 74 """ 75 76 # ---- defaults used at construction time ---- 77 DEFAULTS_CONFIG = { 78 "set_initial_weights": False, 79 "debug":False, 80 "verbose": False, 81 "device":'auto', 82 "initial_data":None, 83 "overwrite_initial_weights": True, 84 } 85 86 # ---- defaults used at fit() time ---- 87 DEFAULTS_FIT = { 88 "epochs": 100, 89 "train_list": None, 90 "callbacks": None, 91 "learning_rate": 0.01, 92 "device": "auto", 93 "optimizers": None, 94 "schedulers": None, 95 "use_scheduler": False, 96 "save_linear_shifts": True, 97 "save_simple_intercepts": True, 98 "debug":False, 99 "verbose": True, 100 "train_mode": "sequential", # or "parallel" 101 "return_history": False, 102 "overwrite_inital_weights": True, 103 "num_workers" : 4, 104 "persistent_workers" : True, 105 "prefetch_factor" : 4, 106 "batch_size":1000, 107 108 } 109 110 def __init__(self): 111 """ 112 Initialize an empty TramDagModel shell. 113 114 Notes 115 ----- 116 This constructor does not build any node models and does not attach a 117 configuration. Use `TramDagModel.from_config` or `TramDagModel.from_directory` 118 to obtain a fully configured and ready-to-use instance. 119 """ 120 121 self.debug = False 122 self.verbose = False 123 self.device = 'auto' 124 pass 125 126 @staticmethod 127 def get_device(settings): 128 """ 129 Resolve the target device string from a settings dictionary. 130 131 Parameters 132 ---------- 133 settings : dict 134 Dictionary containing at least a key ``"device"`` with one of 135 {"auto", "cpu", "cuda"}. If missing, "auto" is assumed. 136 137 Returns 138 ------- 139 str 140 Device string, either "cpu" or "cuda". 141 142 Notes 143 ----- 144 If ``device == "auto"``, CUDA is selected if available, otherwise CPU. 145 """ 146 device_arg = settings.get("device", "auto") 147 if device_arg == "auto": 148 device_str = "cuda" if torch.cuda.is_available() else "cpu" 149 else: 150 device_str = device_arg 151 return device_str 152 153 def _validate_kwargs(self, kwargs: dict, defaults_attr: str = "DEFAULTS_FIT", context: str = None): 154 """ 155 Validate a kwargs dictionary against a class-level defaults dictionary. 156 157 Parameters 158 ---------- 159 kwargs : dict 160 Keyword arguments to validate. 161 defaults_attr : str, optional 162 Name of the attribute on this class that contains the allowed keys, 163 e.g. ``"DEFAULTS_CONFIG"`` or ``"DEFAULTS_FIT"``. Default is "DEFAULTS_FIT". 164 context : str or None, optional 165 Optional label (e.g. caller name) to prepend in error messages. 166 167 Raises 168 ------ 169 AttributeError 170 If the attribute named by ``defaults_attr`` does not exist. 171 ValueError 172 If any key in ``kwargs`` is not present in the corresponding defaults dict. 173 """ 174 defaults = getattr(self, defaults_attr, None) 175 if defaults is None: 176 raise AttributeError(f"{self.__class__.__name__} has no attribute '{defaults_attr}'") 177 178 unknown = set(kwargs) - set(defaults) 179 if unknown: 180 prefix = f"[{context}] " if context else "" 181 raise ValueError(f"{prefix}Unknown parameter(s): {', '.join(sorted(unknown))}") 182 183 ## CREATE A TRAMDADMODEL 184 @classmethod 185 def from_config(cls, cfg, **kwargs): 186 """ 187 Construct a TramDagModel from a TramDagConfig object. 188 189 This builds one TRAM model per node in the DAG and optionally writes 190 the initial model parameters to disk. 191 192 Parameters 193 ---------- 194 cfg : TramDagConfig 195 Configuration wrapper holding the underlying configuration dictionary, 196 including at least: 197 - ``conf_dict["nodes"]``: mapping of node names to node configs. 198 - ``conf_dict["PATHS"]["EXPERIMENT_DIR"]``: experiment directory. 199 **kwargs 200 Node-level construction options. Each key must be present in 201 ``DEFAULTS_CONFIG``. Values can be: 202 - scalar: applied to all nodes. 203 - dict: mapping ``{node_name: value}`` for per-node overrides. 204 205 Common keys include: 206 device : {"auto", "cpu", "cuda"}, default "auto" 207 Device selection (CUDA if available when "auto"). 208 debug : bool, default False 209 If True, print debug messages. 210 verbose : bool, default False 211 If True, print informational messages. 212 set_initial_weights : bool 213 Passed to underlying TRAM model constructors. 214 overwrite_initial_weights : bool, default True 215 If True, overwrite any existing ``initial_model.pt`` files per node. 216 initial_data : Any 217 Optional object passed down to node constructors. 218 219 Returns 220 ------- 221 TramDagModel 222 Fully initialized instance with: 223 - ``cfg`` 224 - ``nodes_dict`` 225 - ``models`` (per-node TRAMs) 226 - ``settings`` (resolved per-node config) 227 228 Raises 229 ------ 230 ValueError 231 If any dict-valued kwarg does not provide values for exactly the set 232 of nodes in ``cfg.conf_dict["nodes"]``. 233 """ 234 235 self = cls() 236 self.cfg = cfg 237 self.cfg.update() # ensure latest version from disk 238 self.cfg._verify_completeness() 239 240 241 try: 242 self.cfg.save() # persist back to disk 243 if getattr(self, "debug", False): 244 print("[DEBUG] Configuration updated and saved.") 245 except Exception as e: 246 print(f"[WARNING] Could not save configuration after update: {e}") 247 248 self.nodes_dict = self.cfg.conf_dict["nodes"] 249 250 self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_CONFIG', context="from_config") 251 252 # update defaults with kwargs 253 settings = dict(cls.DEFAULTS_CONFIG) 254 settings.update(kwargs) 255 256 # resolve device 257 device_arg = settings.get("device", "auto") 258 if device_arg == "auto": 259 device_str = "cuda" if torch.cuda.is_available() else "cpu" 260 else: 261 device_str = device_arg 262 self.device = torch.device(device_str) 263 264 # set flags on the instance so they are accessible later 265 self.debug = settings.get("debug", False) 266 self.verbose = settings.get("verbose", False) 267 268 if self.debug: 269 print(f"[DEBUG] TramDagModel using device: {self.device}") 270 271 # initialize settings storage 272 self.settings = {k: {} for k in settings.keys()} 273 274 # validate dict-typed args 275 for k, v in settings.items(): 276 if isinstance(v, dict): 277 expected = set(self.nodes_dict.keys()) 278 given = set(v.keys()) 279 if expected != given: 280 raise ValueError( 281 f"[ERROR] the provided argument '{k}' keys are not same as in cfg.conf_dict['nodes'].keys().\n" 282 f"Expected: {expected}, but got: {given}\n" 283 f"Please provide values for all variables.") 284 285 # build one model per node 286 self.models = {} 287 for node in self.nodes_dict.keys(): 288 per_node_kwargs = {} 289 for k, v in settings.items(): 290 resolved = v[node] if isinstance(v, dict) else v 291 per_node_kwargs[k] = resolved 292 self.settings[k][node] = resolved 293 if self.debug: 294 print(f"\n[INFO] Building model for node '{node}' with settings: {per_node_kwargs}") 295 self.models[node] = get_fully_specified_tram_model( 296 node=node, 297 configuration_dict=self.cfg.conf_dict, 298 **per_node_kwargs) 299 300 try: 301 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 302 NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}") 303 os.makedirs(NODE_DIR, exist_ok=True) 304 305 model_path = os.path.join(NODE_DIR, "initial_model.pt") 306 overwrite = settings.get("overwrite_initial_weights", True) 307 308 if overwrite or not os.path.exists(model_path): 309 torch.save(self.models[node].state_dict(), model_path) 310 if self.debug: 311 print(f"[DEBUG] Saved initial model state for node '{node}' to {model_path} (overwrite={overwrite})") 312 else: 313 if self.debug: 314 print(f"[DEBUG] Skipped saving initial model for node '{node}' (already exists at {model_path})") 315 except Exception as e: 316 print(f"[ERROR] Could not save initial model state for node '{node}': {e}") 317 318 TEMP_DIR = "temp" 319 if os.path.isdir(TEMP_DIR) and not os.listdir(TEMP_DIR): 320 os.rmdir(TEMP_DIR) 321 322 return self 323 324 @classmethod 325 def from_directory(cls, EXPERIMENT_DIR: str, device: str = "auto", debug: bool = False, verbose: bool = False): 326 """ 327 Reconstruct a TramDagModel from an experiment directory on disk. 328 329 This method: 330 1. Loads the configuration JSON. 331 2. Wraps it in a TramDagConfig. 332 3. Builds all node models via `from_config`. 333 4. Loads the min–max scaling dictionary. 334 335 Parameters 336 ---------- 337 EXPERIMENT_DIR : str 338 Path to an experiment directory containing: 339 - ``configuration.json`` 340 - ``min_max_scaling.json``. 341 device : {"auto", "cpu", "cuda"}, optional 342 Device selection. Default is "auto". 343 debug : bool, optional 344 If True, enable debug messages. Default is False. 345 verbose : bool, optional 346 If True, enable informational messages. Default is False. 347 348 Returns 349 ------- 350 TramDagModel 351 A TramDagModel instance with models, config, and scaling loaded. 352 353 Raises 354 ------ 355 FileNotFoundError 356 If configuration or min–max files cannot be found. 357 RuntimeError 358 If the min–max file cannot be read or parsed. 359 """ 360 361 # --- load config file --- 362 config_path = os.path.join(EXPERIMENT_DIR, "configuration.json") 363 if not os.path.exists(config_path): 364 raise FileNotFoundError(f"[ERROR] Config file not found at {config_path}") 365 366 with open(config_path, "r") as f: 367 cfg_dict = json.load(f) 368 369 # Create TramConfig wrapper 370 cfg = TramDagConfig(cfg_dict, CONF_DICT_PATH=config_path) 371 372 # --- build model from config --- 373 self = cls.from_config(cfg, device=device, debug=debug, verbose=verbose, overwrite_initial_weights=False) 374 375 # --- load minmax scaling --- 376 minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json") 377 if not os.path.exists(minmax_path): 378 raise FileNotFoundError(f"[ERROR] MinMax file not found at {minmax_path}") 379 380 with open(minmax_path, "r") as f: 381 self.minmax_dict = json.load(f) 382 383 if self.verbose or self.debug: 384 print(f"[INFO] Loaded TramDagModel from {EXPERIMENT_DIR}") 385 print(f"[INFO] Config loaded from {config_path}") 386 print(f"[INFO] MinMax scaling loaded from {minmax_path}") 387 388 return self 389 390 def _ensure_dataset(self, data, is_val=False,**kwargs): 391 """ 392 Ensure that the input data is represented as a TramDagDataset. 393 394 Parameters 395 ---------- 396 data : pandas.DataFrame, TramDagDataset, or None 397 Input data to be converted or passed through. 398 is_val : bool, optional 399 If True, the resulting dataset is treated as validation data 400 (e.g. no shuffling). Default is False. 401 **kwargs 402 Additional keyword arguments passed through to 403 ``TramDagDataset.from_dataframe``. 404 405 Returns 406 ------- 407 TramDagDataset or None 408 A TramDagDataset if ``data`` is a DataFrame or TramDagDataset, 409 otherwise None if ``data`` is None. 410 411 Raises 412 ------ 413 TypeError 414 If ``data`` is not a DataFrame, TramDagDataset, or None. 415 """ 416 417 if isinstance(data, pd.DataFrame): 418 return TramDagDataset.from_dataframe(data, self.cfg, shuffle=not is_val,**kwargs) 419 elif isinstance(data, TramDagDataset): 420 return data 421 elif data is None: 422 return None 423 else: 424 raise TypeError( 425 f"[ERROR] data must be pd.DataFrame, TramDagDataset, or None, got {type(data)}" 426 ) 427 428 def load_or_compute_minmax(self, td_train_data=None,use_existing=False, write=True): 429 """ 430 Load an existing Min–Max scaling dictionary from disk or compute a new one 431 from the provided training dataset. 432 433 Parameters 434 ---------- 435 use_existing : bool, optional (default=False) 436 If True, attempts to load an existing `min_max_scaling.json` file 437 from the experiment directory. Raises an error if the file is missing 438 or unreadable. 439 440 write : bool, optional (default=True) 441 If True, writes the computed Min–Max scaling dictionary to 442 `<EXPERIMENT_DIR>/min_max_scaling.json`. 443 444 td_train_data : object, optional 445 Training dataset used to compute scaling statistics. If not provided, 446 the method will ensure or construct it via `_ensure_dataset(data=..., is_val=False)`. 447 448 Behavior 449 -------- 450 - If `use_existing=True`, loads the JSON file containing previously saved 451 min–max values and stores it in `self.minmax_dict`. 452 - If `use_existing=False`, computes a new scaling dictionary using 453 `td_train_data.compute_scaling()` and stores the result in 454 `self.minmax_dict`. 455 - Optionally writes the computed dictionary to disk. 456 457 Side Effects 458 ------------- 459 - Populates `self.minmax_dict` with scaling values. 460 - Writes or loads the file `min_max_scaling.json` under 461 `<EXPERIMENT_DIR>`. 462 - Prints diagnostic output if `self.debug` or `self.verbose` is True. 463 464 Raises 465 ------ 466 FileNotFoundError 467 If `use_existing=True` but the min–max file does not exist. 468 469 RuntimeError 470 If an existing min–max file cannot be read or parsed. 471 472 Notes 473 ----- 474 The computed min–max dictionary is expected to contain scaling statistics 475 per feature, typically in the form: 476 { 477 "node": {"min": float, "max": float}, 478 ... 479 } 480 """ 481 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 482 minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json") 483 484 # laod exisitng if possible 485 if use_existing: 486 if not os.path.exists(minmax_path): 487 raise FileNotFoundError(f"MinMax file not found: {minmax_path}") 488 try: 489 with open(minmax_path, 'r') as f: 490 self.minmax_dict = json.load(f) 491 if self.debug or self.verbose: 492 print(f"[INFO] Loaded existing minmax dict from {minmax_path}") 493 return 494 except Exception as e: 495 raise RuntimeError(f"Could not load existing minmax dict: {e}") 496 497 # 498 if self.debug or self.verbose: 499 print("[INFO] Computing new minmax dict from training data...") 500 501 td_train_data=self._ensure_dataset( data=td_train_data, is_val=False) 502 503 self.minmax_dict = td_train_data.compute_scaling() 504 505 if write: 506 os.makedirs(EXPERIMENT_DIR, exist_ok=True) 507 with open(minmax_path, 'w') as f: 508 json.dump(self.minmax_dict, f, indent=4) 509 if self.debug or self.verbose: 510 print(f"[INFO] Saved new minmax dict to {minmax_path}") 511 512 ## FIT METHODS 513 @staticmethod 514 def _fit_single_node(node, self_ref, settings, td_train_data, td_val_data, device_str): 515 """ 516 Train a single node model (helper for per-node training). 517 518 This method is designed to be called either from the main process 519 (sequential training) or from a joblib worker (parallel CPU training). 520 521 Parameters 522 ---------- 523 node : str 524 Name of the target node to train. 525 self_ref : TramDagModel 526 Reference to the TramDagModel instance containing models and config. 527 settings : dict 528 Training settings dictionary, typically derived from ``DEFAULTS_FIT`` 529 plus any user overrides. 530 td_train_data : TramDagDataset 531 Training dataset with node-specific DataLoaders in ``.loaders``. 532 td_val_data : TramDagDataset or None 533 Validation dataset or None. 534 device_str : str 535 Device string, e.g. "cpu" or "cuda". 536 537 Returns 538 ------- 539 tuple 540 A tuple ``(node, history)`` where: 541 node : str 542 Node name. 543 history : dict or Any 544 Training history as returned by ``train_val_loop``. 545 """ 546 torch.set_num_threads(1) # prevent thread oversubscription 547 548 model = self_ref.models[node] 549 550 # Resolve per-node settings 551 def _resolve(key): 552 val = settings[key] 553 return val[node] if isinstance(val, dict) else val 554 555 node_epochs = _resolve("epochs") 556 node_lr = _resolve("learning_rate") 557 node_debug = _resolve("debug") 558 node_save_linear_shifts = _resolve("save_linear_shifts") 559 save_simple_intercepts = _resolve("save_simple_intercepts") 560 node_verbose = _resolve("verbose") 561 562 # Optimizer & scheduler 563 if settings["optimizers"] and node in settings["optimizers"]: 564 optimizer = settings["optimizers"][node] 565 else: 566 optimizer = Adam(model.parameters(), lr=node_lr) 567 568 scheduler = settings["schedulers"].get(node, None) if settings["schedulers"] else None 569 570 # Data loaders 571 train_loader = td_train_data.loaders[node] 572 val_loader = td_val_data.loaders[node] if td_val_data else None 573 574 # Min-max scaling tensors 575 min_vals = torch.tensor(self_ref.minmax_dict[node][0], dtype=torch.float32) 576 max_vals = torch.tensor(self_ref.minmax_dict[node][1], dtype=torch.float32) 577 min_max = torch.stack([min_vals, max_vals], dim=0) 578 579 # Node directory 580 try: 581 EXPERIMENT_DIR = self_ref.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 582 NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}") 583 except Exception: 584 NODE_DIR = os.path.join("models", node) 585 print("[WARNING] No log directory specified in config, saving to default location.") 586 os.makedirs(NODE_DIR, exist_ok=True) 587 588 if node_verbose: 589 print(f"\n[INFO] Training node '{node}' for {node_epochs} epochs on {device_str} (pid={os.getpid()})") 590 591 # --- train --- 592 history = train_val_loop( 593 node=node, 594 target_nodes=self_ref.nodes_dict, 595 NODE_DIR=NODE_DIR, 596 tram_model=model, 597 train_loader=train_loader, 598 val_loader=val_loader, 599 epochs=node_epochs, 600 optimizer=optimizer, 601 use_scheduler=(scheduler is not None), 602 scheduler=scheduler, 603 save_linear_shifts=node_save_linear_shifts, 604 save_simple_intercepts=save_simple_intercepts, 605 verbose=node_verbose, 606 device=torch.device(device_str), 607 debug=node_debug, 608 min_max=min_max) 609 return node, history 610 611 def fit(self, train_data, val_data=None, **kwargs): 612 """ 613 Train TRAM models for all nodes in the DAG. 614 615 Coordinates dataset preparation, min–max scaling, and per-node training, 616 optionally in parallel on CPU. 617 618 Parameters 619 ---------- 620 train_data : pandas.DataFrame or TramDagDataset 621 Training data. If a DataFrame is given, it is converted into a 622 TramDagDataset using `_ensure_dataset`. 623 val_data : pandas.DataFrame or TramDagDataset or None, optional 624 Validation data. If a DataFrame is given, it is converted into a 625 TramDagDataset. If None, no validation loss is computed. 626 **kwargs 627 Overrides for ``DEFAULTS_FIT``. All keys must exist in 628 ``DEFAULTS_FIT``. Common options: 629 630 epochs : int, default 100 631 Number of training epochs per node. 632 learning_rate : float, default 0.01 633 Learning rate for the default Adam optimizer. 634 train_list : list of str or None, optional 635 List of node names to train. If None, all nodes are trained. 636 train_mode : {"sequential", "parallel"}, default "sequential" 637 Training mode. "parallel" uses joblib-based CPU multiprocessing. 638 GPU forces sequential mode. 639 device : {"auto", "cpu", "cuda"}, default "auto" 640 Device selection. 641 optimizers : dict or None 642 Optional mapping ``{node_name: optimizer}``. If provided for a 643 node, that optimizer is used instead of creating a new Adam. 644 schedulers : dict or None 645 Optional mapping ``{node_name: scheduler}``. 646 use_scheduler : bool 647 If True, enable scheduler usage in the training loop. 648 num_workers : int 649 DataLoader workers in sequential mode (ignored in parallel). 650 persistent_workers : bool 651 DataLoader persistence in sequential mode (ignored in parallel). 652 prefetch_factor : int 653 DataLoader prefetch factor (ignored in parallel). 654 batch_size : int 655 Batch size for all node DataLoaders. 656 debug : bool 657 Enable debug output. 658 verbose : bool 659 Enable informational logging. 660 return_history : bool 661 If True, return a history dict. 662 663 Returns 664 ------- 665 dict or None 666 If ``return_history=True``, a dictionary mapping each node name 667 to its training history. Otherwise, returns None. 668 669 Raises 670 ------ 671 ValueError 672 If ``train_mode`` is not "sequential" or "parallel". 673 """ 674 self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_FIT', context="fit") 675 676 # --- merge defaults --- 677 settings = dict(self.DEFAULTS_FIT) 678 settings.update(kwargs) 679 680 681 self.debug = settings.get("debug", False) 682 self.verbose = settings.get("verbose", False) 683 684 # --- resolve device --- 685 device_str=self.get_device(settings) 686 self.device = torch.device(device_str) 687 688 # --- training mode --- 689 train_mode = settings.get("train_mode", "sequential").lower() 690 if train_mode not in ("sequential", "parallel"): 691 raise ValueError("train_mode must be 'sequential' or 'parallel'") 692 693 # --- DataLoader safety logic --- 694 if train_mode == "parallel": 695 # if user passed loader paralleling params, warn and override 696 for flag in ("num_workers", "persistent_workers", "prefetch_factor"): 697 if flag in kwargs: 698 print(f"[WARNING] '{flag}' is ignored in parallel mode " 699 f"(disabled to prevent nested multiprocessing).") 700 # disable unsafe loader multiprocessing options 701 settings["num_workers"] = 0 702 settings["persistent_workers"] = False 703 settings["prefetch_factor"] = None 704 else: 705 # sequential mode → respect user DataLoader settings 706 if self.debug: 707 print("[DEBUG] Sequential mode: using DataLoader kwargs as provided.") 708 709 # --- which nodes to train --- 710 train_list = settings.get("train_list") or list(self.models.keys()) 711 712 713 # --- dataset prep (receives adjusted settings) --- 714 td_train_data = self._ensure_dataset(train_data, is_val=False, **settings) 715 td_val_data = self._ensure_dataset(val_data, is_val=True, **settings) 716 717 # --- normalization --- 718 self.load_or_compute_minmax(use_existing=False, write=True, td_train_data=td_train_data) 719 720 # --- print header --- 721 if self.verbose or self.debug: 722 print(f"[INFO] Training {len(train_list)} nodes ({train_mode}) on {device_str}") 723 724 # ====================================================================== 725 # Sequential mode safe for GPU or debugging) 726 # ====================================================================== 727 if train_mode == "sequential" or "cuda" in device_str: 728 if "cuda" in device_str and train_mode == "parallel": 729 print("[WARNING] GPU device detected — forcing sequential mode.") 730 results = {} 731 for node in train_list: 732 node, history = self._fit_single_node( 733 node, self, settings, td_train_data, td_val_data, device_str 734 ) 735 results[node] = history 736 737 738 # ====================================================================== 739 # parallel mode (CPU only) 740 # ====================================================================== 741 if train_mode == "parallel": 742 743 n_jobs = min(len(train_list), os.cpu_count() // 2 or 1) 744 if self.verbose or self.debug: 745 print(f"[INFO] Using {n_jobs} CPU workers for parallel node training") 746 parallel_outputs = Parallel( 747 n_jobs=n_jobs, 748 backend="loky",#loky, multiprocessing 749 verbose=10, 750 prefer="processes" 751 )(delayed(self._fit_single_node)(node, self, settings, td_train_data, td_val_data, device_str) for node in train_list ) 752 753 results = {node: hist for node, hist in parallel_outputs} 754 755 if settings.get("return_history", False): 756 return results 757 758 ## FIT-DIAGNOSTICS 759 def loss_history(self): 760 """ 761 Load training and validation loss history for all nodes. 762 763 Looks for per-node JSON files: 764 765 - ``EXPERIMENT_DIR/{node}/train_loss_hist.json`` 766 - ``EXPERIMENT_DIR/{node}/val_loss_hist.json`` 767 768 Returns 769 ------- 770 dict 771 A dictionary mapping node names to: 772 773 .. code-block:: python 774 775 { 776 "train": list or None, 777 "validation": list or None 778 } 779 780 where each list contains NLL values per epoch, or None if not found. 781 782 Raises 783 ------ 784 ValueError 785 If the experiment directory cannot be resolved from the configuration. 786 """ 787 try: 788 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 789 except KeyError: 790 raise ValueError( 791 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 792 "History retrieval requires experiment logs." 793 ) 794 795 all_histories = {} 796 for node in self.nodes_dict.keys(): 797 node_dir = os.path.join(EXPERIMENT_DIR, node) 798 train_path = os.path.join(node_dir, "train_loss_hist.json") 799 val_path = os.path.join(node_dir, "val_loss_hist.json") 800 801 node_hist = {} 802 803 # --- load train history --- 804 if os.path.exists(train_path): 805 try: 806 with open(train_path, "r") as f: 807 node_hist["train"] = json.load(f) 808 except Exception as e: 809 print(f"[WARNING] Could not load {train_path}: {e}") 810 node_hist["train"] = None 811 else: 812 node_hist["train"] = None 813 814 # --- load val history --- 815 if os.path.exists(val_path): 816 try: 817 with open(val_path, "r") as f: 818 node_hist["validation"] = json.load(f) 819 except Exception as e: 820 print(f"[WARNING] Could not load {val_path}: {e}") 821 node_hist["validation"] = None 822 else: 823 node_hist["validation"] = None 824 825 all_histories[node] = node_hist 826 827 if self.verbose or self.debug: 828 print(f"[INFO] Loaded training/validation histories for {len(all_histories)} nodes.") 829 830 return all_histories 831 832 def linear_shift_history(self): 833 """ 834 Load linear shift term histories for all nodes. 835 836 Each node history is expected in a JSON file named 837 ``linear_shifts_all_epochs.json`` under the node directory. 838 839 Returns 840 ------- 841 dict 842 A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame 843 contains linear shift weights across epochs. 844 845 Raises 846 ------ 847 ValueError 848 If the experiment directory cannot be resolved from the configuration. 849 850 Notes 851 ----- 852 If a history file is missing for a node, a warning is printed and the 853 node is omitted from the returned dictionary. 854 """ 855 histories = {} 856 try: 857 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 858 except KeyError: 859 raise ValueError( 860 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 861 "Cannot load histories without experiment directory." 862 ) 863 864 for node in self.nodes_dict.keys(): 865 node_dir = os.path.join(EXPERIMENT_DIR, node) 866 history_path = os.path.join(node_dir, "linear_shifts_all_epochs.json") 867 if os.path.exists(history_path): 868 histories[node] = pd.read_json(history_path) 869 else: 870 print(f"[WARNING] No linear shift history found for node '{node}' at {history_path}") 871 return histories 872 873 def simple_intercept_history(self): 874 """ 875 Load simple intercept histories for all nodes. 876 877 Each node history is expected in a JSON file named 878 ``simple_intercepts_all_epochs.json`` under the node directory. 879 880 Returns 881 ------- 882 dict 883 A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame 884 contains intercept weights across epochs. 885 886 Raises 887 ------ 888 ValueError 889 If the experiment directory cannot be resolved from the configuration. 890 891 Notes 892 ----- 893 If a history file is missing for a node, a warning is printed and the 894 node is omitted from the returned dictionary. 895 """ 896 histories = {} 897 try: 898 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 899 except KeyError: 900 raise ValueError( 901 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 902 "Cannot load histories without experiment directory." 903 ) 904 905 for node in self.nodes_dict.keys(): 906 node_dir = os.path.join(EXPERIMENT_DIR, node) 907 history_path = os.path.join(node_dir, "simple_intercepts_all_epochs.json") 908 if os.path.exists(history_path): 909 histories[node] = pd.read_json(history_path) 910 else: 911 print(f"[WARNING] No simple intercept history found for node '{node}' at {history_path}") 912 return histories 913 914 def get_latent(self, df, verbose=False): 915 """ 916 Compute latent representations for all nodes in the DAG. 917 918 Parameters 919 ---------- 920 df : pandas.DataFrame 921 Input data frame with columns corresponding to nodes in the DAG. 922 verbose : bool, optional 923 If True, print informational messages during latent computation. 924 Default is False. 925 926 Returns 927 ------- 928 pandas.DataFrame 929 DataFrame containing the original columns plus latent variables 930 for each node (e.g. columns named ``f"{node}_U"``). 931 932 Raises 933 ------ 934 ValueError 935 If the experiment directory is missing from the configuration or 936 if ``self.minmax_dict`` has not been set. 937 """ 938 try: 939 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 940 except KeyError: 941 raise ValueError( 942 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 943 "Latent extraction requires trained model checkpoints." 944 ) 945 946 # ensure minmax_dict is available 947 if not hasattr(self, "minmax_dict"): 948 raise ValueError( 949 "[ERROR] minmax_dict not found in the TramDagModel instance. " 950 "Either call .load_or_compute_minmax(td_train_data=train_df) or .fit() first." 951 ) 952 953 all_latents_df = create_latent_df_for_full_dag( 954 configuration_dict=self.cfg.conf_dict, 955 EXPERIMENT_DIR=EXPERIMENT_DIR, 956 df=df, 957 verbose=verbose, 958 min_max_dict=self.minmax_dict, 959 ) 960 961 return all_latents_df 962 963 964 ## PLOTTING FIT-DIAGNOSTICS 965 966 def plot_loss_history(self, variable: str = None): 967 """ 968 Plot training and validation loss evolution per node. 969 970 Parameters 971 ---------- 972 variable : str or None, optional 973 If provided, plot loss history for this node only. If None, plot 974 histories for all nodes that have both train and validation logs. 975 976 Returns 977 ------- 978 None 979 980 Notes 981 ----- 982 Two subplots are produced: 983 - Full epoch history. 984 - Last 10% of epochs (or only the last epoch if fewer than 5 epochs). 985 """ 986 987 histories = self.loss_history() 988 if not histories: 989 print("[WARNING] No loss histories found.") 990 return 991 992 # Select which nodes to plot 993 if variable is not None: 994 if variable not in histories: 995 raise ValueError(f"[ERROR] Node '{variable}' not found in histories.") 996 nodes_to_plot = [variable] 997 else: 998 nodes_to_plot = list(histories.keys()) 999 1000 # Filter out nodes with no valid history 1001 nodes_to_plot = [ 1002 n for n in nodes_to_plot 1003 if histories[n].get("train") is not None and len(histories[n]["train"]) > 0 1004 and histories[n].get("validation") is not None and len(histories[n]["validation"]) > 0 1005 ] 1006 1007 if not nodes_to_plot: 1008 print("[WARNING] No valid histories found to plot.") 1009 return 1010 1011 plt.figure(figsize=(14, 12)) 1012 1013 # --- Full history (top plot) --- 1014 plt.subplot(2, 1, 1) 1015 for node in nodes_to_plot: 1016 node_hist = histories[node] 1017 train_hist, val_hist = node_hist["train"], node_hist["validation"] 1018 1019 epochs = range(1, len(train_hist) + 1) 1020 plt.plot(epochs, train_hist, label=f"{node} - train", linestyle="--") 1021 plt.plot(epochs, val_hist, label=f"{node} - val") 1022 1023 plt.title("Training and Validation NLL - Full History") 1024 plt.xlabel("Epoch") 1025 plt.ylabel("NLL") 1026 plt.legend() 1027 plt.grid(True) 1028 1029 # --- Last 10% of epochs (bottom plot) --- 1030 plt.subplot(2, 1, 2) 1031 for node in nodes_to_plot: 1032 node_hist = histories[node] 1033 train_hist, val_hist = node_hist["train"], node_hist["validation"] 1034 1035 total_epochs = len(train_hist) 1036 start_idx = total_epochs - 1 if total_epochs < 5 else int(total_epochs * 0.9) 1037 1038 epochs = range(start_idx + 1, total_epochs + 1) 1039 plt.plot(epochs, train_hist[start_idx:], label=f"{node} - train", linestyle="--") 1040 plt.plot(epochs, val_hist[start_idx:], label=f"{node} - val") 1041 1042 plt.title("Training and Validation NLL - Last 10% of Epochs (or Last Epoch if <5)") 1043 plt.xlabel("Epoch") 1044 plt.ylabel("NLL") 1045 plt.legend() 1046 plt.grid(True) 1047 1048 plt.tight_layout() 1049 plt.show() 1050 1051 def plot_linear_shift_history(self, data_dict=None, node=None, ref_lines=None): 1052 """ 1053 Plot the evolution of linear shift terms over epochs. 1054 1055 Parameters 1056 ---------- 1057 data_dict : dict or None, optional 1058 Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing shift 1059 weights across epochs. If None, `linear_shift_history()` is called. 1060 node : str or None, optional 1061 If provided, plot only this node. Otherwise, plot all nodes 1062 present in ``data_dict``. 1063 ref_lines : dict or None, optional 1064 Optional mapping ``{node_name: list of float}``. For each specified 1065 node, horizontal reference lines are drawn at the given values. 1066 1067 Returns 1068 ------- 1069 None 1070 1071 Notes 1072 ----- 1073 The function flattens nested list-like entries in the DataFrames to scalars, 1074 converts epoch labels to numeric, and then draws one line per shift term. 1075 """ 1076 1077 if data_dict is None: 1078 data_dict = self.linear_shift_history() 1079 if data_dict is None: 1080 raise ValueError("No shift history data provided or stored in the class.") 1081 1082 nodes = [node] if node else list(data_dict.keys()) 1083 1084 for n in nodes: 1085 df = data_dict[n].copy() 1086 1087 # Flatten nested lists or list-like cells 1088 def flatten(x): 1089 if isinstance(x, list): 1090 if len(x) == 0: 1091 return np.nan 1092 if all(isinstance(i, (int, float)) for i in x): 1093 return np.mean(x) # average simple list 1094 if all(isinstance(i, list) for i in x): 1095 # nested list -> flatten inner and average 1096 flat = [v for sub in x for v in (sub if isinstance(sub, list) else [sub])] 1097 return np.mean(flat) if flat else np.nan 1098 return x[0] if len(x) == 1 else np.nan 1099 return x 1100 1101 df = df.applymap(flatten) 1102 1103 # Ensure numeric columns 1104 df = df.apply(pd.to_numeric, errors='coerce') 1105 1106 # Convert epoch labels to numeric 1107 df.columns = [ 1108 int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c 1109 for c in df.columns 1110 ] 1111 df = df.reindex(sorted(df.columns), axis=1) 1112 1113 plt.figure(figsize=(10, 6)) 1114 for idx in df.index: 1115 plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"shift_{idx}") 1116 1117 if ref_lines and n in ref_lines: 1118 for v in ref_lines[n]: 1119 plt.axhline(y=v, color="k", linestyle="--", lw=1.0) 1120 plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8) 1121 1122 plt.xlabel("Epoch") 1123 plt.ylabel("Shift Value") 1124 plt.title(f"Shift Term History — Node: {n}") 1125 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small") 1126 plt.tight_layout() 1127 plt.show() 1128 1129 def plot_simple_intercepts_history(self, data_dict=None, node=None,ref_lines=None): 1130 """ 1131 Plot the evolution of simple intercept weights over epochs. 1132 1133 Parameters 1134 ---------- 1135 data_dict : dict or None, optional 1136 Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing intercept 1137 weights across epochs. If None, `simple_intercept_history()` is called. 1138 node : str or None, optional 1139 If provided, plot only this node. Otherwise, plot all nodes present 1140 in ``data_dict``. 1141 ref_lines : dict or None, optional 1142 Optional mapping ``{node_name: list of float}``. For each specified 1143 node, horizontal reference lines are drawn at the given values. 1144 1145 Returns 1146 ------- 1147 None 1148 1149 Notes 1150 ----- 1151 Nested list-like entries in the DataFrames are reduced to scalars before 1152 plotting. One line is drawn per intercept parameter. 1153 """ 1154 if data_dict is None: 1155 data_dict = self.simple_intercept_history() 1156 if data_dict is None: 1157 raise ValueError("No intercept history data provided or stored in the class.") 1158 1159 nodes = [node] if node else list(data_dict.keys()) 1160 1161 for n in nodes: 1162 df = data_dict[n].copy() 1163 1164 def extract_scalar(x): 1165 if isinstance(x, list): 1166 while isinstance(x, list) and len(x) > 0: 1167 x = x[0] 1168 return float(x) if isinstance(x, (int, float, np.floating)) else np.nan 1169 1170 df = df.applymap(extract_scalar) 1171 1172 # Convert epoch labels → numeric 1173 df.columns = [ 1174 int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c 1175 for c in df.columns 1176 ] 1177 df = df.reindex(sorted(df.columns), axis=1) 1178 1179 plt.figure(figsize=(10, 6)) 1180 for idx in df.index: 1181 plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"theta_{idx}") 1182 1183 if ref_lines and n in ref_lines: 1184 for v in ref_lines[n]: 1185 plt.axhline(y=v, color="k", linestyle="--", lw=1.0) 1186 plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8) 1187 1188 plt.xlabel("Epoch") 1189 plt.ylabel("Intercept Weight") 1190 plt.title(f"Simple Intercept Evolution — Node: {n}") 1191 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small") 1192 plt.tight_layout() 1193 plt.show() 1194 1195 def plot_latents(self, df, variable: str = None, confidence: float = 0.95, simulations: int = 1000): 1196 """ 1197 Visualize latent U distributions for one or all nodes. 1198 1199 Parameters 1200 ---------- 1201 df : pandas.DataFrame 1202 Input data frame with raw node values. 1203 variable : str or None, optional 1204 If provided, only this node's latents are plotted. If None, all 1205 nodes with latent columns are processed. 1206 confidence : float, optional 1207 Confidence level for QQ-plot bands (0 < confidence < 1). 1208 Default is 0.95. 1209 simulations : int, optional 1210 Number of Monte Carlo simulations for QQ-plot bands. Default is 1000. 1211 1212 Returns 1213 ------- 1214 None 1215 1216 Notes 1217 ----- 1218 For each node, two plots are produced: 1219 - Histogram of the latent U values. 1220 - QQ-plot with simulation-based confidence bands under a logistic reference. 1221 """ 1222 # Compute latent representations 1223 latents_df = self.get_latent(df) 1224 1225 # Select nodes 1226 nodes = [variable] if variable is not None else self.nodes_dict.keys() 1227 1228 for node in nodes: 1229 if f"{node}_U" not in latents_df.columns: 1230 print(f"[WARNING] No latent found for node {node}, skipping.") 1231 continue 1232 1233 sample = latents_df[f"{node}_U"].values 1234 1235 # --- Create plots --- 1236 fig, axs = plt.subplots(1, 2, figsize=(12, 5)) 1237 1238 # Histogram 1239 axs[0].hist(sample, bins=50, color="steelblue", alpha=0.7) 1240 axs[0].set_title(f"Latent Histogram ({node})") 1241 axs[0].set_xlabel("U") 1242 axs[0].set_ylabel("Frequency") 1243 1244 # QQ Plot with confidence bands 1245 probplot(sample, dist="logistic", plot=axs[1]) 1246 self._add_r_style_confidence_bands(axs[1], sample, dist=logistic,confidence=confidence, simulations=simulations) 1247 axs[1].set_title(f"Latent QQ Plot ({node})") 1248 1249 plt.suptitle(f"Latent Diagnostics for Node: {node}", fontsize=14) 1250 plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 1251 plt.show() 1252 1253 def plot_hdag(self,df,variables=None, plot_n_rows=1,**kwargs): 1254 1255 """ 1256 Visualize the transformation function h() for selected DAG nodes. 1257 1258 Parameters 1259 ---------- 1260 df : pandas.DataFrame 1261 Input data containing node values or model predictions. 1262 variables : list of str or None, optional 1263 Names of nodes to visualize. If None, all nodes in ``self.models`` 1264 are considered. 1265 plot_n_rows : int, optional 1266 Maximum number of rows from ``df`` to visualize. Default is 1. 1267 **kwargs 1268 Additional keyword arguments forwarded to the underlying plotting 1269 helpers (`show_hdag_continous` / `show_hdag_ordinal`). 1270 1271 Returns 1272 ------- 1273 None 1274 1275 Notes 1276 ----- 1277 - For continuous outcomes, `show_hdag_continous` is called. 1278 - For ordinal outcomes, `show_hdag_ordinal` is called. 1279 - Nodes that are neither continuous nor ordinal are skipped with a warning. 1280 """ 1281 1282 1283 if len(df)> 1: 1284 print("[WARNING] len(df)>1, set: plot_n_rows accordingly") 1285 1286 variables_list=variables if variables is not None else list(self.models.keys()) 1287 for node in variables_list: 1288 if is_outcome_modelled_continous(node, self.nodes_dict): 1289 show_hdag_continous(df,node=node,configuration_dict=self.cfg.conf_dict,minmax_dict=self.minmax_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1290 1291 elif is_outcome_modelled_ordinal(node, self.nodes_dict): 1292 show_hdag_ordinal(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1293 # plot_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1294 # save_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,**kwargs) 1295 else: 1296 print(f"[WARNING] Node {node} is wheter ordinal nor continous, not implemented yet") 1297 1298 @staticmethod 1299 def _add_r_style_confidence_bands(ax, sample, dist, confidence=0.95, simulations=1000): 1300 """ 1301 Add simulation-based confidence bands to a QQ-plot. 1302 1303 Parameters 1304 ---------- 1305 ax : matplotlib.axes.Axes 1306 Axes object on which to draw the QQ-plot and bands. 1307 sample : array-like 1308 Empirical sample used in the QQ-plot. 1309 dist : scipy.stats distribution 1310 Distribution object providing ``ppf`` and ``rvs`` methods (e.g. logistic). 1311 confidence : float, optional 1312 Confidence level (0 < confidence < 1) for the bands. Default is 0.95. 1313 simulations : int, optional 1314 Number of Monte Carlo simulations used to estimate the bands. Default is 1000. 1315 1316 Returns 1317 ------- 1318 None 1319 1320 Notes 1321 ----- 1322 The axes are cleared, and a new QQ-plot is drawn with: 1323 - Empirical vs. theoretical quantiles. 1324 - 45-degree reference line. 1325 - Shaded confidence band region. 1326 """ 1327 1328 n = len(sample) 1329 if n == 0: 1330 return 1331 1332 quantiles = np.linspace(0, 1, n, endpoint=False) + 0.5 / n 1333 theo_q = dist.ppf(quantiles) 1334 1335 # Simulate order statistics from the theoretical distribution 1336 sim_data = dist.rvs(size=(simulations, n)) 1337 sim_order_stats = np.sort(sim_data, axis=1) 1338 1339 # Confidence bands 1340 lower = np.percentile(sim_order_stats, 100 * (1 - confidence) / 2, axis=0) 1341 upper = np.percentile(sim_order_stats, 100 * (1 + confidence) / 2, axis=0) 1342 1343 # Sort empirical sample 1344 sample_sorted = np.sort(sample) 1345 1346 # Re-draw points and CI (overwrite probplot defaults) 1347 ax.clear() 1348 ax.plot(theo_q, sample_sorted, 'o', markersize=3, alpha=0.6, label="Empirical Q-Q") 1349 ax.plot(theo_q, theo_q, 'b--', label="y = x") 1350 ax.fill_between(theo_q, lower, upper, color='gray', alpha=0.3, 1351 label=f'{int(confidence*100)}% CI') 1352 ax.legend() 1353 1354 ## SAMPLING METHODS 1355 def sample( 1356 self, 1357 do_interventions: dict = None, 1358 predefined_latent_samples_df: pd.DataFrame = None, 1359 **kwargs, 1360 ): 1361 """ 1362 Sample from the joint DAG using the trained TRAM models. 1363 1364 Allows for: 1365 1366 Oberservational sampling 1367 Interventional sampling via ``do()`` operations 1368 Counterfactial sampling using predefined latent draws and do() 1369 1370 Parameters 1371 ---------- 1372 do_interventions : dict or None, optional 1373 Mapping of node names to intervened (fixed) values. For example: 1374 ``{"x1": 1.0}`` represents ``do(x1 = 1.0)``. Default is None. 1375 predefined_latent_samples_df : pandas.DataFrame or None, optional 1376 DataFrame containing columns ``"{node}_U"`` with predefined latent 1377 draws to be used instead of sampling from the prior. Default is None. 1378 **kwargs 1379 Sampling options overriding internal defaults: 1380 1381 number_of_samples : int, default 10000 1382 Total number of samples to draw. 1383 batch_size : int, default 32 1384 Batch size for internal sampling loops. 1385 delete_all_previously_sampled : bool, default True 1386 If True, delete old sampling files in node-specific sampling 1387 directories before writing new ones. 1388 verbose : bool 1389 If True, print informational messages. 1390 debug : bool 1391 If True, print debug output. 1392 device : {"auto", "cpu", "cuda"} 1393 Device selection for sampling. 1394 use_initial_weights_for_sampling : bool, default False 1395 If True, sample from initial (untrained) model parameters. 1396 1397 Returns 1398 ------- 1399 tuple 1400 A tuple ``(sampled_by_node, latents_by_node)``: 1401 1402 sampled_by_node : dict 1403 Mapping ``{node_name: torch.Tensor}`` of sampled node values. 1404 latents_by_node : dict 1405 Mapping ``{node_name: torch.Tensor}`` of latent U values used. 1406 1407 Raises 1408 ------ 1409 ValueError 1410 If the experiment directory cannot be resolved or if scaling 1411 information (``self.minmax_dict``) is missing. 1412 RuntimeError 1413 If min–max scaling has not been computed before calling `sample`. 1414 """ 1415 try: 1416 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1417 except KeyError: 1418 raise ValueError( 1419 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 1420 "Sampling requires trained model checkpoints." 1421 ) 1422 1423 # ---- defaults ---- 1424 settings = { 1425 "number_of_samples": 10_000, 1426 "batch_size": 32, 1427 "delete_all_previously_sampled": True, 1428 "verbose": self.verbose if hasattr(self, "verbose") else False, 1429 "debug": self.debug if hasattr(self, "debug") else False, 1430 "device": self.device.type if hasattr(self, "device") else "auto", 1431 "use_initial_weights_for_sampling": False, 1432 1433 } 1434 1435 # self._validate_kwargs( kwargs, defaults_attr= "settings", context="sample") 1436 1437 settings.update(kwargs) 1438 1439 1440 if not hasattr(self, "minmax_dict"): 1441 raise RuntimeError( 1442 "[ERROR] minmax_dict not found. You must call .fit() or .load_or_compute_minmax() " 1443 "before sampling, so scaling info is available." 1444 ) 1445 1446 # ---- resolve device ---- 1447 device_str=self.get_device(settings) 1448 self.device = torch.device(device_str) 1449 1450 1451 if self.debug or settings["debug"]: 1452 print(f"[DEBUG] sample(): device: {self.device}") 1453 1454 # ---- perform sampling ---- 1455 sampled_by_node, latents_by_node = sample_full_dag( 1456 configuration_dict=self.cfg.conf_dict, 1457 EXPERIMENT_DIR=EXPERIMENT_DIR, 1458 device=self.device, 1459 do_interventions=do_interventions or {}, 1460 predefined_latent_samples_df=predefined_latent_samples_df, 1461 number_of_samples=settings["number_of_samples"], 1462 batch_size=settings["batch_size"], 1463 delete_all_previously_sampled=settings["delete_all_previously_sampled"], 1464 verbose=settings["verbose"], 1465 debug=settings["debug"], 1466 minmax_dict=self.minmax_dict, 1467 use_initial_weights_for_sampling=settings["use_initial_weights_for_sampling"] 1468 ) 1469 1470 return sampled_by_node, latents_by_node 1471 1472 def load_sampled_and_latents(self, EXPERIMENT_DIR: str = None, nodes: list = None): 1473 """ 1474 Load previously stored sampled values and latents for each node. 1475 1476 Parameters 1477 ---------- 1478 EXPERIMENT_DIR : str or None, optional 1479 Experiment directory path. If None, it is taken from 1480 ``self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]``. 1481 nodes : list of str or None, optional 1482 Nodes for which to load samples. If None, use all nodes from 1483 ``self.nodes_dict``. 1484 1485 Returns 1486 ------- 1487 tuple 1488 A tuple ``(sampled_by_node, latents_by_node)``: 1489 1490 sampled_by_node : dict 1491 Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU). 1492 latents_by_node : dict 1493 Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU). 1494 1495 Raises 1496 ------ 1497 ValueError 1498 If the experiment directory cannot be resolved or if no node list 1499 is available and ``nodes`` is None. 1500 1501 Notes 1502 ----- 1503 Nodes without both ``sampled.pt`` and ``latents.pt`` files are skipped 1504 with a warning. 1505 """ 1506 # --- resolve paths and node list --- 1507 if EXPERIMENT_DIR is None: 1508 try: 1509 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1510 except (AttributeError, KeyError): 1511 raise ValueError( 1512 "[ERROR] Could not resolve EXPERIMENT_DIR from cfg.conf_dict['PATHS']. " 1513 "Please provide EXPERIMENT_DIR explicitly." 1514 ) 1515 1516 if nodes is None: 1517 if hasattr(self, "nodes_dict"): 1518 nodes = list(self.nodes_dict.keys()) 1519 else: 1520 raise ValueError( 1521 "[ERROR] No node list found. Please provide `nodes` or initialize model with a config." 1522 ) 1523 1524 # --- load tensors --- 1525 sampled_by_node = {} 1526 latents_by_node = {} 1527 1528 for node in nodes: 1529 node_dir = os.path.join(EXPERIMENT_DIR, f"{node}") 1530 sampling_dir = os.path.join(node_dir, "sampling") 1531 1532 sampled_path = os.path.join(sampling_dir, "sampled.pt") 1533 latents_path = os.path.join(sampling_dir, "latents.pt") 1534 1535 if not os.path.exists(sampled_path) or not os.path.exists(latents_path): 1536 print(f"[WARNING] Missing files for node '{node}' — skipping.") 1537 continue 1538 1539 try: 1540 sampled = torch.load(sampled_path, map_location="cpu") 1541 latent_sample = torch.load(latents_path, map_location="cpu") 1542 except Exception as e: 1543 print(f"[ERROR] Could not load sampling files for node '{node}': {e}") 1544 continue 1545 1546 sampled_by_node[node] = sampled.detach().cpu() 1547 latents_by_node[node] = latent_sample.detach().cpu() 1548 1549 if self.verbose or self.debug: 1550 print(f"[INFO] Loaded sampled and latent tensors for {len(sampled_by_node)} nodes from {EXPERIMENT_DIR}") 1551 1552 return sampled_by_node, latents_by_node 1553 1554 def plot_samples_vs_true( 1555 self, 1556 df, 1557 sampled: dict = None, 1558 variable: list = None, 1559 bins: int = 100, 1560 hist_true_color: str = "blue", 1561 hist_est_color: str = "orange", 1562 figsize: tuple = (14, 5), 1563 ): 1564 1565 1566 """ 1567 Compare sampled vs. observed distributions for selected nodes. 1568 1569 Parameters 1570 ---------- 1571 df : pandas.DataFrame 1572 Data frame containing the observed node values. 1573 sampled : dict or None, optional 1574 Optional mapping ``{node_name: array-like or torch.Tensor}`` of sampled 1575 values. If None or if a node is missing, samples are loaded from 1576 ``EXPERIMENT_DIR/{node}/sampling/sampled.pt``. 1577 variable : list of str or None, optional 1578 Subset of nodes to plot. If None, all nodes in the configuration 1579 are considered. 1580 bins : int, optional 1581 Number of histogram bins for continuous variables. Default is 100. 1582 hist_true_color : str, optional 1583 Color name for the histogram of true values. Default is "blue". 1584 hist_est_color : str, optional 1585 Color name for the histogram of sampled values. Default is "orange". 1586 figsize : tuple, optional 1587 Figure size for the matplotlib plots. Default is (14, 5). 1588 1589 Returns 1590 ------- 1591 None 1592 1593 Notes 1594 ----- 1595 - Continuous outcomes: histogram overlay + QQ-plot. 1596 - Ordinal outcomes: side-by-side bar plot of relative frequencies. 1597 - Other categorical outcomes: side-by-side bar plot with category labels. 1598 - If samples are probabilistic (2D tensor), the argmax across classes is used. 1599 """ 1600 1601 target_nodes = self.cfg.conf_dict["nodes"] 1602 experiment_dir = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1603 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 1604 1605 plot_list = variable if variable is not None else target_nodes 1606 1607 for node in plot_list: 1608 # Load sampled data 1609 if sampled is not None and node in sampled: 1610 sdata = sampled[node] 1611 if isinstance(sdata, torch.Tensor): 1612 sampled_vals = sdata.detach().cpu().numpy() 1613 else: 1614 sampled_vals = np.asarray(sdata) 1615 else: 1616 sample_path = os.path.join(experiment_dir, f"{node}/sampling/sampled.pt") 1617 if not os.path.isfile(sample_path): 1618 print(f"[WARNING] skip {node}: {sample_path} not found.") 1619 continue 1620 1621 try: 1622 sampled_vals = torch.load(sample_path, map_location=device).cpu().numpy() 1623 except Exception as e: 1624 print(f"[ERROR] Could not load {sample_path}: {e}") 1625 continue 1626 1627 # If logits/probabilities per sample, take argmax 1628 if sampled_vals.ndim == 2: 1629 print(f"[INFO] CAUTION! {node}: samples are probabilistic — each sample follows a probability " 1630 f"distribution based on the valid latent range. " 1631 f"Note that this frequency plot reflects only the distribution of the most probable " 1632 f"class per sample.") 1633 sampled_vals = np.argmax(sampled_vals, axis=1) 1634 1635 sampled_vals = sampled_vals[np.isfinite(sampled_vals)] 1636 1637 if node not in df.columns: 1638 print(f"[WARNING] skip {node}: column not found in DataFrame.") 1639 continue 1640 1641 true_vals = df[node].dropna().values 1642 true_vals = true_vals[np.isfinite(true_vals)] 1643 1644 if sampled_vals.size == 0 or true_vals.size == 0: 1645 print(f"[WARNING] skip {node}: empty array after NaN/Inf removal.") 1646 continue 1647 1648 fig, axs = plt.subplots(1, 2, figsize=figsize) 1649 1650 if is_outcome_modelled_continous(node, target_nodes): 1651 axs[0].hist(true_vals, bins=bins, density=True, alpha=0.6, 1652 color=hist_true_color, label=f"True {node}") 1653 axs[0].hist(sampled_vals, bins=bins, density=True, alpha=0.6, 1654 color=hist_est_color, label="Sampled") 1655 axs[0].set_xlabel("Value") 1656 axs[0].set_ylabel("Density") 1657 axs[0].set_title(f"Histogram overlay for {node}") 1658 axs[0].legend() 1659 axs[0].grid(True, ls="--", alpha=0.4) 1660 1661 qqplot_2samples(true_vals, sampled_vals, line="45", ax=axs[1]) 1662 axs[1].set_xlabel("True quantiles") 1663 axs[1].set_ylabel("Sampled quantiles") 1664 axs[1].set_title(f"QQ plot for {node}") 1665 axs[1].grid(True, ls="--", alpha=0.4) 1666 1667 elif is_outcome_modelled_ordinal(node, target_nodes): 1668 unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals)) 1669 unique_vals = np.sort(unique_vals) 1670 true_counts = np.array([(true_vals == val).sum() for val in unique_vals]) 1671 sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals]) 1672 1673 axs[0].bar(unique_vals - 0.2, true_counts / true_counts.sum(), 1674 width=0.4, color=hist_true_color, alpha=0.7, label="True") 1675 axs[0].bar(unique_vals + 0.2, sampled_counts / sampled_counts.sum(), 1676 width=0.4, color=hist_est_color, alpha=0.7, label="Sampled") 1677 axs[0].set_xticks(unique_vals) 1678 axs[0].set_xlabel("Ordinal Level") 1679 axs[0].set_ylabel("Relative Frequency") 1680 axs[0].set_title(f"Ordinal bar plot for {node}") 1681 axs[0].legend() 1682 axs[0].grid(True, ls="--", alpha=0.4) 1683 axs[1].axis("off") 1684 1685 else: 1686 unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals)) 1687 unique_vals = sorted(unique_vals, key=str) 1688 true_counts = np.array([(true_vals == val).sum() for val in unique_vals]) 1689 sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals]) 1690 1691 axs[0].bar(np.arange(len(unique_vals)) - 0.2, true_counts / true_counts.sum(), 1692 width=0.4, color=hist_true_color, alpha=0.7, label="True") 1693 axs[0].bar(np.arange(len(unique_vals)) + 0.2, sampled_counts / sampled_counts.sum(), 1694 width=0.4, color=hist_est_color, alpha=0.7, label="Sampled") 1695 axs[0].set_xticks(np.arange(len(unique_vals))) 1696 axs[0].set_xticklabels(unique_vals, rotation=45) 1697 axs[0].set_xlabel("Category") 1698 axs[0].set_ylabel("Relative Frequency") 1699 axs[0].set_title(f"Categorical bar plot for {node}") 1700 axs[0].legend() 1701 axs[0].grid(True, ls="--", alpha=0.4) 1702 axs[1].axis("off") 1703 1704 plt.tight_layout() 1705 plt.show() 1706 1707 ## SUMMARY METHODS 1708 def nll(self,data,variables=None): 1709 """ 1710 Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes. 1711 1712 This function evaluates trained TRAM models for each specified variable (node) 1713 on the provided dataset. It performs forward passes only—no training, no weight 1714 updates—and returns the mean NLL per node. 1715 1716 Parameters 1717 ---------- 1718 data : object 1719 Input dataset or data source compatible with `_ensure_dataset`, containing 1720 both inputs and targets for each node. 1721 variables : list[str], optional 1722 List of variable (node) names to evaluate. If None, all nodes in 1723 `self.models` are evaluated. 1724 1725 Returns 1726 ------- 1727 dict[str, float] 1728 Dictionary mapping each node name to its average NLL value. 1729 1730 Notes 1731 ----- 1732 - Each model is evaluated independently on its respective DataLoader. 1733 - The normalization values (`min_max`) for each node are retrieved from 1734 `self.minmax_dict[node]`. 1735 - The function uses `evaluate_tramdag_model()` for per-node evaluation. 1736 - Expected directory structure: 1737 `<EXPERIMENT_DIR>/<node>/` 1738 where each node directory contains the trained model. 1739 """ 1740 1741 td_data = self._ensure_dataset(data, is_val=True) 1742 variables_list = variables if variables != None else list(self.models.keys()) 1743 nll_dict = {} 1744 for node in variables_list: 1745 min_vals = torch.tensor(self.minmax_dict[node][0], dtype=torch.float32) 1746 max_vals = torch.tensor(self.minmax_dict[node][1], dtype=torch.float32) 1747 min_max = torch.stack([min_vals, max_vals], dim=0) 1748 data_loader = td_data.loaders[node] 1749 model = self.models[node] 1750 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1751 NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}") 1752 nll= evaluate_tramdag_model(node=node, 1753 target_nodes=self.nodes_dict, 1754 NODE_DIR=NODE_DIR, 1755 tram_model=model, 1756 data_loader=data_loader, 1757 min_max=min_max) 1758 nll_dict[node]=nll 1759 return nll_dict 1760 1761 def get_train_val_nll(self, node: str, mode: str) -> tuple[float, float]: 1762 """ 1763 Retrieve training and validation NLL for a node and a given model state. 1764 1765 Parameters 1766 ---------- 1767 node : str 1768 Node name. 1769 mode : {"best", "last", "init"} 1770 State of interest: 1771 - "best": epoch with lowest validation NLL. 1772 - "last": final epoch. 1773 - "init": first epoch (index 0). 1774 1775 Returns 1776 ------- 1777 tuple of (float or None, float or None) 1778 A tuple ``(train_nll, val_nll)`` for the requested mode. 1779 Returns ``(None, None)`` if loss files are missing or cannot be read. 1780 1781 Notes 1782 ----- 1783 This method expects per-node JSON files: 1784 1785 - ``train_loss_hist.json`` 1786 - ``val_loss_hist.json`` 1787 1788 in the node directory. 1789 """ 1790 NODE_DIR = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node) 1791 train_path = os.path.join(NODE_DIR, "train_loss_hist.json") 1792 val_path = os.path.join(NODE_DIR, "val_loss_hist.json") 1793 1794 if not os.path.exists(train_path) or not os.path.exists(val_path): 1795 if getattr(self, "debug", False): 1796 print(f"[DEBUG] Missing loss files for node '{node}'. Returning None.") 1797 return None, None 1798 1799 try: 1800 with open(train_path, "r") as f: 1801 train_hist = json.load(f) 1802 with open(val_path, "r") as f: 1803 val_hist = json.load(f) 1804 1805 train_nlls = np.array(train_hist) 1806 val_nlls = np.array(val_hist) 1807 1808 if mode == "init": 1809 idx = 0 1810 elif mode == "last": 1811 idx = len(val_nlls) - 1 1812 elif mode == "best": 1813 idx = int(np.argmin(val_nlls)) 1814 else: 1815 raise ValueError(f"Invalid mode '{mode}' — must be one of 'best', 'last', 'init'.") 1816 1817 train_nll = float(train_nlls[idx]) 1818 val_nll = float(val_nlls[idx]) 1819 return train_nll, val_nll 1820 1821 except Exception as e: 1822 print(f"[ERROR] Failed to load NLLs for node '{node}' ({mode}): {e}") 1823 return None, None 1824 1825 def get_thetas(self, node: str, state: str = "best"): 1826 """ 1827 Return transformed intercept (theta) parameters for a node and state. 1828 1829 Parameters 1830 ---------- 1831 node : str 1832 Node name. 1833 state : {"best", "last", "init"}, optional 1834 Model state for which to return parameters. Default is "best". 1835 1836 Returns 1837 ------- 1838 Any or None 1839 Transformed theta parameters for the requested node and state. 1840 The exact structure (scalar, list, or other) depends on the model. 1841 1842 Raises 1843 ------ 1844 ValueError 1845 If an invalid state is given (not in {"best", "last", "init"}). 1846 1847 Notes 1848 ----- 1849 Intercept dictionaries are cached on the instance under the attribute 1850 ``intercept_dicts``. If missing or incomplete, they are recomputed using 1851 `get_simple_intercepts_dict`. 1852 """ 1853 1854 state = state.lower() 1855 if state not in ["best", "last", "init"]: 1856 raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].") 1857 1858 dict_attr = "intercept_dicts" 1859 1860 # If no cached intercepts exist, compute them 1861 if not hasattr(self, dict_attr): 1862 if getattr(self, "debug", False): 1863 print(f"[DEBUG] '{dict_attr}' not found, computing via get_simple_intercepts_dict().") 1864 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1865 1866 all_dicts = getattr(self, dict_attr) 1867 1868 # If the requested state isn’t cached, recompute 1869 if state not in all_dicts: 1870 if getattr(self, "debug", False): 1871 print(f"[DEBUG] State '{state}' not found in cached intercepts, recomputing full dict.") 1872 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1873 all_dicts = getattr(self, dict_attr) 1874 1875 state_dict = all_dicts.get(state, {}) 1876 1877 # Return cached node intercept if present 1878 if node in state_dict: 1879 return state_dict[node] 1880 1881 # If not found, recompute full dict as fallback 1882 if getattr(self, "debug", False): 1883 print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.") 1884 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1885 all_dicts = getattr(self, dict_attr) 1886 return all_dicts.get(state, {}).get(node, None) 1887 1888 def get_linear_shifts(self, node: str, state: str = "best"): 1889 """ 1890 Return learned linear shift terms for a node and a given state. 1891 1892 Parameters 1893 ---------- 1894 node : str 1895 Node name. 1896 state : {"best", "last", "init"}, optional 1897 Model state for which to return linear shift terms. Default is "best". 1898 1899 Returns 1900 ------- 1901 dict or Any or None 1902 Linear shift terms for the given node and state. Usually a dict 1903 mapping term names to weights. 1904 1905 Raises 1906 ------ 1907 ValueError 1908 If an invalid state is given (not in {"best", "last", "init"}). 1909 1910 Notes 1911 ----- 1912 Linear shift dictionaries are cached on the instance under the attribute 1913 ``linear_shift_dicts``. If missing or incomplete, they are recomputed using 1914 `get_linear_shifts_dict`. 1915 """ 1916 state = state.lower() 1917 if state not in ["best", "last", "init"]: 1918 raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].") 1919 1920 dict_attr = "linear_shift_dicts" 1921 1922 # If no global dicts cached, compute once 1923 if not hasattr(self, dict_attr): 1924 if getattr(self, "debug", False): 1925 print(f"[DEBUG] '{dict_attr}' not found, computing via get_linear_shifts_dict().") 1926 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1927 1928 all_dicts = getattr(self, dict_attr) 1929 1930 # If the requested state isn't cached, compute all again (covers fresh runs) 1931 if state not in all_dicts: 1932 if getattr(self, "debug", False): 1933 print(f"[DEBUG] State '{state}' not found in cached linear shifts, recomputing full dict.") 1934 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1935 all_dicts = getattr(self, dict_attr) 1936 1937 # Now fetch the dictionary for this state 1938 state_dict = all_dicts.get(state, {}) 1939 1940 # If the node is available, return its entry 1941 if node in state_dict: 1942 return state_dict[node] 1943 1944 # If missing, try recomputing (fallback) 1945 if getattr(self, "debug", False): 1946 print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.") 1947 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1948 all_dicts = getattr(self, dict_attr) 1949 return all_dicts.get(state, {}).get(node, None) 1950 1951 def get_linear_shifts_dict(self): 1952 """ 1953 Compute linear shift term dictionaries for all nodes and states. 1954 1955 For each node and each available state ("best", "last", "init"), this 1956 method loads the corresponding model checkpoint, extracts linear shift 1957 weights from the TRAM model, and stores them in a nested dictionary. 1958 1959 Returns 1960 ------- 1961 dict 1962 Nested dictionary of the form: 1963 1964 .. code-block:: python 1965 1966 { 1967 "best": {node: {...}}, 1968 "last": {node: {...}}, 1969 "init": {node: {...}}, 1970 } 1971 1972 where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``) 1973 to their weights. 1974 1975 Notes 1976 ----- 1977 - If "best" or "last" checkpoints are unavailable for a node, only 1978 the "init" entry is populated. 1979 - Empty outer states (without any nodes) are removed from the result. 1980 """ 1981 1982 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1983 nodes_list = list(self.models.keys()) 1984 all_states = ["best", "last", "init"] 1985 all_linear_shift_dicts = {state: {} for state in all_states} 1986 1987 for node in nodes_list: 1988 NODE_DIR = os.path.join(EXPERIMENT_DIR, node) 1989 BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR) 1990 INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt") 1991 1992 state_paths = { 1993 "best": BEST_MODEL_PATH, 1994 "last": LAST_MODEL_PATH, 1995 "init": INIT_MODEL_PATH, 1996 } 1997 1998 for state, LOAD_PATH in state_paths.items(): 1999 if not os.path.exists(LOAD_PATH): 2000 if state != "init": 2001 # skip best/last if unavailable 2002 continue 2003 else: 2004 print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.") 2005 if not os.path.exists(LOAD_PATH): 2006 if getattr(self, "debug", False): 2007 print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.") 2008 continue 2009 2010 # Load parents and model 2011 _, terms_dict, _ = ordered_parents(node, self.nodes_dict) 2012 state_dict = torch.load(LOAD_PATH, map_location=self.device) 2013 tram_model = self.models[node] 2014 tram_model.load_state_dict(state_dict) 2015 2016 epoch_weights = {} 2017 if hasattr(tram_model, "nn_shift") and tram_model.nn_shift is not None: 2018 for i, shift_layer in enumerate(tram_model.nn_shift): 2019 module_name = shift_layer.__class__.__name__ 2020 if ( 2021 hasattr(shift_layer, "fc") 2022 and hasattr(shift_layer.fc, "weight") 2023 and module_name == "LinearShift" 2024 ): 2025 term_name = list(terms_dict.keys())[i] 2026 epoch_weights[f"ls({term_name})"] = ( 2027 shift_layer.fc.weight.detach().cpu().squeeze().tolist() 2028 ) 2029 elif getattr(self, "debug", False): 2030 term_name = list(terms_dict.keys())[i] 2031 print(f"[DEBUG] ls({term_name}): missing 'fc' or 'weight' in LinearShift.") 2032 else: 2033 if getattr(self, "debug", False): 2034 print(f"[DEBUG] Tram model for node '{node}' has no nn_shift or it is None.") 2035 2036 all_linear_shift_dicts[state][node] = epoch_weights 2037 2038 # Remove empty states (e.g., when best/last not found for all nodes) 2039 all_linear_shift_dicts = {k: v for k, v in all_linear_shift_dicts.items() if v} 2040 2041 return all_linear_shift_dicts 2042 2043 def get_simple_intercepts_dict(self): 2044 """ 2045 Compute transformed simple intercept dictionaries for all nodes and states. 2046 2047 For each node and each available state ("best", "last", "init"), this 2048 method loads the corresponding model checkpoint, extracts simple intercept 2049 weights, transforms them into interpretable theta parameters, and stores 2050 them in a nested dictionary. 2051 2052 Returns 2053 ------- 2054 dict 2055 Nested dictionary of the form: 2056 2057 .. code-block:: python 2058 2059 { 2060 "best": {node: [[theta_1], [theta_2], ...]}, 2061 "last": {node: [[theta_1], [theta_2], ...]}, 2062 "init": {node: [[theta_1], [theta_2], ...]}, 2063 } 2064 2065 Notes 2066 ----- 2067 - For ordinal models (``self.is_ontram == True``), `transform_intercepts_ordinal` 2068 is used. 2069 - For continuous models, `transform_intercepts_continous` is used. 2070 - Empty outer states (without any nodes) are removed from the result. 2071 """ 2072 2073 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 2074 nodes_list = list(self.models.keys()) 2075 all_states = ["best", "last", "init"] 2076 all_si_intercept_dicts = {state: {} for state in all_states} 2077 2078 debug = getattr(self, "debug", False) 2079 verbose = getattr(self, "verbose", False) 2080 is_ontram = getattr(self, "is_ontram", False) 2081 2082 for node in nodes_list: 2083 NODE_DIR = os.path.join(EXPERIMENT_DIR, node) 2084 BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR) 2085 INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt") 2086 2087 state_paths = { 2088 "best": BEST_MODEL_PATH, 2089 "last": LAST_MODEL_PATH, 2090 "init": INIT_MODEL_PATH, 2091 } 2092 2093 for state, LOAD_PATH in state_paths.items(): 2094 if not os.path.exists(LOAD_PATH): 2095 if state != "init": 2096 continue 2097 else: 2098 print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.") 2099 if not os.path.exists(LOAD_PATH): 2100 if debug: 2101 print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.") 2102 continue 2103 2104 # Load model state 2105 state_dict = torch.load(LOAD_PATH, map_location=self.device) 2106 tram_model = self.models[node] 2107 tram_model.load_state_dict(state_dict) 2108 2109 # Extract and transform simple intercept weights 2110 si_weights = None 2111 if hasattr(tram_model, "nn_int") and tram_model.nn_int is not None and isinstance(tram_model.nn_int, SimpleIntercept): 2112 if hasattr(tram_model.nn_int, "fc") and hasattr(tram_model.nn_int.fc, "weight"): 2113 weights = tram_model.nn_int.fc.weight.detach().cpu().tolist() 2114 weights_tensor = torch.Tensor(weights) 2115 2116 if debug: 2117 print(f"[DEBUG] Node '{node}' ({state}) theta tilde shape: {weights_tensor.shape}") 2118 2119 if is_ontram: 2120 si_weights = transform_intercepts_ordinal(weights_tensor.reshape(1, -1))[:, 1:-1].reshape(-1, 1) 2121 else: 2122 si_weights = transform_intercepts_continous(weights_tensor.reshape(1, -1)).reshape(-1, 1) 2123 2124 si_weights = si_weights.tolist() 2125 2126 if debug: 2127 print(f"[DEBUG] Node '{node}' ({state}) theta transformed: {si_weights}") 2128 else: 2129 if debug: 2130 print(f"[DEBUG] Node '{node}' ({state}): missing 'fc' or 'weight' in SimpleIntercept.") 2131 else: 2132 if debug: 2133 print(f"[DEBUG] Tram model for node '{node}' has no nn_int or it is None.") 2134 2135 all_si_intercept_dicts[state][node] = si_weights 2136 2137 # Clean up empty states 2138 all_si_intercept_dicts = {k: v for k, v in all_si_intercept_dicts.items() if v} 2139 return all_si_intercept_dicts 2140 2141 def summary(self, verbose=False): 2142 """ 2143 Print a multi-part textual summary of the TramDagModel. 2144 2145 The summary includes: 2146 1. Training metrics overview per node (best/last NLL, epochs). 2147 2. Node-specific details (thetas, linear shifts, optional architecture). 2148 3. Basic information about the attached training DataFrame, if present. 2149 2150 Parameters 2151 ---------- 2152 verbose : bool, optional 2153 If True, include extended per-node details such as the model 2154 architecture, parameter count, and availability of checkpoints 2155 and sampling results. Default is False. 2156 2157 Returns 2158 ------- 2159 None 2160 2161 Notes 2162 ----- 2163 This method prints to stdout and does not return structured data. 2164 It is intended for quick, human-readable inspection of the current 2165 training and model state. 2166 """ 2167 2168 # ---------- SETUP ---------- 2169 try: 2170 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 2171 except KeyError: 2172 EXPERIMENT_DIR = None 2173 print("[WARNING] Missing EXPERIMENT_DIR in cfg.conf_dict['PATHS'].") 2174 2175 print("\n" + "=" * 120) 2176 print(f"{'TRAM DAG MODEL SUMMARY':^120}") 2177 print("=" * 120) 2178 2179 # ---------- METRICS OVERVIEW ---------- 2180 summary_data = [] 2181 for node in self.models.keys(): 2182 node_dir = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node) 2183 train_path = os.path.join(node_dir, "train_loss_hist.json") 2184 val_path = os.path.join(node_dir, "val_loss_hist.json") 2185 2186 if os.path.exists(train_path) and os.path.exists(val_path): 2187 best_train_nll, best_val_nll = self.get_train_val_nll(node, "best") 2188 last_train_nll, last_val_nll = self.get_train_val_nll(node, "last") 2189 n_epochs_total = len(json.load(open(train_path))) 2190 else: 2191 best_train_nll = best_val_nll = last_train_nll = last_val_nll = None 2192 n_epochs_total = 0 2193 2194 summary_data.append({ 2195 "Node": node, 2196 "Best Train NLL": best_train_nll, 2197 "Best Val NLL": best_val_nll, 2198 "Last Train NLL": last_train_nll, 2199 "Last Val NLL": last_val_nll, 2200 "Epochs": n_epochs_total, 2201 }) 2202 2203 df_summary = pd.DataFrame(summary_data) 2204 df_summary = df_summary.round(4) 2205 2206 print("\n[1] TRAINING METRICS OVERVIEW") 2207 print("-" * 120) 2208 if not df_summary.empty: 2209 print( 2210 df_summary.to_string( 2211 index=False, 2212 justify="center", 2213 col_space=14, 2214 float_format=lambda x: f"{x:7.4f}", 2215 ) 2216 ) 2217 else: 2218 print("No training history found for any node.") 2219 print("-" * 120) 2220 2221 # ---------- NODE DETAILS ---------- 2222 print("\n[2] NODE-SPECIFIC DETAILS") 2223 print("-" * 120) 2224 for node in self.models.keys(): 2225 print(f"\n{f'NODE: {node}':^120}") 2226 print("-" * 120) 2227 2228 # THETAS & SHIFTS 2229 for state in ["init", "last", "best"]: 2230 print(f"\n [{state.upper()} STATE]") 2231 2232 # ---- Thetas ---- 2233 try: 2234 thetas = getattr(self, "get_thetas", lambda n, s=None: None)(node, state) 2235 if thetas is not None: 2236 if isinstance(thetas, (list, np.ndarray, pd.Series)): 2237 thetas_flat = np.array(thetas).flatten() 2238 compact = np.round(thetas_flat, 4) 2239 arr_str = np.array2string( 2240 compact, 2241 max_line_width=110, 2242 threshold=np.inf, 2243 separator=", " 2244 ) 2245 lines = arr_str.split("\n") 2246 if len(lines) > 2: 2247 arr_str = "\n".join(lines[:2]) + " ..." 2248 print(f" Θ ({len(thetas_flat)}): {arr_str}") 2249 elif isinstance(thetas, dict): 2250 for k, v in thetas.items(): 2251 print(f" Θ[{k}]: {v}") 2252 else: 2253 print(f" Θ: {thetas}") 2254 else: 2255 print(" Θ: not available") 2256 except Exception as e: 2257 print(f" [Error loading thetas] {e}") 2258 2259 # ---- Linear Shifts ---- 2260 try: 2261 linear_shifts = getattr(self, "get_linear_shifts", lambda n, s=None: None)(node, state) 2262 if linear_shifts is not None: 2263 if isinstance(linear_shifts, dict): 2264 for k, v in linear_shifts.items(): 2265 print(f" {k}: {np.round(v, 4)}") 2266 elif isinstance(linear_shifts, (list, np.ndarray, pd.Series)): 2267 arr = np.round(linear_shifts, 4) 2268 print(f" Linear shifts ({len(arr)}): {arr}") 2269 else: 2270 print(f" Linear shifts: {linear_shifts}") 2271 else: 2272 print(" Linear shifts: not available") 2273 except Exception as e: 2274 print(f" [Error loading linear shifts] {e}") 2275 2276 # ---- Verbose info directly below node ---- 2277 if verbose: 2278 print("\n [DETAILS]") 2279 node_dir = os.path.join(EXPERIMENT_DIR, node) if EXPERIMENT_DIR else None 2280 model = self.models[node] 2281 2282 print(f" Model Architecture:") 2283 arch_str = str(model).split("\n") 2284 for line in arch_str: 2285 print(f" {line}") 2286 print(f" Parameter count: {sum(p.numel() for p in model.parameters()):,}") 2287 2288 if node_dir and os.path.exists(node_dir): 2289 ckpt_exists = any(f.endswith(('.pt', '.pth')) for f in os.listdir(node_dir)) 2290 print(f" Checkpoints found: {ckpt_exists}") 2291 2292 sampling_dir = os.path.join(node_dir, "sampling") 2293 sampling_exists = os.path.isdir(sampling_dir) and len(os.listdir(sampling_dir)) > 0 2294 print(f" Sampling results found: {sampling_exists}") 2295 2296 for label, filename in [("Train", "train_loss_hist.json"), ("Validation", "val_loss_hist.json")]: 2297 path = os.path.join(node_dir, filename) 2298 if os.path.exists(path): 2299 try: 2300 with open(path, "r") as f: 2301 hist = json.load(f) 2302 print(f" {label} history: {len(hist)} epochs") 2303 except Exception as e: 2304 print(f" {label} history: failed to load ({e})") 2305 else: 2306 print(f" {label} history: not found") 2307 else: 2308 print(" [INFO] No experiment directory defined or missing for this node.") 2309 print("-" * 120) 2310 2311 # ---------- TRAINING DATAFRAME ---------- 2312 print("\n[3] TRAINING DATAFRAME") 2313 print("-" * 120) 2314 try: 2315 self.train_df.info() 2316 except AttributeError: 2317 print("No training DataFrame attached to this TramDagModel.") 2318 print("=" * 120 + "\n")
64class TramDagModel: 65 """ 66 Probabilistic DAG model built from node-wise TRAMs (transformation models). 67 68 This class manages: 69 - Configuration and per-node model construction. 70 - Data scaling (min–max). 71 - Training (sequential or per-node parallel on CPU). 72 - Diagnostics (loss history, intercepts, linear shifts, latents). 73 - Sampling from the joint DAG and loading stored samples. 74 - High-level summaries and plotting utilities. 75 """ 76 77 # ---- defaults used at construction time ---- 78 DEFAULTS_CONFIG = { 79 "set_initial_weights": False, 80 "debug":False, 81 "verbose": False, 82 "device":'auto', 83 "initial_data":None, 84 "overwrite_initial_weights": True, 85 } 86 87 # ---- defaults used at fit() time ---- 88 DEFAULTS_FIT = { 89 "epochs": 100, 90 "train_list": None, 91 "callbacks": None, 92 "learning_rate": 0.01, 93 "device": "auto", 94 "optimizers": None, 95 "schedulers": None, 96 "use_scheduler": False, 97 "save_linear_shifts": True, 98 "save_simple_intercepts": True, 99 "debug":False, 100 "verbose": True, 101 "train_mode": "sequential", # or "parallel" 102 "return_history": False, 103 "overwrite_inital_weights": True, 104 "num_workers" : 4, 105 "persistent_workers" : True, 106 "prefetch_factor" : 4, 107 "batch_size":1000, 108 109 } 110 111 def __init__(self): 112 """ 113 Initialize an empty TramDagModel shell. 114 115 Notes 116 ----- 117 This constructor does not build any node models and does not attach a 118 configuration. Use `TramDagModel.from_config` or `TramDagModel.from_directory` 119 to obtain a fully configured and ready-to-use instance. 120 """ 121 122 self.debug = False 123 self.verbose = False 124 self.device = 'auto' 125 pass 126 127 @staticmethod 128 def get_device(settings): 129 """ 130 Resolve the target device string from a settings dictionary. 131 132 Parameters 133 ---------- 134 settings : dict 135 Dictionary containing at least a key ``"device"`` with one of 136 {"auto", "cpu", "cuda"}. If missing, "auto" is assumed. 137 138 Returns 139 ------- 140 str 141 Device string, either "cpu" or "cuda". 142 143 Notes 144 ----- 145 If ``device == "auto"``, CUDA is selected if available, otherwise CPU. 146 """ 147 device_arg = settings.get("device", "auto") 148 if device_arg == "auto": 149 device_str = "cuda" if torch.cuda.is_available() else "cpu" 150 else: 151 device_str = device_arg 152 return device_str 153 154 def _validate_kwargs(self, kwargs: dict, defaults_attr: str = "DEFAULTS_FIT", context: str = None): 155 """ 156 Validate a kwargs dictionary against a class-level defaults dictionary. 157 158 Parameters 159 ---------- 160 kwargs : dict 161 Keyword arguments to validate. 162 defaults_attr : str, optional 163 Name of the attribute on this class that contains the allowed keys, 164 e.g. ``"DEFAULTS_CONFIG"`` or ``"DEFAULTS_FIT"``. Default is "DEFAULTS_FIT". 165 context : str or None, optional 166 Optional label (e.g. caller name) to prepend in error messages. 167 168 Raises 169 ------ 170 AttributeError 171 If the attribute named by ``defaults_attr`` does not exist. 172 ValueError 173 If any key in ``kwargs`` is not present in the corresponding defaults dict. 174 """ 175 defaults = getattr(self, defaults_attr, None) 176 if defaults is None: 177 raise AttributeError(f"{self.__class__.__name__} has no attribute '{defaults_attr}'") 178 179 unknown = set(kwargs) - set(defaults) 180 if unknown: 181 prefix = f"[{context}] " if context else "" 182 raise ValueError(f"{prefix}Unknown parameter(s): {', '.join(sorted(unknown))}") 183 184 ## CREATE A TRAMDADMODEL 185 @classmethod 186 def from_config(cls, cfg, **kwargs): 187 """ 188 Construct a TramDagModel from a TramDagConfig object. 189 190 This builds one TRAM model per node in the DAG and optionally writes 191 the initial model parameters to disk. 192 193 Parameters 194 ---------- 195 cfg : TramDagConfig 196 Configuration wrapper holding the underlying configuration dictionary, 197 including at least: 198 - ``conf_dict["nodes"]``: mapping of node names to node configs. 199 - ``conf_dict["PATHS"]["EXPERIMENT_DIR"]``: experiment directory. 200 **kwargs 201 Node-level construction options. Each key must be present in 202 ``DEFAULTS_CONFIG``. Values can be: 203 - scalar: applied to all nodes. 204 - dict: mapping ``{node_name: value}`` for per-node overrides. 205 206 Common keys include: 207 device : {"auto", "cpu", "cuda"}, default "auto" 208 Device selection (CUDA if available when "auto"). 209 debug : bool, default False 210 If True, print debug messages. 211 verbose : bool, default False 212 If True, print informational messages. 213 set_initial_weights : bool 214 Passed to underlying TRAM model constructors. 215 overwrite_initial_weights : bool, default True 216 If True, overwrite any existing ``initial_model.pt`` files per node. 217 initial_data : Any 218 Optional object passed down to node constructors. 219 220 Returns 221 ------- 222 TramDagModel 223 Fully initialized instance with: 224 - ``cfg`` 225 - ``nodes_dict`` 226 - ``models`` (per-node TRAMs) 227 - ``settings`` (resolved per-node config) 228 229 Raises 230 ------ 231 ValueError 232 If any dict-valued kwarg does not provide values for exactly the set 233 of nodes in ``cfg.conf_dict["nodes"]``. 234 """ 235 236 self = cls() 237 self.cfg = cfg 238 self.cfg.update() # ensure latest version from disk 239 self.cfg._verify_completeness() 240 241 242 try: 243 self.cfg.save() # persist back to disk 244 if getattr(self, "debug", False): 245 print("[DEBUG] Configuration updated and saved.") 246 except Exception as e: 247 print(f"[WARNING] Could not save configuration after update: {e}") 248 249 self.nodes_dict = self.cfg.conf_dict["nodes"] 250 251 self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_CONFIG', context="from_config") 252 253 # update defaults with kwargs 254 settings = dict(cls.DEFAULTS_CONFIG) 255 settings.update(kwargs) 256 257 # resolve device 258 device_arg = settings.get("device", "auto") 259 if device_arg == "auto": 260 device_str = "cuda" if torch.cuda.is_available() else "cpu" 261 else: 262 device_str = device_arg 263 self.device = torch.device(device_str) 264 265 # set flags on the instance so they are accessible later 266 self.debug = settings.get("debug", False) 267 self.verbose = settings.get("verbose", False) 268 269 if self.debug: 270 print(f"[DEBUG] TramDagModel using device: {self.device}") 271 272 # initialize settings storage 273 self.settings = {k: {} for k in settings.keys()} 274 275 # validate dict-typed args 276 for k, v in settings.items(): 277 if isinstance(v, dict): 278 expected = set(self.nodes_dict.keys()) 279 given = set(v.keys()) 280 if expected != given: 281 raise ValueError( 282 f"[ERROR] the provided argument '{k}' keys are not same as in cfg.conf_dict['nodes'].keys().\n" 283 f"Expected: {expected}, but got: {given}\n" 284 f"Please provide values for all variables.") 285 286 # build one model per node 287 self.models = {} 288 for node in self.nodes_dict.keys(): 289 per_node_kwargs = {} 290 for k, v in settings.items(): 291 resolved = v[node] if isinstance(v, dict) else v 292 per_node_kwargs[k] = resolved 293 self.settings[k][node] = resolved 294 if self.debug: 295 print(f"\n[INFO] Building model for node '{node}' with settings: {per_node_kwargs}") 296 self.models[node] = get_fully_specified_tram_model( 297 node=node, 298 configuration_dict=self.cfg.conf_dict, 299 **per_node_kwargs) 300 301 try: 302 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 303 NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}") 304 os.makedirs(NODE_DIR, exist_ok=True) 305 306 model_path = os.path.join(NODE_DIR, "initial_model.pt") 307 overwrite = settings.get("overwrite_initial_weights", True) 308 309 if overwrite or not os.path.exists(model_path): 310 torch.save(self.models[node].state_dict(), model_path) 311 if self.debug: 312 print(f"[DEBUG] Saved initial model state for node '{node}' to {model_path} (overwrite={overwrite})") 313 else: 314 if self.debug: 315 print(f"[DEBUG] Skipped saving initial model for node '{node}' (already exists at {model_path})") 316 except Exception as e: 317 print(f"[ERROR] Could not save initial model state for node '{node}': {e}") 318 319 TEMP_DIR = "temp" 320 if os.path.isdir(TEMP_DIR) and not os.listdir(TEMP_DIR): 321 os.rmdir(TEMP_DIR) 322 323 return self 324 325 @classmethod 326 def from_directory(cls, EXPERIMENT_DIR: str, device: str = "auto", debug: bool = False, verbose: bool = False): 327 """ 328 Reconstruct a TramDagModel from an experiment directory on disk. 329 330 This method: 331 1. Loads the configuration JSON. 332 2. Wraps it in a TramDagConfig. 333 3. Builds all node models via `from_config`. 334 4. Loads the min–max scaling dictionary. 335 336 Parameters 337 ---------- 338 EXPERIMENT_DIR : str 339 Path to an experiment directory containing: 340 - ``configuration.json`` 341 - ``min_max_scaling.json``. 342 device : {"auto", "cpu", "cuda"}, optional 343 Device selection. Default is "auto". 344 debug : bool, optional 345 If True, enable debug messages. Default is False. 346 verbose : bool, optional 347 If True, enable informational messages. Default is False. 348 349 Returns 350 ------- 351 TramDagModel 352 A TramDagModel instance with models, config, and scaling loaded. 353 354 Raises 355 ------ 356 FileNotFoundError 357 If configuration or min–max files cannot be found. 358 RuntimeError 359 If the min–max file cannot be read or parsed. 360 """ 361 362 # --- load config file --- 363 config_path = os.path.join(EXPERIMENT_DIR, "configuration.json") 364 if not os.path.exists(config_path): 365 raise FileNotFoundError(f"[ERROR] Config file not found at {config_path}") 366 367 with open(config_path, "r") as f: 368 cfg_dict = json.load(f) 369 370 # Create TramConfig wrapper 371 cfg = TramDagConfig(cfg_dict, CONF_DICT_PATH=config_path) 372 373 # --- build model from config --- 374 self = cls.from_config(cfg, device=device, debug=debug, verbose=verbose, overwrite_initial_weights=False) 375 376 # --- load minmax scaling --- 377 minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json") 378 if not os.path.exists(minmax_path): 379 raise FileNotFoundError(f"[ERROR] MinMax file not found at {minmax_path}") 380 381 with open(minmax_path, "r") as f: 382 self.minmax_dict = json.load(f) 383 384 if self.verbose or self.debug: 385 print(f"[INFO] Loaded TramDagModel from {EXPERIMENT_DIR}") 386 print(f"[INFO] Config loaded from {config_path}") 387 print(f"[INFO] MinMax scaling loaded from {minmax_path}") 388 389 return self 390 391 def _ensure_dataset(self, data, is_val=False,**kwargs): 392 """ 393 Ensure that the input data is represented as a TramDagDataset. 394 395 Parameters 396 ---------- 397 data : pandas.DataFrame, TramDagDataset, or None 398 Input data to be converted or passed through. 399 is_val : bool, optional 400 If True, the resulting dataset is treated as validation data 401 (e.g. no shuffling). Default is False. 402 **kwargs 403 Additional keyword arguments passed through to 404 ``TramDagDataset.from_dataframe``. 405 406 Returns 407 ------- 408 TramDagDataset or None 409 A TramDagDataset if ``data`` is a DataFrame or TramDagDataset, 410 otherwise None if ``data`` is None. 411 412 Raises 413 ------ 414 TypeError 415 If ``data`` is not a DataFrame, TramDagDataset, or None. 416 """ 417 418 if isinstance(data, pd.DataFrame): 419 return TramDagDataset.from_dataframe(data, self.cfg, shuffle=not is_val,**kwargs) 420 elif isinstance(data, TramDagDataset): 421 return data 422 elif data is None: 423 return None 424 else: 425 raise TypeError( 426 f"[ERROR] data must be pd.DataFrame, TramDagDataset, or None, got {type(data)}" 427 ) 428 429 def load_or_compute_minmax(self, td_train_data=None,use_existing=False, write=True): 430 """ 431 Load an existing Min–Max scaling dictionary from disk or compute a new one 432 from the provided training dataset. 433 434 Parameters 435 ---------- 436 use_existing : bool, optional (default=False) 437 If True, attempts to load an existing `min_max_scaling.json` file 438 from the experiment directory. Raises an error if the file is missing 439 or unreadable. 440 441 write : bool, optional (default=True) 442 If True, writes the computed Min–Max scaling dictionary to 443 `<EXPERIMENT_DIR>/min_max_scaling.json`. 444 445 td_train_data : object, optional 446 Training dataset used to compute scaling statistics. If not provided, 447 the method will ensure or construct it via `_ensure_dataset(data=..., is_val=False)`. 448 449 Behavior 450 -------- 451 - If `use_existing=True`, loads the JSON file containing previously saved 452 min–max values and stores it in `self.minmax_dict`. 453 - If `use_existing=False`, computes a new scaling dictionary using 454 `td_train_data.compute_scaling()` and stores the result in 455 `self.minmax_dict`. 456 - Optionally writes the computed dictionary to disk. 457 458 Side Effects 459 ------------- 460 - Populates `self.minmax_dict` with scaling values. 461 - Writes or loads the file `min_max_scaling.json` under 462 `<EXPERIMENT_DIR>`. 463 - Prints diagnostic output if `self.debug` or `self.verbose` is True. 464 465 Raises 466 ------ 467 FileNotFoundError 468 If `use_existing=True` but the min–max file does not exist. 469 470 RuntimeError 471 If an existing min–max file cannot be read or parsed. 472 473 Notes 474 ----- 475 The computed min–max dictionary is expected to contain scaling statistics 476 per feature, typically in the form: 477 { 478 "node": {"min": float, "max": float}, 479 ... 480 } 481 """ 482 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 483 minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json") 484 485 # laod exisitng if possible 486 if use_existing: 487 if not os.path.exists(minmax_path): 488 raise FileNotFoundError(f"MinMax file not found: {minmax_path}") 489 try: 490 with open(minmax_path, 'r') as f: 491 self.minmax_dict = json.load(f) 492 if self.debug or self.verbose: 493 print(f"[INFO] Loaded existing minmax dict from {minmax_path}") 494 return 495 except Exception as e: 496 raise RuntimeError(f"Could not load existing minmax dict: {e}") 497 498 # 499 if self.debug or self.verbose: 500 print("[INFO] Computing new minmax dict from training data...") 501 502 td_train_data=self._ensure_dataset( data=td_train_data, is_val=False) 503 504 self.minmax_dict = td_train_data.compute_scaling() 505 506 if write: 507 os.makedirs(EXPERIMENT_DIR, exist_ok=True) 508 with open(minmax_path, 'w') as f: 509 json.dump(self.minmax_dict, f, indent=4) 510 if self.debug or self.verbose: 511 print(f"[INFO] Saved new minmax dict to {minmax_path}") 512 513 ## FIT METHODS 514 @staticmethod 515 def _fit_single_node(node, self_ref, settings, td_train_data, td_val_data, device_str): 516 """ 517 Train a single node model (helper for per-node training). 518 519 This method is designed to be called either from the main process 520 (sequential training) or from a joblib worker (parallel CPU training). 521 522 Parameters 523 ---------- 524 node : str 525 Name of the target node to train. 526 self_ref : TramDagModel 527 Reference to the TramDagModel instance containing models and config. 528 settings : dict 529 Training settings dictionary, typically derived from ``DEFAULTS_FIT`` 530 plus any user overrides. 531 td_train_data : TramDagDataset 532 Training dataset with node-specific DataLoaders in ``.loaders``. 533 td_val_data : TramDagDataset or None 534 Validation dataset or None. 535 device_str : str 536 Device string, e.g. "cpu" or "cuda". 537 538 Returns 539 ------- 540 tuple 541 A tuple ``(node, history)`` where: 542 node : str 543 Node name. 544 history : dict or Any 545 Training history as returned by ``train_val_loop``. 546 """ 547 torch.set_num_threads(1) # prevent thread oversubscription 548 549 model = self_ref.models[node] 550 551 # Resolve per-node settings 552 def _resolve(key): 553 val = settings[key] 554 return val[node] if isinstance(val, dict) else val 555 556 node_epochs = _resolve("epochs") 557 node_lr = _resolve("learning_rate") 558 node_debug = _resolve("debug") 559 node_save_linear_shifts = _resolve("save_linear_shifts") 560 save_simple_intercepts = _resolve("save_simple_intercepts") 561 node_verbose = _resolve("verbose") 562 563 # Optimizer & scheduler 564 if settings["optimizers"] and node in settings["optimizers"]: 565 optimizer = settings["optimizers"][node] 566 else: 567 optimizer = Adam(model.parameters(), lr=node_lr) 568 569 scheduler = settings["schedulers"].get(node, None) if settings["schedulers"] else None 570 571 # Data loaders 572 train_loader = td_train_data.loaders[node] 573 val_loader = td_val_data.loaders[node] if td_val_data else None 574 575 # Min-max scaling tensors 576 min_vals = torch.tensor(self_ref.minmax_dict[node][0], dtype=torch.float32) 577 max_vals = torch.tensor(self_ref.minmax_dict[node][1], dtype=torch.float32) 578 min_max = torch.stack([min_vals, max_vals], dim=0) 579 580 # Node directory 581 try: 582 EXPERIMENT_DIR = self_ref.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 583 NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}") 584 except Exception: 585 NODE_DIR = os.path.join("models", node) 586 print("[WARNING] No log directory specified in config, saving to default location.") 587 os.makedirs(NODE_DIR, exist_ok=True) 588 589 if node_verbose: 590 print(f"\n[INFO] Training node '{node}' for {node_epochs} epochs on {device_str} (pid={os.getpid()})") 591 592 # --- train --- 593 history = train_val_loop( 594 node=node, 595 target_nodes=self_ref.nodes_dict, 596 NODE_DIR=NODE_DIR, 597 tram_model=model, 598 train_loader=train_loader, 599 val_loader=val_loader, 600 epochs=node_epochs, 601 optimizer=optimizer, 602 use_scheduler=(scheduler is not None), 603 scheduler=scheduler, 604 save_linear_shifts=node_save_linear_shifts, 605 save_simple_intercepts=save_simple_intercepts, 606 verbose=node_verbose, 607 device=torch.device(device_str), 608 debug=node_debug, 609 min_max=min_max) 610 return node, history 611 612 def fit(self, train_data, val_data=None, **kwargs): 613 """ 614 Train TRAM models for all nodes in the DAG. 615 616 Coordinates dataset preparation, min–max scaling, and per-node training, 617 optionally in parallel on CPU. 618 619 Parameters 620 ---------- 621 train_data : pandas.DataFrame or TramDagDataset 622 Training data. If a DataFrame is given, it is converted into a 623 TramDagDataset using `_ensure_dataset`. 624 val_data : pandas.DataFrame or TramDagDataset or None, optional 625 Validation data. If a DataFrame is given, it is converted into a 626 TramDagDataset. If None, no validation loss is computed. 627 **kwargs 628 Overrides for ``DEFAULTS_FIT``. All keys must exist in 629 ``DEFAULTS_FIT``. Common options: 630 631 epochs : int, default 100 632 Number of training epochs per node. 633 learning_rate : float, default 0.01 634 Learning rate for the default Adam optimizer. 635 train_list : list of str or None, optional 636 List of node names to train. If None, all nodes are trained. 637 train_mode : {"sequential", "parallel"}, default "sequential" 638 Training mode. "parallel" uses joblib-based CPU multiprocessing. 639 GPU forces sequential mode. 640 device : {"auto", "cpu", "cuda"}, default "auto" 641 Device selection. 642 optimizers : dict or None 643 Optional mapping ``{node_name: optimizer}``. If provided for a 644 node, that optimizer is used instead of creating a new Adam. 645 schedulers : dict or None 646 Optional mapping ``{node_name: scheduler}``. 647 use_scheduler : bool 648 If True, enable scheduler usage in the training loop. 649 num_workers : int 650 DataLoader workers in sequential mode (ignored in parallel). 651 persistent_workers : bool 652 DataLoader persistence in sequential mode (ignored in parallel). 653 prefetch_factor : int 654 DataLoader prefetch factor (ignored in parallel). 655 batch_size : int 656 Batch size for all node DataLoaders. 657 debug : bool 658 Enable debug output. 659 verbose : bool 660 Enable informational logging. 661 return_history : bool 662 If True, return a history dict. 663 664 Returns 665 ------- 666 dict or None 667 If ``return_history=True``, a dictionary mapping each node name 668 to its training history. Otherwise, returns None. 669 670 Raises 671 ------ 672 ValueError 673 If ``train_mode`` is not "sequential" or "parallel". 674 """ 675 self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_FIT', context="fit") 676 677 # --- merge defaults --- 678 settings = dict(self.DEFAULTS_FIT) 679 settings.update(kwargs) 680 681 682 self.debug = settings.get("debug", False) 683 self.verbose = settings.get("verbose", False) 684 685 # --- resolve device --- 686 device_str=self.get_device(settings) 687 self.device = torch.device(device_str) 688 689 # --- training mode --- 690 train_mode = settings.get("train_mode", "sequential").lower() 691 if train_mode not in ("sequential", "parallel"): 692 raise ValueError("train_mode must be 'sequential' or 'parallel'") 693 694 # --- DataLoader safety logic --- 695 if train_mode == "parallel": 696 # if user passed loader paralleling params, warn and override 697 for flag in ("num_workers", "persistent_workers", "prefetch_factor"): 698 if flag in kwargs: 699 print(f"[WARNING] '{flag}' is ignored in parallel mode " 700 f"(disabled to prevent nested multiprocessing).") 701 # disable unsafe loader multiprocessing options 702 settings["num_workers"] = 0 703 settings["persistent_workers"] = False 704 settings["prefetch_factor"] = None 705 else: 706 # sequential mode → respect user DataLoader settings 707 if self.debug: 708 print("[DEBUG] Sequential mode: using DataLoader kwargs as provided.") 709 710 # --- which nodes to train --- 711 train_list = settings.get("train_list") or list(self.models.keys()) 712 713 714 # --- dataset prep (receives adjusted settings) --- 715 td_train_data = self._ensure_dataset(train_data, is_val=False, **settings) 716 td_val_data = self._ensure_dataset(val_data, is_val=True, **settings) 717 718 # --- normalization --- 719 self.load_or_compute_minmax(use_existing=False, write=True, td_train_data=td_train_data) 720 721 # --- print header --- 722 if self.verbose or self.debug: 723 print(f"[INFO] Training {len(train_list)} nodes ({train_mode}) on {device_str}") 724 725 # ====================================================================== 726 # Sequential mode safe for GPU or debugging) 727 # ====================================================================== 728 if train_mode == "sequential" or "cuda" in device_str: 729 if "cuda" in device_str and train_mode == "parallel": 730 print("[WARNING] GPU device detected — forcing sequential mode.") 731 results = {} 732 for node in train_list: 733 node, history = self._fit_single_node( 734 node, self, settings, td_train_data, td_val_data, device_str 735 ) 736 results[node] = history 737 738 739 # ====================================================================== 740 # parallel mode (CPU only) 741 # ====================================================================== 742 if train_mode == "parallel": 743 744 n_jobs = min(len(train_list), os.cpu_count() // 2 or 1) 745 if self.verbose or self.debug: 746 print(f"[INFO] Using {n_jobs} CPU workers for parallel node training") 747 parallel_outputs = Parallel( 748 n_jobs=n_jobs, 749 backend="loky",#loky, multiprocessing 750 verbose=10, 751 prefer="processes" 752 )(delayed(self._fit_single_node)(node, self, settings, td_train_data, td_val_data, device_str) for node in train_list ) 753 754 results = {node: hist for node, hist in parallel_outputs} 755 756 if settings.get("return_history", False): 757 return results 758 759 ## FIT-DIAGNOSTICS 760 def loss_history(self): 761 """ 762 Load training and validation loss history for all nodes. 763 764 Looks for per-node JSON files: 765 766 - ``EXPERIMENT_DIR/{node}/train_loss_hist.json`` 767 - ``EXPERIMENT_DIR/{node}/val_loss_hist.json`` 768 769 Returns 770 ------- 771 dict 772 A dictionary mapping node names to: 773 774 .. code-block:: python 775 776 { 777 "train": list or None, 778 "validation": list or None 779 } 780 781 where each list contains NLL values per epoch, or None if not found. 782 783 Raises 784 ------ 785 ValueError 786 If the experiment directory cannot be resolved from the configuration. 787 """ 788 try: 789 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 790 except KeyError: 791 raise ValueError( 792 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 793 "History retrieval requires experiment logs." 794 ) 795 796 all_histories = {} 797 for node in self.nodes_dict.keys(): 798 node_dir = os.path.join(EXPERIMENT_DIR, node) 799 train_path = os.path.join(node_dir, "train_loss_hist.json") 800 val_path = os.path.join(node_dir, "val_loss_hist.json") 801 802 node_hist = {} 803 804 # --- load train history --- 805 if os.path.exists(train_path): 806 try: 807 with open(train_path, "r") as f: 808 node_hist["train"] = json.load(f) 809 except Exception as e: 810 print(f"[WARNING] Could not load {train_path}: {e}") 811 node_hist["train"] = None 812 else: 813 node_hist["train"] = None 814 815 # --- load val history --- 816 if os.path.exists(val_path): 817 try: 818 with open(val_path, "r") as f: 819 node_hist["validation"] = json.load(f) 820 except Exception as e: 821 print(f"[WARNING] Could not load {val_path}: {e}") 822 node_hist["validation"] = None 823 else: 824 node_hist["validation"] = None 825 826 all_histories[node] = node_hist 827 828 if self.verbose or self.debug: 829 print(f"[INFO] Loaded training/validation histories for {len(all_histories)} nodes.") 830 831 return all_histories 832 833 def linear_shift_history(self): 834 """ 835 Load linear shift term histories for all nodes. 836 837 Each node history is expected in a JSON file named 838 ``linear_shifts_all_epochs.json`` under the node directory. 839 840 Returns 841 ------- 842 dict 843 A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame 844 contains linear shift weights across epochs. 845 846 Raises 847 ------ 848 ValueError 849 If the experiment directory cannot be resolved from the configuration. 850 851 Notes 852 ----- 853 If a history file is missing for a node, a warning is printed and the 854 node is omitted from the returned dictionary. 855 """ 856 histories = {} 857 try: 858 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 859 except KeyError: 860 raise ValueError( 861 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 862 "Cannot load histories without experiment directory." 863 ) 864 865 for node in self.nodes_dict.keys(): 866 node_dir = os.path.join(EXPERIMENT_DIR, node) 867 history_path = os.path.join(node_dir, "linear_shifts_all_epochs.json") 868 if os.path.exists(history_path): 869 histories[node] = pd.read_json(history_path) 870 else: 871 print(f"[WARNING] No linear shift history found for node '{node}' at {history_path}") 872 return histories 873 874 def simple_intercept_history(self): 875 """ 876 Load simple intercept histories for all nodes. 877 878 Each node history is expected in a JSON file named 879 ``simple_intercepts_all_epochs.json`` under the node directory. 880 881 Returns 882 ------- 883 dict 884 A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame 885 contains intercept weights across epochs. 886 887 Raises 888 ------ 889 ValueError 890 If the experiment directory cannot be resolved from the configuration. 891 892 Notes 893 ----- 894 If a history file is missing for a node, a warning is printed and the 895 node is omitted from the returned dictionary. 896 """ 897 histories = {} 898 try: 899 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 900 except KeyError: 901 raise ValueError( 902 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 903 "Cannot load histories without experiment directory." 904 ) 905 906 for node in self.nodes_dict.keys(): 907 node_dir = os.path.join(EXPERIMENT_DIR, node) 908 history_path = os.path.join(node_dir, "simple_intercepts_all_epochs.json") 909 if os.path.exists(history_path): 910 histories[node] = pd.read_json(history_path) 911 else: 912 print(f"[WARNING] No simple intercept history found for node '{node}' at {history_path}") 913 return histories 914 915 def get_latent(self, df, verbose=False): 916 """ 917 Compute latent representations for all nodes in the DAG. 918 919 Parameters 920 ---------- 921 df : pandas.DataFrame 922 Input data frame with columns corresponding to nodes in the DAG. 923 verbose : bool, optional 924 If True, print informational messages during latent computation. 925 Default is False. 926 927 Returns 928 ------- 929 pandas.DataFrame 930 DataFrame containing the original columns plus latent variables 931 for each node (e.g. columns named ``f"{node}_U"``). 932 933 Raises 934 ------ 935 ValueError 936 If the experiment directory is missing from the configuration or 937 if ``self.minmax_dict`` has not been set. 938 """ 939 try: 940 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 941 except KeyError: 942 raise ValueError( 943 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 944 "Latent extraction requires trained model checkpoints." 945 ) 946 947 # ensure minmax_dict is available 948 if not hasattr(self, "minmax_dict"): 949 raise ValueError( 950 "[ERROR] minmax_dict not found in the TramDagModel instance. " 951 "Either call .load_or_compute_minmax(td_train_data=train_df) or .fit() first." 952 ) 953 954 all_latents_df = create_latent_df_for_full_dag( 955 configuration_dict=self.cfg.conf_dict, 956 EXPERIMENT_DIR=EXPERIMENT_DIR, 957 df=df, 958 verbose=verbose, 959 min_max_dict=self.minmax_dict, 960 ) 961 962 return all_latents_df 963 964 965 ## PLOTTING FIT-DIAGNOSTICS 966 967 def plot_loss_history(self, variable: str = None): 968 """ 969 Plot training and validation loss evolution per node. 970 971 Parameters 972 ---------- 973 variable : str or None, optional 974 If provided, plot loss history for this node only. If None, plot 975 histories for all nodes that have both train and validation logs. 976 977 Returns 978 ------- 979 None 980 981 Notes 982 ----- 983 Two subplots are produced: 984 - Full epoch history. 985 - Last 10% of epochs (or only the last epoch if fewer than 5 epochs). 986 """ 987 988 histories = self.loss_history() 989 if not histories: 990 print("[WARNING] No loss histories found.") 991 return 992 993 # Select which nodes to plot 994 if variable is not None: 995 if variable not in histories: 996 raise ValueError(f"[ERROR] Node '{variable}' not found in histories.") 997 nodes_to_plot = [variable] 998 else: 999 nodes_to_plot = list(histories.keys()) 1000 1001 # Filter out nodes with no valid history 1002 nodes_to_plot = [ 1003 n for n in nodes_to_plot 1004 if histories[n].get("train") is not None and len(histories[n]["train"]) > 0 1005 and histories[n].get("validation") is not None and len(histories[n]["validation"]) > 0 1006 ] 1007 1008 if not nodes_to_plot: 1009 print("[WARNING] No valid histories found to plot.") 1010 return 1011 1012 plt.figure(figsize=(14, 12)) 1013 1014 # --- Full history (top plot) --- 1015 plt.subplot(2, 1, 1) 1016 for node in nodes_to_plot: 1017 node_hist = histories[node] 1018 train_hist, val_hist = node_hist["train"], node_hist["validation"] 1019 1020 epochs = range(1, len(train_hist) + 1) 1021 plt.plot(epochs, train_hist, label=f"{node} - train", linestyle="--") 1022 plt.plot(epochs, val_hist, label=f"{node} - val") 1023 1024 plt.title("Training and Validation NLL - Full History") 1025 plt.xlabel("Epoch") 1026 plt.ylabel("NLL") 1027 plt.legend() 1028 plt.grid(True) 1029 1030 # --- Last 10% of epochs (bottom plot) --- 1031 plt.subplot(2, 1, 2) 1032 for node in nodes_to_plot: 1033 node_hist = histories[node] 1034 train_hist, val_hist = node_hist["train"], node_hist["validation"] 1035 1036 total_epochs = len(train_hist) 1037 start_idx = total_epochs - 1 if total_epochs < 5 else int(total_epochs * 0.9) 1038 1039 epochs = range(start_idx + 1, total_epochs + 1) 1040 plt.plot(epochs, train_hist[start_idx:], label=f"{node} - train", linestyle="--") 1041 plt.plot(epochs, val_hist[start_idx:], label=f"{node} - val") 1042 1043 plt.title("Training and Validation NLL - Last 10% of Epochs (or Last Epoch if <5)") 1044 plt.xlabel("Epoch") 1045 plt.ylabel("NLL") 1046 plt.legend() 1047 plt.grid(True) 1048 1049 plt.tight_layout() 1050 plt.show() 1051 1052 def plot_linear_shift_history(self, data_dict=None, node=None, ref_lines=None): 1053 """ 1054 Plot the evolution of linear shift terms over epochs. 1055 1056 Parameters 1057 ---------- 1058 data_dict : dict or None, optional 1059 Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing shift 1060 weights across epochs. If None, `linear_shift_history()` is called. 1061 node : str or None, optional 1062 If provided, plot only this node. Otherwise, plot all nodes 1063 present in ``data_dict``. 1064 ref_lines : dict or None, optional 1065 Optional mapping ``{node_name: list of float}``. For each specified 1066 node, horizontal reference lines are drawn at the given values. 1067 1068 Returns 1069 ------- 1070 None 1071 1072 Notes 1073 ----- 1074 The function flattens nested list-like entries in the DataFrames to scalars, 1075 converts epoch labels to numeric, and then draws one line per shift term. 1076 """ 1077 1078 if data_dict is None: 1079 data_dict = self.linear_shift_history() 1080 if data_dict is None: 1081 raise ValueError("No shift history data provided or stored in the class.") 1082 1083 nodes = [node] if node else list(data_dict.keys()) 1084 1085 for n in nodes: 1086 df = data_dict[n].copy() 1087 1088 # Flatten nested lists or list-like cells 1089 def flatten(x): 1090 if isinstance(x, list): 1091 if len(x) == 0: 1092 return np.nan 1093 if all(isinstance(i, (int, float)) for i in x): 1094 return np.mean(x) # average simple list 1095 if all(isinstance(i, list) for i in x): 1096 # nested list -> flatten inner and average 1097 flat = [v for sub in x for v in (sub if isinstance(sub, list) else [sub])] 1098 return np.mean(flat) if flat else np.nan 1099 return x[0] if len(x) == 1 else np.nan 1100 return x 1101 1102 df = df.applymap(flatten) 1103 1104 # Ensure numeric columns 1105 df = df.apply(pd.to_numeric, errors='coerce') 1106 1107 # Convert epoch labels to numeric 1108 df.columns = [ 1109 int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c 1110 for c in df.columns 1111 ] 1112 df = df.reindex(sorted(df.columns), axis=1) 1113 1114 plt.figure(figsize=(10, 6)) 1115 for idx in df.index: 1116 plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"shift_{idx}") 1117 1118 if ref_lines and n in ref_lines: 1119 for v in ref_lines[n]: 1120 plt.axhline(y=v, color="k", linestyle="--", lw=1.0) 1121 plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8) 1122 1123 plt.xlabel("Epoch") 1124 plt.ylabel("Shift Value") 1125 plt.title(f"Shift Term History — Node: {n}") 1126 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small") 1127 plt.tight_layout() 1128 plt.show() 1129 1130 def plot_simple_intercepts_history(self, data_dict=None, node=None,ref_lines=None): 1131 """ 1132 Plot the evolution of simple intercept weights over epochs. 1133 1134 Parameters 1135 ---------- 1136 data_dict : dict or None, optional 1137 Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing intercept 1138 weights across epochs. If None, `simple_intercept_history()` is called. 1139 node : str or None, optional 1140 If provided, plot only this node. Otherwise, plot all nodes present 1141 in ``data_dict``. 1142 ref_lines : dict or None, optional 1143 Optional mapping ``{node_name: list of float}``. For each specified 1144 node, horizontal reference lines are drawn at the given values. 1145 1146 Returns 1147 ------- 1148 None 1149 1150 Notes 1151 ----- 1152 Nested list-like entries in the DataFrames are reduced to scalars before 1153 plotting. One line is drawn per intercept parameter. 1154 """ 1155 if data_dict is None: 1156 data_dict = self.simple_intercept_history() 1157 if data_dict is None: 1158 raise ValueError("No intercept history data provided or stored in the class.") 1159 1160 nodes = [node] if node else list(data_dict.keys()) 1161 1162 for n in nodes: 1163 df = data_dict[n].copy() 1164 1165 def extract_scalar(x): 1166 if isinstance(x, list): 1167 while isinstance(x, list) and len(x) > 0: 1168 x = x[0] 1169 return float(x) if isinstance(x, (int, float, np.floating)) else np.nan 1170 1171 df = df.applymap(extract_scalar) 1172 1173 # Convert epoch labels → numeric 1174 df.columns = [ 1175 int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c 1176 for c in df.columns 1177 ] 1178 df = df.reindex(sorted(df.columns), axis=1) 1179 1180 plt.figure(figsize=(10, 6)) 1181 for idx in df.index: 1182 plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"theta_{idx}") 1183 1184 if ref_lines and n in ref_lines: 1185 for v in ref_lines[n]: 1186 plt.axhline(y=v, color="k", linestyle="--", lw=1.0) 1187 plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8) 1188 1189 plt.xlabel("Epoch") 1190 plt.ylabel("Intercept Weight") 1191 plt.title(f"Simple Intercept Evolution — Node: {n}") 1192 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small") 1193 plt.tight_layout() 1194 plt.show() 1195 1196 def plot_latents(self, df, variable: str = None, confidence: float = 0.95, simulations: int = 1000): 1197 """ 1198 Visualize latent U distributions for one or all nodes. 1199 1200 Parameters 1201 ---------- 1202 df : pandas.DataFrame 1203 Input data frame with raw node values. 1204 variable : str or None, optional 1205 If provided, only this node's latents are plotted. If None, all 1206 nodes with latent columns are processed. 1207 confidence : float, optional 1208 Confidence level for QQ-plot bands (0 < confidence < 1). 1209 Default is 0.95. 1210 simulations : int, optional 1211 Number of Monte Carlo simulations for QQ-plot bands. Default is 1000. 1212 1213 Returns 1214 ------- 1215 None 1216 1217 Notes 1218 ----- 1219 For each node, two plots are produced: 1220 - Histogram of the latent U values. 1221 - QQ-plot with simulation-based confidence bands under a logistic reference. 1222 """ 1223 # Compute latent representations 1224 latents_df = self.get_latent(df) 1225 1226 # Select nodes 1227 nodes = [variable] if variable is not None else self.nodes_dict.keys() 1228 1229 for node in nodes: 1230 if f"{node}_U" not in latents_df.columns: 1231 print(f"[WARNING] No latent found for node {node}, skipping.") 1232 continue 1233 1234 sample = latents_df[f"{node}_U"].values 1235 1236 # --- Create plots --- 1237 fig, axs = plt.subplots(1, 2, figsize=(12, 5)) 1238 1239 # Histogram 1240 axs[0].hist(sample, bins=50, color="steelblue", alpha=0.7) 1241 axs[0].set_title(f"Latent Histogram ({node})") 1242 axs[0].set_xlabel("U") 1243 axs[0].set_ylabel("Frequency") 1244 1245 # QQ Plot with confidence bands 1246 probplot(sample, dist="logistic", plot=axs[1]) 1247 self._add_r_style_confidence_bands(axs[1], sample, dist=logistic,confidence=confidence, simulations=simulations) 1248 axs[1].set_title(f"Latent QQ Plot ({node})") 1249 1250 plt.suptitle(f"Latent Diagnostics for Node: {node}", fontsize=14) 1251 plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 1252 plt.show() 1253 1254 def plot_hdag(self,df,variables=None, plot_n_rows=1,**kwargs): 1255 1256 """ 1257 Visualize the transformation function h() for selected DAG nodes. 1258 1259 Parameters 1260 ---------- 1261 df : pandas.DataFrame 1262 Input data containing node values or model predictions. 1263 variables : list of str or None, optional 1264 Names of nodes to visualize. If None, all nodes in ``self.models`` 1265 are considered. 1266 plot_n_rows : int, optional 1267 Maximum number of rows from ``df`` to visualize. Default is 1. 1268 **kwargs 1269 Additional keyword arguments forwarded to the underlying plotting 1270 helpers (`show_hdag_continous` / `show_hdag_ordinal`). 1271 1272 Returns 1273 ------- 1274 None 1275 1276 Notes 1277 ----- 1278 - For continuous outcomes, `show_hdag_continous` is called. 1279 - For ordinal outcomes, `show_hdag_ordinal` is called. 1280 - Nodes that are neither continuous nor ordinal are skipped with a warning. 1281 """ 1282 1283 1284 if len(df)> 1: 1285 print("[WARNING] len(df)>1, set: plot_n_rows accordingly") 1286 1287 variables_list=variables if variables is not None else list(self.models.keys()) 1288 for node in variables_list: 1289 if is_outcome_modelled_continous(node, self.nodes_dict): 1290 show_hdag_continous(df,node=node,configuration_dict=self.cfg.conf_dict,minmax_dict=self.minmax_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1291 1292 elif is_outcome_modelled_ordinal(node, self.nodes_dict): 1293 show_hdag_ordinal(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1294 # plot_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1295 # save_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,**kwargs) 1296 else: 1297 print(f"[WARNING] Node {node} is wheter ordinal nor continous, not implemented yet") 1298 1299 @staticmethod 1300 def _add_r_style_confidence_bands(ax, sample, dist, confidence=0.95, simulations=1000): 1301 """ 1302 Add simulation-based confidence bands to a QQ-plot. 1303 1304 Parameters 1305 ---------- 1306 ax : matplotlib.axes.Axes 1307 Axes object on which to draw the QQ-plot and bands. 1308 sample : array-like 1309 Empirical sample used in the QQ-plot. 1310 dist : scipy.stats distribution 1311 Distribution object providing ``ppf`` and ``rvs`` methods (e.g. logistic). 1312 confidence : float, optional 1313 Confidence level (0 < confidence < 1) for the bands. Default is 0.95. 1314 simulations : int, optional 1315 Number of Monte Carlo simulations used to estimate the bands. Default is 1000. 1316 1317 Returns 1318 ------- 1319 None 1320 1321 Notes 1322 ----- 1323 The axes are cleared, and a new QQ-plot is drawn with: 1324 - Empirical vs. theoretical quantiles. 1325 - 45-degree reference line. 1326 - Shaded confidence band region. 1327 """ 1328 1329 n = len(sample) 1330 if n == 0: 1331 return 1332 1333 quantiles = np.linspace(0, 1, n, endpoint=False) + 0.5 / n 1334 theo_q = dist.ppf(quantiles) 1335 1336 # Simulate order statistics from the theoretical distribution 1337 sim_data = dist.rvs(size=(simulations, n)) 1338 sim_order_stats = np.sort(sim_data, axis=1) 1339 1340 # Confidence bands 1341 lower = np.percentile(sim_order_stats, 100 * (1 - confidence) / 2, axis=0) 1342 upper = np.percentile(sim_order_stats, 100 * (1 + confidence) / 2, axis=0) 1343 1344 # Sort empirical sample 1345 sample_sorted = np.sort(sample) 1346 1347 # Re-draw points and CI (overwrite probplot defaults) 1348 ax.clear() 1349 ax.plot(theo_q, sample_sorted, 'o', markersize=3, alpha=0.6, label="Empirical Q-Q") 1350 ax.plot(theo_q, theo_q, 'b--', label="y = x") 1351 ax.fill_between(theo_q, lower, upper, color='gray', alpha=0.3, 1352 label=f'{int(confidence*100)}% CI') 1353 ax.legend() 1354 1355 ## SAMPLING METHODS 1356 def sample( 1357 self, 1358 do_interventions: dict = None, 1359 predefined_latent_samples_df: pd.DataFrame = None, 1360 **kwargs, 1361 ): 1362 """ 1363 Sample from the joint DAG using the trained TRAM models. 1364 1365 Allows for: 1366 1367 Oberservational sampling 1368 Interventional sampling via ``do()`` operations 1369 Counterfactial sampling using predefined latent draws and do() 1370 1371 Parameters 1372 ---------- 1373 do_interventions : dict or None, optional 1374 Mapping of node names to intervened (fixed) values. For example: 1375 ``{"x1": 1.0}`` represents ``do(x1 = 1.0)``. Default is None. 1376 predefined_latent_samples_df : pandas.DataFrame or None, optional 1377 DataFrame containing columns ``"{node}_U"`` with predefined latent 1378 draws to be used instead of sampling from the prior. Default is None. 1379 **kwargs 1380 Sampling options overriding internal defaults: 1381 1382 number_of_samples : int, default 10000 1383 Total number of samples to draw. 1384 batch_size : int, default 32 1385 Batch size for internal sampling loops. 1386 delete_all_previously_sampled : bool, default True 1387 If True, delete old sampling files in node-specific sampling 1388 directories before writing new ones. 1389 verbose : bool 1390 If True, print informational messages. 1391 debug : bool 1392 If True, print debug output. 1393 device : {"auto", "cpu", "cuda"} 1394 Device selection for sampling. 1395 use_initial_weights_for_sampling : bool, default False 1396 If True, sample from initial (untrained) model parameters. 1397 1398 Returns 1399 ------- 1400 tuple 1401 A tuple ``(sampled_by_node, latents_by_node)``: 1402 1403 sampled_by_node : dict 1404 Mapping ``{node_name: torch.Tensor}`` of sampled node values. 1405 latents_by_node : dict 1406 Mapping ``{node_name: torch.Tensor}`` of latent U values used. 1407 1408 Raises 1409 ------ 1410 ValueError 1411 If the experiment directory cannot be resolved or if scaling 1412 information (``self.minmax_dict``) is missing. 1413 RuntimeError 1414 If min–max scaling has not been computed before calling `sample`. 1415 """ 1416 try: 1417 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1418 except KeyError: 1419 raise ValueError( 1420 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 1421 "Sampling requires trained model checkpoints." 1422 ) 1423 1424 # ---- defaults ---- 1425 settings = { 1426 "number_of_samples": 10_000, 1427 "batch_size": 32, 1428 "delete_all_previously_sampled": True, 1429 "verbose": self.verbose if hasattr(self, "verbose") else False, 1430 "debug": self.debug if hasattr(self, "debug") else False, 1431 "device": self.device.type if hasattr(self, "device") else "auto", 1432 "use_initial_weights_for_sampling": False, 1433 1434 } 1435 1436 # self._validate_kwargs( kwargs, defaults_attr= "settings", context="sample") 1437 1438 settings.update(kwargs) 1439 1440 1441 if not hasattr(self, "minmax_dict"): 1442 raise RuntimeError( 1443 "[ERROR] minmax_dict not found. You must call .fit() or .load_or_compute_minmax() " 1444 "before sampling, so scaling info is available." 1445 ) 1446 1447 # ---- resolve device ---- 1448 device_str=self.get_device(settings) 1449 self.device = torch.device(device_str) 1450 1451 1452 if self.debug or settings["debug"]: 1453 print(f"[DEBUG] sample(): device: {self.device}") 1454 1455 # ---- perform sampling ---- 1456 sampled_by_node, latents_by_node = sample_full_dag( 1457 configuration_dict=self.cfg.conf_dict, 1458 EXPERIMENT_DIR=EXPERIMENT_DIR, 1459 device=self.device, 1460 do_interventions=do_interventions or {}, 1461 predefined_latent_samples_df=predefined_latent_samples_df, 1462 number_of_samples=settings["number_of_samples"], 1463 batch_size=settings["batch_size"], 1464 delete_all_previously_sampled=settings["delete_all_previously_sampled"], 1465 verbose=settings["verbose"], 1466 debug=settings["debug"], 1467 minmax_dict=self.minmax_dict, 1468 use_initial_weights_for_sampling=settings["use_initial_weights_for_sampling"] 1469 ) 1470 1471 return sampled_by_node, latents_by_node 1472 1473 def load_sampled_and_latents(self, EXPERIMENT_DIR: str = None, nodes: list = None): 1474 """ 1475 Load previously stored sampled values and latents for each node. 1476 1477 Parameters 1478 ---------- 1479 EXPERIMENT_DIR : str or None, optional 1480 Experiment directory path. If None, it is taken from 1481 ``self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]``. 1482 nodes : list of str or None, optional 1483 Nodes for which to load samples. If None, use all nodes from 1484 ``self.nodes_dict``. 1485 1486 Returns 1487 ------- 1488 tuple 1489 A tuple ``(sampled_by_node, latents_by_node)``: 1490 1491 sampled_by_node : dict 1492 Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU). 1493 latents_by_node : dict 1494 Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU). 1495 1496 Raises 1497 ------ 1498 ValueError 1499 If the experiment directory cannot be resolved or if no node list 1500 is available and ``nodes`` is None. 1501 1502 Notes 1503 ----- 1504 Nodes without both ``sampled.pt`` and ``latents.pt`` files are skipped 1505 with a warning. 1506 """ 1507 # --- resolve paths and node list --- 1508 if EXPERIMENT_DIR is None: 1509 try: 1510 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1511 except (AttributeError, KeyError): 1512 raise ValueError( 1513 "[ERROR] Could not resolve EXPERIMENT_DIR from cfg.conf_dict['PATHS']. " 1514 "Please provide EXPERIMENT_DIR explicitly." 1515 ) 1516 1517 if nodes is None: 1518 if hasattr(self, "nodes_dict"): 1519 nodes = list(self.nodes_dict.keys()) 1520 else: 1521 raise ValueError( 1522 "[ERROR] No node list found. Please provide `nodes` or initialize model with a config." 1523 ) 1524 1525 # --- load tensors --- 1526 sampled_by_node = {} 1527 latents_by_node = {} 1528 1529 for node in nodes: 1530 node_dir = os.path.join(EXPERIMENT_DIR, f"{node}") 1531 sampling_dir = os.path.join(node_dir, "sampling") 1532 1533 sampled_path = os.path.join(sampling_dir, "sampled.pt") 1534 latents_path = os.path.join(sampling_dir, "latents.pt") 1535 1536 if not os.path.exists(sampled_path) or not os.path.exists(latents_path): 1537 print(f"[WARNING] Missing files for node '{node}' — skipping.") 1538 continue 1539 1540 try: 1541 sampled = torch.load(sampled_path, map_location="cpu") 1542 latent_sample = torch.load(latents_path, map_location="cpu") 1543 except Exception as e: 1544 print(f"[ERROR] Could not load sampling files for node '{node}': {e}") 1545 continue 1546 1547 sampled_by_node[node] = sampled.detach().cpu() 1548 latents_by_node[node] = latent_sample.detach().cpu() 1549 1550 if self.verbose or self.debug: 1551 print(f"[INFO] Loaded sampled and latent tensors for {len(sampled_by_node)} nodes from {EXPERIMENT_DIR}") 1552 1553 return sampled_by_node, latents_by_node 1554 1555 def plot_samples_vs_true( 1556 self, 1557 df, 1558 sampled: dict = None, 1559 variable: list = None, 1560 bins: int = 100, 1561 hist_true_color: str = "blue", 1562 hist_est_color: str = "orange", 1563 figsize: tuple = (14, 5), 1564 ): 1565 1566 1567 """ 1568 Compare sampled vs. observed distributions for selected nodes. 1569 1570 Parameters 1571 ---------- 1572 df : pandas.DataFrame 1573 Data frame containing the observed node values. 1574 sampled : dict or None, optional 1575 Optional mapping ``{node_name: array-like or torch.Tensor}`` of sampled 1576 values. If None or if a node is missing, samples are loaded from 1577 ``EXPERIMENT_DIR/{node}/sampling/sampled.pt``. 1578 variable : list of str or None, optional 1579 Subset of nodes to plot. If None, all nodes in the configuration 1580 are considered. 1581 bins : int, optional 1582 Number of histogram bins for continuous variables. Default is 100. 1583 hist_true_color : str, optional 1584 Color name for the histogram of true values. Default is "blue". 1585 hist_est_color : str, optional 1586 Color name for the histogram of sampled values. Default is "orange". 1587 figsize : tuple, optional 1588 Figure size for the matplotlib plots. Default is (14, 5). 1589 1590 Returns 1591 ------- 1592 None 1593 1594 Notes 1595 ----- 1596 - Continuous outcomes: histogram overlay + QQ-plot. 1597 - Ordinal outcomes: side-by-side bar plot of relative frequencies. 1598 - Other categorical outcomes: side-by-side bar plot with category labels. 1599 - If samples are probabilistic (2D tensor), the argmax across classes is used. 1600 """ 1601 1602 target_nodes = self.cfg.conf_dict["nodes"] 1603 experiment_dir = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1604 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 1605 1606 plot_list = variable if variable is not None else target_nodes 1607 1608 for node in plot_list: 1609 # Load sampled data 1610 if sampled is not None and node in sampled: 1611 sdata = sampled[node] 1612 if isinstance(sdata, torch.Tensor): 1613 sampled_vals = sdata.detach().cpu().numpy() 1614 else: 1615 sampled_vals = np.asarray(sdata) 1616 else: 1617 sample_path = os.path.join(experiment_dir, f"{node}/sampling/sampled.pt") 1618 if not os.path.isfile(sample_path): 1619 print(f"[WARNING] skip {node}: {sample_path} not found.") 1620 continue 1621 1622 try: 1623 sampled_vals = torch.load(sample_path, map_location=device).cpu().numpy() 1624 except Exception as e: 1625 print(f"[ERROR] Could not load {sample_path}: {e}") 1626 continue 1627 1628 # If logits/probabilities per sample, take argmax 1629 if sampled_vals.ndim == 2: 1630 print(f"[INFO] CAUTION! {node}: samples are probabilistic — each sample follows a probability " 1631 f"distribution based on the valid latent range. " 1632 f"Note that this frequency plot reflects only the distribution of the most probable " 1633 f"class per sample.") 1634 sampled_vals = np.argmax(sampled_vals, axis=1) 1635 1636 sampled_vals = sampled_vals[np.isfinite(sampled_vals)] 1637 1638 if node not in df.columns: 1639 print(f"[WARNING] skip {node}: column not found in DataFrame.") 1640 continue 1641 1642 true_vals = df[node].dropna().values 1643 true_vals = true_vals[np.isfinite(true_vals)] 1644 1645 if sampled_vals.size == 0 or true_vals.size == 0: 1646 print(f"[WARNING] skip {node}: empty array after NaN/Inf removal.") 1647 continue 1648 1649 fig, axs = plt.subplots(1, 2, figsize=figsize) 1650 1651 if is_outcome_modelled_continous(node, target_nodes): 1652 axs[0].hist(true_vals, bins=bins, density=True, alpha=0.6, 1653 color=hist_true_color, label=f"True {node}") 1654 axs[0].hist(sampled_vals, bins=bins, density=True, alpha=0.6, 1655 color=hist_est_color, label="Sampled") 1656 axs[0].set_xlabel("Value") 1657 axs[0].set_ylabel("Density") 1658 axs[0].set_title(f"Histogram overlay for {node}") 1659 axs[0].legend() 1660 axs[0].grid(True, ls="--", alpha=0.4) 1661 1662 qqplot_2samples(true_vals, sampled_vals, line="45", ax=axs[1]) 1663 axs[1].set_xlabel("True quantiles") 1664 axs[1].set_ylabel("Sampled quantiles") 1665 axs[1].set_title(f"QQ plot for {node}") 1666 axs[1].grid(True, ls="--", alpha=0.4) 1667 1668 elif is_outcome_modelled_ordinal(node, target_nodes): 1669 unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals)) 1670 unique_vals = np.sort(unique_vals) 1671 true_counts = np.array([(true_vals == val).sum() for val in unique_vals]) 1672 sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals]) 1673 1674 axs[0].bar(unique_vals - 0.2, true_counts / true_counts.sum(), 1675 width=0.4, color=hist_true_color, alpha=0.7, label="True") 1676 axs[0].bar(unique_vals + 0.2, sampled_counts / sampled_counts.sum(), 1677 width=0.4, color=hist_est_color, alpha=0.7, label="Sampled") 1678 axs[0].set_xticks(unique_vals) 1679 axs[0].set_xlabel("Ordinal Level") 1680 axs[0].set_ylabel("Relative Frequency") 1681 axs[0].set_title(f"Ordinal bar plot for {node}") 1682 axs[0].legend() 1683 axs[0].grid(True, ls="--", alpha=0.4) 1684 axs[1].axis("off") 1685 1686 else: 1687 unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals)) 1688 unique_vals = sorted(unique_vals, key=str) 1689 true_counts = np.array([(true_vals == val).sum() for val in unique_vals]) 1690 sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals]) 1691 1692 axs[0].bar(np.arange(len(unique_vals)) - 0.2, true_counts / true_counts.sum(), 1693 width=0.4, color=hist_true_color, alpha=0.7, label="True") 1694 axs[0].bar(np.arange(len(unique_vals)) + 0.2, sampled_counts / sampled_counts.sum(), 1695 width=0.4, color=hist_est_color, alpha=0.7, label="Sampled") 1696 axs[0].set_xticks(np.arange(len(unique_vals))) 1697 axs[0].set_xticklabels(unique_vals, rotation=45) 1698 axs[0].set_xlabel("Category") 1699 axs[0].set_ylabel("Relative Frequency") 1700 axs[0].set_title(f"Categorical bar plot for {node}") 1701 axs[0].legend() 1702 axs[0].grid(True, ls="--", alpha=0.4) 1703 axs[1].axis("off") 1704 1705 plt.tight_layout() 1706 plt.show() 1707 1708 ## SUMMARY METHODS 1709 def nll(self,data,variables=None): 1710 """ 1711 Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes. 1712 1713 This function evaluates trained TRAM models for each specified variable (node) 1714 on the provided dataset. It performs forward passes only—no training, no weight 1715 updates—and returns the mean NLL per node. 1716 1717 Parameters 1718 ---------- 1719 data : object 1720 Input dataset or data source compatible with `_ensure_dataset`, containing 1721 both inputs and targets for each node. 1722 variables : list[str], optional 1723 List of variable (node) names to evaluate. If None, all nodes in 1724 `self.models` are evaluated. 1725 1726 Returns 1727 ------- 1728 dict[str, float] 1729 Dictionary mapping each node name to its average NLL value. 1730 1731 Notes 1732 ----- 1733 - Each model is evaluated independently on its respective DataLoader. 1734 - The normalization values (`min_max`) for each node are retrieved from 1735 `self.minmax_dict[node]`. 1736 - The function uses `evaluate_tramdag_model()` for per-node evaluation. 1737 - Expected directory structure: 1738 `<EXPERIMENT_DIR>/<node>/` 1739 where each node directory contains the trained model. 1740 """ 1741 1742 td_data = self._ensure_dataset(data, is_val=True) 1743 variables_list = variables if variables != None else list(self.models.keys()) 1744 nll_dict = {} 1745 for node in variables_list: 1746 min_vals = torch.tensor(self.minmax_dict[node][0], dtype=torch.float32) 1747 max_vals = torch.tensor(self.minmax_dict[node][1], dtype=torch.float32) 1748 min_max = torch.stack([min_vals, max_vals], dim=0) 1749 data_loader = td_data.loaders[node] 1750 model = self.models[node] 1751 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1752 NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}") 1753 nll= evaluate_tramdag_model(node=node, 1754 target_nodes=self.nodes_dict, 1755 NODE_DIR=NODE_DIR, 1756 tram_model=model, 1757 data_loader=data_loader, 1758 min_max=min_max) 1759 nll_dict[node]=nll 1760 return nll_dict 1761 1762 def get_train_val_nll(self, node: str, mode: str) -> tuple[float, float]: 1763 """ 1764 Retrieve training and validation NLL for a node and a given model state. 1765 1766 Parameters 1767 ---------- 1768 node : str 1769 Node name. 1770 mode : {"best", "last", "init"} 1771 State of interest: 1772 - "best": epoch with lowest validation NLL. 1773 - "last": final epoch. 1774 - "init": first epoch (index 0). 1775 1776 Returns 1777 ------- 1778 tuple of (float or None, float or None) 1779 A tuple ``(train_nll, val_nll)`` for the requested mode. 1780 Returns ``(None, None)`` if loss files are missing or cannot be read. 1781 1782 Notes 1783 ----- 1784 This method expects per-node JSON files: 1785 1786 - ``train_loss_hist.json`` 1787 - ``val_loss_hist.json`` 1788 1789 in the node directory. 1790 """ 1791 NODE_DIR = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node) 1792 train_path = os.path.join(NODE_DIR, "train_loss_hist.json") 1793 val_path = os.path.join(NODE_DIR, "val_loss_hist.json") 1794 1795 if not os.path.exists(train_path) or not os.path.exists(val_path): 1796 if getattr(self, "debug", False): 1797 print(f"[DEBUG] Missing loss files for node '{node}'. Returning None.") 1798 return None, None 1799 1800 try: 1801 with open(train_path, "r") as f: 1802 train_hist = json.load(f) 1803 with open(val_path, "r") as f: 1804 val_hist = json.load(f) 1805 1806 train_nlls = np.array(train_hist) 1807 val_nlls = np.array(val_hist) 1808 1809 if mode == "init": 1810 idx = 0 1811 elif mode == "last": 1812 idx = len(val_nlls) - 1 1813 elif mode == "best": 1814 idx = int(np.argmin(val_nlls)) 1815 else: 1816 raise ValueError(f"Invalid mode '{mode}' — must be one of 'best', 'last', 'init'.") 1817 1818 train_nll = float(train_nlls[idx]) 1819 val_nll = float(val_nlls[idx]) 1820 return train_nll, val_nll 1821 1822 except Exception as e: 1823 print(f"[ERROR] Failed to load NLLs for node '{node}' ({mode}): {e}") 1824 return None, None 1825 1826 def get_thetas(self, node: str, state: str = "best"): 1827 """ 1828 Return transformed intercept (theta) parameters for a node and state. 1829 1830 Parameters 1831 ---------- 1832 node : str 1833 Node name. 1834 state : {"best", "last", "init"}, optional 1835 Model state for which to return parameters. Default is "best". 1836 1837 Returns 1838 ------- 1839 Any or None 1840 Transformed theta parameters for the requested node and state. 1841 The exact structure (scalar, list, or other) depends on the model. 1842 1843 Raises 1844 ------ 1845 ValueError 1846 If an invalid state is given (not in {"best", "last", "init"}). 1847 1848 Notes 1849 ----- 1850 Intercept dictionaries are cached on the instance under the attribute 1851 ``intercept_dicts``. If missing or incomplete, they are recomputed using 1852 `get_simple_intercepts_dict`. 1853 """ 1854 1855 state = state.lower() 1856 if state not in ["best", "last", "init"]: 1857 raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].") 1858 1859 dict_attr = "intercept_dicts" 1860 1861 # If no cached intercepts exist, compute them 1862 if not hasattr(self, dict_attr): 1863 if getattr(self, "debug", False): 1864 print(f"[DEBUG] '{dict_attr}' not found, computing via get_simple_intercepts_dict().") 1865 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1866 1867 all_dicts = getattr(self, dict_attr) 1868 1869 # If the requested state isn’t cached, recompute 1870 if state not in all_dicts: 1871 if getattr(self, "debug", False): 1872 print(f"[DEBUG] State '{state}' not found in cached intercepts, recomputing full dict.") 1873 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1874 all_dicts = getattr(self, dict_attr) 1875 1876 state_dict = all_dicts.get(state, {}) 1877 1878 # Return cached node intercept if present 1879 if node in state_dict: 1880 return state_dict[node] 1881 1882 # If not found, recompute full dict as fallback 1883 if getattr(self, "debug", False): 1884 print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.") 1885 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1886 all_dicts = getattr(self, dict_attr) 1887 return all_dicts.get(state, {}).get(node, None) 1888 1889 def get_linear_shifts(self, node: str, state: str = "best"): 1890 """ 1891 Return learned linear shift terms for a node and a given state. 1892 1893 Parameters 1894 ---------- 1895 node : str 1896 Node name. 1897 state : {"best", "last", "init"}, optional 1898 Model state for which to return linear shift terms. Default is "best". 1899 1900 Returns 1901 ------- 1902 dict or Any or None 1903 Linear shift terms for the given node and state. Usually a dict 1904 mapping term names to weights. 1905 1906 Raises 1907 ------ 1908 ValueError 1909 If an invalid state is given (not in {"best", "last", "init"}). 1910 1911 Notes 1912 ----- 1913 Linear shift dictionaries are cached on the instance under the attribute 1914 ``linear_shift_dicts``. If missing or incomplete, they are recomputed using 1915 `get_linear_shifts_dict`. 1916 """ 1917 state = state.lower() 1918 if state not in ["best", "last", "init"]: 1919 raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].") 1920 1921 dict_attr = "linear_shift_dicts" 1922 1923 # If no global dicts cached, compute once 1924 if not hasattr(self, dict_attr): 1925 if getattr(self, "debug", False): 1926 print(f"[DEBUG] '{dict_attr}' not found, computing via get_linear_shifts_dict().") 1927 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1928 1929 all_dicts = getattr(self, dict_attr) 1930 1931 # If the requested state isn't cached, compute all again (covers fresh runs) 1932 if state not in all_dicts: 1933 if getattr(self, "debug", False): 1934 print(f"[DEBUG] State '{state}' not found in cached linear shifts, recomputing full dict.") 1935 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1936 all_dicts = getattr(self, dict_attr) 1937 1938 # Now fetch the dictionary for this state 1939 state_dict = all_dicts.get(state, {}) 1940 1941 # If the node is available, return its entry 1942 if node in state_dict: 1943 return state_dict[node] 1944 1945 # If missing, try recomputing (fallback) 1946 if getattr(self, "debug", False): 1947 print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.") 1948 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1949 all_dicts = getattr(self, dict_attr) 1950 return all_dicts.get(state, {}).get(node, None) 1951 1952 def get_linear_shifts_dict(self): 1953 """ 1954 Compute linear shift term dictionaries for all nodes and states. 1955 1956 For each node and each available state ("best", "last", "init"), this 1957 method loads the corresponding model checkpoint, extracts linear shift 1958 weights from the TRAM model, and stores them in a nested dictionary. 1959 1960 Returns 1961 ------- 1962 dict 1963 Nested dictionary of the form: 1964 1965 .. code-block:: python 1966 1967 { 1968 "best": {node: {...}}, 1969 "last": {node: {...}}, 1970 "init": {node: {...}}, 1971 } 1972 1973 where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``) 1974 to their weights. 1975 1976 Notes 1977 ----- 1978 - If "best" or "last" checkpoints are unavailable for a node, only 1979 the "init" entry is populated. 1980 - Empty outer states (without any nodes) are removed from the result. 1981 """ 1982 1983 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1984 nodes_list = list(self.models.keys()) 1985 all_states = ["best", "last", "init"] 1986 all_linear_shift_dicts = {state: {} for state in all_states} 1987 1988 for node in nodes_list: 1989 NODE_DIR = os.path.join(EXPERIMENT_DIR, node) 1990 BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR) 1991 INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt") 1992 1993 state_paths = { 1994 "best": BEST_MODEL_PATH, 1995 "last": LAST_MODEL_PATH, 1996 "init": INIT_MODEL_PATH, 1997 } 1998 1999 for state, LOAD_PATH in state_paths.items(): 2000 if not os.path.exists(LOAD_PATH): 2001 if state != "init": 2002 # skip best/last if unavailable 2003 continue 2004 else: 2005 print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.") 2006 if not os.path.exists(LOAD_PATH): 2007 if getattr(self, "debug", False): 2008 print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.") 2009 continue 2010 2011 # Load parents and model 2012 _, terms_dict, _ = ordered_parents(node, self.nodes_dict) 2013 state_dict = torch.load(LOAD_PATH, map_location=self.device) 2014 tram_model = self.models[node] 2015 tram_model.load_state_dict(state_dict) 2016 2017 epoch_weights = {} 2018 if hasattr(tram_model, "nn_shift") and tram_model.nn_shift is not None: 2019 for i, shift_layer in enumerate(tram_model.nn_shift): 2020 module_name = shift_layer.__class__.__name__ 2021 if ( 2022 hasattr(shift_layer, "fc") 2023 and hasattr(shift_layer.fc, "weight") 2024 and module_name == "LinearShift" 2025 ): 2026 term_name = list(terms_dict.keys())[i] 2027 epoch_weights[f"ls({term_name})"] = ( 2028 shift_layer.fc.weight.detach().cpu().squeeze().tolist() 2029 ) 2030 elif getattr(self, "debug", False): 2031 term_name = list(terms_dict.keys())[i] 2032 print(f"[DEBUG] ls({term_name}): missing 'fc' or 'weight' in LinearShift.") 2033 else: 2034 if getattr(self, "debug", False): 2035 print(f"[DEBUG] Tram model for node '{node}' has no nn_shift or it is None.") 2036 2037 all_linear_shift_dicts[state][node] = epoch_weights 2038 2039 # Remove empty states (e.g., when best/last not found for all nodes) 2040 all_linear_shift_dicts = {k: v for k, v in all_linear_shift_dicts.items() if v} 2041 2042 return all_linear_shift_dicts 2043 2044 def get_simple_intercepts_dict(self): 2045 """ 2046 Compute transformed simple intercept dictionaries for all nodes and states. 2047 2048 For each node and each available state ("best", "last", "init"), this 2049 method loads the corresponding model checkpoint, extracts simple intercept 2050 weights, transforms them into interpretable theta parameters, and stores 2051 them in a nested dictionary. 2052 2053 Returns 2054 ------- 2055 dict 2056 Nested dictionary of the form: 2057 2058 .. code-block:: python 2059 2060 { 2061 "best": {node: [[theta_1], [theta_2], ...]}, 2062 "last": {node: [[theta_1], [theta_2], ...]}, 2063 "init": {node: [[theta_1], [theta_2], ...]}, 2064 } 2065 2066 Notes 2067 ----- 2068 - For ordinal models (``self.is_ontram == True``), `transform_intercepts_ordinal` 2069 is used. 2070 - For continuous models, `transform_intercepts_continous` is used. 2071 - Empty outer states (without any nodes) are removed from the result. 2072 """ 2073 2074 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 2075 nodes_list = list(self.models.keys()) 2076 all_states = ["best", "last", "init"] 2077 all_si_intercept_dicts = {state: {} for state in all_states} 2078 2079 debug = getattr(self, "debug", False) 2080 verbose = getattr(self, "verbose", False) 2081 is_ontram = getattr(self, "is_ontram", False) 2082 2083 for node in nodes_list: 2084 NODE_DIR = os.path.join(EXPERIMENT_DIR, node) 2085 BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR) 2086 INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt") 2087 2088 state_paths = { 2089 "best": BEST_MODEL_PATH, 2090 "last": LAST_MODEL_PATH, 2091 "init": INIT_MODEL_PATH, 2092 } 2093 2094 for state, LOAD_PATH in state_paths.items(): 2095 if not os.path.exists(LOAD_PATH): 2096 if state != "init": 2097 continue 2098 else: 2099 print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.") 2100 if not os.path.exists(LOAD_PATH): 2101 if debug: 2102 print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.") 2103 continue 2104 2105 # Load model state 2106 state_dict = torch.load(LOAD_PATH, map_location=self.device) 2107 tram_model = self.models[node] 2108 tram_model.load_state_dict(state_dict) 2109 2110 # Extract and transform simple intercept weights 2111 si_weights = None 2112 if hasattr(tram_model, "nn_int") and tram_model.nn_int is not None and isinstance(tram_model.nn_int, SimpleIntercept): 2113 if hasattr(tram_model.nn_int, "fc") and hasattr(tram_model.nn_int.fc, "weight"): 2114 weights = tram_model.nn_int.fc.weight.detach().cpu().tolist() 2115 weights_tensor = torch.Tensor(weights) 2116 2117 if debug: 2118 print(f"[DEBUG] Node '{node}' ({state}) theta tilde shape: {weights_tensor.shape}") 2119 2120 if is_ontram: 2121 si_weights = transform_intercepts_ordinal(weights_tensor.reshape(1, -1))[:, 1:-1].reshape(-1, 1) 2122 else: 2123 si_weights = transform_intercepts_continous(weights_tensor.reshape(1, -1)).reshape(-1, 1) 2124 2125 si_weights = si_weights.tolist() 2126 2127 if debug: 2128 print(f"[DEBUG] Node '{node}' ({state}) theta transformed: {si_weights}") 2129 else: 2130 if debug: 2131 print(f"[DEBUG] Node '{node}' ({state}): missing 'fc' or 'weight' in SimpleIntercept.") 2132 else: 2133 if debug: 2134 print(f"[DEBUG] Tram model for node '{node}' has no nn_int or it is None.") 2135 2136 all_si_intercept_dicts[state][node] = si_weights 2137 2138 # Clean up empty states 2139 all_si_intercept_dicts = {k: v for k, v in all_si_intercept_dicts.items() if v} 2140 return all_si_intercept_dicts 2141 2142 def summary(self, verbose=False): 2143 """ 2144 Print a multi-part textual summary of the TramDagModel. 2145 2146 The summary includes: 2147 1. Training metrics overview per node (best/last NLL, epochs). 2148 2. Node-specific details (thetas, linear shifts, optional architecture). 2149 3. Basic information about the attached training DataFrame, if present. 2150 2151 Parameters 2152 ---------- 2153 verbose : bool, optional 2154 If True, include extended per-node details such as the model 2155 architecture, parameter count, and availability of checkpoints 2156 and sampling results. Default is False. 2157 2158 Returns 2159 ------- 2160 None 2161 2162 Notes 2163 ----- 2164 This method prints to stdout and does not return structured data. 2165 It is intended for quick, human-readable inspection of the current 2166 training and model state. 2167 """ 2168 2169 # ---------- SETUP ---------- 2170 try: 2171 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 2172 except KeyError: 2173 EXPERIMENT_DIR = None 2174 print("[WARNING] Missing EXPERIMENT_DIR in cfg.conf_dict['PATHS'].") 2175 2176 print("\n" + "=" * 120) 2177 print(f"{'TRAM DAG MODEL SUMMARY':^120}") 2178 print("=" * 120) 2179 2180 # ---------- METRICS OVERVIEW ---------- 2181 summary_data = [] 2182 for node in self.models.keys(): 2183 node_dir = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node) 2184 train_path = os.path.join(node_dir, "train_loss_hist.json") 2185 val_path = os.path.join(node_dir, "val_loss_hist.json") 2186 2187 if os.path.exists(train_path) and os.path.exists(val_path): 2188 best_train_nll, best_val_nll = self.get_train_val_nll(node, "best") 2189 last_train_nll, last_val_nll = self.get_train_val_nll(node, "last") 2190 n_epochs_total = len(json.load(open(train_path))) 2191 else: 2192 best_train_nll = best_val_nll = last_train_nll = last_val_nll = None 2193 n_epochs_total = 0 2194 2195 summary_data.append({ 2196 "Node": node, 2197 "Best Train NLL": best_train_nll, 2198 "Best Val NLL": best_val_nll, 2199 "Last Train NLL": last_train_nll, 2200 "Last Val NLL": last_val_nll, 2201 "Epochs": n_epochs_total, 2202 }) 2203 2204 df_summary = pd.DataFrame(summary_data) 2205 df_summary = df_summary.round(4) 2206 2207 print("\n[1] TRAINING METRICS OVERVIEW") 2208 print("-" * 120) 2209 if not df_summary.empty: 2210 print( 2211 df_summary.to_string( 2212 index=False, 2213 justify="center", 2214 col_space=14, 2215 float_format=lambda x: f"{x:7.4f}", 2216 ) 2217 ) 2218 else: 2219 print("No training history found for any node.") 2220 print("-" * 120) 2221 2222 # ---------- NODE DETAILS ---------- 2223 print("\n[2] NODE-SPECIFIC DETAILS") 2224 print("-" * 120) 2225 for node in self.models.keys(): 2226 print(f"\n{f'NODE: {node}':^120}") 2227 print("-" * 120) 2228 2229 # THETAS & SHIFTS 2230 for state in ["init", "last", "best"]: 2231 print(f"\n [{state.upper()} STATE]") 2232 2233 # ---- Thetas ---- 2234 try: 2235 thetas = getattr(self, "get_thetas", lambda n, s=None: None)(node, state) 2236 if thetas is not None: 2237 if isinstance(thetas, (list, np.ndarray, pd.Series)): 2238 thetas_flat = np.array(thetas).flatten() 2239 compact = np.round(thetas_flat, 4) 2240 arr_str = np.array2string( 2241 compact, 2242 max_line_width=110, 2243 threshold=np.inf, 2244 separator=", " 2245 ) 2246 lines = arr_str.split("\n") 2247 if len(lines) > 2: 2248 arr_str = "\n".join(lines[:2]) + " ..." 2249 print(f" Θ ({len(thetas_flat)}): {arr_str}") 2250 elif isinstance(thetas, dict): 2251 for k, v in thetas.items(): 2252 print(f" Θ[{k}]: {v}") 2253 else: 2254 print(f" Θ: {thetas}") 2255 else: 2256 print(" Θ: not available") 2257 except Exception as e: 2258 print(f" [Error loading thetas] {e}") 2259 2260 # ---- Linear Shifts ---- 2261 try: 2262 linear_shifts = getattr(self, "get_linear_shifts", lambda n, s=None: None)(node, state) 2263 if linear_shifts is not None: 2264 if isinstance(linear_shifts, dict): 2265 for k, v in linear_shifts.items(): 2266 print(f" {k}: {np.round(v, 4)}") 2267 elif isinstance(linear_shifts, (list, np.ndarray, pd.Series)): 2268 arr = np.round(linear_shifts, 4) 2269 print(f" Linear shifts ({len(arr)}): {arr}") 2270 else: 2271 print(f" Linear shifts: {linear_shifts}") 2272 else: 2273 print(" Linear shifts: not available") 2274 except Exception as e: 2275 print(f" [Error loading linear shifts] {e}") 2276 2277 # ---- Verbose info directly below node ---- 2278 if verbose: 2279 print("\n [DETAILS]") 2280 node_dir = os.path.join(EXPERIMENT_DIR, node) if EXPERIMENT_DIR else None 2281 model = self.models[node] 2282 2283 print(f" Model Architecture:") 2284 arch_str = str(model).split("\n") 2285 for line in arch_str: 2286 print(f" {line}") 2287 print(f" Parameter count: {sum(p.numel() for p in model.parameters()):,}") 2288 2289 if node_dir and os.path.exists(node_dir): 2290 ckpt_exists = any(f.endswith(('.pt', '.pth')) for f in os.listdir(node_dir)) 2291 print(f" Checkpoints found: {ckpt_exists}") 2292 2293 sampling_dir = os.path.join(node_dir, "sampling") 2294 sampling_exists = os.path.isdir(sampling_dir) and len(os.listdir(sampling_dir)) > 0 2295 print(f" Sampling results found: {sampling_exists}") 2296 2297 for label, filename in [("Train", "train_loss_hist.json"), ("Validation", "val_loss_hist.json")]: 2298 path = os.path.join(node_dir, filename) 2299 if os.path.exists(path): 2300 try: 2301 with open(path, "r") as f: 2302 hist = json.load(f) 2303 print(f" {label} history: {len(hist)} epochs") 2304 except Exception as e: 2305 print(f" {label} history: failed to load ({e})") 2306 else: 2307 print(f" {label} history: not found") 2308 else: 2309 print(" [INFO] No experiment directory defined or missing for this node.") 2310 print("-" * 120) 2311 2312 # ---------- TRAINING DATAFRAME ---------- 2313 print("\n[3] TRAINING DATAFRAME") 2314 print("-" * 120) 2315 try: 2316 self.train_df.info() 2317 except AttributeError: 2318 print("No training DataFrame attached to this TramDagModel.") 2319 print("=" * 120 + "\n")
Probabilistic DAG model built from node-wise TRAMs (transformation models).
This class manages:
- Configuration and per-node model construction.
- Data scaling (min–max).
- Training (sequential or per-node parallel on CPU).
- Diagnostics (loss history, intercepts, linear shifts, latents).
- Sampling from the joint DAG and loading stored samples.
- High-level summaries and plotting utilities.
111 def __init__(self): 112 """ 113 Initialize an empty TramDagModel shell. 114 115 Notes 116 ----- 117 This constructor does not build any node models and does not attach a 118 configuration. Use `TramDagModel.from_config` or `TramDagModel.from_directory` 119 to obtain a fully configured and ready-to-use instance. 120 """ 121 122 self.debug = False 123 self.verbose = False 124 self.device = 'auto' 125 pass
Initialize an empty TramDagModel shell.
Notes
This constructor does not build any node models and does not attach a
configuration. Use TramDagModel.from_config or TramDagModel.from_directory
to obtain a fully configured and ready-to-use instance.
127 @staticmethod 128 def get_device(settings): 129 """ 130 Resolve the target device string from a settings dictionary. 131 132 Parameters 133 ---------- 134 settings : dict 135 Dictionary containing at least a key ``"device"`` with one of 136 {"auto", "cpu", "cuda"}. If missing, "auto" is assumed. 137 138 Returns 139 ------- 140 str 141 Device string, either "cpu" or "cuda". 142 143 Notes 144 ----- 145 If ``device == "auto"``, CUDA is selected if available, otherwise CPU. 146 """ 147 device_arg = settings.get("device", "auto") 148 if device_arg == "auto": 149 device_str = "cuda" if torch.cuda.is_available() else "cpu" 150 else: 151 device_str = device_arg 152 return device_str
Resolve the target device string from a settings dictionary.
Parameters
settings : dict
Dictionary containing at least a key "device" with one of
{"auto", "cpu", "cuda"}. If missing, "auto" is assumed.
Returns
str Device string, either "cpu" or "cuda".
Notes
If device == "auto", CUDA is selected if available, otherwise CPU.
185 @classmethod 186 def from_config(cls, cfg, **kwargs): 187 """ 188 Construct a TramDagModel from a TramDagConfig object. 189 190 This builds one TRAM model per node in the DAG and optionally writes 191 the initial model parameters to disk. 192 193 Parameters 194 ---------- 195 cfg : TramDagConfig 196 Configuration wrapper holding the underlying configuration dictionary, 197 including at least: 198 - ``conf_dict["nodes"]``: mapping of node names to node configs. 199 - ``conf_dict["PATHS"]["EXPERIMENT_DIR"]``: experiment directory. 200 **kwargs 201 Node-level construction options. Each key must be present in 202 ``DEFAULTS_CONFIG``. Values can be: 203 - scalar: applied to all nodes. 204 - dict: mapping ``{node_name: value}`` for per-node overrides. 205 206 Common keys include: 207 device : {"auto", "cpu", "cuda"}, default "auto" 208 Device selection (CUDA if available when "auto"). 209 debug : bool, default False 210 If True, print debug messages. 211 verbose : bool, default False 212 If True, print informational messages. 213 set_initial_weights : bool 214 Passed to underlying TRAM model constructors. 215 overwrite_initial_weights : bool, default True 216 If True, overwrite any existing ``initial_model.pt`` files per node. 217 initial_data : Any 218 Optional object passed down to node constructors. 219 220 Returns 221 ------- 222 TramDagModel 223 Fully initialized instance with: 224 - ``cfg`` 225 - ``nodes_dict`` 226 - ``models`` (per-node TRAMs) 227 - ``settings`` (resolved per-node config) 228 229 Raises 230 ------ 231 ValueError 232 If any dict-valued kwarg does not provide values for exactly the set 233 of nodes in ``cfg.conf_dict["nodes"]``. 234 """ 235 236 self = cls() 237 self.cfg = cfg 238 self.cfg.update() # ensure latest version from disk 239 self.cfg._verify_completeness() 240 241 242 try: 243 self.cfg.save() # persist back to disk 244 if getattr(self, "debug", False): 245 print("[DEBUG] Configuration updated and saved.") 246 except Exception as e: 247 print(f"[WARNING] Could not save configuration after update: {e}") 248 249 self.nodes_dict = self.cfg.conf_dict["nodes"] 250 251 self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_CONFIG', context="from_config") 252 253 # update defaults with kwargs 254 settings = dict(cls.DEFAULTS_CONFIG) 255 settings.update(kwargs) 256 257 # resolve device 258 device_arg = settings.get("device", "auto") 259 if device_arg == "auto": 260 device_str = "cuda" if torch.cuda.is_available() else "cpu" 261 else: 262 device_str = device_arg 263 self.device = torch.device(device_str) 264 265 # set flags on the instance so they are accessible later 266 self.debug = settings.get("debug", False) 267 self.verbose = settings.get("verbose", False) 268 269 if self.debug: 270 print(f"[DEBUG] TramDagModel using device: {self.device}") 271 272 # initialize settings storage 273 self.settings = {k: {} for k in settings.keys()} 274 275 # validate dict-typed args 276 for k, v in settings.items(): 277 if isinstance(v, dict): 278 expected = set(self.nodes_dict.keys()) 279 given = set(v.keys()) 280 if expected != given: 281 raise ValueError( 282 f"[ERROR] the provided argument '{k}' keys are not same as in cfg.conf_dict['nodes'].keys().\n" 283 f"Expected: {expected}, but got: {given}\n" 284 f"Please provide values for all variables.") 285 286 # build one model per node 287 self.models = {} 288 for node in self.nodes_dict.keys(): 289 per_node_kwargs = {} 290 for k, v in settings.items(): 291 resolved = v[node] if isinstance(v, dict) else v 292 per_node_kwargs[k] = resolved 293 self.settings[k][node] = resolved 294 if self.debug: 295 print(f"\n[INFO] Building model for node '{node}' with settings: {per_node_kwargs}") 296 self.models[node] = get_fully_specified_tram_model( 297 node=node, 298 configuration_dict=self.cfg.conf_dict, 299 **per_node_kwargs) 300 301 try: 302 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 303 NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}") 304 os.makedirs(NODE_DIR, exist_ok=True) 305 306 model_path = os.path.join(NODE_DIR, "initial_model.pt") 307 overwrite = settings.get("overwrite_initial_weights", True) 308 309 if overwrite or not os.path.exists(model_path): 310 torch.save(self.models[node].state_dict(), model_path) 311 if self.debug: 312 print(f"[DEBUG] Saved initial model state for node '{node}' to {model_path} (overwrite={overwrite})") 313 else: 314 if self.debug: 315 print(f"[DEBUG] Skipped saving initial model for node '{node}' (already exists at {model_path})") 316 except Exception as e: 317 print(f"[ERROR] Could not save initial model state for node '{node}': {e}") 318 319 TEMP_DIR = "temp" 320 if os.path.isdir(TEMP_DIR) and not os.listdir(TEMP_DIR): 321 os.rmdir(TEMP_DIR) 322 323 return self
Construct a TramDagModel from a TramDagConfig object.
This builds one TRAM model per node in the DAG and optionally writes the initial model parameters to disk.
Parameters
cfg : TramDagConfig
Configuration wrapper holding the underlying configuration dictionary,
including at least:
- conf_dict["nodes"]: mapping of node names to node configs.
- conf_dict["PATHS"]["EXPERIMENT_DIR"]: experiment directory.
**kwargs
Node-level construction options. Each key must be present in
DEFAULTS_CONFIG. Values can be:
- scalar: applied to all nodes.
- dict: mapping {node_name: value} for per-node overrides.
Common keys include:
device : {"auto", "cpu", "cuda"}, default "auto"
Device selection (CUDA if available when "auto").
debug : bool, default False
If True, print debug messages.
verbose : bool, default False
If True, print informational messages.
set_initial_weights : bool
Passed to underlying TRAM model constructors.
overwrite_initial_weights : bool, default True
If True, overwrite any existing ``initial_model.pt`` files per node.
initial_data : Any
Optional object passed down to node constructors.
Returns
TramDagModel
Fully initialized instance with:
- cfg
- nodes_dict
- models (per-node TRAMs)
- settings (resolved per-node config)
Raises
ValueError
If any dict-valued kwarg does not provide values for exactly the set
of nodes in cfg.conf_dict["nodes"].
325 @classmethod 326 def from_directory(cls, EXPERIMENT_DIR: str, device: str = "auto", debug: bool = False, verbose: bool = False): 327 """ 328 Reconstruct a TramDagModel from an experiment directory on disk. 329 330 This method: 331 1. Loads the configuration JSON. 332 2. Wraps it in a TramDagConfig. 333 3. Builds all node models via `from_config`. 334 4. Loads the min–max scaling dictionary. 335 336 Parameters 337 ---------- 338 EXPERIMENT_DIR : str 339 Path to an experiment directory containing: 340 - ``configuration.json`` 341 - ``min_max_scaling.json``. 342 device : {"auto", "cpu", "cuda"}, optional 343 Device selection. Default is "auto". 344 debug : bool, optional 345 If True, enable debug messages. Default is False. 346 verbose : bool, optional 347 If True, enable informational messages. Default is False. 348 349 Returns 350 ------- 351 TramDagModel 352 A TramDagModel instance with models, config, and scaling loaded. 353 354 Raises 355 ------ 356 FileNotFoundError 357 If configuration or min–max files cannot be found. 358 RuntimeError 359 If the min–max file cannot be read or parsed. 360 """ 361 362 # --- load config file --- 363 config_path = os.path.join(EXPERIMENT_DIR, "configuration.json") 364 if not os.path.exists(config_path): 365 raise FileNotFoundError(f"[ERROR] Config file not found at {config_path}") 366 367 with open(config_path, "r") as f: 368 cfg_dict = json.load(f) 369 370 # Create TramConfig wrapper 371 cfg = TramDagConfig(cfg_dict, CONF_DICT_PATH=config_path) 372 373 # --- build model from config --- 374 self = cls.from_config(cfg, device=device, debug=debug, verbose=verbose, overwrite_initial_weights=False) 375 376 # --- load minmax scaling --- 377 minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json") 378 if not os.path.exists(minmax_path): 379 raise FileNotFoundError(f"[ERROR] MinMax file not found at {minmax_path}") 380 381 with open(minmax_path, "r") as f: 382 self.minmax_dict = json.load(f) 383 384 if self.verbose or self.debug: 385 print(f"[INFO] Loaded TramDagModel from {EXPERIMENT_DIR}") 386 print(f"[INFO] Config loaded from {config_path}") 387 print(f"[INFO] MinMax scaling loaded from {minmax_path}") 388 389 return self
Reconstruct a TramDagModel from an experiment directory on disk.
This method:
- Loads the configuration JSON.
- Wraps it in a TramDagConfig.
- Builds all node models via
from_config. - Loads the min–max scaling dictionary.
Parameters
EXPERIMENT_DIR : str
Path to an experiment directory containing:
- configuration.json
- min_max_scaling.json.
device : {"auto", "cpu", "cuda"}, optional
Device selection. Default is "auto".
debug : bool, optional
If True, enable debug messages. Default is False.
verbose : bool, optional
If True, enable informational messages. Default is False.
Returns
TramDagModel A TramDagModel instance with models, config, and scaling loaded.
Raises
FileNotFoundError If configuration or min–max files cannot be found. RuntimeError If the min–max file cannot be read or parsed.
429 def load_or_compute_minmax(self, td_train_data=None,use_existing=False, write=True): 430 """ 431 Load an existing Min–Max scaling dictionary from disk or compute a new one 432 from the provided training dataset. 433 434 Parameters 435 ---------- 436 use_existing : bool, optional (default=False) 437 If True, attempts to load an existing `min_max_scaling.json` file 438 from the experiment directory. Raises an error if the file is missing 439 or unreadable. 440 441 write : bool, optional (default=True) 442 If True, writes the computed Min–Max scaling dictionary to 443 `<EXPERIMENT_DIR>/min_max_scaling.json`. 444 445 td_train_data : object, optional 446 Training dataset used to compute scaling statistics. If not provided, 447 the method will ensure or construct it via `_ensure_dataset(data=..., is_val=False)`. 448 449 Behavior 450 -------- 451 - If `use_existing=True`, loads the JSON file containing previously saved 452 min–max values and stores it in `self.minmax_dict`. 453 - If `use_existing=False`, computes a new scaling dictionary using 454 `td_train_data.compute_scaling()` and stores the result in 455 `self.minmax_dict`. 456 - Optionally writes the computed dictionary to disk. 457 458 Side Effects 459 ------------- 460 - Populates `self.minmax_dict` with scaling values. 461 - Writes or loads the file `min_max_scaling.json` under 462 `<EXPERIMENT_DIR>`. 463 - Prints diagnostic output if `self.debug` or `self.verbose` is True. 464 465 Raises 466 ------ 467 FileNotFoundError 468 If `use_existing=True` but the min–max file does not exist. 469 470 RuntimeError 471 If an existing min–max file cannot be read or parsed. 472 473 Notes 474 ----- 475 The computed min–max dictionary is expected to contain scaling statistics 476 per feature, typically in the form: 477 { 478 "node": {"min": float, "max": float}, 479 ... 480 } 481 """ 482 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 483 minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json") 484 485 # laod exisitng if possible 486 if use_existing: 487 if not os.path.exists(minmax_path): 488 raise FileNotFoundError(f"MinMax file not found: {minmax_path}") 489 try: 490 with open(minmax_path, 'r') as f: 491 self.minmax_dict = json.load(f) 492 if self.debug or self.verbose: 493 print(f"[INFO] Loaded existing minmax dict from {minmax_path}") 494 return 495 except Exception as e: 496 raise RuntimeError(f"Could not load existing minmax dict: {e}") 497 498 # 499 if self.debug or self.verbose: 500 print("[INFO] Computing new minmax dict from training data...") 501 502 td_train_data=self._ensure_dataset( data=td_train_data, is_val=False) 503 504 self.minmax_dict = td_train_data.compute_scaling() 505 506 if write: 507 os.makedirs(EXPERIMENT_DIR, exist_ok=True) 508 with open(minmax_path, 'w') as f: 509 json.dump(self.minmax_dict, f, indent=4) 510 if self.debug or self.verbose: 511 print(f"[INFO] Saved new minmax dict to {minmax_path}")
Load an existing Min–Max scaling dictionary from disk or compute a new one from the provided training dataset.
Parameters
use_existing : bool, optional (default=False)
If True, attempts to load an existing min_max_scaling.json file
from the experiment directory. Raises an error if the file is missing
or unreadable.
write : bool, optional (default=True)
If True, writes the computed Min–Max scaling dictionary to
<EXPERIMENT_DIR>/min_max_scaling.json.
td_train_data : object, optional
Training dataset used to compute scaling statistics. If not provided,
the method will ensure or construct it via _ensure_dataset(data=..., is_val=False).
Behavior
- If
use_existing=True, loads the JSON file containing previously saved min–max values and stores it inself.minmax_dict. - If
use_existing=False, computes a new scaling dictionary usingtd_train_data.compute_scaling()and stores the result inself.minmax_dict. - Optionally writes the computed dictionary to disk.
Side Effects
- Populates
self.minmax_dictwith scaling values. - Writes or loads the file
min_max_scaling.jsonunder<EXPERIMENT_DIR>. - Prints diagnostic output if
self.debugorself.verboseis True.
Raises
FileNotFoundError
If use_existing=True but the min–max file does not exist.
RuntimeError If an existing min–max file cannot be read or parsed.
Notes
The computed min–max dictionary is expected to contain scaling statistics per feature, typically in the form: { "node": {"min": float, "max": float}, ... }
612 def fit(self, train_data, val_data=None, **kwargs): 613 """ 614 Train TRAM models for all nodes in the DAG. 615 616 Coordinates dataset preparation, min–max scaling, and per-node training, 617 optionally in parallel on CPU. 618 619 Parameters 620 ---------- 621 train_data : pandas.DataFrame or TramDagDataset 622 Training data. If a DataFrame is given, it is converted into a 623 TramDagDataset using `_ensure_dataset`. 624 val_data : pandas.DataFrame or TramDagDataset or None, optional 625 Validation data. If a DataFrame is given, it is converted into a 626 TramDagDataset. If None, no validation loss is computed. 627 **kwargs 628 Overrides for ``DEFAULTS_FIT``. All keys must exist in 629 ``DEFAULTS_FIT``. Common options: 630 631 epochs : int, default 100 632 Number of training epochs per node. 633 learning_rate : float, default 0.01 634 Learning rate for the default Adam optimizer. 635 train_list : list of str or None, optional 636 List of node names to train. If None, all nodes are trained. 637 train_mode : {"sequential", "parallel"}, default "sequential" 638 Training mode. "parallel" uses joblib-based CPU multiprocessing. 639 GPU forces sequential mode. 640 device : {"auto", "cpu", "cuda"}, default "auto" 641 Device selection. 642 optimizers : dict or None 643 Optional mapping ``{node_name: optimizer}``. If provided for a 644 node, that optimizer is used instead of creating a new Adam. 645 schedulers : dict or None 646 Optional mapping ``{node_name: scheduler}``. 647 use_scheduler : bool 648 If True, enable scheduler usage in the training loop. 649 num_workers : int 650 DataLoader workers in sequential mode (ignored in parallel). 651 persistent_workers : bool 652 DataLoader persistence in sequential mode (ignored in parallel). 653 prefetch_factor : int 654 DataLoader prefetch factor (ignored in parallel). 655 batch_size : int 656 Batch size for all node DataLoaders. 657 debug : bool 658 Enable debug output. 659 verbose : bool 660 Enable informational logging. 661 return_history : bool 662 If True, return a history dict. 663 664 Returns 665 ------- 666 dict or None 667 If ``return_history=True``, a dictionary mapping each node name 668 to its training history. Otherwise, returns None. 669 670 Raises 671 ------ 672 ValueError 673 If ``train_mode`` is not "sequential" or "parallel". 674 """ 675 self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_FIT', context="fit") 676 677 # --- merge defaults --- 678 settings = dict(self.DEFAULTS_FIT) 679 settings.update(kwargs) 680 681 682 self.debug = settings.get("debug", False) 683 self.verbose = settings.get("verbose", False) 684 685 # --- resolve device --- 686 device_str=self.get_device(settings) 687 self.device = torch.device(device_str) 688 689 # --- training mode --- 690 train_mode = settings.get("train_mode", "sequential").lower() 691 if train_mode not in ("sequential", "parallel"): 692 raise ValueError("train_mode must be 'sequential' or 'parallel'") 693 694 # --- DataLoader safety logic --- 695 if train_mode == "parallel": 696 # if user passed loader paralleling params, warn and override 697 for flag in ("num_workers", "persistent_workers", "prefetch_factor"): 698 if flag in kwargs: 699 print(f"[WARNING] '{flag}' is ignored in parallel mode " 700 f"(disabled to prevent nested multiprocessing).") 701 # disable unsafe loader multiprocessing options 702 settings["num_workers"] = 0 703 settings["persistent_workers"] = False 704 settings["prefetch_factor"] = None 705 else: 706 # sequential mode → respect user DataLoader settings 707 if self.debug: 708 print("[DEBUG] Sequential mode: using DataLoader kwargs as provided.") 709 710 # --- which nodes to train --- 711 train_list = settings.get("train_list") or list(self.models.keys()) 712 713 714 # --- dataset prep (receives adjusted settings) --- 715 td_train_data = self._ensure_dataset(train_data, is_val=False, **settings) 716 td_val_data = self._ensure_dataset(val_data, is_val=True, **settings) 717 718 # --- normalization --- 719 self.load_or_compute_minmax(use_existing=False, write=True, td_train_data=td_train_data) 720 721 # --- print header --- 722 if self.verbose or self.debug: 723 print(f"[INFO] Training {len(train_list)} nodes ({train_mode}) on {device_str}") 724 725 # ====================================================================== 726 # Sequential mode safe for GPU or debugging) 727 # ====================================================================== 728 if train_mode == "sequential" or "cuda" in device_str: 729 if "cuda" in device_str and train_mode == "parallel": 730 print("[WARNING] GPU device detected — forcing sequential mode.") 731 results = {} 732 for node in train_list: 733 node, history = self._fit_single_node( 734 node, self, settings, td_train_data, td_val_data, device_str 735 ) 736 results[node] = history 737 738 739 # ====================================================================== 740 # parallel mode (CPU only) 741 # ====================================================================== 742 if train_mode == "parallel": 743 744 n_jobs = min(len(train_list), os.cpu_count() // 2 or 1) 745 if self.verbose or self.debug: 746 print(f"[INFO] Using {n_jobs} CPU workers for parallel node training") 747 parallel_outputs = Parallel( 748 n_jobs=n_jobs, 749 backend="loky",#loky, multiprocessing 750 verbose=10, 751 prefer="processes" 752 )(delayed(self._fit_single_node)(node, self, settings, td_train_data, td_val_data, device_str) for node in train_list ) 753 754 results = {node: hist for node, hist in parallel_outputs} 755 756 if settings.get("return_history", False): 757 return results
Train TRAM models for all nodes in the DAG.
Coordinates dataset preparation, min–max scaling, and per-node training, optionally in parallel on CPU.
Parameters
train_data : pandas.DataFrame or TramDagDataset
Training data. If a DataFrame is given, it is converted into a
TramDagDataset using _ensure_dataset.
val_data : pandas.DataFrame or TramDagDataset or None, optional
Validation data. If a DataFrame is given, it is converted into a
TramDagDataset. If None, no validation loss is computed.
**kwargs
Overrides for DEFAULTS_FIT. All keys must exist in
DEFAULTS_FIT. Common options:
epochs : int, default 100
Number of training epochs per node.
learning_rate : float, default 0.01
Learning rate for the default Adam optimizer.
train_list : list of str or None, optional
List of node names to train. If None, all nodes are trained.
train_mode : {"sequential", "parallel"}, default "sequential"
Training mode. "parallel" uses joblib-based CPU multiprocessing.
GPU forces sequential mode.
device : {"auto", "cpu", "cuda"}, default "auto"
Device selection.
optimizers : dict or None
Optional mapping ``{node_name: optimizer}``. If provided for a
node, that optimizer is used instead of creating a new Adam.
schedulers : dict or None
Optional mapping ``{node_name: scheduler}``.
use_scheduler : bool
If True, enable scheduler usage in the training loop.
num_workers : int
DataLoader workers in sequential mode (ignored in parallel).
persistent_workers : bool
DataLoader persistence in sequential mode (ignored in parallel).
prefetch_factor : int
DataLoader prefetch factor (ignored in parallel).
batch_size : int
Batch size for all node DataLoaders.
debug : bool
Enable debug output.
verbose : bool
Enable informational logging.
return_history : bool
If True, return a history dict.
Returns
dict or None
If return_history=True, a dictionary mapping each node name
to its training history. Otherwise, returns None.
Raises
ValueError
If train_mode is not "sequential" or "parallel".
760 def loss_history(self): 761 """ 762 Load training and validation loss history for all nodes. 763 764 Looks for per-node JSON files: 765 766 - ``EXPERIMENT_DIR/{node}/train_loss_hist.json`` 767 - ``EXPERIMENT_DIR/{node}/val_loss_hist.json`` 768 769 Returns 770 ------- 771 dict 772 A dictionary mapping node names to: 773 774 .. code-block:: python 775 776 { 777 "train": list or None, 778 "validation": list or None 779 } 780 781 where each list contains NLL values per epoch, or None if not found. 782 783 Raises 784 ------ 785 ValueError 786 If the experiment directory cannot be resolved from the configuration. 787 """ 788 try: 789 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 790 except KeyError: 791 raise ValueError( 792 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 793 "History retrieval requires experiment logs." 794 ) 795 796 all_histories = {} 797 for node in self.nodes_dict.keys(): 798 node_dir = os.path.join(EXPERIMENT_DIR, node) 799 train_path = os.path.join(node_dir, "train_loss_hist.json") 800 val_path = os.path.join(node_dir, "val_loss_hist.json") 801 802 node_hist = {} 803 804 # --- load train history --- 805 if os.path.exists(train_path): 806 try: 807 with open(train_path, "r") as f: 808 node_hist["train"] = json.load(f) 809 except Exception as e: 810 print(f"[WARNING] Could not load {train_path}: {e}") 811 node_hist["train"] = None 812 else: 813 node_hist["train"] = None 814 815 # --- load val history --- 816 if os.path.exists(val_path): 817 try: 818 with open(val_path, "r") as f: 819 node_hist["validation"] = json.load(f) 820 except Exception as e: 821 print(f"[WARNING] Could not load {val_path}: {e}") 822 node_hist["validation"] = None 823 else: 824 node_hist["validation"] = None 825 826 all_histories[node] = node_hist 827 828 if self.verbose or self.debug: 829 print(f"[INFO] Loaded training/validation histories for {len(all_histories)} nodes.") 830 831 return all_histories
Load training and validation loss history for all nodes.
Looks for per-node JSON files:
EXPERIMENT_DIR/{node}/train_loss_hist.jsonEXPERIMENT_DIR/{node}/val_loss_hist.json
Returns
dict A dictionary mapping node names to:
```python
{ "train": list or None, "validation": list or None } ```
where each list contains NLL values per epoch, or None if not found.
Raises
ValueError If the experiment directory cannot be resolved from the configuration.
833 def linear_shift_history(self): 834 """ 835 Load linear shift term histories for all nodes. 836 837 Each node history is expected in a JSON file named 838 ``linear_shifts_all_epochs.json`` under the node directory. 839 840 Returns 841 ------- 842 dict 843 A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame 844 contains linear shift weights across epochs. 845 846 Raises 847 ------ 848 ValueError 849 If the experiment directory cannot be resolved from the configuration. 850 851 Notes 852 ----- 853 If a history file is missing for a node, a warning is printed and the 854 node is omitted from the returned dictionary. 855 """ 856 histories = {} 857 try: 858 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 859 except KeyError: 860 raise ValueError( 861 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 862 "Cannot load histories without experiment directory." 863 ) 864 865 for node in self.nodes_dict.keys(): 866 node_dir = os.path.join(EXPERIMENT_DIR, node) 867 history_path = os.path.join(node_dir, "linear_shifts_all_epochs.json") 868 if os.path.exists(history_path): 869 histories[node] = pd.read_json(history_path) 870 else: 871 print(f"[WARNING] No linear shift history found for node '{node}' at {history_path}") 872 return histories
Load linear shift term histories for all nodes.
Each node history is expected in a JSON file named
linear_shifts_all_epochs.json under the node directory.
Returns
dict
A mapping {node_name: pandas.DataFrame}, where each DataFrame
contains linear shift weights across epochs.
Raises
ValueError If the experiment directory cannot be resolved from the configuration.
Notes
If a history file is missing for a node, a warning is printed and the node is omitted from the returned dictionary.
874 def simple_intercept_history(self): 875 """ 876 Load simple intercept histories for all nodes. 877 878 Each node history is expected in a JSON file named 879 ``simple_intercepts_all_epochs.json`` under the node directory. 880 881 Returns 882 ------- 883 dict 884 A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame 885 contains intercept weights across epochs. 886 887 Raises 888 ------ 889 ValueError 890 If the experiment directory cannot be resolved from the configuration. 891 892 Notes 893 ----- 894 If a history file is missing for a node, a warning is printed and the 895 node is omitted from the returned dictionary. 896 """ 897 histories = {} 898 try: 899 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 900 except KeyError: 901 raise ValueError( 902 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 903 "Cannot load histories without experiment directory." 904 ) 905 906 for node in self.nodes_dict.keys(): 907 node_dir = os.path.join(EXPERIMENT_DIR, node) 908 history_path = os.path.join(node_dir, "simple_intercepts_all_epochs.json") 909 if os.path.exists(history_path): 910 histories[node] = pd.read_json(history_path) 911 else: 912 print(f"[WARNING] No simple intercept history found for node '{node}' at {history_path}") 913 return histories
Load simple intercept histories for all nodes.
Each node history is expected in a JSON file named
simple_intercepts_all_epochs.json under the node directory.
Returns
dict
A mapping {node_name: pandas.DataFrame}, where each DataFrame
contains intercept weights across epochs.
Raises
ValueError If the experiment directory cannot be resolved from the configuration.
Notes
If a history file is missing for a node, a warning is printed and the node is omitted from the returned dictionary.
915 def get_latent(self, df, verbose=False): 916 """ 917 Compute latent representations for all nodes in the DAG. 918 919 Parameters 920 ---------- 921 df : pandas.DataFrame 922 Input data frame with columns corresponding to nodes in the DAG. 923 verbose : bool, optional 924 If True, print informational messages during latent computation. 925 Default is False. 926 927 Returns 928 ------- 929 pandas.DataFrame 930 DataFrame containing the original columns plus latent variables 931 for each node (e.g. columns named ``f"{node}_U"``). 932 933 Raises 934 ------ 935 ValueError 936 If the experiment directory is missing from the configuration or 937 if ``self.minmax_dict`` has not been set. 938 """ 939 try: 940 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 941 except KeyError: 942 raise ValueError( 943 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 944 "Latent extraction requires trained model checkpoints." 945 ) 946 947 # ensure minmax_dict is available 948 if not hasattr(self, "minmax_dict"): 949 raise ValueError( 950 "[ERROR] minmax_dict not found in the TramDagModel instance. " 951 "Either call .load_or_compute_minmax(td_train_data=train_df) or .fit() first." 952 ) 953 954 all_latents_df = create_latent_df_for_full_dag( 955 configuration_dict=self.cfg.conf_dict, 956 EXPERIMENT_DIR=EXPERIMENT_DIR, 957 df=df, 958 verbose=verbose, 959 min_max_dict=self.minmax_dict, 960 ) 961 962 return all_latents_df
Compute latent representations for all nodes in the DAG.
Parameters
df : pandas.DataFrame Input data frame with columns corresponding to nodes in the DAG. verbose : bool, optional If True, print informational messages during latent computation. Default is False.
Returns
pandas.DataFrame
DataFrame containing the original columns plus latent variables
for each node (e.g. columns named f"{node}_U").
Raises
ValueError
If the experiment directory is missing from the configuration or
if self.minmax_dict has not been set.
967 def plot_loss_history(self, variable: str = None): 968 """ 969 Plot training and validation loss evolution per node. 970 971 Parameters 972 ---------- 973 variable : str or None, optional 974 If provided, plot loss history for this node only. If None, plot 975 histories for all nodes that have both train and validation logs. 976 977 Returns 978 ------- 979 None 980 981 Notes 982 ----- 983 Two subplots are produced: 984 - Full epoch history. 985 - Last 10% of epochs (or only the last epoch if fewer than 5 epochs). 986 """ 987 988 histories = self.loss_history() 989 if not histories: 990 print("[WARNING] No loss histories found.") 991 return 992 993 # Select which nodes to plot 994 if variable is not None: 995 if variable not in histories: 996 raise ValueError(f"[ERROR] Node '{variable}' not found in histories.") 997 nodes_to_plot = [variable] 998 else: 999 nodes_to_plot = list(histories.keys()) 1000 1001 # Filter out nodes with no valid history 1002 nodes_to_plot = [ 1003 n for n in nodes_to_plot 1004 if histories[n].get("train") is not None and len(histories[n]["train"]) > 0 1005 and histories[n].get("validation") is not None and len(histories[n]["validation"]) > 0 1006 ] 1007 1008 if not nodes_to_plot: 1009 print("[WARNING] No valid histories found to plot.") 1010 return 1011 1012 plt.figure(figsize=(14, 12)) 1013 1014 # --- Full history (top plot) --- 1015 plt.subplot(2, 1, 1) 1016 for node in nodes_to_plot: 1017 node_hist = histories[node] 1018 train_hist, val_hist = node_hist["train"], node_hist["validation"] 1019 1020 epochs = range(1, len(train_hist) + 1) 1021 plt.plot(epochs, train_hist, label=f"{node} - train", linestyle="--") 1022 plt.plot(epochs, val_hist, label=f"{node} - val") 1023 1024 plt.title("Training and Validation NLL - Full History") 1025 plt.xlabel("Epoch") 1026 plt.ylabel("NLL") 1027 plt.legend() 1028 plt.grid(True) 1029 1030 # --- Last 10% of epochs (bottom plot) --- 1031 plt.subplot(2, 1, 2) 1032 for node in nodes_to_plot: 1033 node_hist = histories[node] 1034 train_hist, val_hist = node_hist["train"], node_hist["validation"] 1035 1036 total_epochs = len(train_hist) 1037 start_idx = total_epochs - 1 if total_epochs < 5 else int(total_epochs * 0.9) 1038 1039 epochs = range(start_idx + 1, total_epochs + 1) 1040 plt.plot(epochs, train_hist[start_idx:], label=f"{node} - train", linestyle="--") 1041 plt.plot(epochs, val_hist[start_idx:], label=f"{node} - val") 1042 1043 plt.title("Training and Validation NLL - Last 10% of Epochs (or Last Epoch if <5)") 1044 plt.xlabel("Epoch") 1045 plt.ylabel("NLL") 1046 plt.legend() 1047 plt.grid(True) 1048 1049 plt.tight_layout() 1050 plt.show()
Plot training and validation loss evolution per node.
Parameters
variable : str or None, optional If provided, plot loss history for this node only. If None, plot histories for all nodes that have both train and validation logs.
Returns
None
Notes
Two subplots are produced:
- Full epoch history.
- Last 10% of epochs (or only the last epoch if fewer than 5 epochs).
1052 def plot_linear_shift_history(self, data_dict=None, node=None, ref_lines=None): 1053 """ 1054 Plot the evolution of linear shift terms over epochs. 1055 1056 Parameters 1057 ---------- 1058 data_dict : dict or None, optional 1059 Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing shift 1060 weights across epochs. If None, `linear_shift_history()` is called. 1061 node : str or None, optional 1062 If provided, plot only this node. Otherwise, plot all nodes 1063 present in ``data_dict``. 1064 ref_lines : dict or None, optional 1065 Optional mapping ``{node_name: list of float}``. For each specified 1066 node, horizontal reference lines are drawn at the given values. 1067 1068 Returns 1069 ------- 1070 None 1071 1072 Notes 1073 ----- 1074 The function flattens nested list-like entries in the DataFrames to scalars, 1075 converts epoch labels to numeric, and then draws one line per shift term. 1076 """ 1077 1078 if data_dict is None: 1079 data_dict = self.linear_shift_history() 1080 if data_dict is None: 1081 raise ValueError("No shift history data provided or stored in the class.") 1082 1083 nodes = [node] if node else list(data_dict.keys()) 1084 1085 for n in nodes: 1086 df = data_dict[n].copy() 1087 1088 # Flatten nested lists or list-like cells 1089 def flatten(x): 1090 if isinstance(x, list): 1091 if len(x) == 0: 1092 return np.nan 1093 if all(isinstance(i, (int, float)) for i in x): 1094 return np.mean(x) # average simple list 1095 if all(isinstance(i, list) for i in x): 1096 # nested list -> flatten inner and average 1097 flat = [v for sub in x for v in (sub if isinstance(sub, list) else [sub])] 1098 return np.mean(flat) if flat else np.nan 1099 return x[0] if len(x) == 1 else np.nan 1100 return x 1101 1102 df = df.applymap(flatten) 1103 1104 # Ensure numeric columns 1105 df = df.apply(pd.to_numeric, errors='coerce') 1106 1107 # Convert epoch labels to numeric 1108 df.columns = [ 1109 int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c 1110 for c in df.columns 1111 ] 1112 df = df.reindex(sorted(df.columns), axis=1) 1113 1114 plt.figure(figsize=(10, 6)) 1115 for idx in df.index: 1116 plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"shift_{idx}") 1117 1118 if ref_lines and n in ref_lines: 1119 for v in ref_lines[n]: 1120 plt.axhline(y=v, color="k", linestyle="--", lw=1.0) 1121 plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8) 1122 1123 plt.xlabel("Epoch") 1124 plt.ylabel("Shift Value") 1125 plt.title(f"Shift Term History — Node: {n}") 1126 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small") 1127 plt.tight_layout() 1128 plt.show()
Plot the evolution of linear shift terms over epochs.
Parameters
data_dict : dict or None, optional
Pre-loaded mapping {node_name: pandas.DataFrame} containing shift
weights across epochs. If None, linear_shift_history() is called.
node : str or None, optional
If provided, plot only this node. Otherwise, plot all nodes
present in data_dict.
ref_lines : dict or None, optional
Optional mapping {node_name: list of float}. For each specified
node, horizontal reference lines are drawn at the given values.
Returns
None
Notes
The function flattens nested list-like entries in the DataFrames to scalars, converts epoch labels to numeric, and then draws one line per shift term.
1130 def plot_simple_intercepts_history(self, data_dict=None, node=None,ref_lines=None): 1131 """ 1132 Plot the evolution of simple intercept weights over epochs. 1133 1134 Parameters 1135 ---------- 1136 data_dict : dict or None, optional 1137 Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing intercept 1138 weights across epochs. If None, `simple_intercept_history()` is called. 1139 node : str or None, optional 1140 If provided, plot only this node. Otherwise, plot all nodes present 1141 in ``data_dict``. 1142 ref_lines : dict or None, optional 1143 Optional mapping ``{node_name: list of float}``. For each specified 1144 node, horizontal reference lines are drawn at the given values. 1145 1146 Returns 1147 ------- 1148 None 1149 1150 Notes 1151 ----- 1152 Nested list-like entries in the DataFrames are reduced to scalars before 1153 plotting. One line is drawn per intercept parameter. 1154 """ 1155 if data_dict is None: 1156 data_dict = self.simple_intercept_history() 1157 if data_dict is None: 1158 raise ValueError("No intercept history data provided or stored in the class.") 1159 1160 nodes = [node] if node else list(data_dict.keys()) 1161 1162 for n in nodes: 1163 df = data_dict[n].copy() 1164 1165 def extract_scalar(x): 1166 if isinstance(x, list): 1167 while isinstance(x, list) and len(x) > 0: 1168 x = x[0] 1169 return float(x) if isinstance(x, (int, float, np.floating)) else np.nan 1170 1171 df = df.applymap(extract_scalar) 1172 1173 # Convert epoch labels → numeric 1174 df.columns = [ 1175 int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c 1176 for c in df.columns 1177 ] 1178 df = df.reindex(sorted(df.columns), axis=1) 1179 1180 plt.figure(figsize=(10, 6)) 1181 for idx in df.index: 1182 plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"theta_{idx}") 1183 1184 if ref_lines and n in ref_lines: 1185 for v in ref_lines[n]: 1186 plt.axhline(y=v, color="k", linestyle="--", lw=1.0) 1187 plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8) 1188 1189 plt.xlabel("Epoch") 1190 plt.ylabel("Intercept Weight") 1191 plt.title(f"Simple Intercept Evolution — Node: {n}") 1192 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small") 1193 plt.tight_layout() 1194 plt.show()
Plot the evolution of simple intercept weights over epochs.
Parameters
data_dict : dict or None, optional
Pre-loaded mapping {node_name: pandas.DataFrame} containing intercept
weights across epochs. If None, simple_intercept_history() is called.
node : str or None, optional
If provided, plot only this node. Otherwise, plot all nodes present
in data_dict.
ref_lines : dict or None, optional
Optional mapping {node_name: list of float}. For each specified
node, horizontal reference lines are drawn at the given values.
Returns
None
Notes
Nested list-like entries in the DataFrames are reduced to scalars before plotting. One line is drawn per intercept parameter.
1196 def plot_latents(self, df, variable: str = None, confidence: float = 0.95, simulations: int = 1000): 1197 """ 1198 Visualize latent U distributions for one or all nodes. 1199 1200 Parameters 1201 ---------- 1202 df : pandas.DataFrame 1203 Input data frame with raw node values. 1204 variable : str or None, optional 1205 If provided, only this node's latents are plotted. If None, all 1206 nodes with latent columns are processed. 1207 confidence : float, optional 1208 Confidence level for QQ-plot bands (0 < confidence < 1). 1209 Default is 0.95. 1210 simulations : int, optional 1211 Number of Monte Carlo simulations for QQ-plot bands. Default is 1000. 1212 1213 Returns 1214 ------- 1215 None 1216 1217 Notes 1218 ----- 1219 For each node, two plots are produced: 1220 - Histogram of the latent U values. 1221 - QQ-plot with simulation-based confidence bands under a logistic reference. 1222 """ 1223 # Compute latent representations 1224 latents_df = self.get_latent(df) 1225 1226 # Select nodes 1227 nodes = [variable] if variable is not None else self.nodes_dict.keys() 1228 1229 for node in nodes: 1230 if f"{node}_U" not in latents_df.columns: 1231 print(f"[WARNING] No latent found for node {node}, skipping.") 1232 continue 1233 1234 sample = latents_df[f"{node}_U"].values 1235 1236 # --- Create plots --- 1237 fig, axs = plt.subplots(1, 2, figsize=(12, 5)) 1238 1239 # Histogram 1240 axs[0].hist(sample, bins=50, color="steelblue", alpha=0.7) 1241 axs[0].set_title(f"Latent Histogram ({node})") 1242 axs[0].set_xlabel("U") 1243 axs[0].set_ylabel("Frequency") 1244 1245 # QQ Plot with confidence bands 1246 probplot(sample, dist="logistic", plot=axs[1]) 1247 self._add_r_style_confidence_bands(axs[1], sample, dist=logistic,confidence=confidence, simulations=simulations) 1248 axs[1].set_title(f"Latent QQ Plot ({node})") 1249 1250 plt.suptitle(f"Latent Diagnostics for Node: {node}", fontsize=14) 1251 plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 1252 plt.show()
Visualize latent U distributions for one or all nodes.
Parameters
df : pandas.DataFrame Input data frame with raw node values. variable : str or None, optional If provided, only this node's latents are plotted. If None, all nodes with latent columns are processed. confidence : float, optional Confidence level for QQ-plot bands (0 < confidence < 1). Default is 0.95. simulations : int, optional Number of Monte Carlo simulations for QQ-plot bands. Default is 1000.
Returns
None
Notes
For each node, two plots are produced:
- Histogram of the latent U values.
- QQ-plot with simulation-based confidence bands under a logistic reference.
1254 def plot_hdag(self,df,variables=None, plot_n_rows=1,**kwargs): 1255 1256 """ 1257 Visualize the transformation function h() for selected DAG nodes. 1258 1259 Parameters 1260 ---------- 1261 df : pandas.DataFrame 1262 Input data containing node values or model predictions. 1263 variables : list of str or None, optional 1264 Names of nodes to visualize. If None, all nodes in ``self.models`` 1265 are considered. 1266 plot_n_rows : int, optional 1267 Maximum number of rows from ``df`` to visualize. Default is 1. 1268 **kwargs 1269 Additional keyword arguments forwarded to the underlying plotting 1270 helpers (`show_hdag_continous` / `show_hdag_ordinal`). 1271 1272 Returns 1273 ------- 1274 None 1275 1276 Notes 1277 ----- 1278 - For continuous outcomes, `show_hdag_continous` is called. 1279 - For ordinal outcomes, `show_hdag_ordinal` is called. 1280 - Nodes that are neither continuous nor ordinal are skipped with a warning. 1281 """ 1282 1283 1284 if len(df)> 1: 1285 print("[WARNING] len(df)>1, set: plot_n_rows accordingly") 1286 1287 variables_list=variables if variables is not None else list(self.models.keys()) 1288 for node in variables_list: 1289 if is_outcome_modelled_continous(node, self.nodes_dict): 1290 show_hdag_continous(df,node=node,configuration_dict=self.cfg.conf_dict,minmax_dict=self.minmax_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1291 1292 elif is_outcome_modelled_ordinal(node, self.nodes_dict): 1293 show_hdag_ordinal(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1294 # plot_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1295 # save_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,**kwargs) 1296 else: 1297 print(f"[WARNING] Node {node} is wheter ordinal nor continous, not implemented yet")
Visualize the transformation function h() for selected DAG nodes.
Parameters
df : pandas.DataFrame
Input data containing node values or model predictions.
variables : list of str or None, optional
Names of nodes to visualize. If None, all nodes in self.models
are considered.
plot_n_rows : int, optional
Maximum number of rows from df to visualize. Default is 1.
**kwargs
Additional keyword arguments forwarded to the underlying plotting
helpers (show_hdag_continous / show_hdag_ordinal).
Returns
None
Notes
- For continuous outcomes,
show_hdag_continousis called. - For ordinal outcomes,
show_hdag_ordinalis called. - Nodes that are neither continuous nor ordinal are skipped with a warning.
1356 def sample( 1357 self, 1358 do_interventions: dict = None, 1359 predefined_latent_samples_df: pd.DataFrame = None, 1360 **kwargs, 1361 ): 1362 """ 1363 Sample from the joint DAG using the trained TRAM models. 1364 1365 Allows for: 1366 1367 Oberservational sampling 1368 Interventional sampling via ``do()`` operations 1369 Counterfactial sampling using predefined latent draws and do() 1370 1371 Parameters 1372 ---------- 1373 do_interventions : dict or None, optional 1374 Mapping of node names to intervened (fixed) values. For example: 1375 ``{"x1": 1.0}`` represents ``do(x1 = 1.0)``. Default is None. 1376 predefined_latent_samples_df : pandas.DataFrame or None, optional 1377 DataFrame containing columns ``"{node}_U"`` with predefined latent 1378 draws to be used instead of sampling from the prior. Default is None. 1379 **kwargs 1380 Sampling options overriding internal defaults: 1381 1382 number_of_samples : int, default 10000 1383 Total number of samples to draw. 1384 batch_size : int, default 32 1385 Batch size for internal sampling loops. 1386 delete_all_previously_sampled : bool, default True 1387 If True, delete old sampling files in node-specific sampling 1388 directories before writing new ones. 1389 verbose : bool 1390 If True, print informational messages. 1391 debug : bool 1392 If True, print debug output. 1393 device : {"auto", "cpu", "cuda"} 1394 Device selection for sampling. 1395 use_initial_weights_for_sampling : bool, default False 1396 If True, sample from initial (untrained) model parameters. 1397 1398 Returns 1399 ------- 1400 tuple 1401 A tuple ``(sampled_by_node, latents_by_node)``: 1402 1403 sampled_by_node : dict 1404 Mapping ``{node_name: torch.Tensor}`` of sampled node values. 1405 latents_by_node : dict 1406 Mapping ``{node_name: torch.Tensor}`` of latent U values used. 1407 1408 Raises 1409 ------ 1410 ValueError 1411 If the experiment directory cannot be resolved or if scaling 1412 information (``self.minmax_dict``) is missing. 1413 RuntimeError 1414 If min–max scaling has not been computed before calling `sample`. 1415 """ 1416 try: 1417 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1418 except KeyError: 1419 raise ValueError( 1420 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 1421 "Sampling requires trained model checkpoints." 1422 ) 1423 1424 # ---- defaults ---- 1425 settings = { 1426 "number_of_samples": 10_000, 1427 "batch_size": 32, 1428 "delete_all_previously_sampled": True, 1429 "verbose": self.verbose if hasattr(self, "verbose") else False, 1430 "debug": self.debug if hasattr(self, "debug") else False, 1431 "device": self.device.type if hasattr(self, "device") else "auto", 1432 "use_initial_weights_for_sampling": False, 1433 1434 } 1435 1436 # self._validate_kwargs( kwargs, defaults_attr= "settings", context="sample") 1437 1438 settings.update(kwargs) 1439 1440 1441 if not hasattr(self, "minmax_dict"): 1442 raise RuntimeError( 1443 "[ERROR] minmax_dict not found. You must call .fit() or .load_or_compute_minmax() " 1444 "before sampling, so scaling info is available." 1445 ) 1446 1447 # ---- resolve device ---- 1448 device_str=self.get_device(settings) 1449 self.device = torch.device(device_str) 1450 1451 1452 if self.debug or settings["debug"]: 1453 print(f"[DEBUG] sample(): device: {self.device}") 1454 1455 # ---- perform sampling ---- 1456 sampled_by_node, latents_by_node = sample_full_dag( 1457 configuration_dict=self.cfg.conf_dict, 1458 EXPERIMENT_DIR=EXPERIMENT_DIR, 1459 device=self.device, 1460 do_interventions=do_interventions or {}, 1461 predefined_latent_samples_df=predefined_latent_samples_df, 1462 number_of_samples=settings["number_of_samples"], 1463 batch_size=settings["batch_size"], 1464 delete_all_previously_sampled=settings["delete_all_previously_sampled"], 1465 verbose=settings["verbose"], 1466 debug=settings["debug"], 1467 minmax_dict=self.minmax_dict, 1468 use_initial_weights_for_sampling=settings["use_initial_weights_for_sampling"] 1469 ) 1470 1471 return sampled_by_node, latents_by_node
Sample from the joint DAG using the trained TRAM models.
Allows for:
Oberservational sampling
Interventional sampling via do() operations
Counterfactial sampling using predefined latent draws and do()
Parameters
do_interventions : dict or None, optional
Mapping of node names to intervened (fixed) values. For example:
{"x1": 1.0} represents do(x1 = 1.0). Default is None.
predefined_latent_samples_df : pandas.DataFrame or None, optional
DataFrame containing columns "{node}_U" with predefined latent
draws to be used instead of sampling from the prior. Default is None.
**kwargs
Sampling options overriding internal defaults:
number_of_samples : int, default 10000
Total number of samples to draw.
batch_size : int, default 32
Batch size for internal sampling loops.
delete_all_previously_sampled : bool, default True
If True, delete old sampling files in node-specific sampling
directories before writing new ones.
verbose : bool
If True, print informational messages.
debug : bool
If True, print debug output.
device : {"auto", "cpu", "cuda"}
Device selection for sampling.
use_initial_weights_for_sampling : bool, default False
If True, sample from initial (untrained) model parameters.
Returns
tuple
A tuple (sampled_by_node, latents_by_node):
sampled_by_node : dict
Mapping ``{node_name: torch.Tensor}`` of sampled node values.
latents_by_node : dict
Mapping ``{node_name: torch.Tensor}`` of latent U values used.
Raises
ValueError
If the experiment directory cannot be resolved or if scaling
information (self.minmax_dict) is missing.
RuntimeError
If min–max scaling has not been computed before calling sample.
1473 def load_sampled_and_latents(self, EXPERIMENT_DIR: str = None, nodes: list = None): 1474 """ 1475 Load previously stored sampled values and latents for each node. 1476 1477 Parameters 1478 ---------- 1479 EXPERIMENT_DIR : str or None, optional 1480 Experiment directory path. If None, it is taken from 1481 ``self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]``. 1482 nodes : list of str or None, optional 1483 Nodes for which to load samples. If None, use all nodes from 1484 ``self.nodes_dict``. 1485 1486 Returns 1487 ------- 1488 tuple 1489 A tuple ``(sampled_by_node, latents_by_node)``: 1490 1491 sampled_by_node : dict 1492 Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU). 1493 latents_by_node : dict 1494 Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU). 1495 1496 Raises 1497 ------ 1498 ValueError 1499 If the experiment directory cannot be resolved or if no node list 1500 is available and ``nodes`` is None. 1501 1502 Notes 1503 ----- 1504 Nodes without both ``sampled.pt`` and ``latents.pt`` files are skipped 1505 with a warning. 1506 """ 1507 # --- resolve paths and node list --- 1508 if EXPERIMENT_DIR is None: 1509 try: 1510 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1511 except (AttributeError, KeyError): 1512 raise ValueError( 1513 "[ERROR] Could not resolve EXPERIMENT_DIR from cfg.conf_dict['PATHS']. " 1514 "Please provide EXPERIMENT_DIR explicitly." 1515 ) 1516 1517 if nodes is None: 1518 if hasattr(self, "nodes_dict"): 1519 nodes = list(self.nodes_dict.keys()) 1520 else: 1521 raise ValueError( 1522 "[ERROR] No node list found. Please provide `nodes` or initialize model with a config." 1523 ) 1524 1525 # --- load tensors --- 1526 sampled_by_node = {} 1527 latents_by_node = {} 1528 1529 for node in nodes: 1530 node_dir = os.path.join(EXPERIMENT_DIR, f"{node}") 1531 sampling_dir = os.path.join(node_dir, "sampling") 1532 1533 sampled_path = os.path.join(sampling_dir, "sampled.pt") 1534 latents_path = os.path.join(sampling_dir, "latents.pt") 1535 1536 if not os.path.exists(sampled_path) or not os.path.exists(latents_path): 1537 print(f"[WARNING] Missing files for node '{node}' — skipping.") 1538 continue 1539 1540 try: 1541 sampled = torch.load(sampled_path, map_location="cpu") 1542 latent_sample = torch.load(latents_path, map_location="cpu") 1543 except Exception as e: 1544 print(f"[ERROR] Could not load sampling files for node '{node}': {e}") 1545 continue 1546 1547 sampled_by_node[node] = sampled.detach().cpu() 1548 latents_by_node[node] = latent_sample.detach().cpu() 1549 1550 if self.verbose or self.debug: 1551 print(f"[INFO] Loaded sampled and latent tensors for {len(sampled_by_node)} nodes from {EXPERIMENT_DIR}") 1552 1553 return sampled_by_node, latents_by_node
Load previously stored sampled values and latents for each node.
Parameters
EXPERIMENT_DIR : str or None, optional
Experiment directory path. If None, it is taken from
self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"].
nodes : list of str or None, optional
Nodes for which to load samples. If None, use all nodes from
self.nodes_dict.
Returns
tuple
A tuple (sampled_by_node, latents_by_node):
sampled_by_node : dict
Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU).
latents_by_node : dict
Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU).
Raises
ValueError
If the experiment directory cannot be resolved or if no node list
is available and nodes is None.
Notes
Nodes without both sampled.pt and latents.pt files are skipped
with a warning.
1555 def plot_samples_vs_true( 1556 self, 1557 df, 1558 sampled: dict = None, 1559 variable: list = None, 1560 bins: int = 100, 1561 hist_true_color: str = "blue", 1562 hist_est_color: str = "orange", 1563 figsize: tuple = (14, 5), 1564 ): 1565 1566 1567 """ 1568 Compare sampled vs. observed distributions for selected nodes. 1569 1570 Parameters 1571 ---------- 1572 df : pandas.DataFrame 1573 Data frame containing the observed node values. 1574 sampled : dict or None, optional 1575 Optional mapping ``{node_name: array-like or torch.Tensor}`` of sampled 1576 values. If None or if a node is missing, samples are loaded from 1577 ``EXPERIMENT_DIR/{node}/sampling/sampled.pt``. 1578 variable : list of str or None, optional 1579 Subset of nodes to plot. If None, all nodes in the configuration 1580 are considered. 1581 bins : int, optional 1582 Number of histogram bins for continuous variables. Default is 100. 1583 hist_true_color : str, optional 1584 Color name for the histogram of true values. Default is "blue". 1585 hist_est_color : str, optional 1586 Color name for the histogram of sampled values. Default is "orange". 1587 figsize : tuple, optional 1588 Figure size for the matplotlib plots. Default is (14, 5). 1589 1590 Returns 1591 ------- 1592 None 1593 1594 Notes 1595 ----- 1596 - Continuous outcomes: histogram overlay + QQ-plot. 1597 - Ordinal outcomes: side-by-side bar plot of relative frequencies. 1598 - Other categorical outcomes: side-by-side bar plot with category labels. 1599 - If samples are probabilistic (2D tensor), the argmax across classes is used. 1600 """ 1601 1602 target_nodes = self.cfg.conf_dict["nodes"] 1603 experiment_dir = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1604 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 1605 1606 plot_list = variable if variable is not None else target_nodes 1607 1608 for node in plot_list: 1609 # Load sampled data 1610 if sampled is not None and node in sampled: 1611 sdata = sampled[node] 1612 if isinstance(sdata, torch.Tensor): 1613 sampled_vals = sdata.detach().cpu().numpy() 1614 else: 1615 sampled_vals = np.asarray(sdata) 1616 else: 1617 sample_path = os.path.join(experiment_dir, f"{node}/sampling/sampled.pt") 1618 if not os.path.isfile(sample_path): 1619 print(f"[WARNING] skip {node}: {sample_path} not found.") 1620 continue 1621 1622 try: 1623 sampled_vals = torch.load(sample_path, map_location=device).cpu().numpy() 1624 except Exception as e: 1625 print(f"[ERROR] Could not load {sample_path}: {e}") 1626 continue 1627 1628 # If logits/probabilities per sample, take argmax 1629 if sampled_vals.ndim == 2: 1630 print(f"[INFO] CAUTION! {node}: samples are probabilistic — each sample follows a probability " 1631 f"distribution based on the valid latent range. " 1632 f"Note that this frequency plot reflects only the distribution of the most probable " 1633 f"class per sample.") 1634 sampled_vals = np.argmax(sampled_vals, axis=1) 1635 1636 sampled_vals = sampled_vals[np.isfinite(sampled_vals)] 1637 1638 if node not in df.columns: 1639 print(f"[WARNING] skip {node}: column not found in DataFrame.") 1640 continue 1641 1642 true_vals = df[node].dropna().values 1643 true_vals = true_vals[np.isfinite(true_vals)] 1644 1645 if sampled_vals.size == 0 or true_vals.size == 0: 1646 print(f"[WARNING] skip {node}: empty array after NaN/Inf removal.") 1647 continue 1648 1649 fig, axs = plt.subplots(1, 2, figsize=figsize) 1650 1651 if is_outcome_modelled_continous(node, target_nodes): 1652 axs[0].hist(true_vals, bins=bins, density=True, alpha=0.6, 1653 color=hist_true_color, label=f"True {node}") 1654 axs[0].hist(sampled_vals, bins=bins, density=True, alpha=0.6, 1655 color=hist_est_color, label="Sampled") 1656 axs[0].set_xlabel("Value") 1657 axs[0].set_ylabel("Density") 1658 axs[0].set_title(f"Histogram overlay for {node}") 1659 axs[0].legend() 1660 axs[0].grid(True, ls="--", alpha=0.4) 1661 1662 qqplot_2samples(true_vals, sampled_vals, line="45", ax=axs[1]) 1663 axs[1].set_xlabel("True quantiles") 1664 axs[1].set_ylabel("Sampled quantiles") 1665 axs[1].set_title(f"QQ plot for {node}") 1666 axs[1].grid(True, ls="--", alpha=0.4) 1667 1668 elif is_outcome_modelled_ordinal(node, target_nodes): 1669 unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals)) 1670 unique_vals = np.sort(unique_vals) 1671 true_counts = np.array([(true_vals == val).sum() for val in unique_vals]) 1672 sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals]) 1673 1674 axs[0].bar(unique_vals - 0.2, true_counts / true_counts.sum(), 1675 width=0.4, color=hist_true_color, alpha=0.7, label="True") 1676 axs[0].bar(unique_vals + 0.2, sampled_counts / sampled_counts.sum(), 1677 width=0.4, color=hist_est_color, alpha=0.7, label="Sampled") 1678 axs[0].set_xticks(unique_vals) 1679 axs[0].set_xlabel("Ordinal Level") 1680 axs[0].set_ylabel("Relative Frequency") 1681 axs[0].set_title(f"Ordinal bar plot for {node}") 1682 axs[0].legend() 1683 axs[0].grid(True, ls="--", alpha=0.4) 1684 axs[1].axis("off") 1685 1686 else: 1687 unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals)) 1688 unique_vals = sorted(unique_vals, key=str) 1689 true_counts = np.array([(true_vals == val).sum() for val in unique_vals]) 1690 sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals]) 1691 1692 axs[0].bar(np.arange(len(unique_vals)) - 0.2, true_counts / true_counts.sum(), 1693 width=0.4, color=hist_true_color, alpha=0.7, label="True") 1694 axs[0].bar(np.arange(len(unique_vals)) + 0.2, sampled_counts / sampled_counts.sum(), 1695 width=0.4, color=hist_est_color, alpha=0.7, label="Sampled") 1696 axs[0].set_xticks(np.arange(len(unique_vals))) 1697 axs[0].set_xticklabels(unique_vals, rotation=45) 1698 axs[0].set_xlabel("Category") 1699 axs[0].set_ylabel("Relative Frequency") 1700 axs[0].set_title(f"Categorical bar plot for {node}") 1701 axs[0].legend() 1702 axs[0].grid(True, ls="--", alpha=0.4) 1703 axs[1].axis("off") 1704 1705 plt.tight_layout() 1706 plt.show()
Compare sampled vs. observed distributions for selected nodes.
Parameters
df : pandas.DataFrame
Data frame containing the observed node values.
sampled : dict or None, optional
Optional mapping {node_name: array-like or torch.Tensor} of sampled
values. If None or if a node is missing, samples are loaded from
EXPERIMENT_DIR/{node}/sampling/sampled.pt.
variable : list of str or None, optional
Subset of nodes to plot. If None, all nodes in the configuration
are considered.
bins : int, optional
Number of histogram bins for continuous variables. Default is 100.
hist_true_color : str, optional
Color name for the histogram of true values. Default is "blue".
hist_est_color : str, optional
Color name for the histogram of sampled values. Default is "orange".
figsize : tuple, optional
Figure size for the matplotlib plots. Default is (14, 5).
Returns
None
Notes
- Continuous outcomes: histogram overlay + QQ-plot.
- Ordinal outcomes: side-by-side bar plot of relative frequencies.
- Other categorical outcomes: side-by-side bar plot with category labels.
- If samples are probabilistic (2D tensor), the argmax across classes is used.
1709 def nll(self,data,variables=None): 1710 """ 1711 Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes. 1712 1713 This function evaluates trained TRAM models for each specified variable (node) 1714 on the provided dataset. It performs forward passes only—no training, no weight 1715 updates—and returns the mean NLL per node. 1716 1717 Parameters 1718 ---------- 1719 data : object 1720 Input dataset or data source compatible with `_ensure_dataset`, containing 1721 both inputs and targets for each node. 1722 variables : list[str], optional 1723 List of variable (node) names to evaluate. If None, all nodes in 1724 `self.models` are evaluated. 1725 1726 Returns 1727 ------- 1728 dict[str, float] 1729 Dictionary mapping each node name to its average NLL value. 1730 1731 Notes 1732 ----- 1733 - Each model is evaluated independently on its respective DataLoader. 1734 - The normalization values (`min_max`) for each node are retrieved from 1735 `self.minmax_dict[node]`. 1736 - The function uses `evaluate_tramdag_model()` for per-node evaluation. 1737 - Expected directory structure: 1738 `<EXPERIMENT_DIR>/<node>/` 1739 where each node directory contains the trained model. 1740 """ 1741 1742 td_data = self._ensure_dataset(data, is_val=True) 1743 variables_list = variables if variables != None else list(self.models.keys()) 1744 nll_dict = {} 1745 for node in variables_list: 1746 min_vals = torch.tensor(self.minmax_dict[node][0], dtype=torch.float32) 1747 max_vals = torch.tensor(self.minmax_dict[node][1], dtype=torch.float32) 1748 min_max = torch.stack([min_vals, max_vals], dim=0) 1749 data_loader = td_data.loaders[node] 1750 model = self.models[node] 1751 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1752 NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}") 1753 nll= evaluate_tramdag_model(node=node, 1754 target_nodes=self.nodes_dict, 1755 NODE_DIR=NODE_DIR, 1756 tram_model=model, 1757 data_loader=data_loader, 1758 min_max=min_max) 1759 nll_dict[node]=nll 1760 return nll_dict
Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes.
This function evaluates trained TRAM models for each specified variable (node) on the provided dataset. It performs forward passes only—no training, no weight updates—and returns the mean NLL per node.
Parameters
data : object
Input dataset or data source compatible with _ensure_dataset, containing
both inputs and targets for each node.
variables : list[str], optional
List of variable (node) names to evaluate. If None, all nodes in
self.models are evaluated.
Returns
dict[str, float] Dictionary mapping each node name to its average NLL value.
Notes
- Each model is evaluated independently on its respective DataLoader.
- The normalization values (
min_max) for each node are retrieved fromself.minmax_dict[node]. - The function uses
evaluate_tramdag_model()for per-node evaluation. - Expected directory structure:
<EXPERIMENT_DIR>/<node>/where each node directory contains the trained model.
1762 def get_train_val_nll(self, node: str, mode: str) -> tuple[float, float]: 1763 """ 1764 Retrieve training and validation NLL for a node and a given model state. 1765 1766 Parameters 1767 ---------- 1768 node : str 1769 Node name. 1770 mode : {"best", "last", "init"} 1771 State of interest: 1772 - "best": epoch with lowest validation NLL. 1773 - "last": final epoch. 1774 - "init": first epoch (index 0). 1775 1776 Returns 1777 ------- 1778 tuple of (float or None, float or None) 1779 A tuple ``(train_nll, val_nll)`` for the requested mode. 1780 Returns ``(None, None)`` if loss files are missing or cannot be read. 1781 1782 Notes 1783 ----- 1784 This method expects per-node JSON files: 1785 1786 - ``train_loss_hist.json`` 1787 - ``val_loss_hist.json`` 1788 1789 in the node directory. 1790 """ 1791 NODE_DIR = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node) 1792 train_path = os.path.join(NODE_DIR, "train_loss_hist.json") 1793 val_path = os.path.join(NODE_DIR, "val_loss_hist.json") 1794 1795 if not os.path.exists(train_path) or not os.path.exists(val_path): 1796 if getattr(self, "debug", False): 1797 print(f"[DEBUG] Missing loss files for node '{node}'. Returning None.") 1798 return None, None 1799 1800 try: 1801 with open(train_path, "r") as f: 1802 train_hist = json.load(f) 1803 with open(val_path, "r") as f: 1804 val_hist = json.load(f) 1805 1806 train_nlls = np.array(train_hist) 1807 val_nlls = np.array(val_hist) 1808 1809 if mode == "init": 1810 idx = 0 1811 elif mode == "last": 1812 idx = len(val_nlls) - 1 1813 elif mode == "best": 1814 idx = int(np.argmin(val_nlls)) 1815 else: 1816 raise ValueError(f"Invalid mode '{mode}' — must be one of 'best', 'last', 'init'.") 1817 1818 train_nll = float(train_nlls[idx]) 1819 val_nll = float(val_nlls[idx]) 1820 return train_nll, val_nll 1821 1822 except Exception as e: 1823 print(f"[ERROR] Failed to load NLLs for node '{node}' ({mode}): {e}") 1824 return None, None
Retrieve training and validation NLL for a node and a given model state.
Parameters
node : str Node name. mode : {"best", "last", "init"} State of interest: - "best": epoch with lowest validation NLL. - "last": final epoch. - "init": first epoch (index 0).
Returns
tuple of (float or None, float or None)
A tuple (train_nll, val_nll) for the requested mode.
Returns (None, None) if loss files are missing or cannot be read.
Notes
This method expects per-node JSON files:
train_loss_hist.jsonval_loss_hist.json
in the node directory.
1826 def get_thetas(self, node: str, state: str = "best"): 1827 """ 1828 Return transformed intercept (theta) parameters for a node and state. 1829 1830 Parameters 1831 ---------- 1832 node : str 1833 Node name. 1834 state : {"best", "last", "init"}, optional 1835 Model state for which to return parameters. Default is "best". 1836 1837 Returns 1838 ------- 1839 Any or None 1840 Transformed theta parameters for the requested node and state. 1841 The exact structure (scalar, list, or other) depends on the model. 1842 1843 Raises 1844 ------ 1845 ValueError 1846 If an invalid state is given (not in {"best", "last", "init"}). 1847 1848 Notes 1849 ----- 1850 Intercept dictionaries are cached on the instance under the attribute 1851 ``intercept_dicts``. If missing or incomplete, they are recomputed using 1852 `get_simple_intercepts_dict`. 1853 """ 1854 1855 state = state.lower() 1856 if state not in ["best", "last", "init"]: 1857 raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].") 1858 1859 dict_attr = "intercept_dicts" 1860 1861 # If no cached intercepts exist, compute them 1862 if not hasattr(self, dict_attr): 1863 if getattr(self, "debug", False): 1864 print(f"[DEBUG] '{dict_attr}' not found, computing via get_simple_intercepts_dict().") 1865 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1866 1867 all_dicts = getattr(self, dict_attr) 1868 1869 # If the requested state isn’t cached, recompute 1870 if state not in all_dicts: 1871 if getattr(self, "debug", False): 1872 print(f"[DEBUG] State '{state}' not found in cached intercepts, recomputing full dict.") 1873 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1874 all_dicts = getattr(self, dict_attr) 1875 1876 state_dict = all_dicts.get(state, {}) 1877 1878 # Return cached node intercept if present 1879 if node in state_dict: 1880 return state_dict[node] 1881 1882 # If not found, recompute full dict as fallback 1883 if getattr(self, "debug", False): 1884 print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.") 1885 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1886 all_dicts = getattr(self, dict_attr) 1887 return all_dicts.get(state, {}).get(node, None)
Return transformed intercept (theta) parameters for a node and state.
Parameters
node : str Node name. state : {"best", "last", "init"}, optional Model state for which to return parameters. Default is "best".
Returns
Any or None Transformed theta parameters for the requested node and state. The exact structure (scalar, list, or other) depends on the model.
Raises
ValueError If an invalid state is given (not in {"best", "last", "init"}).
Notes
Intercept dictionaries are cached on the instance under the attribute
intercept_dicts. If missing or incomplete, they are recomputed using
get_simple_intercepts_dict.
1889 def get_linear_shifts(self, node: str, state: str = "best"): 1890 """ 1891 Return learned linear shift terms for a node and a given state. 1892 1893 Parameters 1894 ---------- 1895 node : str 1896 Node name. 1897 state : {"best", "last", "init"}, optional 1898 Model state for which to return linear shift terms. Default is "best". 1899 1900 Returns 1901 ------- 1902 dict or Any or None 1903 Linear shift terms for the given node and state. Usually a dict 1904 mapping term names to weights. 1905 1906 Raises 1907 ------ 1908 ValueError 1909 If an invalid state is given (not in {"best", "last", "init"}). 1910 1911 Notes 1912 ----- 1913 Linear shift dictionaries are cached on the instance under the attribute 1914 ``linear_shift_dicts``. If missing or incomplete, they are recomputed using 1915 `get_linear_shifts_dict`. 1916 """ 1917 state = state.lower() 1918 if state not in ["best", "last", "init"]: 1919 raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].") 1920 1921 dict_attr = "linear_shift_dicts" 1922 1923 # If no global dicts cached, compute once 1924 if not hasattr(self, dict_attr): 1925 if getattr(self, "debug", False): 1926 print(f"[DEBUG] '{dict_attr}' not found, computing via get_linear_shifts_dict().") 1927 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1928 1929 all_dicts = getattr(self, dict_attr) 1930 1931 # If the requested state isn't cached, compute all again (covers fresh runs) 1932 if state not in all_dicts: 1933 if getattr(self, "debug", False): 1934 print(f"[DEBUG] State '{state}' not found in cached linear shifts, recomputing full dict.") 1935 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1936 all_dicts = getattr(self, dict_attr) 1937 1938 # Now fetch the dictionary for this state 1939 state_dict = all_dicts.get(state, {}) 1940 1941 # If the node is available, return its entry 1942 if node in state_dict: 1943 return state_dict[node] 1944 1945 # If missing, try recomputing (fallback) 1946 if getattr(self, "debug", False): 1947 print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.") 1948 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1949 all_dicts = getattr(self, dict_attr) 1950 return all_dicts.get(state, {}).get(node, None)
Return learned linear shift terms for a node and a given state.
Parameters
node : str Node name. state : {"best", "last", "init"}, optional Model state for which to return linear shift terms. Default is "best".
Returns
dict or Any or None Linear shift terms for the given node and state. Usually a dict mapping term names to weights.
Raises
ValueError If an invalid state is given (not in {"best", "last", "init"}).
Notes
Linear shift dictionaries are cached on the instance under the attribute
linear_shift_dicts. If missing or incomplete, they are recomputed using
get_linear_shifts_dict.
1952 def get_linear_shifts_dict(self): 1953 """ 1954 Compute linear shift term dictionaries for all nodes and states. 1955 1956 For each node and each available state ("best", "last", "init"), this 1957 method loads the corresponding model checkpoint, extracts linear shift 1958 weights from the TRAM model, and stores them in a nested dictionary. 1959 1960 Returns 1961 ------- 1962 dict 1963 Nested dictionary of the form: 1964 1965 .. code-block:: python 1966 1967 { 1968 "best": {node: {...}}, 1969 "last": {node: {...}}, 1970 "init": {node: {...}}, 1971 } 1972 1973 where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``) 1974 to their weights. 1975 1976 Notes 1977 ----- 1978 - If "best" or "last" checkpoints are unavailable for a node, only 1979 the "init" entry is populated. 1980 - Empty outer states (without any nodes) are removed from the result. 1981 """ 1982 1983 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1984 nodes_list = list(self.models.keys()) 1985 all_states = ["best", "last", "init"] 1986 all_linear_shift_dicts = {state: {} for state in all_states} 1987 1988 for node in nodes_list: 1989 NODE_DIR = os.path.join(EXPERIMENT_DIR, node) 1990 BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR) 1991 INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt") 1992 1993 state_paths = { 1994 "best": BEST_MODEL_PATH, 1995 "last": LAST_MODEL_PATH, 1996 "init": INIT_MODEL_PATH, 1997 } 1998 1999 for state, LOAD_PATH in state_paths.items(): 2000 if not os.path.exists(LOAD_PATH): 2001 if state != "init": 2002 # skip best/last if unavailable 2003 continue 2004 else: 2005 print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.") 2006 if not os.path.exists(LOAD_PATH): 2007 if getattr(self, "debug", False): 2008 print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.") 2009 continue 2010 2011 # Load parents and model 2012 _, terms_dict, _ = ordered_parents(node, self.nodes_dict) 2013 state_dict = torch.load(LOAD_PATH, map_location=self.device) 2014 tram_model = self.models[node] 2015 tram_model.load_state_dict(state_dict) 2016 2017 epoch_weights = {} 2018 if hasattr(tram_model, "nn_shift") and tram_model.nn_shift is not None: 2019 for i, shift_layer in enumerate(tram_model.nn_shift): 2020 module_name = shift_layer.__class__.__name__ 2021 if ( 2022 hasattr(shift_layer, "fc") 2023 and hasattr(shift_layer.fc, "weight") 2024 and module_name == "LinearShift" 2025 ): 2026 term_name = list(terms_dict.keys())[i] 2027 epoch_weights[f"ls({term_name})"] = ( 2028 shift_layer.fc.weight.detach().cpu().squeeze().tolist() 2029 ) 2030 elif getattr(self, "debug", False): 2031 term_name = list(terms_dict.keys())[i] 2032 print(f"[DEBUG] ls({term_name}): missing 'fc' or 'weight' in LinearShift.") 2033 else: 2034 if getattr(self, "debug", False): 2035 print(f"[DEBUG] Tram model for node '{node}' has no nn_shift or it is None.") 2036 2037 all_linear_shift_dicts[state][node] = epoch_weights 2038 2039 # Remove empty states (e.g., when best/last not found for all nodes) 2040 all_linear_shift_dicts = {k: v for k, v in all_linear_shift_dicts.items() if v} 2041 2042 return all_linear_shift_dicts
Compute linear shift term dictionaries for all nodes and states.
For each node and each available state ("best", "last", "init"), this method loads the corresponding model checkpoint, extracts linear shift weights from the TRAM model, and stores them in a nested dictionary.
Returns
dict Nested dictionary of the form:
```python
{ "best": {node: {...}}, "last": {node: {...}}, "init": {node: {...}}, } ```
where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``)
to their weights.
Notes
- If "best" or "last" checkpoints are unavailable for a node, only the "init" entry is populated.
- Empty outer states (without any nodes) are removed from the result.
2044 def get_simple_intercepts_dict(self): 2045 """ 2046 Compute transformed simple intercept dictionaries for all nodes and states. 2047 2048 For each node and each available state ("best", "last", "init"), this 2049 method loads the corresponding model checkpoint, extracts simple intercept 2050 weights, transforms them into interpretable theta parameters, and stores 2051 them in a nested dictionary. 2052 2053 Returns 2054 ------- 2055 dict 2056 Nested dictionary of the form: 2057 2058 .. code-block:: python 2059 2060 { 2061 "best": {node: [[theta_1], [theta_2], ...]}, 2062 "last": {node: [[theta_1], [theta_2], ...]}, 2063 "init": {node: [[theta_1], [theta_2], ...]}, 2064 } 2065 2066 Notes 2067 ----- 2068 - For ordinal models (``self.is_ontram == True``), `transform_intercepts_ordinal` 2069 is used. 2070 - For continuous models, `transform_intercepts_continous` is used. 2071 - Empty outer states (without any nodes) are removed from the result. 2072 """ 2073 2074 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 2075 nodes_list = list(self.models.keys()) 2076 all_states = ["best", "last", "init"] 2077 all_si_intercept_dicts = {state: {} for state in all_states} 2078 2079 debug = getattr(self, "debug", False) 2080 verbose = getattr(self, "verbose", False) 2081 is_ontram = getattr(self, "is_ontram", False) 2082 2083 for node in nodes_list: 2084 NODE_DIR = os.path.join(EXPERIMENT_DIR, node) 2085 BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR) 2086 INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt") 2087 2088 state_paths = { 2089 "best": BEST_MODEL_PATH, 2090 "last": LAST_MODEL_PATH, 2091 "init": INIT_MODEL_PATH, 2092 } 2093 2094 for state, LOAD_PATH in state_paths.items(): 2095 if not os.path.exists(LOAD_PATH): 2096 if state != "init": 2097 continue 2098 else: 2099 print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.") 2100 if not os.path.exists(LOAD_PATH): 2101 if debug: 2102 print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.") 2103 continue 2104 2105 # Load model state 2106 state_dict = torch.load(LOAD_PATH, map_location=self.device) 2107 tram_model = self.models[node] 2108 tram_model.load_state_dict(state_dict) 2109 2110 # Extract and transform simple intercept weights 2111 si_weights = None 2112 if hasattr(tram_model, "nn_int") and tram_model.nn_int is not None and isinstance(tram_model.nn_int, SimpleIntercept): 2113 if hasattr(tram_model.nn_int, "fc") and hasattr(tram_model.nn_int.fc, "weight"): 2114 weights = tram_model.nn_int.fc.weight.detach().cpu().tolist() 2115 weights_tensor = torch.Tensor(weights) 2116 2117 if debug: 2118 print(f"[DEBUG] Node '{node}' ({state}) theta tilde shape: {weights_tensor.shape}") 2119 2120 if is_ontram: 2121 si_weights = transform_intercepts_ordinal(weights_tensor.reshape(1, -1))[:, 1:-1].reshape(-1, 1) 2122 else: 2123 si_weights = transform_intercepts_continous(weights_tensor.reshape(1, -1)).reshape(-1, 1) 2124 2125 si_weights = si_weights.tolist() 2126 2127 if debug: 2128 print(f"[DEBUG] Node '{node}' ({state}) theta transformed: {si_weights}") 2129 else: 2130 if debug: 2131 print(f"[DEBUG] Node '{node}' ({state}): missing 'fc' or 'weight' in SimpleIntercept.") 2132 else: 2133 if debug: 2134 print(f"[DEBUG] Tram model for node '{node}' has no nn_int or it is None.") 2135 2136 all_si_intercept_dicts[state][node] = si_weights 2137 2138 # Clean up empty states 2139 all_si_intercept_dicts = {k: v for k, v in all_si_intercept_dicts.items() if v} 2140 return all_si_intercept_dicts
Compute transformed simple intercept dictionaries for all nodes and states.
For each node and each available state ("best", "last", "init"), this method loads the corresponding model checkpoint, extracts simple intercept weights, transforms them into interpretable theta parameters, and stores them in a nested dictionary.
Returns
dict Nested dictionary of the form:
```python
{ "best": {node: [[theta_1], [theta_2], ...]}, "last": {node: [[theta_1], [theta_2], ...]}, "init": {node: [[theta_1], [theta_2], ...]}, } ```
Notes
- For ordinal models (
self.is_ontram == True),transform_intercepts_ordinalis used. - For continuous models,
transform_intercepts_continousis used. - Empty outer states (without any nodes) are removed from the result.
2142 def summary(self, verbose=False): 2143 """ 2144 Print a multi-part textual summary of the TramDagModel. 2145 2146 The summary includes: 2147 1. Training metrics overview per node (best/last NLL, epochs). 2148 2. Node-specific details (thetas, linear shifts, optional architecture). 2149 3. Basic information about the attached training DataFrame, if present. 2150 2151 Parameters 2152 ---------- 2153 verbose : bool, optional 2154 If True, include extended per-node details such as the model 2155 architecture, parameter count, and availability of checkpoints 2156 and sampling results. Default is False. 2157 2158 Returns 2159 ------- 2160 None 2161 2162 Notes 2163 ----- 2164 This method prints to stdout and does not return structured data. 2165 It is intended for quick, human-readable inspection of the current 2166 training and model state. 2167 """ 2168 2169 # ---------- SETUP ---------- 2170 try: 2171 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 2172 except KeyError: 2173 EXPERIMENT_DIR = None 2174 print("[WARNING] Missing EXPERIMENT_DIR in cfg.conf_dict['PATHS'].") 2175 2176 print("\n" + "=" * 120) 2177 print(f"{'TRAM DAG MODEL SUMMARY':^120}") 2178 print("=" * 120) 2179 2180 # ---------- METRICS OVERVIEW ---------- 2181 summary_data = [] 2182 for node in self.models.keys(): 2183 node_dir = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node) 2184 train_path = os.path.join(node_dir, "train_loss_hist.json") 2185 val_path = os.path.join(node_dir, "val_loss_hist.json") 2186 2187 if os.path.exists(train_path) and os.path.exists(val_path): 2188 best_train_nll, best_val_nll = self.get_train_val_nll(node, "best") 2189 last_train_nll, last_val_nll = self.get_train_val_nll(node, "last") 2190 n_epochs_total = len(json.load(open(train_path))) 2191 else: 2192 best_train_nll = best_val_nll = last_train_nll = last_val_nll = None 2193 n_epochs_total = 0 2194 2195 summary_data.append({ 2196 "Node": node, 2197 "Best Train NLL": best_train_nll, 2198 "Best Val NLL": best_val_nll, 2199 "Last Train NLL": last_train_nll, 2200 "Last Val NLL": last_val_nll, 2201 "Epochs": n_epochs_total, 2202 }) 2203 2204 df_summary = pd.DataFrame(summary_data) 2205 df_summary = df_summary.round(4) 2206 2207 print("\n[1] TRAINING METRICS OVERVIEW") 2208 print("-" * 120) 2209 if not df_summary.empty: 2210 print( 2211 df_summary.to_string( 2212 index=False, 2213 justify="center", 2214 col_space=14, 2215 float_format=lambda x: f"{x:7.4f}", 2216 ) 2217 ) 2218 else: 2219 print("No training history found for any node.") 2220 print("-" * 120) 2221 2222 # ---------- NODE DETAILS ---------- 2223 print("\n[2] NODE-SPECIFIC DETAILS") 2224 print("-" * 120) 2225 for node in self.models.keys(): 2226 print(f"\n{f'NODE: {node}':^120}") 2227 print("-" * 120) 2228 2229 # THETAS & SHIFTS 2230 for state in ["init", "last", "best"]: 2231 print(f"\n [{state.upper()} STATE]") 2232 2233 # ---- Thetas ---- 2234 try: 2235 thetas = getattr(self, "get_thetas", lambda n, s=None: None)(node, state) 2236 if thetas is not None: 2237 if isinstance(thetas, (list, np.ndarray, pd.Series)): 2238 thetas_flat = np.array(thetas).flatten() 2239 compact = np.round(thetas_flat, 4) 2240 arr_str = np.array2string( 2241 compact, 2242 max_line_width=110, 2243 threshold=np.inf, 2244 separator=", " 2245 ) 2246 lines = arr_str.split("\n") 2247 if len(lines) > 2: 2248 arr_str = "\n".join(lines[:2]) + " ..." 2249 print(f" Θ ({len(thetas_flat)}): {arr_str}") 2250 elif isinstance(thetas, dict): 2251 for k, v in thetas.items(): 2252 print(f" Θ[{k}]: {v}") 2253 else: 2254 print(f" Θ: {thetas}") 2255 else: 2256 print(" Θ: not available") 2257 except Exception as e: 2258 print(f" [Error loading thetas] {e}") 2259 2260 # ---- Linear Shifts ---- 2261 try: 2262 linear_shifts = getattr(self, "get_linear_shifts", lambda n, s=None: None)(node, state) 2263 if linear_shifts is not None: 2264 if isinstance(linear_shifts, dict): 2265 for k, v in linear_shifts.items(): 2266 print(f" {k}: {np.round(v, 4)}") 2267 elif isinstance(linear_shifts, (list, np.ndarray, pd.Series)): 2268 arr = np.round(linear_shifts, 4) 2269 print(f" Linear shifts ({len(arr)}): {arr}") 2270 else: 2271 print(f" Linear shifts: {linear_shifts}") 2272 else: 2273 print(" Linear shifts: not available") 2274 except Exception as e: 2275 print(f" [Error loading linear shifts] {e}") 2276 2277 # ---- Verbose info directly below node ---- 2278 if verbose: 2279 print("\n [DETAILS]") 2280 node_dir = os.path.join(EXPERIMENT_DIR, node) if EXPERIMENT_DIR else None 2281 model = self.models[node] 2282 2283 print(f" Model Architecture:") 2284 arch_str = str(model).split("\n") 2285 for line in arch_str: 2286 print(f" {line}") 2287 print(f" Parameter count: {sum(p.numel() for p in model.parameters()):,}") 2288 2289 if node_dir and os.path.exists(node_dir): 2290 ckpt_exists = any(f.endswith(('.pt', '.pth')) for f in os.listdir(node_dir)) 2291 print(f" Checkpoints found: {ckpt_exists}") 2292 2293 sampling_dir = os.path.join(node_dir, "sampling") 2294 sampling_exists = os.path.isdir(sampling_dir) and len(os.listdir(sampling_dir)) > 0 2295 print(f" Sampling results found: {sampling_exists}") 2296 2297 for label, filename in [("Train", "train_loss_hist.json"), ("Validation", "val_loss_hist.json")]: 2298 path = os.path.join(node_dir, filename) 2299 if os.path.exists(path): 2300 try: 2301 with open(path, "r") as f: 2302 hist = json.load(f) 2303 print(f" {label} history: {len(hist)} epochs") 2304 except Exception as e: 2305 print(f" {label} history: failed to load ({e})") 2306 else: 2307 print(f" {label} history: not found") 2308 else: 2309 print(" [INFO] No experiment directory defined or missing for this node.") 2310 print("-" * 120) 2311 2312 # ---------- TRAINING DATAFRAME ---------- 2313 print("\n[3] TRAINING DATAFRAME") 2314 print("-" * 120) 2315 try: 2316 self.train_df.info() 2317 except AttributeError: 2318 print("No training DataFrame attached to this TramDagModel.") 2319 print("=" * 120 + "\n")
Print a multi-part textual summary of the TramDagModel.
The summary includes:
- Training metrics overview per node (best/last NLL, epochs).
- Node-specific details (thetas, linear shifts, optional architecture).
- Basic information about the attached training DataFrame, if present.
Parameters
verbose : bool, optional If True, include extended per-node details such as the model architecture, parameter count, and availability of checkpoints and sampling results. Default is False.
Returns
None
Notes
This method prints to stdout and does not return structured data. It is intended for quick, human-readable inspection of the current training and model state.