tramdag
1# tramdag/__init__.py 2from .TramDagConfig import TramDagConfig 3from .TramDagDataset import TramDagDataset 4from .TramDagModel import TramDagModel 5from importlib.metadata import version, PackageNotFoundError 6 7 8__all__ = ["TramDagConfig", "TramDagDataset", "TramDagModel"] 9 10try: 11 __version__ = version("tramdag") 12except PackageNotFoundError: 13 __version__ = "0.0.0"
30class TramDagConfig: 31 """ 32 Configuration manager for TRAM-DAG experiments. 33 34 This class encapsulates: 35 36 - The experiment configuration dictionary (`conf_dict`). 37 - Its backing file path (`CONF_DICT_PATH`). 38 - Utilities to load, validate, modify, and persist configuration. 39 - DAG visualization and interactive editing helpers. 40 41 Typical usage 42 ------------- 43 - Load existing configuration from disk via `TramDagConfig.load_json`. 44 - Or create/reuse experiment setup via `TramDagConfig().setup_configuration`. 45 - Update sections such as `data_type`, adjacency matrix, and neural network 46 model names using the provided methods. 47 """ 48 49 def __init__(self, conf_dict: dict = None, CONF_DICT_PATH: str = None, _verify: bool = False,**kwargs): 50 """ 51 Initialize a TramDagConfig instance. 52 53 Parameters 54 ---------- 55 conf_dict : dict or None, optional 56 Configuration dictionary. If None, an empty dict is used and can 57 be populated later. Default is None. 58 CONF_DICT_PATH : str or None, optional 59 Path to the configuration file on disk. Default is None. 60 _verify : bool, optional 61 If True, run `_verify_completeness()` after initialization. 62 Default is False. 63 **kwargs 64 Additional attributes to be set on the instance. Keys "conf_dict" 65 and "CONF_DICT_PATH" are forbidden and raise a ValueError. 66 67 Raises 68 ------ 69 ValueError 70 If any key in `kwargs` is "conf_dict" or "CONF_DICT_PATH". 71 72 Notes 73 ----- 74 By default, `debug` and `verbose` are set to False. They can be 75 overridden via `kwargs`. 76 """ 77 78 self.debug = False 79 self.verbose = False 80 81 for key, value in kwargs.items(): 82 if key in ['conf_dict', 'CONF_DICT_PATH']: 83 raise ValueError(f"Cannot override '{key}' via kwargs.") 84 setattr(self, key, value) 85 86 self.conf_dict = conf_dict or {} 87 self.CONF_DICT_PATH = CONF_DICT_PATH 88 89 # verification 90 if _verify: 91 self._verify_completeness() 92 93 @classmethod 94 def load_json(cls, CONF_DICT_PATH: str,debug: bool = False): 95 """ 96 Load a configuration from a JSON file and construct a TramDagConfig. 97 98 Parameters 99 ---------- 100 CONF_DICT_PATH : str 101 Path to the configuration JSON file. 102 debug : bool, optional 103 If True, initialize the instance with `debug=True`. Default is False. 104 105 Returns 106 ------- 107 TramDagConfig 108 Newly created configuration instance with `conf_dict` loaded from 109 `CONF_DICT_PATH` and `_verify_completeness()` executed. 110 111 Raises 112 ------ 113 FileNotFoundError 114 If the configuration file cannot be found (propagated by 115 `load_configuration_dict`). 116 """ 117 118 conf = load_configuration_dict(CONF_DICT_PATH) 119 return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=True) 120 121 def update(self): 122 """ 123 Reload the latest configuration from disk into this instance. 124 125 Parameters 126 ---------- 127 None 128 129 Returns 130 ------- 131 None 132 133 Raises 134 ------ 135 ValueError 136 If `CONF_DICT_PATH` is not set on the instance. 137 138 Notes 139 ----- 140 The current in-memory `conf_dict` is overwritten by the contents 141 loaded from `CONF_DICT_PATH`. 142 """ 143 144 if not hasattr(self, "CONF_DICT_PATH") or self.CONF_DICT_PATH is None: 145 raise ValueError("CONF_DICT_PATH not set — cannot update configuration.") 146 147 148 self.conf_dict = load_configuration_dict(self.CONF_DICT_PATH) 149 150 151 def save(self, CONF_DICT_PATH: str = None): 152 """ 153 Persist the current configuration dictionary to disk. 154 155 Parameters 156 ---------- 157 CONF_DICT_PATH : str or None, optional 158 Target path for the configuration file. If None, uses 159 `self.CONF_DICT_PATH`. Default is None. 160 161 Returns 162 ------- 163 None 164 165 Raises 166 ------ 167 ValueError 168 If neither the function argument nor `self.CONF_DICT_PATH` 169 provides a valid path. 170 171 Notes 172 ----- 173 The resulting file is written via `write_configuration_dict`. 174 """ 175 path = CONF_DICT_PATH or self.CONF_DICT_PATH 176 if path is None: 177 raise ValueError("No CONF_DICT_PATH provided to save config.") 178 write_configuration_dict(self.conf_dict, path) 179 180 def _verify_completeness(self): 181 """ 182 Check that the configuration is structurally complete and consistent. 183 184 The following checks are performed: 185 186 1. Top-level mandatory keys: 187 - "experiment_name" 188 - "PATHS" 189 - "nodes" 190 - "data_type" 191 - "adj_matrix" 192 - "model_names" 193 194 2. Per-node mandatory keys: 195 - "data_type" 196 - "node_type" 197 - "parents" 198 - "parents_datatype" 199 - "transformation_terms_in_h()" 200 - "transformation_term_nn_models_in_h()" 201 202 3. Ordinal / categorical levels: 203 - All ordinal variables must have a corresponding "levels" entry 204 under `conf_dict["nodes"][var]`. 205 206 4. Experiment name: 207 - Must be non-empty. 208 209 5. Adjacency matrix: 210 - Must be valid under `validate_adj_matrix`. 211 212 Parameters 213 ---------- 214 None 215 216 Returns 217 ------- 218 None 219 220 Notes 221 ----- 222 - Missing or invalid components are reported via printed warnings. 223 - Detailed debug messages are printed when `self.debug=True`. 224 """ 225 mandatory_keys = ["experiment_name","PATHS", "nodes", "data_type", "adj_matrix","nodes","model_names"] 226 optional_keys = ["date_of_creation", "seed"] 227 228 # ---- 1. Check mandatory keys exist 229 missing = [k for k in mandatory_keys if k not in self.conf_dict] 230 if missing: 231 print(f"[WARNING] Missing mandatory keys in configuration: {missing}" 232 "\n Please add them to the configuration dict and reload.") 233 234 # --- 2. Check if mandatory keys in nodesdict are present 235 mandatory_keys_nodes = ['data_type', 'node_type','parents','parents_datatype','transformation_terms_in_h()','transformation_term_nn_models_in_h()'] 236 optional_keys_nodes = ["levels"] 237 for node, node_dict in self.conf_dict.get("nodes", {}).items(): 238 # check missing mandatory keys 239 missing_node_keys = [k for k in mandatory_keys_nodes if k not in node_dict] 240 if missing_node_keys: 241 print(f"[WARNING] Node '{node}' is missing mandatory keys: {missing_node_keys}") 242 243 244 245 if self._verify_levels_dict(): 246 if self.debug: 247 print("[DEBUG] levels are present for all ordinal variables in configuration dict.") 248 pass 249 else: 250 print("[WARNING] levels are missing for some ordinal variables in configuration dict. THIS will FAIL in model training later!\n" 251 " Please provide levels manually to config and reload or compute levels from data using the method compute_levels().\n" 252 " e.g. cfg.compute_levels(train_df) # computes levels from training data and writes to cfg") 253 254 if self._verify_experiment_name(): 255 if self.debug: 256 print("[DEBUG] experiment_name is valid in configuration dict.") 257 pass 258 259 if self._verify_adj_matrix(): 260 if self.debug: 261 print("[DEBUG] adj_matrix is valid in configuration dict.") 262 pass 263 264 def _verify_levels_dict(self): 265 """ 266 Verify that all ordinal variables have levels specified in the config. 267 268 Parameters 269 ---------- 270 None 271 272 Returns 273 ------- 274 bool 275 True if all variables declared as ordinal in ``conf_dict["data_type"]`` 276 have a "levels" entry in ``conf_dict["nodes"][var]``. 277 False otherwise. 278 279 Notes 280 ----- 281 This method does not modify the configuration; it only checks presence 282 of level information. 283 """ 284 data_type = self.conf_dict.get('data_type', {}) 285 nodes = self.conf_dict.get('nodes', {}) 286 for var, dtype in data_type.items(): 287 if 'ordinal' in dtype: 288 if var not in nodes or 'levels' not in nodes[var]: 289 return False 290 return True 291 292 def _verify_experiment_name(self): 293 """ 294 Check whether the experiment name in the configuration is valid. 295 296 Parameters 297 ---------- 298 None 299 300 Returns 301 ------- 302 bool 303 True if ``conf_dict["experiment_name"]`` exists and is non-empty 304 after stripping whitespace. False otherwise. 305 """ 306 experiment_name = self.conf_dict.get("experiment_name") 307 if experiment_name is None or str(experiment_name).strip() == "": 308 return False 309 return True 310 311 def _verify_adj_matrix(self): 312 """ 313 Validate the adjacency matrix stored in the configuration. 314 315 Parameters 316 ---------- 317 None 318 319 Returns 320 ------- 321 bool 322 True if the adjacency matrix passes `validate_adj_matrix`, False otherwise. 323 324 Notes 325 ----- 326 If the adjacency matrix is stored as a list, it is converted to a 327 NumPy array before validation. 328 """ 329 adj_matrix = self.conf_dict['adj_matrix'] 330 if isinstance(adj_matrix, list): 331 adj_matrix = np.array(self.conf_dict['adj_matrix']) 332 if validate_adj_matrix(adj_matrix): 333 return True 334 else: 335 return False 336 337 def compute_levels(self, df: pd.DataFrame, write: bool = True): 338 """ 339 Infer and update ordinal/categorical levels from data. 340 341 For each variable in the configuration's `data_type` section, this 342 method uses the provided DataFrame to construct a levels dictionary 343 and injects the corresponding "levels" entry into `conf_dict["nodes"]`. 344 345 Parameters 346 ---------- 347 df : pandas.DataFrame 348 DataFrame used to infer levels for configured variables. 349 write : bool, optional 350 If True and `CONF_DICT_PATH` is set, the updated configuration is 351 written back to disk. Default is True. 352 353 Returns 354 ------- 355 None 356 357 Raises 358 ------ 359 Exception 360 If saving the configuration fails when `write=True`. 361 362 Notes 363 ----- 364 - Variables present in `levels_dict` but not in `conf_dict["nodes"]` 365 trigger a warning and are skipped. 366 - If `self.verbose` or `self.debug` is True, a success message is printed 367 when the configuration is saved. 368 """ 369 self.update() 370 levels_dict = create_levels_dict(df, self.conf_dict['data_type']) 371 372 # update nodes dict with levels 373 for var, levels in levels_dict.items(): 374 if var in self.conf_dict['nodes']: 375 self.conf_dict['nodes'][var]['levels'] = levels 376 else: 377 print(f"[WARNING] Variable '{var}' not found in nodes dict. Cannot add levels.") 378 379 if write and self.CONF_DICT_PATH is not None: 380 try: 381 self.save(self.CONF_DICT_PATH) 382 if self.verbose or self.debug: 383 print(f'[INFO] Configuration with updated levels saved to {self.CONF_DICT_PATH}') 384 except Exception as e: 385 print(f'[ERROR] Failed to save configuration: {e}') 386 387 def plot_dag(self, seed: int = 42, causal_order: bool = False): 388 """ 389 Visualize the DAG defined by the configuration. 390 391 Nodes are categorized and colored as: 392 - Source nodes (no incoming edges): green. 393 - Sink nodes (no outgoing edges): red. 394 - Intermediate nodes: light blue. 395 396 Parameters 397 ---------- 398 seed : int, optional 399 Random seed for layout stability in the spring layout fallback. 400 Default is 42. 401 causal_order : bool, optional 402 If True, attempt to use Graphviz 'dot' layout via 403 `networkx.nx_agraph.graphviz_layout` to preserve causal ordering. 404 If False or if Graphviz is unavailable, use `spring_layout`. 405 Default is False. 406 407 Returns 408 ------- 409 None 410 411 Raises 412 ------ 413 ValueError 414 If `adj_matrix` or `data_type` is missing or inconsistent with each other, 415 or if the adjacency matrix fails validation. 416 417 Notes 418 ----- 419 Edge labels are colored by prefix: 420 - "ci": blue 421 - "ls": red 422 - "cs": green 423 - other: black 424 """ 425 adj_matrix = self.conf_dict.get("adj_matrix") 426 data_type = self.conf_dict.get("data_type") 427 428 if adj_matrix is None or data_type is None: 429 raise ValueError("Configuration must include 'adj_matrix' and 'data_type'.") 430 431 if isinstance(adj_matrix, list): 432 adj_matrix = np.array(adj_matrix) 433 434 if not validate_adj_matrix(adj_matrix): 435 raise ValueError("Invalid adjacency matrix.") 436 if len(data_type) != adj_matrix.shape[0]: 437 raise ValueError("data_type must match adjacency matrix size.") 438 439 node_labels = list(data_type.keys()) 440 G, edge_labels = create_nx_graph(adj_matrix, node_labels) 441 442 sources = {n for n in G.nodes if G.in_degree(n) == 0} 443 sinks = {n for n in G.nodes if G.out_degree(n) == 0} 444 intermediates = set(G.nodes) - sources - sinks 445 446 node_colors = [ 447 "green" if n in sources 448 else "red" if n in sinks 449 else "lightblue" 450 for n in G.nodes 451 ] 452 453 if causal_order: 454 try: 455 pos = nx.nx_agraph.graphviz_layout(G, prog="dot") 456 except (ImportError, nx.NetworkXException): 457 pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100) 458 else: 459 pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100) 460 461 plt.figure(figsize=(8, 6)) 462 nx.draw( 463 G, pos, 464 with_labels=True, 465 node_color=node_colors, 466 edge_color="gray", 467 node_size=2500, 468 arrowsize=20 469 ) 470 471 for (u, v), lbl in edge_labels.items(): 472 color = ( 473 "blue" if lbl.startswith("ci") 474 else "red" if lbl.startswith("ls") 475 else "green" if lbl.startswith("cs") 476 else "black" 477 ) 478 nx.draw_networkx_edge_labels( 479 G, pos, 480 edge_labels={(u, v): lbl}, 481 font_color=color, 482 font_size=12 483 ) 484 485 legend_items = [ 486 Patch(facecolor="green", edgecolor="black", label="Source"), 487 Patch(facecolor="red", edgecolor="black", label="Sink"), 488 Patch(facecolor="lightblue", edgecolor="black", label="Intermediate") 489 ] 490 plt.legend(handles=legend_items, loc="upper right", frameon=True) 491 492 plt.title(f"TRAM DAG") 493 plt.axis("off") 494 plt.tight_layout() 495 plt.show() 496 497 498 def setup_configuration(self, experiment_name=None, EXPERIMENT_DIR=None, debug=False, _verify=False): 499 """ 500 Create or reuse a configuration for an experiment. 501 502 This method behaves differently depending on how it is called: 503 504 1. Class call (e.g. `TramDagConfig.setup_configuration(...)`): 505 - Creates or loads a configuration at the resolved path. 506 - Returns a new `TramDagConfig` instance. 507 508 2. Instance call (e.g. `cfg.setup_configuration(...)`): 509 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` in place. 510 - Optionally verifies completeness. 511 - Returns None. 512 513 Parameters 514 ---------- 515 experiment_name : str or None, optional 516 Name of the experiment. If None, defaults to "experiment_1". 517 EXPERIMENT_DIR : str or None, optional 518 Directory for the experiment. If None, defaults to 519 `<cwd>/<experiment_name>`. 520 debug : bool, optional 521 If True, initialize / update with `debug=True`. Default is False. 522 _verify : bool, optional 523 If True, call `_verify_completeness()` after loading. Default is False. 524 525 Returns 526 ------- 527 TramDagConfig or None 528 - A new instance when called on the class. 529 - None when called on an existing instance. 530 531 Notes 532 ----- 533 - A configuration file named "configuration.json" is created if it does 534 not exist yet. 535 - Underlying creation uses `create_and_write_new_configuration_dict` 536 and `load_configuration_dict`. 537 """ 538 is_class_call = isinstance(self, type) 539 cls = self if is_class_call else self.__class__ 540 541 if experiment_name is None: 542 experiment_name = "experiment_1" 543 if EXPERIMENT_DIR is None: 544 EXPERIMENT_DIR = os.path.join(os.getcwd(), experiment_name) 545 546 CONF_DICT_PATH = os.path.join(EXPERIMENT_DIR, "configuration.json") 547 DATA_PATH = EXPERIMENT_DIR 548 549 os.makedirs(EXPERIMENT_DIR, exist_ok=True) 550 551 if os.path.exists(CONF_DICT_PATH): 552 print(f"Configuration already exists: {CONF_DICT_PATH}") 553 else: 554 _ = create_and_write_new_configuration_dict( 555 experiment_name, CONF_DICT_PATH, EXPERIMENT_DIR, DATA_PATH, None 556 ) 557 print(f"Created new configuration file at {CONF_DICT_PATH}") 558 559 conf = load_configuration_dict(CONF_DICT_PATH) 560 561 if is_class_call: 562 return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=_verify) 563 else: 564 self.conf_dict = conf 565 self.CONF_DICT_PATH = CONF_DICT_PATH 566 if _verify: 567 self._verify_completeness() 568 569 def set_data_type(self, data_type: dict, CONF_DICT_PATH: str = None) -> None: 570 """ 571 Update or write the `data_type` section of a configuration file. 572 573 Supports both class-level and instance-level usage: 574 575 - Class call: 576 - Requires `CONF_DICT_PATH` argument. 577 - Reads the file if it exists, or starts from an empty dict. 578 - Writes updated configuration to `CONF_DICT_PATH`. 579 580 - Instance call: 581 - Uses `self.CONF_DICT_PATH` if available, otherwise defaults to 582 `<cwd>/configuration.json` if no path is provided. 583 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` after writing. 584 585 Parameters 586 ---------- 587 data_type : dict 588 Mapping `{variable_name: type_spec}`, where `type_spec` encodes 589 modeling types (e.g. continuous, ordinal, etc.). 590 CONF_DICT_PATH : str or None, optional 591 Path to the configuration file. Must be provided for class calls. 592 For instance calls, defaults as described above. 593 594 Returns 595 ------- 596 None 597 598 Raises 599 ------ 600 ValueError 601 If `CONF_DICT_PATH` is missing when called on the class, or if 602 validation of data types fails. 603 604 Notes 605 ----- 606 - Variable names are validated via `validate_variable_names`. 607 - Data type values are validated via `validate_data_types`. 608 - A textual summary of modeling settings is printed via 609 `print_data_type_modeling_setting`, if possible. 610 """ 611 is_class_call = isinstance(self, type) 612 cls = self if is_class_call else self.__class__ 613 614 # resolve path 615 if CONF_DICT_PATH is None: 616 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 617 CONF_DICT_PATH = self.CONF_DICT_PATH 618 elif not is_class_call: 619 CONF_DICT_PATH = os.path.join(os.getcwd(), "configuration.json") 620 else: 621 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 622 623 try: 624 # load existing or create empty configuration 625 configuration_dict = ( 626 load_configuration_dict(CONF_DICT_PATH) 627 if os.path.exists(CONF_DICT_PATH) 628 else {} 629 ) 630 631 validate_variable_names(data_type.keys()) 632 if not validate_data_types(data_type): 633 raise ValueError("Invalid data types in the provided dictionary.") 634 635 configuration_dict["data_type"] = data_type 636 write_configuration_dict(configuration_dict, CONF_DICT_PATH) 637 638 # safe printing 639 try: 640 print_data_type_modeling_setting(data_type or {}) 641 except Exception as e: 642 print(f"[WARNING] Could not print data type modeling settings: {e}") 643 644 if not is_class_call: 645 self.conf_dict = configuration_dict 646 self.CONF_DICT_PATH = CONF_DICT_PATH 647 648 649 except Exception as e: 650 print(f"Failed to update configuration: {e}") 651 else: 652 print(f"Configuration updated successfully at {CONF_DICT_PATH}.") 653 654 def set_meta_adj_matrix(self, CONF_DICT_PATH: str = None, seed: int = 5): 655 """ 656 Launch the interactive editor to set or modify the adjacency matrix. 657 658 This method: 659 660 1. Resolves the configuration path either from the argument or, for 661 instances, from `self.CONF_DICT_PATH`. 662 2. Invokes `interactive_adj_matrix` to edit the adjacency matrix. 663 3. For instances, reloads the updated configuration into `self.conf_dict`. 664 665 Parameters 666 ---------- 667 CONF_DICT_PATH : str or None, optional 668 Path to the configuration file. Must be provided when called 669 on the class. For instance calls, defaults to `self.CONF_DICT_PATH`. 670 seed : int, optional 671 Random seed for any layout or stochastic behavior in the interactive 672 editor. Default is 5. 673 674 Returns 675 ------- 676 None 677 678 Raises 679 ------ 680 ValueError 681 If `CONF_DICT_PATH` is not provided and cannot be inferred 682 (e.g. in a class call without path). 683 684 Notes 685 ----- 686 `self.update()` is called at the start to ensure the in-memory config 687 is in sync with the file before launching the editor. 688 """ 689 self.update() 690 is_class_call = isinstance(self, type) 691 # resolve path 692 if CONF_DICT_PATH is None: 693 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 694 CONF_DICT_PATH = self.CONF_DICT_PATH 695 else: 696 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 697 698 # launch interactive editor 699 700 interactive_adj_matrix(CONF_DICT_PATH, seed=seed) 701 702 # reload config if instance 703 if not is_class_call: 704 self.conf_dict = load_configuration_dict(CONF_DICT_PATH) 705 self.CONF_DICT_PATH = CONF_DICT_PATH 706 707 708 def set_tramdag_nn_models(self, CONF_DICT_PATH: str = None): 709 """ 710 Launch the interactive editor to set TRAM-DAG neural network model names. 711 712 Depending on call context: 713 714 - Class call: 715 - Requires `CONF_DICT_PATH` argument. 716 - Returns nothing and does not modify a specific instance. 717 718 - Instance call: 719 - Resolves `CONF_DICT_PATH` from the argument or `self.CONF_DICT_PATH`. 720 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` if the editor 721 returns an updated configuration. 722 723 Parameters 724 ---------- 725 CONF_DICT_PATH : str or None, optional 726 Path to the configuration file. Must be provided when called on 727 the class. For instance calls, defaults to `self.CONF_DICT_PATH`. 728 729 Returns 730 ------- 731 None 732 733 Raises 734 ------ 735 ValueError 736 If `CONF_DICT_PATH` is not provided and cannot be inferred 737 (e.g. in a class call without path). 738 739 Notes 740 ----- 741 The interactive editor is invoked via `interactive_nn_names_matrix`. 742 If it returns `None`, the instance configuration is left unchanged. 743 """ 744 is_class_call = isinstance(self, type) 745 if CONF_DICT_PATH is None: 746 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 747 CONF_DICT_PATH = self.CONF_DICT_PATH 748 else: 749 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 750 751 updated_conf = interactive_nn_names_matrix(CONF_DICT_PATH) 752 if updated_conf is not None and not is_class_call: 753 self.conf_dict = updated_conf 754 self.CONF_DICT_PATH = CONF_DICT_PATH
Configuration manager for TRAM-DAG experiments.
This class encapsulates:
- The experiment configuration dictionary (
conf_dict). - Its backing file path (
CONF_DICT_PATH). - Utilities to load, validate, modify, and persist configuration.
- DAG visualization and interactive editing helpers.
Typical usage
- Load existing configuration from disk via
TramDagConfig.load_json. - Or create/reuse experiment setup via
TramDagConfig().setup_configuration. - Update sections such as
data_type, adjacency matrix, and neural network model names using the provided methods.
49 def __init__(self, conf_dict: dict = None, CONF_DICT_PATH: str = None, _verify: bool = False,**kwargs): 50 """ 51 Initialize a TramDagConfig instance. 52 53 Parameters 54 ---------- 55 conf_dict : dict or None, optional 56 Configuration dictionary. If None, an empty dict is used and can 57 be populated later. Default is None. 58 CONF_DICT_PATH : str or None, optional 59 Path to the configuration file on disk. Default is None. 60 _verify : bool, optional 61 If True, run `_verify_completeness()` after initialization. 62 Default is False. 63 **kwargs 64 Additional attributes to be set on the instance. Keys "conf_dict" 65 and "CONF_DICT_PATH" are forbidden and raise a ValueError. 66 67 Raises 68 ------ 69 ValueError 70 If any key in `kwargs` is "conf_dict" or "CONF_DICT_PATH". 71 72 Notes 73 ----- 74 By default, `debug` and `verbose` are set to False. They can be 75 overridden via `kwargs`. 76 """ 77 78 self.debug = False 79 self.verbose = False 80 81 for key, value in kwargs.items(): 82 if key in ['conf_dict', 'CONF_DICT_PATH']: 83 raise ValueError(f"Cannot override '{key}' via kwargs.") 84 setattr(self, key, value) 85 86 self.conf_dict = conf_dict or {} 87 self.CONF_DICT_PATH = CONF_DICT_PATH 88 89 # verification 90 if _verify: 91 self._verify_completeness()
Initialize a TramDagConfig instance.
Parameters
conf_dict : dict or None, optional
Configuration dictionary. If None, an empty dict is used and can
be populated later. Default is None.
CONF_DICT_PATH : str or None, optional
Path to the configuration file on disk. Default is None.
_verify : bool, optional
If True, run _verify_completeness() after initialization.
Default is False.
**kwargs
Additional attributes to be set on the instance. Keys "conf_dict"
and "CONF_DICT_PATH" are forbidden and raise a ValueError.
Raises
ValueError
If any key in kwargs is "conf_dict" or "CONF_DICT_PATH".
Notes
By default, debug and verbose are set to False. They can be
overridden via kwargs.
93 @classmethod 94 def load_json(cls, CONF_DICT_PATH: str,debug: bool = False): 95 """ 96 Load a configuration from a JSON file and construct a TramDagConfig. 97 98 Parameters 99 ---------- 100 CONF_DICT_PATH : str 101 Path to the configuration JSON file. 102 debug : bool, optional 103 If True, initialize the instance with `debug=True`. Default is False. 104 105 Returns 106 ------- 107 TramDagConfig 108 Newly created configuration instance with `conf_dict` loaded from 109 `CONF_DICT_PATH` and `_verify_completeness()` executed. 110 111 Raises 112 ------ 113 FileNotFoundError 114 If the configuration file cannot be found (propagated by 115 `load_configuration_dict`). 116 """ 117 118 conf = load_configuration_dict(CONF_DICT_PATH) 119 return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=True)
Load a configuration from a JSON file and construct a TramDagConfig.
Parameters
CONF_DICT_PATH : str
Path to the configuration JSON file.
debug : bool, optional
If True, initialize the instance with debug=True. Default is False.
Returns
TramDagConfig
Newly created configuration instance with conf_dict loaded from
CONF_DICT_PATH and _verify_completeness() executed.
Raises
FileNotFoundError
If the configuration file cannot be found (propagated by
load_configuration_dict).
121 def update(self): 122 """ 123 Reload the latest configuration from disk into this instance. 124 125 Parameters 126 ---------- 127 None 128 129 Returns 130 ------- 131 None 132 133 Raises 134 ------ 135 ValueError 136 If `CONF_DICT_PATH` is not set on the instance. 137 138 Notes 139 ----- 140 The current in-memory `conf_dict` is overwritten by the contents 141 loaded from `CONF_DICT_PATH`. 142 """ 143 144 if not hasattr(self, "CONF_DICT_PATH") or self.CONF_DICT_PATH is None: 145 raise ValueError("CONF_DICT_PATH not set — cannot update configuration.") 146 147 148 self.conf_dict = load_configuration_dict(self.CONF_DICT_PATH)
Reload the latest configuration from disk into this instance.
Parameters
None
Returns
None
Raises
ValueError
If CONF_DICT_PATH is not set on the instance.
Notes
The current in-memory conf_dict is overwritten by the contents
loaded from CONF_DICT_PATH.
151 def save(self, CONF_DICT_PATH: str = None): 152 """ 153 Persist the current configuration dictionary to disk. 154 155 Parameters 156 ---------- 157 CONF_DICT_PATH : str or None, optional 158 Target path for the configuration file. If None, uses 159 `self.CONF_DICT_PATH`. Default is None. 160 161 Returns 162 ------- 163 None 164 165 Raises 166 ------ 167 ValueError 168 If neither the function argument nor `self.CONF_DICT_PATH` 169 provides a valid path. 170 171 Notes 172 ----- 173 The resulting file is written via `write_configuration_dict`. 174 """ 175 path = CONF_DICT_PATH or self.CONF_DICT_PATH 176 if path is None: 177 raise ValueError("No CONF_DICT_PATH provided to save config.") 178 write_configuration_dict(self.conf_dict, path)
Persist the current configuration dictionary to disk.
Parameters
CONF_DICT_PATH : str or None, optional
Target path for the configuration file. If None, uses
self.CONF_DICT_PATH. Default is None.
Returns
None
Raises
ValueError
If neither the function argument nor self.CONF_DICT_PATH
provides a valid path.
Notes
The resulting file is written via write_configuration_dict.
337 def compute_levels(self, df: pd.DataFrame, write: bool = True): 338 """ 339 Infer and update ordinal/categorical levels from data. 340 341 For each variable in the configuration's `data_type` section, this 342 method uses the provided DataFrame to construct a levels dictionary 343 and injects the corresponding "levels" entry into `conf_dict["nodes"]`. 344 345 Parameters 346 ---------- 347 df : pandas.DataFrame 348 DataFrame used to infer levels for configured variables. 349 write : bool, optional 350 If True and `CONF_DICT_PATH` is set, the updated configuration is 351 written back to disk. Default is True. 352 353 Returns 354 ------- 355 None 356 357 Raises 358 ------ 359 Exception 360 If saving the configuration fails when `write=True`. 361 362 Notes 363 ----- 364 - Variables present in `levels_dict` but not in `conf_dict["nodes"]` 365 trigger a warning and are skipped. 366 - If `self.verbose` or `self.debug` is True, a success message is printed 367 when the configuration is saved. 368 """ 369 self.update() 370 levels_dict = create_levels_dict(df, self.conf_dict['data_type']) 371 372 # update nodes dict with levels 373 for var, levels in levels_dict.items(): 374 if var in self.conf_dict['nodes']: 375 self.conf_dict['nodes'][var]['levels'] = levels 376 else: 377 print(f"[WARNING] Variable '{var}' not found in nodes dict. Cannot add levels.") 378 379 if write and self.CONF_DICT_PATH is not None: 380 try: 381 self.save(self.CONF_DICT_PATH) 382 if self.verbose or self.debug: 383 print(f'[INFO] Configuration with updated levels saved to {self.CONF_DICT_PATH}') 384 except Exception as e: 385 print(f'[ERROR] Failed to save configuration: {e}')
Infer and update ordinal/categorical levels from data.
For each variable in the configuration's data_type section, this
method uses the provided DataFrame to construct a levels dictionary
and injects the corresponding "levels" entry into conf_dict["nodes"].
Parameters
df : pandas.DataFrame
DataFrame used to infer levels for configured variables.
write : bool, optional
If True and CONF_DICT_PATH is set, the updated configuration is
written back to disk. Default is True.
Returns
None
Raises
Exception
If saving the configuration fails when write=True.
Notes
- Variables present in
levels_dictbut not inconf_dict["nodes"]trigger a warning and are skipped. - If
self.verboseorself.debugis True, a success message is printed when the configuration is saved.
387 def plot_dag(self, seed: int = 42, causal_order: bool = False): 388 """ 389 Visualize the DAG defined by the configuration. 390 391 Nodes are categorized and colored as: 392 - Source nodes (no incoming edges): green. 393 - Sink nodes (no outgoing edges): red. 394 - Intermediate nodes: light blue. 395 396 Parameters 397 ---------- 398 seed : int, optional 399 Random seed for layout stability in the spring layout fallback. 400 Default is 42. 401 causal_order : bool, optional 402 If True, attempt to use Graphviz 'dot' layout via 403 `networkx.nx_agraph.graphviz_layout` to preserve causal ordering. 404 If False or if Graphviz is unavailable, use `spring_layout`. 405 Default is False. 406 407 Returns 408 ------- 409 None 410 411 Raises 412 ------ 413 ValueError 414 If `adj_matrix` or `data_type` is missing or inconsistent with each other, 415 or if the adjacency matrix fails validation. 416 417 Notes 418 ----- 419 Edge labels are colored by prefix: 420 - "ci": blue 421 - "ls": red 422 - "cs": green 423 - other: black 424 """ 425 adj_matrix = self.conf_dict.get("adj_matrix") 426 data_type = self.conf_dict.get("data_type") 427 428 if adj_matrix is None or data_type is None: 429 raise ValueError("Configuration must include 'adj_matrix' and 'data_type'.") 430 431 if isinstance(adj_matrix, list): 432 adj_matrix = np.array(adj_matrix) 433 434 if not validate_adj_matrix(adj_matrix): 435 raise ValueError("Invalid adjacency matrix.") 436 if len(data_type) != adj_matrix.shape[0]: 437 raise ValueError("data_type must match adjacency matrix size.") 438 439 node_labels = list(data_type.keys()) 440 G, edge_labels = create_nx_graph(adj_matrix, node_labels) 441 442 sources = {n for n in G.nodes if G.in_degree(n) == 0} 443 sinks = {n for n in G.nodes if G.out_degree(n) == 0} 444 intermediates = set(G.nodes) - sources - sinks 445 446 node_colors = [ 447 "green" if n in sources 448 else "red" if n in sinks 449 else "lightblue" 450 for n in G.nodes 451 ] 452 453 if causal_order: 454 try: 455 pos = nx.nx_agraph.graphviz_layout(G, prog="dot") 456 except (ImportError, nx.NetworkXException): 457 pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100) 458 else: 459 pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100) 460 461 plt.figure(figsize=(8, 6)) 462 nx.draw( 463 G, pos, 464 with_labels=True, 465 node_color=node_colors, 466 edge_color="gray", 467 node_size=2500, 468 arrowsize=20 469 ) 470 471 for (u, v), lbl in edge_labels.items(): 472 color = ( 473 "blue" if lbl.startswith("ci") 474 else "red" if lbl.startswith("ls") 475 else "green" if lbl.startswith("cs") 476 else "black" 477 ) 478 nx.draw_networkx_edge_labels( 479 G, pos, 480 edge_labels={(u, v): lbl}, 481 font_color=color, 482 font_size=12 483 ) 484 485 legend_items = [ 486 Patch(facecolor="green", edgecolor="black", label="Source"), 487 Patch(facecolor="red", edgecolor="black", label="Sink"), 488 Patch(facecolor="lightblue", edgecolor="black", label="Intermediate") 489 ] 490 plt.legend(handles=legend_items, loc="upper right", frameon=True) 491 492 plt.title(f"TRAM DAG") 493 plt.axis("off") 494 plt.tight_layout() 495 plt.show()
Visualize the DAG defined by the configuration.
Nodes are categorized and colored as:
- Source nodes (no incoming edges): green.
- Sink nodes (no outgoing edges): red.
- Intermediate nodes: light blue.
Parameters
seed : int, optional
Random seed for layout stability in the spring layout fallback.
Default is 42.
causal_order : bool, optional
If True, attempt to use Graphviz 'dot' layout via
networkx.nx_agraph.graphviz_layout to preserve causal ordering.
If False or if Graphviz is unavailable, use spring_layout.
Default is False.
Returns
None
Raises
ValueError
If adj_matrix or data_type is missing or inconsistent with each other,
or if the adjacency matrix fails validation.
Notes
Edge labels are colored by prefix:
- "ci": blue
- "ls": red
- "cs": green
- other: black
498 def setup_configuration(self, experiment_name=None, EXPERIMENT_DIR=None, debug=False, _verify=False): 499 """ 500 Create or reuse a configuration for an experiment. 501 502 This method behaves differently depending on how it is called: 503 504 1. Class call (e.g. `TramDagConfig.setup_configuration(...)`): 505 - Creates or loads a configuration at the resolved path. 506 - Returns a new `TramDagConfig` instance. 507 508 2. Instance call (e.g. `cfg.setup_configuration(...)`): 509 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` in place. 510 - Optionally verifies completeness. 511 - Returns None. 512 513 Parameters 514 ---------- 515 experiment_name : str or None, optional 516 Name of the experiment. If None, defaults to "experiment_1". 517 EXPERIMENT_DIR : str or None, optional 518 Directory for the experiment. If None, defaults to 519 `<cwd>/<experiment_name>`. 520 debug : bool, optional 521 If True, initialize / update with `debug=True`. Default is False. 522 _verify : bool, optional 523 If True, call `_verify_completeness()` after loading. Default is False. 524 525 Returns 526 ------- 527 TramDagConfig or None 528 - A new instance when called on the class. 529 - None when called on an existing instance. 530 531 Notes 532 ----- 533 - A configuration file named "configuration.json" is created if it does 534 not exist yet. 535 - Underlying creation uses `create_and_write_new_configuration_dict` 536 and `load_configuration_dict`. 537 """ 538 is_class_call = isinstance(self, type) 539 cls = self if is_class_call else self.__class__ 540 541 if experiment_name is None: 542 experiment_name = "experiment_1" 543 if EXPERIMENT_DIR is None: 544 EXPERIMENT_DIR = os.path.join(os.getcwd(), experiment_name) 545 546 CONF_DICT_PATH = os.path.join(EXPERIMENT_DIR, "configuration.json") 547 DATA_PATH = EXPERIMENT_DIR 548 549 os.makedirs(EXPERIMENT_DIR, exist_ok=True) 550 551 if os.path.exists(CONF_DICT_PATH): 552 print(f"Configuration already exists: {CONF_DICT_PATH}") 553 else: 554 _ = create_and_write_new_configuration_dict( 555 experiment_name, CONF_DICT_PATH, EXPERIMENT_DIR, DATA_PATH, None 556 ) 557 print(f"Created new configuration file at {CONF_DICT_PATH}") 558 559 conf = load_configuration_dict(CONF_DICT_PATH) 560 561 if is_class_call: 562 return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=_verify) 563 else: 564 self.conf_dict = conf 565 self.CONF_DICT_PATH = CONF_DICT_PATH 566 if _verify: 567 self._verify_completeness()
Create or reuse a configuration for an experiment.
This method behaves differently depending on how it is called:
- Class call (e.g.
TramDagConfig.setup_configuration(...)):
- Creates or loads a configuration at the resolved path.
- Returns a new
TramDagConfiginstance.
- Instance call (e.g.
cfg.setup_configuration(...)):
- Updates
self.conf_dictandself.CONF_DICT_PATHin place. - Optionally verifies completeness.
- Returns None.
Parameters
experiment_name : str or None, optional
Name of the experiment. If None, defaults to "experiment_1".
EXPERIMENT_DIR : str or None, optional
Directory for the experiment. If None, defaults to
<cwd>/<experiment_name>.
debug : bool, optional
If True, initialize / update with debug=True. Default is False.
_verify : bool, optional
If True, call _verify_completeness() after loading. Default is False.
Returns
TramDagConfig or None - A new instance when called on the class. - None when called on an existing instance.
Notes
- A configuration file named "configuration.json" is created if it does not exist yet.
- Underlying creation uses
create_and_write_new_configuration_dictandload_configuration_dict.
569 def set_data_type(self, data_type: dict, CONF_DICT_PATH: str = None) -> None: 570 """ 571 Update or write the `data_type` section of a configuration file. 572 573 Supports both class-level and instance-level usage: 574 575 - Class call: 576 - Requires `CONF_DICT_PATH` argument. 577 - Reads the file if it exists, or starts from an empty dict. 578 - Writes updated configuration to `CONF_DICT_PATH`. 579 580 - Instance call: 581 - Uses `self.CONF_DICT_PATH` if available, otherwise defaults to 582 `<cwd>/configuration.json` if no path is provided. 583 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` after writing. 584 585 Parameters 586 ---------- 587 data_type : dict 588 Mapping `{variable_name: type_spec}`, where `type_spec` encodes 589 modeling types (e.g. continuous, ordinal, etc.). 590 CONF_DICT_PATH : str or None, optional 591 Path to the configuration file. Must be provided for class calls. 592 For instance calls, defaults as described above. 593 594 Returns 595 ------- 596 None 597 598 Raises 599 ------ 600 ValueError 601 If `CONF_DICT_PATH` is missing when called on the class, or if 602 validation of data types fails. 603 604 Notes 605 ----- 606 - Variable names are validated via `validate_variable_names`. 607 - Data type values are validated via `validate_data_types`. 608 - A textual summary of modeling settings is printed via 609 `print_data_type_modeling_setting`, if possible. 610 """ 611 is_class_call = isinstance(self, type) 612 cls = self if is_class_call else self.__class__ 613 614 # resolve path 615 if CONF_DICT_PATH is None: 616 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 617 CONF_DICT_PATH = self.CONF_DICT_PATH 618 elif not is_class_call: 619 CONF_DICT_PATH = os.path.join(os.getcwd(), "configuration.json") 620 else: 621 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 622 623 try: 624 # load existing or create empty configuration 625 configuration_dict = ( 626 load_configuration_dict(CONF_DICT_PATH) 627 if os.path.exists(CONF_DICT_PATH) 628 else {} 629 ) 630 631 validate_variable_names(data_type.keys()) 632 if not validate_data_types(data_type): 633 raise ValueError("Invalid data types in the provided dictionary.") 634 635 configuration_dict["data_type"] = data_type 636 write_configuration_dict(configuration_dict, CONF_DICT_PATH) 637 638 # safe printing 639 try: 640 print_data_type_modeling_setting(data_type or {}) 641 except Exception as e: 642 print(f"[WARNING] Could not print data type modeling settings: {e}") 643 644 if not is_class_call: 645 self.conf_dict = configuration_dict 646 self.CONF_DICT_PATH = CONF_DICT_PATH 647 648 649 except Exception as e: 650 print(f"Failed to update configuration: {e}") 651 else: 652 print(f"Configuration updated successfully at {CONF_DICT_PATH}.")
Update or write the data_type section of a configuration file.
Supports both class-level and instance-level usage:
- Class call:
- Requires
CONF_DICT_PATHargument. - Reads the file if it exists, or starts from an empty dict.
Writes updated configuration to
CONF_DICT_PATH.Instance call:
- Uses
self.CONF_DICT_PATHif available, otherwise defaults to<cwd>/configuration.jsonif no path is provided. - Updates
self.conf_dictandself.CONF_DICT_PATHafter writing.
Parameters
data_type : dict
Mapping {variable_name: type_spec}, where type_spec encodes
modeling types (e.g. continuous, ordinal, etc.).
CONF_DICT_PATH : str or None, optional
Path to the configuration file. Must be provided for class calls.
For instance calls, defaults as described above.
Returns
None
Raises
ValueError
If CONF_DICT_PATH is missing when called on the class, or if
validation of data types fails.
Notes
- Variable names are validated via
validate_variable_names. - Data type values are validated via
validate_data_types. - A textual summary of modeling settings is printed via
print_data_type_modeling_setting, if possible.
654 def set_meta_adj_matrix(self, CONF_DICT_PATH: str = None, seed: int = 5): 655 """ 656 Launch the interactive editor to set or modify the adjacency matrix. 657 658 This method: 659 660 1. Resolves the configuration path either from the argument or, for 661 instances, from `self.CONF_DICT_PATH`. 662 2. Invokes `interactive_adj_matrix` to edit the adjacency matrix. 663 3. For instances, reloads the updated configuration into `self.conf_dict`. 664 665 Parameters 666 ---------- 667 CONF_DICT_PATH : str or None, optional 668 Path to the configuration file. Must be provided when called 669 on the class. For instance calls, defaults to `self.CONF_DICT_PATH`. 670 seed : int, optional 671 Random seed for any layout or stochastic behavior in the interactive 672 editor. Default is 5. 673 674 Returns 675 ------- 676 None 677 678 Raises 679 ------ 680 ValueError 681 If `CONF_DICT_PATH` is not provided and cannot be inferred 682 (e.g. in a class call without path). 683 684 Notes 685 ----- 686 `self.update()` is called at the start to ensure the in-memory config 687 is in sync with the file before launching the editor. 688 """ 689 self.update() 690 is_class_call = isinstance(self, type) 691 # resolve path 692 if CONF_DICT_PATH is None: 693 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 694 CONF_DICT_PATH = self.CONF_DICT_PATH 695 else: 696 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 697 698 # launch interactive editor 699 700 interactive_adj_matrix(CONF_DICT_PATH, seed=seed) 701 702 # reload config if instance 703 if not is_class_call: 704 self.conf_dict = load_configuration_dict(CONF_DICT_PATH) 705 self.CONF_DICT_PATH = CONF_DICT_PATH
Launch the interactive editor to set or modify the adjacency matrix.
This method:
- Resolves the configuration path either from the argument or, for
instances, from
self.CONF_DICT_PATH. - Invokes
interactive_adj_matrixto edit the adjacency matrix. - For instances, reloads the updated configuration into
self.conf_dict.
Parameters
CONF_DICT_PATH : str or None, optional
Path to the configuration file. Must be provided when called
on the class. For instance calls, defaults to self.CONF_DICT_PATH.
seed : int, optional
Random seed for any layout or stochastic behavior in the interactive
editor. Default is 5.
Returns
None
Raises
ValueError
If CONF_DICT_PATH is not provided and cannot be inferred
(e.g. in a class call without path).
Notes
self.update() is called at the start to ensure the in-memory config
is in sync with the file before launching the editor.
708 def set_tramdag_nn_models(self, CONF_DICT_PATH: str = None): 709 """ 710 Launch the interactive editor to set TRAM-DAG neural network model names. 711 712 Depending on call context: 713 714 - Class call: 715 - Requires `CONF_DICT_PATH` argument. 716 - Returns nothing and does not modify a specific instance. 717 718 - Instance call: 719 - Resolves `CONF_DICT_PATH` from the argument or `self.CONF_DICT_PATH`. 720 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` if the editor 721 returns an updated configuration. 722 723 Parameters 724 ---------- 725 CONF_DICT_PATH : str or None, optional 726 Path to the configuration file. Must be provided when called on 727 the class. For instance calls, defaults to `self.CONF_DICT_PATH`. 728 729 Returns 730 ------- 731 None 732 733 Raises 734 ------ 735 ValueError 736 If `CONF_DICT_PATH` is not provided and cannot be inferred 737 (e.g. in a class call without path). 738 739 Notes 740 ----- 741 The interactive editor is invoked via `interactive_nn_names_matrix`. 742 If it returns `None`, the instance configuration is left unchanged. 743 """ 744 is_class_call = isinstance(self, type) 745 if CONF_DICT_PATH is None: 746 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 747 CONF_DICT_PATH = self.CONF_DICT_PATH 748 else: 749 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 750 751 updated_conf = interactive_nn_names_matrix(CONF_DICT_PATH) 752 if updated_conf is not None and not is_class_call: 753 self.conf_dict = updated_conf 754 self.CONF_DICT_PATH = CONF_DICT_PATH
Launch the interactive editor to set TRAM-DAG neural network model names.
Depending on call context:
- Class call:
- Requires
CONF_DICT_PATHargument. Returns nothing and does not modify a specific instance.
Instance call:
- Resolves
CONF_DICT_PATHfrom the argument orself.CONF_DICT_PATH. - Updates
self.conf_dictandself.CONF_DICT_PATHif the editor returns an updated configuration.
Parameters
CONF_DICT_PATH : str or None, optional
Path to the configuration file. Must be provided when called on
the class. For instance calls, defaults to self.CONF_DICT_PATH.
Returns
None
Raises
ValueError
If CONF_DICT_PATH is not provided and cannot be inferred
(e.g. in a class call without path).
Notes
The interactive editor is invoked via interactive_nn_names_matrix.
If it returns None, the instance configuration is left unchanged.
29class TramDagDataset(Dataset): 30 31 """ 32 TramDagDataset 33 ============== 34 35 The `TramDagDataset` class handles structured data preparation for TRAM-DAG 36 models. It wraps a pandas DataFrame together with its configuration and provides 37 utilities for scaling, transformation, and efficient DataLoader construction 38 for each node in a DAG-based configuration. 39 40 --------------------------------------------------------------------- 41 Core Responsibilities 42 --------------------------------------------------------------------- 43 - Validate and store configuration metadata (`TramDagConfig`). 44 - Manage per-node settings for DataLoader creation (batch size, shuffling, workers). 45 - Compute scaling information (quantile-based min/max). 46 - Optionally precompute and cache dataset representations. 47 - Expose PyTorch Dataset and DataLoader interfaces for model training. 48 49 --------------------------------------------------------------------- 50 Key Attributes 51 --------------------------------------------------------------------- 52 - **df** : pandas.DataFrame 53 The dataset content used for building loaders and computing scaling. 54 55 - **cfg** : TramDagConfig 56 Configuration object defining nodes and variable metadata. 57 58 - **nodes_dict** : dict 59 Mapping of variable names to node specifications from the configuration. 60 61 - **loaders** : dict 62 Mapping of node names to `torch.utils.data.DataLoader` instances or `GenericDataset` objects. 63 64 - **DEFAULTS** : dict 65 Default DataLoader and dataset-related settings (e.g., batch_size, shuffle, num_workers, etc.). 66 67 --------------------------------------------------------------------- 68 Main Methods 69 --------------------------------------------------------------------- 70 - **from_dataframe(df, cfg, **kwargs)** 71 Construct the dataset directly from a pandas DataFrame. 72 73 - **compute_scaling(df=None, write=True)** 74 Compute per-variable min/max scaling values from data. 75 76 - **summary()** 77 Print dataset overview including shape, dtypes, statistics, and node settings. 78 79 --------------------------------------------------------------------- 80 Notes 81 --------------------------------------------------------------------- 82 - Intended for training data; `compute_scaling()` should use only training subsets. 83 - Compatible with both CPU and GPU DataLoader options. 84 - Strict validation of keyword arguments against `DEFAULTS` prevents silent misconfiguration. 85 86 --------------------------------------------------------------------- 87 Example 88 --------------------------------------------------------------------- 89 >>> cfg = TramDagConfig.from_json("config.json") 90 >>> dataset = TramDagDataset.from_dataframe(train_df, cfg, batch_size=1024, debug=True) 91 >>> dataset.summary() 92 >>> minmax = dataset.compute_scaling(train_df) 93 >>> loader = dataset.loaders["variable_x"] 94 >>> next(iter(loader)) 95 """ 96 97 DEFAULTS = { 98 "batch_size": 32_000, 99 "shuffle": True, 100 "num_workers": 4, 101 "pin_memory": True, 102 "return_intercept_shift": True, 103 "debug": False, 104 "transform": None, 105 "use_dataloader": True, 106 "use_precomputed": False, 107 # DataLoader extras 108 "sampler": None, 109 "batch_sampler": None, 110 "collate_fn": None, 111 "drop_last": False, 112 "timeout": 0, 113 "worker_init_fn": None, 114 "multiprocessing_context": None, 115 "generator": None, 116 "prefetch_factor": 2, 117 "persistent_workers": True, 118 "pin_memory_device": "", 119 } 120 121 def __init__(self): 122 """ 123 Initialize an empty TramDagDataset shell. 124 125 Notes 126 ----- 127 This constructor does not attach data or configuration. Use 128 `TramDagDataset.from_dataframe` to obtain a ready-to-use instance. 129 """ 130 pass 131 132 @classmethod 133 def from_dataframe(cls, df, cfg, **kwargs): 134 """ 135 Create a TramDagDataset instance directly from a pandas DataFrame. 136 137 This classmethod: 138 139 1. Validates keyword arguments against `DEFAULTS`. 140 2. Merges user overrides with defaults into a resolved settings dict. 141 3. Stores the configuration and verifies its completeness. 142 4. Applies settings to the instance. 143 5. Builds per-node datasets and DataLoaders. 144 145 Parameters 146 ---------- 147 df : pandas.DataFrame 148 Input DataFrame containing the dataset. 149 cfg : TramDagConfig 150 Configuration object defining nodes and variable metadata. 151 **kwargs 152 Optional overrides for `DEFAULTS`. All keys must exist in 153 `TramDagDataset.DEFAULTS`. Common keys include: 154 155 batch_size : int 156 Batch size for DataLoaders. 157 shuffle : bool 158 Whether to shuffle samples per epoch. 159 num_workers : int 160 Number of DataLoader workers. 161 pin_memory : bool 162 Whether to pin memory for faster host-to-device transfers. 163 return_intercept_shift : bool 164 Whether datasets should return intercept/shift information. 165 debug : bool 166 Enable debug printing. 167 transform : callable or dict or None 168 Optional transform(s) applied to samples. 169 use_dataloader : bool 170 If True, construct DataLoaders; else store raw Dataset objects. 171 use_precomputed : bool 172 If True, precompute dataset representation to disk and reload it. 173 174 Returns 175 ------- 176 TramDagDataset 177 Initialized dataset instance. 178 179 Raises 180 ------ 181 TypeError 182 If `df` is not a pandas DataFrame. 183 ValueError 184 If unknown keyword arguments are provided (when validation is enabled). 185 186 Notes 187 ----- 188 If `shuffle=True` and the inferred variable name of `df` suggests 189 validation/test data (e.g. "val", "test"), a warning is printed. 190 """ 191 self = cls() 192 if not isinstance(df, pd.DataFrame): 193 raise TypeError(f"[ERROR] df must be a pandas DataFrame, but got {type(df)}") 194 195 # validate kwargs 196 #self._validate_kwargs(kwargs, context="from_dataframe") 197 198 # merge defaults with overrides 199 settings = dict(cls.DEFAULTS) 200 settings.update(kwargs) 201 202 # store config and verify 203 self.cfg = cfg 204 self.cfg._verify_completeness() 205 206 # ouptu all setttings if debug 207 if settings.get("debug", False): 208 print("[DEBUG] TramDagDataset.from_dataframe() settings (after defaults + overrides):") 209 for k, v in settings.items(): 210 print(f" {k}: {v}") 211 212 # infer variable name automatically 213 callers_locals = inspect.currentframe().f_back.f_locals 214 inferred = None 215 for var_name, var_val in callers_locals.items(): 216 if var_val is df: 217 inferred = var_name 218 break 219 df_name = inferred or "dataframe" 220 221 if settings["shuffle"]: 222 if any(x in df_name.lower() for x in ["val", "validation", "test"]): 223 print(f"[WARNING] DataFrame '{df_name}' looks like a validation/test set → shuffle=True. Are you sure?") 224 225 # call again to ensure Warning messages if ordinal vars have missing levels 226 self.df = df.copy() 227 self._apply_settings(settings) 228 self._build_dataloaders() 229 return self 230 231 def compute_scaling(self, df: pd.DataFrame = None, write: bool = True): 232 """ 233 Compute variable-wise scaling parameters from data. 234 235 Per variable, this method computes approximate minimum and maximum 236 values using the 5th and 95th percentiles. This is typically used 237 to derive robust normalization/clipping ranges from training data. 238 239 Parameters 240 ---------- 241 df : pandas.DataFrame or None, optional 242 DataFrame used to compute scaling. If None, `self.df` is used. 243 write : bool, optional 244 Unused placeholder for interface compatibility with other components. 245 Kept for potential future extensions. Default is True. 246 247 Returns 248 ------- 249 dict 250 Mapping `{column_name: [min_value, max_value]}`, where values 251 are derived from the 0.05 and 0.95 quantiles. 252 253 Notes 254 ----- 255 If `self.debug` is True, the method emits debug messages about the 256 data source. Only training data should be used to avoid leakage. 257 """ 258 if self.debug: 259 print("[DEBUG] Make sure to provide only training data to compute_scaling!") 260 if df is None: 261 df = self.df 262 if self.debug: 263 print("[DEBUG] No DataFrame provided, using internal df.") 264 quantiles = df.quantile([0.05, 0.95]) 265 min_vals = quantiles.loc[0.05] 266 max_vals = quantiles.loc[0.95] 267 minmax_dict = pd.concat([min_vals, max_vals], axis=1).T.to_dict('list') 268 return minmax_dict 269 270 def summary(self): 271 """ 272 Print a structured overview of the dataset and configuration. 273 274 The summary includes: 275 276 1. DataFrame information: 277 - Shape 278 - Columns 279 - Head (first rows) 280 - Dtypes 281 - Descriptive statistics 282 283 2. Configuration overview: 284 - Number of nodes 285 - Loader mode (DataLoader vs. raw Dataset) 286 - Precomputation status 287 288 3. Node settings: 289 - Batch size 290 - Shuffle flag 291 - num_workers 292 - pin_memory 293 - return_intercept_shift 294 - debug 295 - transform 296 297 4. DataLoader overview: 298 - Type and length of each loader. 299 300 Parameters 301 ---------- 302 None 303 304 Returns 305 ------- 306 None 307 308 Notes 309 ----- 310 Intended for quick inspection and debugging. Uses `print` statements 311 and does not return structured metadata. 312 """ 313 314 print("\n[TramDagDataset Summary]") 315 print("=" * 60) 316 317 print("\n[DataFrame]") 318 print("Shape:", self.df.shape) 319 print("Columns:", list(self.df.columns)) 320 print("\nHead:") 321 print(self.df.head()) 322 323 print("\nDtypes:") 324 print(self.df.dtypes) 325 326 print("\nDescribe:") 327 print(self.df.describe(include="all")) 328 329 print("\n[Configuration]") 330 print(f"Nodes: {len(self.nodes_dict)}") 331 print(f"Loader mode: {'DataLoader' if self.use_dataloader else 'Direct dataset'}") 332 print(f"Precomputed: {getattr(self, 'use_precomputed', False)}") 333 334 print("\n[Node Settings]") 335 for node in self.nodes_dict.keys(): 336 batch_size = self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size 337 shuffle_flag = self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle 338 num_workers = self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers 339 pin_memory = self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory 340 rshift = self.return_intercept_shift[node] if isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift 341 debug_flag = self.debug[node] if isinstance(self.debug, dict) else self.debug 342 transform = self.transform[node] if isinstance(self.transform, dict) else self.transform 343 344 print( 345 f" Node '{node}': " 346 f"batch_size={batch_size}, " 347 f"shuffle={shuffle_flag}, " 348 f"num_workers={num_workers}, " 349 f"pin_memory={pin_memory}, " 350 f"return_intercept_shift={rshift}, " 351 f"debug={debug_flag}, " 352 f"transform={transform}" 353 ) 354 355 if hasattr(self, "loaders"): 356 print("\n[DataLoaders]") 357 for node, loader in self.loaders.items(): 358 try: 359 length = len(loader) 360 except Exception: 361 length = "?" 362 print(f" {node}: {type(loader).__name__}, len={length}") 363 364 print("=" * 60 + "\n") 365 366 def _validate_kwargs(self, kwargs: dict, defaults_attr: str = "DEFAULTS", context: str = None): 367 """ 368 Validate keyword arguments against a defaults dictionary. 369 370 Parameters 371 ---------- 372 kwargs : dict 373 Keyword arguments to validate. 374 defaults_attr : str, optional 375 Name of the attribute on this class containing the default keys 376 (e.g. "DEFAULTS"). Default is "DEFAULTS". 377 context : str or None, optional 378 Optional context label (e.g. caller name) included in error messages. 379 380 Raises 381 ------ 382 AttributeError 383 If the attribute named by `defaults_attr` does not exist. 384 ValueError 385 If any keys in `kwargs` are not present in the defaults dictionary. 386 """ 387 defaults = getattr(self, defaults_attr, None) 388 if defaults is None: 389 raise AttributeError(f"{self.__class__.__name__} has no attribute '{defaults_attr}'") 390 391 unknown = set(kwargs) - set(defaults) 392 if unknown: 393 prefix = f"[{context}] " if context else "" 394 raise ValueError(f"{prefix}Unknown parameter(s): {', '.join(sorted(unknown))}") 395 396 def _apply_settings(self, settings: dict): 397 """ 398 Apply resolved settings to the dataset instance. 399 400 This method: 401 402 1. Stores all key–value pairs from `settings` as attributes on `self`. 403 2. Extracts `nodes_dict` from the configuration. 404 3. Validates that dict-valued core attributes (batch_size, shuffle, etc.) 405 have keys matching the node set. 406 407 Parameters 408 ---------- 409 settings : dict 410 Resolved settings dictionary, usually built from `DEFAULTS` plus 411 user overrides. 412 413 Returns 414 ------- 415 None 416 417 Raises 418 ------ 419 ValueError 420 If any dict-valued core attribute has keys that do not match 421 `cfg.conf_dict["nodes"].keys()`. 422 """ 423 for k, v in settings.items(): 424 setattr(self, k, v) 425 426 self.nodes_dict = self.cfg.conf_dict["nodes"] 427 428 # validate only the most important ones 429 for name in ["batch_size", "shuffle", "num_workers", "pin_memory", 430 "return_intercept_shift", "debug", "transform"]: 431 self._check_keys(name, getattr(self, name)) 432 433 def _build_dataloaders(self): 434 """Build node-specific dataloaders or raw datasets depending on settings.""" 435 self.loaders = {} 436 for node in self.nodes_dict: 437 ds = GenericDataset( 438 self.df, 439 target_col=node, 440 target_nodes=self.nodes_dict, 441 transform=self.transform if not isinstance(self.transform, dict) else self.transform[node], 442 return_intercept_shift=self.return_intercept_shift if not isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift[node], 443 debug=self.debug if not isinstance(self.debug, dict) else self.debug[node], 444 ) 445 446 ########## QUICK PATCH 447 if hasattr(self, "use_precomputed") and self.use_precomputed: 448 os.makedirs("temp", exist_ok=True) 449 pth = os.path.join("temp", "precomputed.pt") 450 451 if hasattr(ds, "save_precomputed") and callable(getattr(ds, "save_precomputed")): 452 ds.save_precomputed(pth) 453 ds = GenericDatasetPrecomputed(pth) 454 else: 455 print("[WARNING] Dataset has no 'save_precomputed()' method — skipping precomputation.") 456 457 458 if self.use_dataloader: 459 # resolve per-node overrides 460 kwargs = { 461 "batch_size": self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size, 462 "shuffle": self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle, 463 "num_workers": self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers, 464 "pin_memory": self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory, 465 "sampler": self.sampler, 466 "batch_sampler": self.batch_sampler, 467 "collate_fn": self.collate_fn, 468 "drop_last": self.drop_last, 469 "timeout": self.timeout, 470 "worker_init_fn": self.worker_init_fn, 471 "multiprocessing_context": self.multiprocessing_context, 472 "generator": self.generator, 473 "prefetch_factor": self.prefetch_factor, 474 "persistent_workers": self.persistent_workers, 475 "pin_memory_device": self.pin_memory_device, 476 } 477 self.loaders[node] = DataLoader(ds, **kwargs) 478 else: 479 self.loaders[node] = ds 480 481 if hasattr(self, "use_precomputed") and self.use_precomputed: 482 if os.path.exists(pth): 483 try: 484 os.remove(pth) 485 if self.debug: 486 print(f"[INFO] Removed existing precomputed file: {pth}") 487 except Exception as e: 488 print(f"[WARNING] Could not remove {pth}: {e}") 489 490 def _check_keys(self, attr_name, attr_value): 491 """ 492 Check that dict-valued attributes use node names as keys. 493 494 Parameters 495 ---------- 496 attr_name : str 497 Name of the attribute being checked (for error messages). 498 attr_value : Any 499 Attribute value. If it is a dict, its keys are validated. 500 501 Returns 502 ------- 503 None 504 505 Raises 506 ------ 507 ValueError 508 If `attr_value` is a dict and its keys do not exactly match 509 `cfg.conf_dict["nodes"].keys()`. 510 511 Notes 512 ----- 513 This check prevents partial or mismatched per-node settings such as 514 batch sizes or shuffle flags. 515 """ 516 if isinstance(attr_value, dict): 517 expected_keys = set(self.nodes_dict.keys()) 518 given_keys = set(attr_value.keys()) 519 if expected_keys != given_keys: 520 raise ValueError( 521 f"[ERROR] the provided attribute '{attr_name}' keys are not same as in cfg.conf_dict['nodes'].keys().\n" 522 f"Expected: {expected_keys}, but got: {given_keys}\n" 523 f"Please provide values for all variables." 524 ) 525 526 def __getitem__(self, idx): 527 return self.df.iloc[idx].to_dict() 528 529 def __len__(self): 530 return len(self.df)
TramDagDataset
The TramDagDataset class handles structured data preparation for TRAM-DAG
models. It wraps a pandas DataFrame together with its configuration and provides
utilities for scaling, transformation, and efficient DataLoader construction
for each node in a DAG-based configuration.
Core Responsibilities
- Validate and store configuration metadata (
TramDagConfig). - Manage per-node settings for DataLoader creation (batch size, shuffling, workers).
- Compute scaling information (quantile-based min/max).
- Optionally precompute and cache dataset representations.
- Expose PyTorch Dataset and DataLoader interfaces for model training.
Key Attributes
df : pandas.DataFrame
The dataset content used for building loaders and computing scaling.cfg : TramDagConfig
Configuration object defining nodes and variable metadata.nodes_dict : dict
Mapping of variable names to node specifications from the configuration.loaders : dict
Mapping of node names totorch.utils.data.DataLoaderinstances orGenericDatasetobjects.DEFAULTS : dict
Default DataLoader and dataset-related settings (e.g., batch_size, shuffle, num_workers, etc.).
Main Methods
from_dataframe(df, cfg, **kwargs)
Construct the dataset directly from a pandas DataFrame.compute_scaling(df=None, write=True)
Compute per-variable min/max scaling values from data.summary()
Print dataset overview including shape, dtypes, statistics, and node settings.
Notes
- Intended for training data;
compute_scaling()should use only training subsets. - Compatible with both CPU and GPU DataLoader options.
- Strict validation of keyword arguments against
DEFAULTSprevents silent misconfiguration.
Example
>>> cfg = TramDagConfig.from_json("config.json")
>>> dataset = TramDagDataset.from_dataframe(train_df, cfg, batch_size=1024, debug=True)
>>> dataset.summary()
>>> minmax = dataset.compute_scaling(train_df)
>>> loader = dataset.loaders["variable_x"]
>>> next(iter(loader))
121 def __init__(self): 122 """ 123 Initialize an empty TramDagDataset shell. 124 125 Notes 126 ----- 127 This constructor does not attach data or configuration. Use 128 `TramDagDataset.from_dataframe` to obtain a ready-to-use instance. 129 """ 130 pass
Initialize an empty TramDagDataset shell.
Notes
This constructor does not attach data or configuration. Use
TramDagDataset.from_dataframe to obtain a ready-to-use instance.
132 @classmethod 133 def from_dataframe(cls, df, cfg, **kwargs): 134 """ 135 Create a TramDagDataset instance directly from a pandas DataFrame. 136 137 This classmethod: 138 139 1. Validates keyword arguments against `DEFAULTS`. 140 2. Merges user overrides with defaults into a resolved settings dict. 141 3. Stores the configuration and verifies its completeness. 142 4. Applies settings to the instance. 143 5. Builds per-node datasets and DataLoaders. 144 145 Parameters 146 ---------- 147 df : pandas.DataFrame 148 Input DataFrame containing the dataset. 149 cfg : TramDagConfig 150 Configuration object defining nodes and variable metadata. 151 **kwargs 152 Optional overrides for `DEFAULTS`. All keys must exist in 153 `TramDagDataset.DEFAULTS`. Common keys include: 154 155 batch_size : int 156 Batch size for DataLoaders. 157 shuffle : bool 158 Whether to shuffle samples per epoch. 159 num_workers : int 160 Number of DataLoader workers. 161 pin_memory : bool 162 Whether to pin memory for faster host-to-device transfers. 163 return_intercept_shift : bool 164 Whether datasets should return intercept/shift information. 165 debug : bool 166 Enable debug printing. 167 transform : callable or dict or None 168 Optional transform(s) applied to samples. 169 use_dataloader : bool 170 If True, construct DataLoaders; else store raw Dataset objects. 171 use_precomputed : bool 172 If True, precompute dataset representation to disk and reload it. 173 174 Returns 175 ------- 176 TramDagDataset 177 Initialized dataset instance. 178 179 Raises 180 ------ 181 TypeError 182 If `df` is not a pandas DataFrame. 183 ValueError 184 If unknown keyword arguments are provided (when validation is enabled). 185 186 Notes 187 ----- 188 If `shuffle=True` and the inferred variable name of `df` suggests 189 validation/test data (e.g. "val", "test"), a warning is printed. 190 """ 191 self = cls() 192 if not isinstance(df, pd.DataFrame): 193 raise TypeError(f"[ERROR] df must be a pandas DataFrame, but got {type(df)}") 194 195 # validate kwargs 196 #self._validate_kwargs(kwargs, context="from_dataframe") 197 198 # merge defaults with overrides 199 settings = dict(cls.DEFAULTS) 200 settings.update(kwargs) 201 202 # store config and verify 203 self.cfg = cfg 204 self.cfg._verify_completeness() 205 206 # ouptu all setttings if debug 207 if settings.get("debug", False): 208 print("[DEBUG] TramDagDataset.from_dataframe() settings (after defaults + overrides):") 209 for k, v in settings.items(): 210 print(f" {k}: {v}") 211 212 # infer variable name automatically 213 callers_locals = inspect.currentframe().f_back.f_locals 214 inferred = None 215 for var_name, var_val in callers_locals.items(): 216 if var_val is df: 217 inferred = var_name 218 break 219 df_name = inferred or "dataframe" 220 221 if settings["shuffle"]: 222 if any(x in df_name.lower() for x in ["val", "validation", "test"]): 223 print(f"[WARNING] DataFrame '{df_name}' looks like a validation/test set → shuffle=True. Are you sure?") 224 225 # call again to ensure Warning messages if ordinal vars have missing levels 226 self.df = df.copy() 227 self._apply_settings(settings) 228 self._build_dataloaders() 229 return self
Create a TramDagDataset instance directly from a pandas DataFrame.
This classmethod:
- Validates keyword arguments against
DEFAULTS. - Merges user overrides with defaults into a resolved settings dict.
- Stores the configuration and verifies its completeness.
- Applies settings to the instance.
- Builds per-node datasets and DataLoaders.
Parameters
df : pandas.DataFrame
Input DataFrame containing the dataset.
cfg : TramDagConfig
Configuration object defining nodes and variable metadata.
**kwargs
Optional overrides for DEFAULTS. All keys must exist in
TramDagDataset.DEFAULTS. Common keys include:
batch_size : int
Batch size for DataLoaders.
shuffle : bool
Whether to shuffle samples per epoch.
num_workers : int
Number of DataLoader workers.
pin_memory : bool
Whether to pin memory for faster host-to-device transfers.
return_intercept_shift : bool
Whether datasets should return intercept/shift information.
debug : bool
Enable debug printing.
transform : callable or dict or None
Optional transform(s) applied to samples.
use_dataloader : bool
If True, construct DataLoaders; else store raw Dataset objects.
use_precomputed : bool
If True, precompute dataset representation to disk and reload it.
Returns
TramDagDataset Initialized dataset instance.
Raises
TypeError
If df is not a pandas DataFrame.
ValueError
If unknown keyword arguments are provided (when validation is enabled).
Notes
If shuffle=True and the inferred variable name of df suggests
validation/test data (e.g. "val", "test"), a warning is printed.
231 def compute_scaling(self, df: pd.DataFrame = None, write: bool = True): 232 """ 233 Compute variable-wise scaling parameters from data. 234 235 Per variable, this method computes approximate minimum and maximum 236 values using the 5th and 95th percentiles. This is typically used 237 to derive robust normalization/clipping ranges from training data. 238 239 Parameters 240 ---------- 241 df : pandas.DataFrame or None, optional 242 DataFrame used to compute scaling. If None, `self.df` is used. 243 write : bool, optional 244 Unused placeholder for interface compatibility with other components. 245 Kept for potential future extensions. Default is True. 246 247 Returns 248 ------- 249 dict 250 Mapping `{column_name: [min_value, max_value]}`, where values 251 are derived from the 0.05 and 0.95 quantiles. 252 253 Notes 254 ----- 255 If `self.debug` is True, the method emits debug messages about the 256 data source. Only training data should be used to avoid leakage. 257 """ 258 if self.debug: 259 print("[DEBUG] Make sure to provide only training data to compute_scaling!") 260 if df is None: 261 df = self.df 262 if self.debug: 263 print("[DEBUG] No DataFrame provided, using internal df.") 264 quantiles = df.quantile([0.05, 0.95]) 265 min_vals = quantiles.loc[0.05] 266 max_vals = quantiles.loc[0.95] 267 minmax_dict = pd.concat([min_vals, max_vals], axis=1).T.to_dict('list') 268 return minmax_dict
Compute variable-wise scaling parameters from data.
Per variable, this method computes approximate minimum and maximum values using the 5th and 95th percentiles. This is typically used to derive robust normalization/clipping ranges from training data.
Parameters
df : pandas.DataFrame or None, optional
DataFrame used to compute scaling. If None, self.df is used.
write : bool, optional
Unused placeholder for interface compatibility with other components.
Kept for potential future extensions. Default is True.
Returns
dict
Mapping {column_name: [min_value, max_value]}, where values
are derived from the 0.05 and 0.95 quantiles.
Notes
If self.debug is True, the method emits debug messages about the
data source. Only training data should be used to avoid leakage.
270 def summary(self): 271 """ 272 Print a structured overview of the dataset and configuration. 273 274 The summary includes: 275 276 1. DataFrame information: 277 - Shape 278 - Columns 279 - Head (first rows) 280 - Dtypes 281 - Descriptive statistics 282 283 2. Configuration overview: 284 - Number of nodes 285 - Loader mode (DataLoader vs. raw Dataset) 286 - Precomputation status 287 288 3. Node settings: 289 - Batch size 290 - Shuffle flag 291 - num_workers 292 - pin_memory 293 - return_intercept_shift 294 - debug 295 - transform 296 297 4. DataLoader overview: 298 - Type and length of each loader. 299 300 Parameters 301 ---------- 302 None 303 304 Returns 305 ------- 306 None 307 308 Notes 309 ----- 310 Intended for quick inspection and debugging. Uses `print` statements 311 and does not return structured metadata. 312 """ 313 314 print("\n[TramDagDataset Summary]") 315 print("=" * 60) 316 317 print("\n[DataFrame]") 318 print("Shape:", self.df.shape) 319 print("Columns:", list(self.df.columns)) 320 print("\nHead:") 321 print(self.df.head()) 322 323 print("\nDtypes:") 324 print(self.df.dtypes) 325 326 print("\nDescribe:") 327 print(self.df.describe(include="all")) 328 329 print("\n[Configuration]") 330 print(f"Nodes: {len(self.nodes_dict)}") 331 print(f"Loader mode: {'DataLoader' if self.use_dataloader else 'Direct dataset'}") 332 print(f"Precomputed: {getattr(self, 'use_precomputed', False)}") 333 334 print("\n[Node Settings]") 335 for node in self.nodes_dict.keys(): 336 batch_size = self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size 337 shuffle_flag = self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle 338 num_workers = self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers 339 pin_memory = self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory 340 rshift = self.return_intercept_shift[node] if isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift 341 debug_flag = self.debug[node] if isinstance(self.debug, dict) else self.debug 342 transform = self.transform[node] if isinstance(self.transform, dict) else self.transform 343 344 print( 345 f" Node '{node}': " 346 f"batch_size={batch_size}, " 347 f"shuffle={shuffle_flag}, " 348 f"num_workers={num_workers}, " 349 f"pin_memory={pin_memory}, " 350 f"return_intercept_shift={rshift}, " 351 f"debug={debug_flag}, " 352 f"transform={transform}" 353 ) 354 355 if hasattr(self, "loaders"): 356 print("\n[DataLoaders]") 357 for node, loader in self.loaders.items(): 358 try: 359 length = len(loader) 360 except Exception: 361 length = "?" 362 print(f" {node}: {type(loader).__name__}, len={length}") 363 364 print("=" * 60 + "\n")
Print a structured overview of the dataset and configuration.
The summary includes:
- DataFrame information:
- Shape
- Columns
- Head (first rows)
- Dtypes
- Descriptive statistics
- Configuration overview:
- Number of nodes
- Loader mode (DataLoader vs. raw Dataset)
- Precomputation status
- Node settings:
- Batch size
- Shuffle flag
- num_workers
- pin_memory
- return_intercept_shift
- debug
- transform
- DataLoader overview:
- Type and length of each loader.
Parameters
None
Returns
None
Notes
Intended for quick inspection and debugging. Uses print statements
and does not return structured metadata.
64class TramDagModel: 65 """ 66 Probabilistic DAG model built from node-wise TRAMs (transformation models). 67 68 This class manages: 69 - Configuration and per-node model construction. 70 - Data scaling (min–max). 71 - Training (sequential or per-node parallel on CPU). 72 - Diagnostics (loss history, intercepts, linear shifts, latents). 73 - Sampling from the joint DAG and loading stored samples. 74 - High-level summaries and plotting utilities. 75 """ 76 77 # ---- defaults used at construction time ---- 78 DEFAULTS_CONFIG = { 79 "set_initial_weights": False, 80 "debug":False, 81 "verbose": False, 82 "device":'auto', 83 "initial_data":None, 84 "overwrite_initial_weights": True, 85 } 86 87 # ---- defaults used at fit() time ---- 88 DEFAULTS_FIT = { 89 "epochs": 100, 90 "train_list": None, 91 "callbacks": None, 92 "learning_rate": 0.01, 93 "device": "auto", 94 "optimizers": None, 95 "schedulers": None, 96 "use_scheduler": False, 97 "save_linear_shifts": True, 98 "save_simple_intercepts": True, 99 "debug":False, 100 "verbose": True, 101 "train_mode": "sequential", # or "parallel" 102 "return_history": False, 103 "overwrite_inital_weights": True, 104 "num_workers" : 4, 105 "persistent_workers" : True, 106 "prefetch_factor" : 4, 107 "batch_size":1000, 108 109 } 110 111 def __init__(self): 112 """ 113 Initialize an empty TramDagModel shell. 114 115 Notes 116 ----- 117 This constructor does not build any node models and does not attach a 118 configuration. Use `TramDagModel.from_config` or `TramDagModel.from_directory` 119 to obtain a fully configured and ready-to-use instance. 120 """ 121 122 self.debug = False 123 self.verbose = False 124 self.device = 'auto' 125 pass 126 127 @staticmethod 128 def get_device(settings): 129 """ 130 Resolve the target device string from a settings dictionary. 131 132 Parameters 133 ---------- 134 settings : dict 135 Dictionary containing at least a key ``"device"`` with one of 136 {"auto", "cpu", "cuda"}. If missing, "auto" is assumed. 137 138 Returns 139 ------- 140 str 141 Device string, either "cpu" or "cuda". 142 143 Notes 144 ----- 145 If ``device == "auto"``, CUDA is selected if available, otherwise CPU. 146 """ 147 device_arg = settings.get("device", "auto") 148 if device_arg == "auto": 149 device_str = "cuda" if torch.cuda.is_available() else "cpu" 150 else: 151 device_str = device_arg 152 return device_str 153 154 def _validate_kwargs(self, kwargs: dict, defaults_attr: str = "DEFAULTS_FIT", context: str = None): 155 """ 156 Validate a kwargs dictionary against a class-level defaults dictionary. 157 158 Parameters 159 ---------- 160 kwargs : dict 161 Keyword arguments to validate. 162 defaults_attr : str, optional 163 Name of the attribute on this class that contains the allowed keys, 164 e.g. ``"DEFAULTS_CONFIG"`` or ``"DEFAULTS_FIT"``. Default is "DEFAULTS_FIT". 165 context : str or None, optional 166 Optional label (e.g. caller name) to prepend in error messages. 167 168 Raises 169 ------ 170 AttributeError 171 If the attribute named by ``defaults_attr`` does not exist. 172 ValueError 173 If any key in ``kwargs`` is not present in the corresponding defaults dict. 174 """ 175 defaults = getattr(self, defaults_attr, None) 176 if defaults is None: 177 raise AttributeError(f"{self.__class__.__name__} has no attribute '{defaults_attr}'") 178 179 unknown = set(kwargs) - set(defaults) 180 if unknown: 181 prefix = f"[{context}] " if context else "" 182 raise ValueError(f"{prefix}Unknown parameter(s): {', '.join(sorted(unknown))}") 183 184 ## CREATE A TRAMDADMODEL 185 @classmethod 186 def from_config(cls, cfg, **kwargs): 187 """ 188 Construct a TramDagModel from a TramDagConfig object. 189 190 This builds one TRAM model per node in the DAG and optionally writes 191 the initial model parameters to disk. 192 193 Parameters 194 ---------- 195 cfg : TramDagConfig 196 Configuration wrapper holding the underlying configuration dictionary, 197 including at least: 198 - ``conf_dict["nodes"]``: mapping of node names to node configs. 199 - ``conf_dict["PATHS"]["EXPERIMENT_DIR"]``: experiment directory. 200 **kwargs 201 Node-level construction options. Each key must be present in 202 ``DEFAULTS_CONFIG``. Values can be: 203 - scalar: applied to all nodes. 204 - dict: mapping ``{node_name: value}`` for per-node overrides. 205 206 Common keys include: 207 device : {"auto", "cpu", "cuda"}, default "auto" 208 Device selection (CUDA if available when "auto"). 209 debug : bool, default False 210 If True, print debug messages. 211 verbose : bool, default False 212 If True, print informational messages. 213 set_initial_weights : bool 214 Passed to underlying TRAM model constructors. 215 overwrite_initial_weights : bool, default True 216 If True, overwrite any existing ``initial_model.pt`` files per node. 217 initial_data : Any 218 Optional object passed down to node constructors. 219 220 Returns 221 ------- 222 TramDagModel 223 Fully initialized instance with: 224 - ``cfg`` 225 - ``nodes_dict`` 226 - ``models`` (per-node TRAMs) 227 - ``settings`` (resolved per-node config) 228 229 Raises 230 ------ 231 ValueError 232 If any dict-valued kwarg does not provide values for exactly the set 233 of nodes in ``cfg.conf_dict["nodes"]``. 234 """ 235 236 self = cls() 237 self.cfg = cfg 238 self.cfg.update() # ensure latest version from disk 239 self.cfg._verify_completeness() 240 241 242 try: 243 self.cfg.save() # persist back to disk 244 if getattr(self, "debug", False): 245 print("[DEBUG] Configuration updated and saved.") 246 except Exception as e: 247 print(f"[WARNING] Could not save configuration after update: {e}") 248 249 self.nodes_dict = self.cfg.conf_dict["nodes"] 250 251 self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_CONFIG', context="from_config") 252 253 # update defaults with kwargs 254 settings = dict(cls.DEFAULTS_CONFIG) 255 settings.update(kwargs) 256 257 # resolve device 258 device_arg = settings.get("device", "auto") 259 if device_arg == "auto": 260 device_str = "cuda" if torch.cuda.is_available() else "cpu" 261 else: 262 device_str = device_arg 263 self.device = torch.device(device_str) 264 265 # set flags on the instance so they are accessible later 266 self.debug = settings.get("debug", False) 267 self.verbose = settings.get("verbose", False) 268 269 if self.debug: 270 print(f"[DEBUG] TramDagModel using device: {self.device}") 271 272 # initialize settings storage 273 self.settings = {k: {} for k in settings.keys()} 274 275 # validate dict-typed args 276 for k, v in settings.items(): 277 if isinstance(v, dict): 278 expected = set(self.nodes_dict.keys()) 279 given = set(v.keys()) 280 if expected != given: 281 raise ValueError( 282 f"[ERROR] the provided argument '{k}' keys are not same as in cfg.conf_dict['nodes'].keys().\n" 283 f"Expected: {expected}, but got: {given}\n" 284 f"Please provide values for all variables.") 285 286 # build one model per node 287 self.models = {} 288 for node in self.nodes_dict.keys(): 289 per_node_kwargs = {} 290 for k, v in settings.items(): 291 resolved = v[node] if isinstance(v, dict) else v 292 per_node_kwargs[k] = resolved 293 self.settings[k][node] = resolved 294 if self.debug: 295 print(f"\n[INFO] Building model for node '{node}' with settings: {per_node_kwargs}") 296 self.models[node] = get_fully_specified_tram_model( 297 node=node, 298 configuration_dict=self.cfg.conf_dict, 299 **per_node_kwargs) 300 301 try: 302 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 303 NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}") 304 os.makedirs(NODE_DIR, exist_ok=True) 305 306 model_path = os.path.join(NODE_DIR, "initial_model.pt") 307 overwrite = settings.get("overwrite_initial_weights", True) 308 309 if overwrite or not os.path.exists(model_path): 310 torch.save(self.models[node].state_dict(), model_path) 311 if self.debug: 312 print(f"[DEBUG] Saved initial model state for node '{node}' to {model_path} (overwrite={overwrite})") 313 else: 314 if self.debug: 315 print(f"[DEBUG] Skipped saving initial model for node '{node}' (already exists at {model_path})") 316 except Exception as e: 317 print(f"[ERROR] Could not save initial model state for node '{node}': {e}") 318 319 TEMP_DIR = "temp" 320 if os.path.isdir(TEMP_DIR) and not os.listdir(TEMP_DIR): 321 os.rmdir(TEMP_DIR) 322 323 return self 324 325 @classmethod 326 def from_directory(cls, EXPERIMENT_DIR: str, device: str = "auto", debug: bool = False, verbose: bool = False): 327 """ 328 Reconstruct a TramDagModel from an experiment directory on disk. 329 330 This method: 331 1. Loads the configuration JSON. 332 2. Wraps it in a TramDagConfig. 333 3. Builds all node models via `from_config`. 334 4. Loads the min–max scaling dictionary. 335 336 Parameters 337 ---------- 338 EXPERIMENT_DIR : str 339 Path to an experiment directory containing: 340 - ``configuration.json`` 341 - ``min_max_scaling.json``. 342 device : {"auto", "cpu", "cuda"}, optional 343 Device selection. Default is "auto". 344 debug : bool, optional 345 If True, enable debug messages. Default is False. 346 verbose : bool, optional 347 If True, enable informational messages. Default is False. 348 349 Returns 350 ------- 351 TramDagModel 352 A TramDagModel instance with models, config, and scaling loaded. 353 354 Raises 355 ------ 356 FileNotFoundError 357 If configuration or min–max files cannot be found. 358 RuntimeError 359 If the min–max file cannot be read or parsed. 360 """ 361 362 # --- load config file --- 363 config_path = os.path.join(EXPERIMENT_DIR, "configuration.json") 364 if not os.path.exists(config_path): 365 raise FileNotFoundError(f"[ERROR] Config file not found at {config_path}") 366 367 with open(config_path, "r") as f: 368 cfg_dict = json.load(f) 369 370 # Create TramConfig wrapper 371 cfg = TramDagConfig(cfg_dict, CONF_DICT_PATH=config_path) 372 373 # --- build model from config --- 374 self = cls.from_config(cfg, device=device, debug=debug, verbose=verbose, overwrite_initial_weights=False) 375 376 # --- load minmax scaling --- 377 minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json") 378 if not os.path.exists(minmax_path): 379 raise FileNotFoundError(f"[ERROR] MinMax file not found at {minmax_path}") 380 381 with open(minmax_path, "r") as f: 382 self.minmax_dict = json.load(f) 383 384 if self.verbose or self.debug: 385 print(f"[INFO] Loaded TramDagModel from {EXPERIMENT_DIR}") 386 print(f"[INFO] Config loaded from {config_path}") 387 print(f"[INFO] MinMax scaling loaded from {minmax_path}") 388 389 return self 390 391 def _ensure_dataset(self, data, is_val=False,**kwargs): 392 """ 393 Ensure that the input data is represented as a TramDagDataset. 394 395 Parameters 396 ---------- 397 data : pandas.DataFrame, TramDagDataset, or None 398 Input data to be converted or passed through. 399 is_val : bool, optional 400 If True, the resulting dataset is treated as validation data 401 (e.g. no shuffling). Default is False. 402 **kwargs 403 Additional keyword arguments passed through to 404 ``TramDagDataset.from_dataframe``. 405 406 Returns 407 ------- 408 TramDagDataset or None 409 A TramDagDataset if ``data`` is a DataFrame or TramDagDataset, 410 otherwise None if ``data`` is None. 411 412 Raises 413 ------ 414 TypeError 415 If ``data`` is not a DataFrame, TramDagDataset, or None. 416 """ 417 418 if isinstance(data, pd.DataFrame): 419 return TramDagDataset.from_dataframe(data, self.cfg, shuffle=not is_val,**kwargs) 420 elif isinstance(data, TramDagDataset): 421 return data 422 elif data is None: 423 return None 424 else: 425 raise TypeError( 426 f"[ERROR] data must be pd.DataFrame, TramDagDataset, or None, got {type(data)}" 427 ) 428 429 def load_or_compute_minmax(self, td_train_data=None,use_existing=False, write=True): 430 """ 431 Load an existing Min–Max scaling dictionary from disk or compute a new one 432 from the provided training dataset. 433 434 Parameters 435 ---------- 436 use_existing : bool, optional (default=False) 437 If True, attempts to load an existing `min_max_scaling.json` file 438 from the experiment directory. Raises an error if the file is missing 439 or unreadable. 440 441 write : bool, optional (default=True) 442 If True, writes the computed Min–Max scaling dictionary to 443 `<EXPERIMENT_DIR>/min_max_scaling.json`. 444 445 td_train_data : object, optional 446 Training dataset used to compute scaling statistics. If not provided, 447 the method will ensure or construct it via `_ensure_dataset(data=..., is_val=False)`. 448 449 Behavior 450 -------- 451 - If `use_existing=True`, loads the JSON file containing previously saved 452 min–max values and stores it in `self.minmax_dict`. 453 - If `use_existing=False`, computes a new scaling dictionary using 454 `td_train_data.compute_scaling()` and stores the result in 455 `self.minmax_dict`. 456 - Optionally writes the computed dictionary to disk. 457 458 Side Effects 459 ------------- 460 - Populates `self.minmax_dict` with scaling values. 461 - Writes or loads the file `min_max_scaling.json` under 462 `<EXPERIMENT_DIR>`. 463 - Prints diagnostic output if `self.debug` or `self.verbose` is True. 464 465 Raises 466 ------ 467 FileNotFoundError 468 If `use_existing=True` but the min–max file does not exist. 469 470 RuntimeError 471 If an existing min–max file cannot be read or parsed. 472 473 Notes 474 ----- 475 The computed min–max dictionary is expected to contain scaling statistics 476 per feature, typically in the form: 477 { 478 "node": {"min": float, "max": float}, 479 ... 480 } 481 """ 482 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 483 minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json") 484 485 # laod exisitng if possible 486 if use_existing: 487 if not os.path.exists(minmax_path): 488 raise FileNotFoundError(f"MinMax file not found: {minmax_path}") 489 try: 490 with open(minmax_path, 'r') as f: 491 self.minmax_dict = json.load(f) 492 if self.debug or self.verbose: 493 print(f"[INFO] Loaded existing minmax dict from {minmax_path}") 494 return 495 except Exception as e: 496 raise RuntimeError(f"Could not load existing minmax dict: {e}") 497 498 # 499 if self.debug or self.verbose: 500 print("[INFO] Computing new minmax dict from training data...") 501 502 td_train_data=self._ensure_dataset( data=td_train_data, is_val=False) 503 504 self.minmax_dict = td_train_data.compute_scaling() 505 506 if write: 507 os.makedirs(EXPERIMENT_DIR, exist_ok=True) 508 with open(minmax_path, 'w') as f: 509 json.dump(self.minmax_dict, f, indent=4) 510 if self.debug or self.verbose: 511 print(f"[INFO] Saved new minmax dict to {minmax_path}") 512 513 ## FIT METHODS 514 @staticmethod 515 def _fit_single_node(node, self_ref, settings, td_train_data, td_val_data, device_str): 516 """ 517 Train a single node model (helper for per-node training). 518 519 This method is designed to be called either from the main process 520 (sequential training) or from a joblib worker (parallel CPU training). 521 522 Parameters 523 ---------- 524 node : str 525 Name of the target node to train. 526 self_ref : TramDagModel 527 Reference to the TramDagModel instance containing models and config. 528 settings : dict 529 Training settings dictionary, typically derived from ``DEFAULTS_FIT`` 530 plus any user overrides. 531 td_train_data : TramDagDataset 532 Training dataset with node-specific DataLoaders in ``.loaders``. 533 td_val_data : TramDagDataset or None 534 Validation dataset or None. 535 device_str : str 536 Device string, e.g. "cpu" or "cuda". 537 538 Returns 539 ------- 540 tuple 541 A tuple ``(node, history)`` where: 542 node : str 543 Node name. 544 history : dict or Any 545 Training history as returned by ``train_val_loop``. 546 """ 547 torch.set_num_threads(1) # prevent thread oversubscription 548 549 model = self_ref.models[node] 550 551 # Resolve per-node settings 552 def _resolve(key): 553 val = settings[key] 554 return val[node] if isinstance(val, dict) else val 555 556 node_epochs = _resolve("epochs") 557 node_lr = _resolve("learning_rate") 558 node_debug = _resolve("debug") 559 node_save_linear_shifts = _resolve("save_linear_shifts") 560 save_simple_intercepts = _resolve("save_simple_intercepts") 561 node_verbose = _resolve("verbose") 562 563 # Optimizer & scheduler 564 if settings["optimizers"] and node in settings["optimizers"]: 565 optimizer = settings["optimizers"][node] 566 else: 567 optimizer = Adam(model.parameters(), lr=node_lr) 568 569 scheduler = settings["schedulers"].get(node, None) if settings["schedulers"] else None 570 571 # Data loaders 572 train_loader = td_train_data.loaders[node] 573 val_loader = td_val_data.loaders[node] if td_val_data else None 574 575 # Min-max scaling tensors 576 min_vals = torch.tensor(self_ref.minmax_dict[node][0], dtype=torch.float32) 577 max_vals = torch.tensor(self_ref.minmax_dict[node][1], dtype=torch.float32) 578 min_max = torch.stack([min_vals, max_vals], dim=0) 579 580 # Node directory 581 try: 582 EXPERIMENT_DIR = self_ref.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 583 NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}") 584 except Exception: 585 NODE_DIR = os.path.join("models", node) 586 print("[WARNING] No log directory specified in config, saving to default location.") 587 os.makedirs(NODE_DIR, exist_ok=True) 588 589 if node_verbose: 590 print(f"\n[INFO] Training node '{node}' for {node_epochs} epochs on {device_str} (pid={os.getpid()})") 591 592 # --- train --- 593 history = train_val_loop( 594 node=node, 595 target_nodes=self_ref.nodes_dict, 596 NODE_DIR=NODE_DIR, 597 tram_model=model, 598 train_loader=train_loader, 599 val_loader=val_loader, 600 epochs=node_epochs, 601 optimizer=optimizer, 602 use_scheduler=(scheduler is not None), 603 scheduler=scheduler, 604 save_linear_shifts=node_save_linear_shifts, 605 save_simple_intercepts=save_simple_intercepts, 606 verbose=node_verbose, 607 device=torch.device(device_str), 608 debug=node_debug, 609 min_max=min_max) 610 return node, history 611 612 def fit(self, train_data, val_data=None, **kwargs): 613 """ 614 Train TRAM models for all nodes in the DAG. 615 616 Coordinates dataset preparation, min–max scaling, and per-node training, 617 optionally in parallel on CPU. 618 619 Parameters 620 ---------- 621 train_data : pandas.DataFrame or TramDagDataset 622 Training data. If a DataFrame is given, it is converted into a 623 TramDagDataset using `_ensure_dataset`. 624 val_data : pandas.DataFrame or TramDagDataset or None, optional 625 Validation data. If a DataFrame is given, it is converted into a 626 TramDagDataset. If None, no validation loss is computed. 627 **kwargs 628 Overrides for ``DEFAULTS_FIT``. All keys must exist in 629 ``DEFAULTS_FIT``. Common options: 630 631 epochs : int, default 100 632 Number of training epochs per node. 633 learning_rate : float, default 0.01 634 Learning rate for the default Adam optimizer. 635 train_list : list of str or None, optional 636 List of node names to train. If None, all nodes are trained. 637 train_mode : {"sequential", "parallel"}, default "sequential" 638 Training mode. "parallel" uses joblib-based CPU multiprocessing. 639 GPU forces sequential mode. 640 device : {"auto", "cpu", "cuda"}, default "auto" 641 Device selection. 642 optimizers : dict or None 643 Optional mapping ``{node_name: optimizer}``. If provided for a 644 node, that optimizer is used instead of creating a new Adam. 645 schedulers : dict or None 646 Optional mapping ``{node_name: scheduler}``. 647 use_scheduler : bool 648 If True, enable scheduler usage in the training loop. 649 num_workers : int 650 DataLoader workers in sequential mode (ignored in parallel). 651 persistent_workers : bool 652 DataLoader persistence in sequential mode (ignored in parallel). 653 prefetch_factor : int 654 DataLoader prefetch factor (ignored in parallel). 655 batch_size : int 656 Batch size for all node DataLoaders. 657 debug : bool 658 Enable debug output. 659 verbose : bool 660 Enable informational logging. 661 return_history : bool 662 If True, return a history dict. 663 664 Returns 665 ------- 666 dict or None 667 If ``return_history=True``, a dictionary mapping each node name 668 to its training history. Otherwise, returns None. 669 670 Raises 671 ------ 672 ValueError 673 If ``train_mode`` is not "sequential" or "parallel". 674 """ 675 self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_FIT', context="fit") 676 677 # --- merge defaults --- 678 settings = dict(self.DEFAULTS_FIT) 679 settings.update(kwargs) 680 681 682 self.debug = settings.get("debug", False) 683 self.verbose = settings.get("verbose", False) 684 685 # --- resolve device --- 686 device_str=self.get_device(settings) 687 self.device = torch.device(device_str) 688 689 # --- training mode --- 690 train_mode = settings.get("train_mode", "sequential").lower() 691 if train_mode not in ("sequential", "parallel"): 692 raise ValueError("train_mode must be 'sequential' or 'parallel'") 693 694 # --- DataLoader safety logic --- 695 if train_mode == "parallel": 696 # if user passed loader paralleling params, warn and override 697 for flag in ("num_workers", "persistent_workers", "prefetch_factor"): 698 if flag in kwargs: 699 print(f"[WARNING] '{flag}' is ignored in parallel mode " 700 f"(disabled to prevent nested multiprocessing).") 701 # disable unsafe loader multiprocessing options 702 settings["num_workers"] = 0 703 settings["persistent_workers"] = False 704 settings["prefetch_factor"] = None 705 else: 706 # sequential mode → respect user DataLoader settings 707 if self.debug: 708 print("[DEBUG] Sequential mode: using DataLoader kwargs as provided.") 709 710 # --- which nodes to train --- 711 train_list = settings.get("train_list") or list(self.models.keys()) 712 713 714 # --- dataset prep (receives adjusted settings) --- 715 td_train_data = self._ensure_dataset(train_data, is_val=False, **settings) 716 td_val_data = self._ensure_dataset(val_data, is_val=True, **settings) 717 718 # --- normalization --- 719 self.load_or_compute_minmax(use_existing=False, write=True, td_train_data=td_train_data) 720 721 # --- print header --- 722 if self.verbose or self.debug: 723 print(f"[INFO] Training {len(train_list)} nodes ({train_mode}) on {device_str}") 724 725 # ====================================================================== 726 # Sequential mode safe for GPU or debugging) 727 # ====================================================================== 728 if train_mode == "sequential" or "cuda" in device_str: 729 if "cuda" in device_str and train_mode == "parallel": 730 print("[WARNING] GPU device detected — forcing sequential mode.") 731 results = {} 732 for node in train_list: 733 node, history = self._fit_single_node( 734 node, self, settings, td_train_data, td_val_data, device_str 735 ) 736 results[node] = history 737 738 739 # ====================================================================== 740 # parallel mode (CPU only) 741 # ====================================================================== 742 if train_mode == "parallel": 743 744 n_jobs = min(len(train_list), os.cpu_count() // 2 or 1) 745 if self.verbose or self.debug: 746 print(f"[INFO] Using {n_jobs} CPU workers for parallel node training") 747 parallel_outputs = Parallel( 748 n_jobs=n_jobs, 749 backend="loky",#loky, multiprocessing 750 verbose=10, 751 prefer="processes" 752 )(delayed(self._fit_single_node)(node, self, settings, td_train_data, td_val_data, device_str) for node in train_list ) 753 754 results = {node: hist for node, hist in parallel_outputs} 755 756 if settings.get("return_history", False): 757 return results 758 759 ## FIT-DIAGNOSTICS 760 def loss_history(self): 761 """ 762 Load training and validation loss history for all nodes. 763 764 Looks for per-node JSON files: 765 766 - ``EXPERIMENT_DIR/{node}/train_loss_hist.json`` 767 - ``EXPERIMENT_DIR/{node}/val_loss_hist.json`` 768 769 Returns 770 ------- 771 dict 772 A dictionary mapping node names to: 773 774 .. code-block:: python 775 776 { 777 "train": list or None, 778 "validation": list or None 779 } 780 781 where each list contains NLL values per epoch, or None if not found. 782 783 Raises 784 ------ 785 ValueError 786 If the experiment directory cannot be resolved from the configuration. 787 """ 788 try: 789 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 790 except KeyError: 791 raise ValueError( 792 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 793 "History retrieval requires experiment logs." 794 ) 795 796 all_histories = {} 797 for node in self.nodes_dict.keys(): 798 node_dir = os.path.join(EXPERIMENT_DIR, node) 799 train_path = os.path.join(node_dir, "train_loss_hist.json") 800 val_path = os.path.join(node_dir, "val_loss_hist.json") 801 802 node_hist = {} 803 804 # --- load train history --- 805 if os.path.exists(train_path): 806 try: 807 with open(train_path, "r") as f: 808 node_hist["train"] = json.load(f) 809 except Exception as e: 810 print(f"[WARNING] Could not load {train_path}: {e}") 811 node_hist["train"] = None 812 else: 813 node_hist["train"] = None 814 815 # --- load val history --- 816 if os.path.exists(val_path): 817 try: 818 with open(val_path, "r") as f: 819 node_hist["validation"] = json.load(f) 820 except Exception as e: 821 print(f"[WARNING] Could not load {val_path}: {e}") 822 node_hist["validation"] = None 823 else: 824 node_hist["validation"] = None 825 826 all_histories[node] = node_hist 827 828 if self.verbose or self.debug: 829 print(f"[INFO] Loaded training/validation histories for {len(all_histories)} nodes.") 830 831 return all_histories 832 833 def linear_shift_history(self): 834 """ 835 Load linear shift term histories for all nodes. 836 837 Each node history is expected in a JSON file named 838 ``linear_shifts_all_epochs.json`` under the node directory. 839 840 Returns 841 ------- 842 dict 843 A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame 844 contains linear shift weights across epochs. 845 846 Raises 847 ------ 848 ValueError 849 If the experiment directory cannot be resolved from the configuration. 850 851 Notes 852 ----- 853 If a history file is missing for a node, a warning is printed and the 854 node is omitted from the returned dictionary. 855 """ 856 histories = {} 857 try: 858 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 859 except KeyError: 860 raise ValueError( 861 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 862 "Cannot load histories without experiment directory." 863 ) 864 865 for node in self.nodes_dict.keys(): 866 node_dir = os.path.join(EXPERIMENT_DIR, node) 867 history_path = os.path.join(node_dir, "linear_shifts_all_epochs.json") 868 if os.path.exists(history_path): 869 histories[node] = pd.read_json(history_path) 870 else: 871 print(f"[WARNING] No linear shift history found for node '{node}' at {history_path}") 872 return histories 873 874 def simple_intercept_history(self): 875 """ 876 Load simple intercept histories for all nodes. 877 878 Each node history is expected in a JSON file named 879 ``simple_intercepts_all_epochs.json`` under the node directory. 880 881 Returns 882 ------- 883 dict 884 A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame 885 contains intercept weights across epochs. 886 887 Raises 888 ------ 889 ValueError 890 If the experiment directory cannot be resolved from the configuration. 891 892 Notes 893 ----- 894 If a history file is missing for a node, a warning is printed and the 895 node is omitted from the returned dictionary. 896 """ 897 histories = {} 898 try: 899 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 900 except KeyError: 901 raise ValueError( 902 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 903 "Cannot load histories without experiment directory." 904 ) 905 906 for node in self.nodes_dict.keys(): 907 node_dir = os.path.join(EXPERIMENT_DIR, node) 908 history_path = os.path.join(node_dir, "simple_intercepts_all_epochs.json") 909 if os.path.exists(history_path): 910 histories[node] = pd.read_json(history_path) 911 else: 912 print(f"[WARNING] No simple intercept history found for node '{node}' at {history_path}") 913 return histories 914 915 def get_latent(self, df, verbose=False): 916 """ 917 Compute latent representations for all nodes in the DAG. 918 919 Parameters 920 ---------- 921 df : pandas.DataFrame 922 Input data frame with columns corresponding to nodes in the DAG. 923 verbose : bool, optional 924 If True, print informational messages during latent computation. 925 Default is False. 926 927 Returns 928 ------- 929 pandas.DataFrame 930 DataFrame containing the original columns plus latent variables 931 for each node (e.g. columns named ``f"{node}_U"``). 932 933 Raises 934 ------ 935 ValueError 936 If the experiment directory is missing from the configuration or 937 if ``self.minmax_dict`` has not been set. 938 """ 939 try: 940 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 941 except KeyError: 942 raise ValueError( 943 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 944 "Latent extraction requires trained model checkpoints." 945 ) 946 947 # ensure minmax_dict is available 948 if not hasattr(self, "minmax_dict"): 949 raise ValueError( 950 "[ERROR] minmax_dict not found in the TramDagModel instance. " 951 "Either call .load_or_compute_minmax(td_train_data=train_df) or .fit() first." 952 ) 953 954 all_latents_df = create_latent_df_for_full_dag( 955 configuration_dict=self.cfg.conf_dict, 956 EXPERIMENT_DIR=EXPERIMENT_DIR, 957 df=df, 958 verbose=verbose, 959 min_max_dict=self.minmax_dict, 960 ) 961 962 return all_latents_df 963 964 965 ## PLOTTING FIT-DIAGNOSTICS 966 967 def plot_loss_history(self, variable: str = None): 968 """ 969 Plot training and validation loss evolution per node. 970 971 Parameters 972 ---------- 973 variable : str or None, optional 974 If provided, plot loss history for this node only. If None, plot 975 histories for all nodes that have both train and validation logs. 976 977 Returns 978 ------- 979 None 980 981 Notes 982 ----- 983 Two subplots are produced: 984 - Full epoch history. 985 - Last 10% of epochs (or only the last epoch if fewer than 5 epochs). 986 """ 987 988 histories = self.loss_history() 989 if not histories: 990 print("[WARNING] No loss histories found.") 991 return 992 993 # Select which nodes to plot 994 if variable is not None: 995 if variable not in histories: 996 raise ValueError(f"[ERROR] Node '{variable}' not found in histories.") 997 nodes_to_plot = [variable] 998 else: 999 nodes_to_plot = list(histories.keys()) 1000 1001 # Filter out nodes with no valid history 1002 nodes_to_plot = [ 1003 n for n in nodes_to_plot 1004 if histories[n].get("train") is not None and len(histories[n]["train"]) > 0 1005 and histories[n].get("validation") is not None and len(histories[n]["validation"]) > 0 1006 ] 1007 1008 if not nodes_to_plot: 1009 print("[WARNING] No valid histories found to plot.") 1010 return 1011 1012 plt.figure(figsize=(14, 12)) 1013 1014 # --- Full history (top plot) --- 1015 plt.subplot(2, 1, 1) 1016 for node in nodes_to_plot: 1017 node_hist = histories[node] 1018 train_hist, val_hist = node_hist["train"], node_hist["validation"] 1019 1020 epochs = range(1, len(train_hist) + 1) 1021 plt.plot(epochs, train_hist, label=f"{node} - train", linestyle="--") 1022 plt.plot(epochs, val_hist, label=f"{node} - val") 1023 1024 plt.title("Training and Validation NLL - Full History") 1025 plt.xlabel("Epoch") 1026 plt.ylabel("NLL") 1027 plt.legend() 1028 plt.grid(True) 1029 1030 # --- Last 10% of epochs (bottom plot) --- 1031 plt.subplot(2, 1, 2) 1032 for node in nodes_to_plot: 1033 node_hist = histories[node] 1034 train_hist, val_hist = node_hist["train"], node_hist["validation"] 1035 1036 total_epochs = len(train_hist) 1037 start_idx = total_epochs - 1 if total_epochs < 5 else int(total_epochs * 0.9) 1038 1039 epochs = range(start_idx + 1, total_epochs + 1) 1040 plt.plot(epochs, train_hist[start_idx:], label=f"{node} - train", linestyle="--") 1041 plt.plot(epochs, val_hist[start_idx:], label=f"{node} - val") 1042 1043 plt.title("Training and Validation NLL - Last 10% of Epochs (or Last Epoch if <5)") 1044 plt.xlabel("Epoch") 1045 plt.ylabel("NLL") 1046 plt.legend() 1047 plt.grid(True) 1048 1049 plt.tight_layout() 1050 plt.show() 1051 1052 def plot_linear_shift_history(self, data_dict=None, node=None, ref_lines=None): 1053 """ 1054 Plot the evolution of linear shift terms over epochs. 1055 1056 Parameters 1057 ---------- 1058 data_dict : dict or None, optional 1059 Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing shift 1060 weights across epochs. If None, `linear_shift_history()` is called. 1061 node : str or None, optional 1062 If provided, plot only this node. Otherwise, plot all nodes 1063 present in ``data_dict``. 1064 ref_lines : dict or None, optional 1065 Optional mapping ``{node_name: list of float}``. For each specified 1066 node, horizontal reference lines are drawn at the given values. 1067 1068 Returns 1069 ------- 1070 None 1071 1072 Notes 1073 ----- 1074 The function flattens nested list-like entries in the DataFrames to scalars, 1075 converts epoch labels to numeric, and then draws one line per shift term. 1076 """ 1077 1078 if data_dict is None: 1079 data_dict = self.linear_shift_history() 1080 if data_dict is None: 1081 raise ValueError("No shift history data provided or stored in the class.") 1082 1083 nodes = [node] if node else list(data_dict.keys()) 1084 1085 for n in nodes: 1086 df = data_dict[n].copy() 1087 1088 # Flatten nested lists or list-like cells 1089 def flatten(x): 1090 if isinstance(x, list): 1091 if len(x) == 0: 1092 return np.nan 1093 if all(isinstance(i, (int, float)) for i in x): 1094 return np.mean(x) # average simple list 1095 if all(isinstance(i, list) for i in x): 1096 # nested list -> flatten inner and average 1097 flat = [v for sub in x for v in (sub if isinstance(sub, list) else [sub])] 1098 return np.mean(flat) if flat else np.nan 1099 return x[0] if len(x) == 1 else np.nan 1100 return x 1101 1102 df = df.applymap(flatten) 1103 1104 # Ensure numeric columns 1105 df = df.apply(pd.to_numeric, errors='coerce') 1106 1107 # Convert epoch labels to numeric 1108 df.columns = [ 1109 int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c 1110 for c in df.columns 1111 ] 1112 df = df.reindex(sorted(df.columns), axis=1) 1113 1114 plt.figure(figsize=(10, 6)) 1115 for idx in df.index: 1116 plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"shift_{idx}") 1117 1118 if ref_lines and n in ref_lines: 1119 for v in ref_lines[n]: 1120 plt.axhline(y=v, color="k", linestyle="--", lw=1.0) 1121 plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8) 1122 1123 plt.xlabel("Epoch") 1124 plt.ylabel("Shift Value") 1125 plt.title(f"Shift Term History — Node: {n}") 1126 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small") 1127 plt.tight_layout() 1128 plt.show() 1129 1130 def plot_simple_intercepts_history(self, data_dict=None, node=None,ref_lines=None): 1131 """ 1132 Plot the evolution of simple intercept weights over epochs. 1133 1134 Parameters 1135 ---------- 1136 data_dict : dict or None, optional 1137 Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing intercept 1138 weights across epochs. If None, `simple_intercept_history()` is called. 1139 node : str or None, optional 1140 If provided, plot only this node. Otherwise, plot all nodes present 1141 in ``data_dict``. 1142 ref_lines : dict or None, optional 1143 Optional mapping ``{node_name: list of float}``. For each specified 1144 node, horizontal reference lines are drawn at the given values. 1145 1146 Returns 1147 ------- 1148 None 1149 1150 Notes 1151 ----- 1152 Nested list-like entries in the DataFrames are reduced to scalars before 1153 plotting. One line is drawn per intercept parameter. 1154 """ 1155 if data_dict is None: 1156 data_dict = self.simple_intercept_history() 1157 if data_dict is None: 1158 raise ValueError("No intercept history data provided or stored in the class.") 1159 1160 nodes = [node] if node else list(data_dict.keys()) 1161 1162 for n in nodes: 1163 df = data_dict[n].copy() 1164 1165 def extract_scalar(x): 1166 if isinstance(x, list): 1167 while isinstance(x, list) and len(x) > 0: 1168 x = x[0] 1169 return float(x) if isinstance(x, (int, float, np.floating)) else np.nan 1170 1171 df = df.applymap(extract_scalar) 1172 1173 # Convert epoch labels → numeric 1174 df.columns = [ 1175 int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c 1176 for c in df.columns 1177 ] 1178 df = df.reindex(sorted(df.columns), axis=1) 1179 1180 plt.figure(figsize=(10, 6)) 1181 for idx in df.index: 1182 plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"theta_{idx}") 1183 1184 if ref_lines and n in ref_lines: 1185 for v in ref_lines[n]: 1186 plt.axhline(y=v, color="k", linestyle="--", lw=1.0) 1187 plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8) 1188 1189 plt.xlabel("Epoch") 1190 plt.ylabel("Intercept Weight") 1191 plt.title(f"Simple Intercept Evolution — Node: {n}") 1192 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small") 1193 plt.tight_layout() 1194 plt.show() 1195 1196 def plot_latents(self, df, variable: str = None, confidence: float = 0.95, simulations: int = 1000): 1197 """ 1198 Visualize latent U distributions for one or all nodes. 1199 1200 Parameters 1201 ---------- 1202 df : pandas.DataFrame 1203 Input data frame with raw node values. 1204 variable : str or None, optional 1205 If provided, only this node's latents are plotted. If None, all 1206 nodes with latent columns are processed. 1207 confidence : float, optional 1208 Confidence level for QQ-plot bands (0 < confidence < 1). 1209 Default is 0.95. 1210 simulations : int, optional 1211 Number of Monte Carlo simulations for QQ-plot bands. Default is 1000. 1212 1213 Returns 1214 ------- 1215 None 1216 1217 Notes 1218 ----- 1219 For each node, two plots are produced: 1220 - Histogram of the latent U values. 1221 - QQ-plot with simulation-based confidence bands under a logistic reference. 1222 """ 1223 # Compute latent representations 1224 latents_df = self.get_latent(df) 1225 1226 # Select nodes 1227 nodes = [variable] if variable is not None else self.nodes_dict.keys() 1228 1229 for node in nodes: 1230 if f"{node}_U" not in latents_df.columns: 1231 print(f"[WARNING] No latent found for node {node}, skipping.") 1232 continue 1233 1234 sample = latents_df[f"{node}_U"].values 1235 1236 # --- Create plots --- 1237 fig, axs = plt.subplots(1, 2, figsize=(12, 5)) 1238 1239 # Histogram 1240 axs[0].hist(sample, bins=50, color="steelblue", alpha=0.7) 1241 axs[0].set_title(f"Latent Histogram ({node})") 1242 axs[0].set_xlabel("U") 1243 axs[0].set_ylabel("Frequency") 1244 1245 # QQ Plot with confidence bands 1246 probplot(sample, dist="logistic", plot=axs[1]) 1247 self._add_r_style_confidence_bands(axs[1], sample, dist=logistic,confidence=confidence, simulations=simulations) 1248 axs[1].set_title(f"Latent QQ Plot ({node})") 1249 1250 plt.suptitle(f"Latent Diagnostics for Node: {node}", fontsize=14) 1251 plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 1252 plt.show() 1253 1254 def plot_hdag(self,df,variables=None, plot_n_rows=1,**kwargs): 1255 1256 """ 1257 Visualize the transformation function h() for selected DAG nodes. 1258 1259 Parameters 1260 ---------- 1261 df : pandas.DataFrame 1262 Input data containing node values or model predictions. 1263 variables : list of str or None, optional 1264 Names of nodes to visualize. If None, all nodes in ``self.models`` 1265 are considered. 1266 plot_n_rows : int, optional 1267 Maximum number of rows from ``df`` to visualize. Default is 1. 1268 **kwargs 1269 Additional keyword arguments forwarded to the underlying plotting 1270 helpers (`show_hdag_continous` / `show_hdag_ordinal`). 1271 1272 Returns 1273 ------- 1274 None 1275 1276 Notes 1277 ----- 1278 - For continuous outcomes, `show_hdag_continous` is called. 1279 - For ordinal outcomes, `show_hdag_ordinal` is called. 1280 - Nodes that are neither continuous nor ordinal are skipped with a warning. 1281 """ 1282 1283 1284 if len(df)> 1: 1285 print("[WARNING] len(df)>1, set: plot_n_rows accordingly") 1286 1287 variables_list=variables if variables is not None else list(self.models.keys()) 1288 for node in variables_list: 1289 if is_outcome_modelled_continous(node, self.nodes_dict): 1290 show_hdag_continous(df,node=node,configuration_dict=self.cfg.conf_dict,minmax_dict=self.minmax_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1291 1292 elif is_outcome_modelled_ordinal(node, self.nodes_dict): 1293 show_hdag_ordinal(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1294 # plot_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1295 # save_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,**kwargs) 1296 else: 1297 print(f"[WARNING] Node {node} is wheter ordinal nor continous, not implemented yet") 1298 1299 @staticmethod 1300 def _add_r_style_confidence_bands(ax, sample, dist, confidence=0.95, simulations=1000): 1301 """ 1302 Add simulation-based confidence bands to a QQ-plot. 1303 1304 Parameters 1305 ---------- 1306 ax : matplotlib.axes.Axes 1307 Axes object on which to draw the QQ-plot and bands. 1308 sample : array-like 1309 Empirical sample used in the QQ-plot. 1310 dist : scipy.stats distribution 1311 Distribution object providing ``ppf`` and ``rvs`` methods (e.g. logistic). 1312 confidence : float, optional 1313 Confidence level (0 < confidence < 1) for the bands. Default is 0.95. 1314 simulations : int, optional 1315 Number of Monte Carlo simulations used to estimate the bands. Default is 1000. 1316 1317 Returns 1318 ------- 1319 None 1320 1321 Notes 1322 ----- 1323 The axes are cleared, and a new QQ-plot is drawn with: 1324 - Empirical vs. theoretical quantiles. 1325 - 45-degree reference line. 1326 - Shaded confidence band region. 1327 """ 1328 1329 n = len(sample) 1330 if n == 0: 1331 return 1332 1333 quantiles = np.linspace(0, 1, n, endpoint=False) + 0.5 / n 1334 theo_q = dist.ppf(quantiles) 1335 1336 # Simulate order statistics from the theoretical distribution 1337 sim_data = dist.rvs(size=(simulations, n)) 1338 sim_order_stats = np.sort(sim_data, axis=1) 1339 1340 # Confidence bands 1341 lower = np.percentile(sim_order_stats, 100 * (1 - confidence) / 2, axis=0) 1342 upper = np.percentile(sim_order_stats, 100 * (1 + confidence) / 2, axis=0) 1343 1344 # Sort empirical sample 1345 sample_sorted = np.sort(sample) 1346 1347 # Re-draw points and CI (overwrite probplot defaults) 1348 ax.clear() 1349 ax.plot(theo_q, sample_sorted, 'o', markersize=3, alpha=0.6, label="Empirical Q-Q") 1350 ax.plot(theo_q, theo_q, 'b--', label="y = x") 1351 ax.fill_between(theo_q, lower, upper, color='gray', alpha=0.3, 1352 label=f'{int(confidence*100)}% CI') 1353 ax.legend() 1354 1355 ## SAMPLING METHODS 1356 def sample( 1357 self, 1358 do_interventions: dict = None, 1359 predefined_latent_samples_df: pd.DataFrame = None, 1360 **kwargs, 1361 ): 1362 """ 1363 Sample from the joint DAG using the trained TRAM models. 1364 1365 Allows for: 1366 1367 Oberservational sampling 1368 Interventional sampling via ``do()`` operations 1369 Counterfactial sampling using predefined latent draws and do() 1370 1371 Parameters 1372 ---------- 1373 do_interventions : dict or None, optional 1374 Mapping of node names to intervened (fixed) values. For example: 1375 ``{"x1": 1.0}`` represents ``do(x1 = 1.0)``. Default is None. 1376 predefined_latent_samples_df : pandas.DataFrame or None, optional 1377 DataFrame containing columns ``"{node}_U"`` with predefined latent 1378 draws to be used instead of sampling from the prior. Default is None. 1379 **kwargs 1380 Sampling options overriding internal defaults: 1381 1382 number_of_samples : int, default 10000 1383 Total number of samples to draw. 1384 batch_size : int, default 32 1385 Batch size for internal sampling loops. 1386 delete_all_previously_sampled : bool, default True 1387 If True, delete old sampling files in node-specific sampling 1388 directories before writing new ones. 1389 verbose : bool 1390 If True, print informational messages. 1391 debug : bool 1392 If True, print debug output. 1393 device : {"auto", "cpu", "cuda"} 1394 Device selection for sampling. 1395 use_initial_weights_for_sampling : bool, default False 1396 If True, sample from initial (untrained) model parameters. 1397 1398 Returns 1399 ------- 1400 tuple 1401 A tuple ``(sampled_by_node, latents_by_node)``: 1402 1403 sampled_by_node : dict 1404 Mapping ``{node_name: torch.Tensor}`` of sampled node values. 1405 latents_by_node : dict 1406 Mapping ``{node_name: torch.Tensor}`` of latent U values used. 1407 1408 Raises 1409 ------ 1410 ValueError 1411 If the experiment directory cannot be resolved or if scaling 1412 information (``self.minmax_dict``) is missing. 1413 RuntimeError 1414 If min–max scaling has not been computed before calling `sample`. 1415 """ 1416 try: 1417 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1418 except KeyError: 1419 raise ValueError( 1420 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 1421 "Sampling requires trained model checkpoints." 1422 ) 1423 1424 # ---- defaults ---- 1425 settings = { 1426 "number_of_samples": 10_000, 1427 "batch_size": 32, 1428 "delete_all_previously_sampled": True, 1429 "verbose": self.verbose if hasattr(self, "verbose") else False, 1430 "debug": self.debug if hasattr(self, "debug") else False, 1431 "device": self.device.type if hasattr(self, "device") else "auto", 1432 "use_initial_weights_for_sampling": False, 1433 1434 } 1435 1436 # self._validate_kwargs( kwargs, defaults_attr= "settings", context="sample") 1437 1438 settings.update(kwargs) 1439 1440 1441 if not hasattr(self, "minmax_dict"): 1442 raise RuntimeError( 1443 "[ERROR] minmax_dict not found. You must call .fit() or .load_or_compute_minmax() " 1444 "before sampling, so scaling info is available." 1445 ) 1446 1447 # ---- resolve device ---- 1448 device_str=self.get_device(settings) 1449 self.device = torch.device(device_str) 1450 1451 1452 if self.debug or settings["debug"]: 1453 print(f"[DEBUG] sample(): device: {self.device}") 1454 1455 # ---- perform sampling ---- 1456 sampled_by_node, latents_by_node = sample_full_dag( 1457 configuration_dict=self.cfg.conf_dict, 1458 EXPERIMENT_DIR=EXPERIMENT_DIR, 1459 device=self.device, 1460 do_interventions=do_interventions or {}, 1461 predefined_latent_samples_df=predefined_latent_samples_df, 1462 number_of_samples=settings["number_of_samples"], 1463 batch_size=settings["batch_size"], 1464 delete_all_previously_sampled=settings["delete_all_previously_sampled"], 1465 verbose=settings["verbose"], 1466 debug=settings["debug"], 1467 minmax_dict=self.minmax_dict, 1468 use_initial_weights_for_sampling=settings["use_initial_weights_for_sampling"] 1469 ) 1470 1471 return sampled_by_node, latents_by_node 1472 1473 def load_sampled_and_latents(self, EXPERIMENT_DIR: str = None, nodes: list = None): 1474 """ 1475 Load previously stored sampled values and latents for each node. 1476 1477 Parameters 1478 ---------- 1479 EXPERIMENT_DIR : str or None, optional 1480 Experiment directory path. If None, it is taken from 1481 ``self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]``. 1482 nodes : list of str or None, optional 1483 Nodes for which to load samples. If None, use all nodes from 1484 ``self.nodes_dict``. 1485 1486 Returns 1487 ------- 1488 tuple 1489 A tuple ``(sampled_by_node, latents_by_node)``: 1490 1491 sampled_by_node : dict 1492 Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU). 1493 latents_by_node : dict 1494 Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU). 1495 1496 Raises 1497 ------ 1498 ValueError 1499 If the experiment directory cannot be resolved or if no node list 1500 is available and ``nodes`` is None. 1501 1502 Notes 1503 ----- 1504 Nodes without both ``sampled.pt`` and ``latents.pt`` files are skipped 1505 with a warning. 1506 """ 1507 # --- resolve paths and node list --- 1508 if EXPERIMENT_DIR is None: 1509 try: 1510 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1511 except (AttributeError, KeyError): 1512 raise ValueError( 1513 "[ERROR] Could not resolve EXPERIMENT_DIR from cfg.conf_dict['PATHS']. " 1514 "Please provide EXPERIMENT_DIR explicitly." 1515 ) 1516 1517 if nodes is None: 1518 if hasattr(self, "nodes_dict"): 1519 nodes = list(self.nodes_dict.keys()) 1520 else: 1521 raise ValueError( 1522 "[ERROR] No node list found. Please provide `nodes` or initialize model with a config." 1523 ) 1524 1525 # --- load tensors --- 1526 sampled_by_node = {} 1527 latents_by_node = {} 1528 1529 for node in nodes: 1530 node_dir = os.path.join(EXPERIMENT_DIR, f"{node}") 1531 sampling_dir = os.path.join(node_dir, "sampling") 1532 1533 sampled_path = os.path.join(sampling_dir, "sampled.pt") 1534 latents_path = os.path.join(sampling_dir, "latents.pt") 1535 1536 if not os.path.exists(sampled_path) or not os.path.exists(latents_path): 1537 print(f"[WARNING] Missing files for node '{node}' — skipping.") 1538 continue 1539 1540 try: 1541 sampled = torch.load(sampled_path, map_location="cpu") 1542 latent_sample = torch.load(latents_path, map_location="cpu") 1543 except Exception as e: 1544 print(f"[ERROR] Could not load sampling files for node '{node}': {e}") 1545 continue 1546 1547 sampled_by_node[node] = sampled.detach().cpu() 1548 latents_by_node[node] = latent_sample.detach().cpu() 1549 1550 if self.verbose or self.debug: 1551 print(f"[INFO] Loaded sampled and latent tensors for {len(sampled_by_node)} nodes from {EXPERIMENT_DIR}") 1552 1553 return sampled_by_node, latents_by_node 1554 1555 def plot_samples_vs_true( 1556 self, 1557 df, 1558 sampled: dict = None, 1559 variable: list = None, 1560 bins: int = 100, 1561 hist_true_color: str = "blue", 1562 hist_est_color: str = "orange", 1563 figsize: tuple = (14, 5), 1564 ): 1565 1566 1567 """ 1568 Compare sampled vs. observed distributions for selected nodes. 1569 1570 Parameters 1571 ---------- 1572 df : pandas.DataFrame 1573 Data frame containing the observed node values. 1574 sampled : dict or None, optional 1575 Optional mapping ``{node_name: array-like or torch.Tensor}`` of sampled 1576 values. If None or if a node is missing, samples are loaded from 1577 ``EXPERIMENT_DIR/{node}/sampling/sampled.pt``. 1578 variable : list of str or None, optional 1579 Subset of nodes to plot. If None, all nodes in the configuration 1580 are considered. 1581 bins : int, optional 1582 Number of histogram bins for continuous variables. Default is 100. 1583 hist_true_color : str, optional 1584 Color name for the histogram of true values. Default is "blue". 1585 hist_est_color : str, optional 1586 Color name for the histogram of sampled values. Default is "orange". 1587 figsize : tuple, optional 1588 Figure size for the matplotlib plots. Default is (14, 5). 1589 1590 Returns 1591 ------- 1592 None 1593 1594 Notes 1595 ----- 1596 - Continuous outcomes: histogram overlay + QQ-plot. 1597 - Ordinal outcomes: side-by-side bar plot of relative frequencies. 1598 - Other categorical outcomes: side-by-side bar plot with category labels. 1599 - If samples are probabilistic (2D tensor), the argmax across classes is used. 1600 """ 1601 1602 target_nodes = self.cfg.conf_dict["nodes"] 1603 experiment_dir = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1604 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 1605 1606 plot_list = variable if variable is not None else target_nodes 1607 1608 for node in plot_list: 1609 # Load sampled data 1610 if sampled is not None and node in sampled: 1611 sdata = sampled[node] 1612 if isinstance(sdata, torch.Tensor): 1613 sampled_vals = sdata.detach().cpu().numpy() 1614 else: 1615 sampled_vals = np.asarray(sdata) 1616 else: 1617 sample_path = os.path.join(experiment_dir, f"{node}/sampling/sampled.pt") 1618 if not os.path.isfile(sample_path): 1619 print(f"[WARNING] skip {node}: {sample_path} not found.") 1620 continue 1621 1622 try: 1623 sampled_vals = torch.load(sample_path, map_location=device).cpu().numpy() 1624 except Exception as e: 1625 print(f"[ERROR] Could not load {sample_path}: {e}") 1626 continue 1627 1628 # If logits/probabilities per sample, take argmax 1629 if sampled_vals.ndim == 2: 1630 print(f"[INFO] CAUTION! {node}: samples are probabilistic — each sample follows a probability " 1631 f"distribution based on the valid latent range. " 1632 f"Note that this frequency plot reflects only the distribution of the most probable " 1633 f"class per sample.") 1634 sampled_vals = np.argmax(sampled_vals, axis=1) 1635 1636 sampled_vals = sampled_vals[np.isfinite(sampled_vals)] 1637 1638 if node not in df.columns: 1639 print(f"[WARNING] skip {node}: column not found in DataFrame.") 1640 continue 1641 1642 true_vals = df[node].dropna().values 1643 true_vals = true_vals[np.isfinite(true_vals)] 1644 1645 if sampled_vals.size == 0 or true_vals.size == 0: 1646 print(f"[WARNING] skip {node}: empty array after NaN/Inf removal.") 1647 continue 1648 1649 fig, axs = plt.subplots(1, 2, figsize=figsize) 1650 1651 if is_outcome_modelled_continous(node, target_nodes): 1652 axs[0].hist(true_vals, bins=bins, density=True, alpha=0.6, 1653 color=hist_true_color, label=f"True {node}") 1654 axs[0].hist(sampled_vals, bins=bins, density=True, alpha=0.6, 1655 color=hist_est_color, label="Sampled") 1656 axs[0].set_xlabel("Value") 1657 axs[0].set_ylabel("Density") 1658 axs[0].set_title(f"Histogram overlay for {node}") 1659 axs[0].legend() 1660 axs[0].grid(True, ls="--", alpha=0.4) 1661 1662 qqplot_2samples(true_vals, sampled_vals, line="45", ax=axs[1]) 1663 axs[1].set_xlabel("True quantiles") 1664 axs[1].set_ylabel("Sampled quantiles") 1665 axs[1].set_title(f"QQ plot for {node}") 1666 axs[1].grid(True, ls="--", alpha=0.4) 1667 1668 elif is_outcome_modelled_ordinal(node, target_nodes): 1669 unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals)) 1670 unique_vals = np.sort(unique_vals) 1671 true_counts = np.array([(true_vals == val).sum() for val in unique_vals]) 1672 sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals]) 1673 1674 axs[0].bar(unique_vals - 0.2, true_counts / true_counts.sum(), 1675 width=0.4, color=hist_true_color, alpha=0.7, label="True") 1676 axs[0].bar(unique_vals + 0.2, sampled_counts / sampled_counts.sum(), 1677 width=0.4, color=hist_est_color, alpha=0.7, label="Sampled") 1678 axs[0].set_xticks(unique_vals) 1679 axs[0].set_xlabel("Ordinal Level") 1680 axs[0].set_ylabel("Relative Frequency") 1681 axs[0].set_title(f"Ordinal bar plot for {node}") 1682 axs[0].legend() 1683 axs[0].grid(True, ls="--", alpha=0.4) 1684 axs[1].axis("off") 1685 1686 else: 1687 unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals)) 1688 unique_vals = sorted(unique_vals, key=str) 1689 true_counts = np.array([(true_vals == val).sum() for val in unique_vals]) 1690 sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals]) 1691 1692 axs[0].bar(np.arange(len(unique_vals)) - 0.2, true_counts / true_counts.sum(), 1693 width=0.4, color=hist_true_color, alpha=0.7, label="True") 1694 axs[0].bar(np.arange(len(unique_vals)) + 0.2, sampled_counts / sampled_counts.sum(), 1695 width=0.4, color=hist_est_color, alpha=0.7, label="Sampled") 1696 axs[0].set_xticks(np.arange(len(unique_vals))) 1697 axs[0].set_xticklabels(unique_vals, rotation=45) 1698 axs[0].set_xlabel("Category") 1699 axs[0].set_ylabel("Relative Frequency") 1700 axs[0].set_title(f"Categorical bar plot for {node}") 1701 axs[0].legend() 1702 axs[0].grid(True, ls="--", alpha=0.4) 1703 axs[1].axis("off") 1704 1705 plt.tight_layout() 1706 plt.show() 1707 1708 ## SUMMARY METHODS 1709 def nll(self,data,variables=None): 1710 """ 1711 Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes. 1712 1713 This function evaluates trained TRAM models for each specified variable (node) 1714 on the provided dataset. It performs forward passes only—no training, no weight 1715 updates—and returns the mean NLL per node. 1716 1717 Parameters 1718 ---------- 1719 data : object 1720 Input dataset or data source compatible with `_ensure_dataset`, containing 1721 both inputs and targets for each node. 1722 variables : list[str], optional 1723 List of variable (node) names to evaluate. If None, all nodes in 1724 `self.models` are evaluated. 1725 1726 Returns 1727 ------- 1728 dict[str, float] 1729 Dictionary mapping each node name to its average NLL value. 1730 1731 Notes 1732 ----- 1733 - Each model is evaluated independently on its respective DataLoader. 1734 - The normalization values (`min_max`) for each node are retrieved from 1735 `self.minmax_dict[node]`. 1736 - The function uses `evaluate_tramdag_model()` for per-node evaluation. 1737 - Expected directory structure: 1738 `<EXPERIMENT_DIR>/<node>/` 1739 where each node directory contains the trained model. 1740 """ 1741 1742 td_data = self._ensure_dataset(data, is_val=True) 1743 variables_list = variables if variables != None else list(self.models.keys()) 1744 nll_dict = {} 1745 for node in variables_list: 1746 min_vals = torch.tensor(self.minmax_dict[node][0], dtype=torch.float32) 1747 max_vals = torch.tensor(self.minmax_dict[node][1], dtype=torch.float32) 1748 min_max = torch.stack([min_vals, max_vals], dim=0) 1749 data_loader = td_data.loaders[node] 1750 model = self.models[node] 1751 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1752 NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}") 1753 nll= evaluate_tramdag_model(node=node, 1754 target_nodes=self.nodes_dict, 1755 NODE_DIR=NODE_DIR, 1756 tram_model=model, 1757 data_loader=data_loader, 1758 min_max=min_max) 1759 nll_dict[node]=nll 1760 return nll_dict 1761 1762 def get_train_val_nll(self, node: str, mode: str) -> tuple[float, float]: 1763 """ 1764 Retrieve training and validation NLL for a node and a given model state. 1765 1766 Parameters 1767 ---------- 1768 node : str 1769 Node name. 1770 mode : {"best", "last", "init"} 1771 State of interest: 1772 - "best": epoch with lowest validation NLL. 1773 - "last": final epoch. 1774 - "init": first epoch (index 0). 1775 1776 Returns 1777 ------- 1778 tuple of (float or None, float or None) 1779 A tuple ``(train_nll, val_nll)`` for the requested mode. 1780 Returns ``(None, None)`` if loss files are missing or cannot be read. 1781 1782 Notes 1783 ----- 1784 This method expects per-node JSON files: 1785 1786 - ``train_loss_hist.json`` 1787 - ``val_loss_hist.json`` 1788 1789 in the node directory. 1790 """ 1791 NODE_DIR = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node) 1792 train_path = os.path.join(NODE_DIR, "train_loss_hist.json") 1793 val_path = os.path.join(NODE_DIR, "val_loss_hist.json") 1794 1795 if not os.path.exists(train_path) or not os.path.exists(val_path): 1796 if getattr(self, "debug", False): 1797 print(f"[DEBUG] Missing loss files for node '{node}'. Returning None.") 1798 return None, None 1799 1800 try: 1801 with open(train_path, "r") as f: 1802 train_hist = json.load(f) 1803 with open(val_path, "r") as f: 1804 val_hist = json.load(f) 1805 1806 train_nlls = np.array(train_hist) 1807 val_nlls = np.array(val_hist) 1808 1809 if mode == "init": 1810 idx = 0 1811 elif mode == "last": 1812 idx = len(val_nlls) - 1 1813 elif mode == "best": 1814 idx = int(np.argmin(val_nlls)) 1815 else: 1816 raise ValueError(f"Invalid mode '{mode}' — must be one of 'best', 'last', 'init'.") 1817 1818 train_nll = float(train_nlls[idx]) 1819 val_nll = float(val_nlls[idx]) 1820 return train_nll, val_nll 1821 1822 except Exception as e: 1823 print(f"[ERROR] Failed to load NLLs for node '{node}' ({mode}): {e}") 1824 return None, None 1825 1826 def get_thetas(self, node: str, state: str = "best"): 1827 """ 1828 Return transformed intercept (theta) parameters for a node and state. 1829 1830 Parameters 1831 ---------- 1832 node : str 1833 Node name. 1834 state : {"best", "last", "init"}, optional 1835 Model state for which to return parameters. Default is "best". 1836 1837 Returns 1838 ------- 1839 Any or None 1840 Transformed theta parameters for the requested node and state. 1841 The exact structure (scalar, list, or other) depends on the model. 1842 1843 Raises 1844 ------ 1845 ValueError 1846 If an invalid state is given (not in {"best", "last", "init"}). 1847 1848 Notes 1849 ----- 1850 Intercept dictionaries are cached on the instance under the attribute 1851 ``intercept_dicts``. If missing or incomplete, they are recomputed using 1852 `get_simple_intercepts_dict`. 1853 """ 1854 1855 state = state.lower() 1856 if state not in ["best", "last", "init"]: 1857 raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].") 1858 1859 dict_attr = "intercept_dicts" 1860 1861 # If no cached intercepts exist, compute them 1862 if not hasattr(self, dict_attr): 1863 if getattr(self, "debug", False): 1864 print(f"[DEBUG] '{dict_attr}' not found, computing via get_simple_intercepts_dict().") 1865 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1866 1867 all_dicts = getattr(self, dict_attr) 1868 1869 # If the requested state isn’t cached, recompute 1870 if state not in all_dicts: 1871 if getattr(self, "debug", False): 1872 print(f"[DEBUG] State '{state}' not found in cached intercepts, recomputing full dict.") 1873 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1874 all_dicts = getattr(self, dict_attr) 1875 1876 state_dict = all_dicts.get(state, {}) 1877 1878 # Return cached node intercept if present 1879 if node in state_dict: 1880 return state_dict[node] 1881 1882 # If not found, recompute full dict as fallback 1883 if getattr(self, "debug", False): 1884 print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.") 1885 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1886 all_dicts = getattr(self, dict_attr) 1887 return all_dicts.get(state, {}).get(node, None) 1888 1889 def get_linear_shifts(self, node: str, state: str = "best"): 1890 """ 1891 Return learned linear shift terms for a node and a given state. 1892 1893 Parameters 1894 ---------- 1895 node : str 1896 Node name. 1897 state : {"best", "last", "init"}, optional 1898 Model state for which to return linear shift terms. Default is "best". 1899 1900 Returns 1901 ------- 1902 dict or Any or None 1903 Linear shift terms for the given node and state. Usually a dict 1904 mapping term names to weights. 1905 1906 Raises 1907 ------ 1908 ValueError 1909 If an invalid state is given (not in {"best", "last", "init"}). 1910 1911 Notes 1912 ----- 1913 Linear shift dictionaries are cached on the instance under the attribute 1914 ``linear_shift_dicts``. If missing or incomplete, they are recomputed using 1915 `get_linear_shifts_dict`. 1916 """ 1917 state = state.lower() 1918 if state not in ["best", "last", "init"]: 1919 raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].") 1920 1921 dict_attr = "linear_shift_dicts" 1922 1923 # If no global dicts cached, compute once 1924 if not hasattr(self, dict_attr): 1925 if getattr(self, "debug", False): 1926 print(f"[DEBUG] '{dict_attr}' not found, computing via get_linear_shifts_dict().") 1927 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1928 1929 all_dicts = getattr(self, dict_attr) 1930 1931 # If the requested state isn't cached, compute all again (covers fresh runs) 1932 if state not in all_dicts: 1933 if getattr(self, "debug", False): 1934 print(f"[DEBUG] State '{state}' not found in cached linear shifts, recomputing full dict.") 1935 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1936 all_dicts = getattr(self, dict_attr) 1937 1938 # Now fetch the dictionary for this state 1939 state_dict = all_dicts.get(state, {}) 1940 1941 # If the node is available, return its entry 1942 if node in state_dict: 1943 return state_dict[node] 1944 1945 # If missing, try recomputing (fallback) 1946 if getattr(self, "debug", False): 1947 print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.") 1948 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1949 all_dicts = getattr(self, dict_attr) 1950 return all_dicts.get(state, {}).get(node, None) 1951 1952 def get_linear_shifts_dict(self): 1953 """ 1954 Compute linear shift term dictionaries for all nodes and states. 1955 1956 For each node and each available state ("best", "last", "init"), this 1957 method loads the corresponding model checkpoint, extracts linear shift 1958 weights from the TRAM model, and stores them in a nested dictionary. 1959 1960 Returns 1961 ------- 1962 dict 1963 Nested dictionary of the form: 1964 1965 .. code-block:: python 1966 1967 { 1968 "best": {node: {...}}, 1969 "last": {node: {...}}, 1970 "init": {node: {...}}, 1971 } 1972 1973 where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``) 1974 to their weights. 1975 1976 Notes 1977 ----- 1978 - If "best" or "last" checkpoints are unavailable for a node, only 1979 the "init" entry is populated. 1980 - Empty outer states (without any nodes) are removed from the result. 1981 """ 1982 1983 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1984 nodes_list = list(self.models.keys()) 1985 all_states = ["best", "last", "init"] 1986 all_linear_shift_dicts = {state: {} for state in all_states} 1987 1988 for node in nodes_list: 1989 NODE_DIR = os.path.join(EXPERIMENT_DIR, node) 1990 BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR) 1991 INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt") 1992 1993 state_paths = { 1994 "best": BEST_MODEL_PATH, 1995 "last": LAST_MODEL_PATH, 1996 "init": INIT_MODEL_PATH, 1997 } 1998 1999 for state, LOAD_PATH in state_paths.items(): 2000 if not os.path.exists(LOAD_PATH): 2001 if state != "init": 2002 # skip best/last if unavailable 2003 continue 2004 else: 2005 print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.") 2006 if not os.path.exists(LOAD_PATH): 2007 if getattr(self, "debug", False): 2008 print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.") 2009 continue 2010 2011 # Load parents and model 2012 _, terms_dict, _ = ordered_parents(node, self.nodes_dict) 2013 state_dict = torch.load(LOAD_PATH, map_location=self.device) 2014 tram_model = self.models[node] 2015 tram_model.load_state_dict(state_dict) 2016 2017 epoch_weights = {} 2018 if hasattr(tram_model, "nn_shift") and tram_model.nn_shift is not None: 2019 for i, shift_layer in enumerate(tram_model.nn_shift): 2020 module_name = shift_layer.__class__.__name__ 2021 if ( 2022 hasattr(shift_layer, "fc") 2023 and hasattr(shift_layer.fc, "weight") 2024 and module_name == "LinearShift" 2025 ): 2026 term_name = list(terms_dict.keys())[i] 2027 epoch_weights[f"ls({term_name})"] = ( 2028 shift_layer.fc.weight.detach().cpu().squeeze().tolist() 2029 ) 2030 elif getattr(self, "debug", False): 2031 term_name = list(terms_dict.keys())[i] 2032 print(f"[DEBUG] ls({term_name}): missing 'fc' or 'weight' in LinearShift.") 2033 else: 2034 if getattr(self, "debug", False): 2035 print(f"[DEBUG] Tram model for node '{node}' has no nn_shift or it is None.") 2036 2037 all_linear_shift_dicts[state][node] = epoch_weights 2038 2039 # Remove empty states (e.g., when best/last not found for all nodes) 2040 all_linear_shift_dicts = {k: v for k, v in all_linear_shift_dicts.items() if v} 2041 2042 return all_linear_shift_dicts 2043 2044 def get_simple_intercepts_dict(self): 2045 """ 2046 Compute transformed simple intercept dictionaries for all nodes and states. 2047 2048 For each node and each available state ("best", "last", "init"), this 2049 method loads the corresponding model checkpoint, extracts simple intercept 2050 weights, transforms them into interpretable theta parameters, and stores 2051 them in a nested dictionary. 2052 2053 Returns 2054 ------- 2055 dict 2056 Nested dictionary of the form: 2057 2058 .. code-block:: python 2059 2060 { 2061 "best": {node: [[theta_1], [theta_2], ...]}, 2062 "last": {node: [[theta_1], [theta_2], ...]}, 2063 "init": {node: [[theta_1], [theta_2], ...]}, 2064 } 2065 2066 Notes 2067 ----- 2068 - For ordinal models (``self.is_ontram == True``), `transform_intercepts_ordinal` 2069 is used. 2070 - For continuous models, `transform_intercepts_continous` is used. 2071 - Empty outer states (without any nodes) are removed from the result. 2072 """ 2073 2074 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 2075 nodes_list = list(self.models.keys()) 2076 all_states = ["best", "last", "init"] 2077 all_si_intercept_dicts = {state: {} for state in all_states} 2078 2079 debug = getattr(self, "debug", False) 2080 verbose = getattr(self, "verbose", False) 2081 is_ontram = getattr(self, "is_ontram", False) 2082 2083 for node in nodes_list: 2084 NODE_DIR = os.path.join(EXPERIMENT_DIR, node) 2085 BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR) 2086 INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt") 2087 2088 state_paths = { 2089 "best": BEST_MODEL_PATH, 2090 "last": LAST_MODEL_PATH, 2091 "init": INIT_MODEL_PATH, 2092 } 2093 2094 for state, LOAD_PATH in state_paths.items(): 2095 if not os.path.exists(LOAD_PATH): 2096 if state != "init": 2097 continue 2098 else: 2099 print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.") 2100 if not os.path.exists(LOAD_PATH): 2101 if debug: 2102 print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.") 2103 continue 2104 2105 # Load model state 2106 state_dict = torch.load(LOAD_PATH, map_location=self.device) 2107 tram_model = self.models[node] 2108 tram_model.load_state_dict(state_dict) 2109 2110 # Extract and transform simple intercept weights 2111 si_weights = None 2112 if hasattr(tram_model, "nn_int") and tram_model.nn_int is not None and isinstance(tram_model.nn_int, SimpleIntercept): 2113 if hasattr(tram_model.nn_int, "fc") and hasattr(tram_model.nn_int.fc, "weight"): 2114 weights = tram_model.nn_int.fc.weight.detach().cpu().tolist() 2115 weights_tensor = torch.Tensor(weights) 2116 2117 if debug: 2118 print(f"[DEBUG] Node '{node}' ({state}) theta tilde shape: {weights_tensor.shape}") 2119 2120 if is_ontram: 2121 si_weights = transform_intercepts_ordinal(weights_tensor.reshape(1, -1))[:, 1:-1].reshape(-1, 1) 2122 else: 2123 si_weights = transform_intercepts_continous(weights_tensor.reshape(1, -1)).reshape(-1, 1) 2124 2125 si_weights = si_weights.tolist() 2126 2127 if debug: 2128 print(f"[DEBUG] Node '{node}' ({state}) theta transformed: {si_weights}") 2129 else: 2130 if debug: 2131 print(f"[DEBUG] Node '{node}' ({state}): missing 'fc' or 'weight' in SimpleIntercept.") 2132 else: 2133 if debug: 2134 print(f"[DEBUG] Tram model for node '{node}' has no nn_int or it is None.") 2135 2136 all_si_intercept_dicts[state][node] = si_weights 2137 2138 # Clean up empty states 2139 all_si_intercept_dicts = {k: v for k, v in all_si_intercept_dicts.items() if v} 2140 return all_si_intercept_dicts 2141 2142 def summary(self, verbose=False): 2143 """ 2144 Print a multi-part textual summary of the TramDagModel. 2145 2146 The summary includes: 2147 1. Training metrics overview per node (best/last NLL, epochs). 2148 2. Node-specific details (thetas, linear shifts, optional architecture). 2149 3. Basic information about the attached training DataFrame, if present. 2150 2151 Parameters 2152 ---------- 2153 verbose : bool, optional 2154 If True, include extended per-node details such as the model 2155 architecture, parameter count, and availability of checkpoints 2156 and sampling results. Default is False. 2157 2158 Returns 2159 ------- 2160 None 2161 2162 Notes 2163 ----- 2164 This method prints to stdout and does not return structured data. 2165 It is intended for quick, human-readable inspection of the current 2166 training and model state. 2167 """ 2168 2169 # ---------- SETUP ---------- 2170 try: 2171 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 2172 except KeyError: 2173 EXPERIMENT_DIR = None 2174 print("[WARNING] Missing EXPERIMENT_DIR in cfg.conf_dict['PATHS'].") 2175 2176 print("\n" + "=" * 120) 2177 print(f"{'TRAM DAG MODEL SUMMARY':^120}") 2178 print("=" * 120) 2179 2180 # ---------- METRICS OVERVIEW ---------- 2181 summary_data = [] 2182 for node in self.models.keys(): 2183 node_dir = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node) 2184 train_path = os.path.join(node_dir, "train_loss_hist.json") 2185 val_path = os.path.join(node_dir, "val_loss_hist.json") 2186 2187 if os.path.exists(train_path) and os.path.exists(val_path): 2188 best_train_nll, best_val_nll = self.get_train_val_nll(node, "best") 2189 last_train_nll, last_val_nll = self.get_train_val_nll(node, "last") 2190 n_epochs_total = len(json.load(open(train_path))) 2191 else: 2192 best_train_nll = best_val_nll = last_train_nll = last_val_nll = None 2193 n_epochs_total = 0 2194 2195 summary_data.append({ 2196 "Node": node, 2197 "Best Train NLL": best_train_nll, 2198 "Best Val NLL": best_val_nll, 2199 "Last Train NLL": last_train_nll, 2200 "Last Val NLL": last_val_nll, 2201 "Epochs": n_epochs_total, 2202 }) 2203 2204 df_summary = pd.DataFrame(summary_data) 2205 df_summary = df_summary.round(4) 2206 2207 print("\n[1] TRAINING METRICS OVERVIEW") 2208 print("-" * 120) 2209 if not df_summary.empty: 2210 print( 2211 df_summary.to_string( 2212 index=False, 2213 justify="center", 2214 col_space=14, 2215 float_format=lambda x: f"{x:7.4f}", 2216 ) 2217 ) 2218 else: 2219 print("No training history found for any node.") 2220 print("-" * 120) 2221 2222 # ---------- NODE DETAILS ---------- 2223 print("\n[2] NODE-SPECIFIC DETAILS") 2224 print("-" * 120) 2225 for node in self.models.keys(): 2226 print(f"\n{f'NODE: {node}':^120}") 2227 print("-" * 120) 2228 2229 # THETAS & SHIFTS 2230 for state in ["init", "last", "best"]: 2231 print(f"\n [{state.upper()} STATE]") 2232 2233 # ---- Thetas ---- 2234 try: 2235 thetas = getattr(self, "get_thetas", lambda n, s=None: None)(node, state) 2236 if thetas is not None: 2237 if isinstance(thetas, (list, np.ndarray, pd.Series)): 2238 thetas_flat = np.array(thetas).flatten() 2239 compact = np.round(thetas_flat, 4) 2240 arr_str = np.array2string( 2241 compact, 2242 max_line_width=110, 2243 threshold=np.inf, 2244 separator=", " 2245 ) 2246 lines = arr_str.split("\n") 2247 if len(lines) > 2: 2248 arr_str = "\n".join(lines[:2]) + " ..." 2249 print(f" Θ ({len(thetas_flat)}): {arr_str}") 2250 elif isinstance(thetas, dict): 2251 for k, v in thetas.items(): 2252 print(f" Θ[{k}]: {v}") 2253 else: 2254 print(f" Θ: {thetas}") 2255 else: 2256 print(" Θ: not available") 2257 except Exception as e: 2258 print(f" [Error loading thetas] {e}") 2259 2260 # ---- Linear Shifts ---- 2261 try: 2262 linear_shifts = getattr(self, "get_linear_shifts", lambda n, s=None: None)(node, state) 2263 if linear_shifts is not None: 2264 if isinstance(linear_shifts, dict): 2265 for k, v in linear_shifts.items(): 2266 print(f" {k}: {np.round(v, 4)}") 2267 elif isinstance(linear_shifts, (list, np.ndarray, pd.Series)): 2268 arr = np.round(linear_shifts, 4) 2269 print(f" Linear shifts ({len(arr)}): {arr}") 2270 else: 2271 print(f" Linear shifts: {linear_shifts}") 2272 else: 2273 print(" Linear shifts: not available") 2274 except Exception as e: 2275 print(f" [Error loading linear shifts] {e}") 2276 2277 # ---- Verbose info directly below node ---- 2278 if verbose: 2279 print("\n [DETAILS]") 2280 node_dir = os.path.join(EXPERIMENT_DIR, node) if EXPERIMENT_DIR else None 2281 model = self.models[node] 2282 2283 print(f" Model Architecture:") 2284 arch_str = str(model).split("\n") 2285 for line in arch_str: 2286 print(f" {line}") 2287 print(f" Parameter count: {sum(p.numel() for p in model.parameters()):,}") 2288 2289 if node_dir and os.path.exists(node_dir): 2290 ckpt_exists = any(f.endswith(('.pt', '.pth')) for f in os.listdir(node_dir)) 2291 print(f" Checkpoints found: {ckpt_exists}") 2292 2293 sampling_dir = os.path.join(node_dir, "sampling") 2294 sampling_exists = os.path.isdir(sampling_dir) and len(os.listdir(sampling_dir)) > 0 2295 print(f" Sampling results found: {sampling_exists}") 2296 2297 for label, filename in [("Train", "train_loss_hist.json"), ("Validation", "val_loss_hist.json")]: 2298 path = os.path.join(node_dir, filename) 2299 if os.path.exists(path): 2300 try: 2301 with open(path, "r") as f: 2302 hist = json.load(f) 2303 print(f" {label} history: {len(hist)} epochs") 2304 except Exception as e: 2305 print(f" {label} history: failed to load ({e})") 2306 else: 2307 print(f" {label} history: not found") 2308 else: 2309 print(" [INFO] No experiment directory defined or missing for this node.") 2310 print("-" * 120) 2311 2312 # ---------- TRAINING DATAFRAME ---------- 2313 print("\n[3] TRAINING DATAFRAME") 2314 print("-" * 120) 2315 try: 2316 self.train_df.info() 2317 except AttributeError: 2318 print("No training DataFrame attached to this TramDagModel.") 2319 print("=" * 120 + "\n")
Probabilistic DAG model built from node-wise TRAMs (transformation models).
This class manages:
- Configuration and per-node model construction.
- Data scaling (min–max).
- Training (sequential or per-node parallel on CPU).
- Diagnostics (loss history, intercepts, linear shifts, latents).
- Sampling from the joint DAG and loading stored samples.
- High-level summaries and plotting utilities.
111 def __init__(self): 112 """ 113 Initialize an empty TramDagModel shell. 114 115 Notes 116 ----- 117 This constructor does not build any node models and does not attach a 118 configuration. Use `TramDagModel.from_config` or `TramDagModel.from_directory` 119 to obtain a fully configured and ready-to-use instance. 120 """ 121 122 self.debug = False 123 self.verbose = False 124 self.device = 'auto' 125 pass
Initialize an empty TramDagModel shell.
Notes
This constructor does not build any node models and does not attach a
configuration. Use TramDagModel.from_config or TramDagModel.from_directory
to obtain a fully configured and ready-to-use instance.
127 @staticmethod 128 def get_device(settings): 129 """ 130 Resolve the target device string from a settings dictionary. 131 132 Parameters 133 ---------- 134 settings : dict 135 Dictionary containing at least a key ``"device"`` with one of 136 {"auto", "cpu", "cuda"}. If missing, "auto" is assumed. 137 138 Returns 139 ------- 140 str 141 Device string, either "cpu" or "cuda". 142 143 Notes 144 ----- 145 If ``device == "auto"``, CUDA is selected if available, otherwise CPU. 146 """ 147 device_arg = settings.get("device", "auto") 148 if device_arg == "auto": 149 device_str = "cuda" if torch.cuda.is_available() else "cpu" 150 else: 151 device_str = device_arg 152 return device_str
Resolve the target device string from a settings dictionary.
Parameters
settings : dict
Dictionary containing at least a key "device" with one of
{"auto", "cpu", "cuda"}. If missing, "auto" is assumed.
Returns
str Device string, either "cpu" or "cuda".
Notes
If device == "auto", CUDA is selected if available, otherwise CPU.
185 @classmethod 186 def from_config(cls, cfg, **kwargs): 187 """ 188 Construct a TramDagModel from a TramDagConfig object. 189 190 This builds one TRAM model per node in the DAG and optionally writes 191 the initial model parameters to disk. 192 193 Parameters 194 ---------- 195 cfg : TramDagConfig 196 Configuration wrapper holding the underlying configuration dictionary, 197 including at least: 198 - ``conf_dict["nodes"]``: mapping of node names to node configs. 199 - ``conf_dict["PATHS"]["EXPERIMENT_DIR"]``: experiment directory. 200 **kwargs 201 Node-level construction options. Each key must be present in 202 ``DEFAULTS_CONFIG``. Values can be: 203 - scalar: applied to all nodes. 204 - dict: mapping ``{node_name: value}`` for per-node overrides. 205 206 Common keys include: 207 device : {"auto", "cpu", "cuda"}, default "auto" 208 Device selection (CUDA if available when "auto"). 209 debug : bool, default False 210 If True, print debug messages. 211 verbose : bool, default False 212 If True, print informational messages. 213 set_initial_weights : bool 214 Passed to underlying TRAM model constructors. 215 overwrite_initial_weights : bool, default True 216 If True, overwrite any existing ``initial_model.pt`` files per node. 217 initial_data : Any 218 Optional object passed down to node constructors. 219 220 Returns 221 ------- 222 TramDagModel 223 Fully initialized instance with: 224 - ``cfg`` 225 - ``nodes_dict`` 226 - ``models`` (per-node TRAMs) 227 - ``settings`` (resolved per-node config) 228 229 Raises 230 ------ 231 ValueError 232 If any dict-valued kwarg does not provide values for exactly the set 233 of nodes in ``cfg.conf_dict["nodes"]``. 234 """ 235 236 self = cls() 237 self.cfg = cfg 238 self.cfg.update() # ensure latest version from disk 239 self.cfg._verify_completeness() 240 241 242 try: 243 self.cfg.save() # persist back to disk 244 if getattr(self, "debug", False): 245 print("[DEBUG] Configuration updated and saved.") 246 except Exception as e: 247 print(f"[WARNING] Could not save configuration after update: {e}") 248 249 self.nodes_dict = self.cfg.conf_dict["nodes"] 250 251 self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_CONFIG', context="from_config") 252 253 # update defaults with kwargs 254 settings = dict(cls.DEFAULTS_CONFIG) 255 settings.update(kwargs) 256 257 # resolve device 258 device_arg = settings.get("device", "auto") 259 if device_arg == "auto": 260 device_str = "cuda" if torch.cuda.is_available() else "cpu" 261 else: 262 device_str = device_arg 263 self.device = torch.device(device_str) 264 265 # set flags on the instance so they are accessible later 266 self.debug = settings.get("debug", False) 267 self.verbose = settings.get("verbose", False) 268 269 if self.debug: 270 print(f"[DEBUG] TramDagModel using device: {self.device}") 271 272 # initialize settings storage 273 self.settings = {k: {} for k in settings.keys()} 274 275 # validate dict-typed args 276 for k, v in settings.items(): 277 if isinstance(v, dict): 278 expected = set(self.nodes_dict.keys()) 279 given = set(v.keys()) 280 if expected != given: 281 raise ValueError( 282 f"[ERROR] the provided argument '{k}' keys are not same as in cfg.conf_dict['nodes'].keys().\n" 283 f"Expected: {expected}, but got: {given}\n" 284 f"Please provide values for all variables.") 285 286 # build one model per node 287 self.models = {} 288 for node in self.nodes_dict.keys(): 289 per_node_kwargs = {} 290 for k, v in settings.items(): 291 resolved = v[node] if isinstance(v, dict) else v 292 per_node_kwargs[k] = resolved 293 self.settings[k][node] = resolved 294 if self.debug: 295 print(f"\n[INFO] Building model for node '{node}' with settings: {per_node_kwargs}") 296 self.models[node] = get_fully_specified_tram_model( 297 node=node, 298 configuration_dict=self.cfg.conf_dict, 299 **per_node_kwargs) 300 301 try: 302 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 303 NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}") 304 os.makedirs(NODE_DIR, exist_ok=True) 305 306 model_path = os.path.join(NODE_DIR, "initial_model.pt") 307 overwrite = settings.get("overwrite_initial_weights", True) 308 309 if overwrite or not os.path.exists(model_path): 310 torch.save(self.models[node].state_dict(), model_path) 311 if self.debug: 312 print(f"[DEBUG] Saved initial model state for node '{node}' to {model_path} (overwrite={overwrite})") 313 else: 314 if self.debug: 315 print(f"[DEBUG] Skipped saving initial model for node '{node}' (already exists at {model_path})") 316 except Exception as e: 317 print(f"[ERROR] Could not save initial model state for node '{node}': {e}") 318 319 TEMP_DIR = "temp" 320 if os.path.isdir(TEMP_DIR) and not os.listdir(TEMP_DIR): 321 os.rmdir(TEMP_DIR) 322 323 return self
Construct a TramDagModel from a TramDagConfig object.
This builds one TRAM model per node in the DAG and optionally writes the initial model parameters to disk.
Parameters
cfg : TramDagConfig
Configuration wrapper holding the underlying configuration dictionary,
including at least:
- conf_dict["nodes"]: mapping of node names to node configs.
- conf_dict["PATHS"]["EXPERIMENT_DIR"]: experiment directory.
**kwargs
Node-level construction options. Each key must be present in
DEFAULTS_CONFIG. Values can be:
- scalar: applied to all nodes.
- dict: mapping {node_name: value} for per-node overrides.
Common keys include:
device : {"auto", "cpu", "cuda"}, default "auto"
Device selection (CUDA if available when "auto").
debug : bool, default False
If True, print debug messages.
verbose : bool, default False
If True, print informational messages.
set_initial_weights : bool
Passed to underlying TRAM model constructors.
overwrite_initial_weights : bool, default True
If True, overwrite any existing ``initial_model.pt`` files per node.
initial_data : Any
Optional object passed down to node constructors.
Returns
TramDagModel
Fully initialized instance with:
- cfg
- nodes_dict
- models (per-node TRAMs)
- settings (resolved per-node config)
Raises
ValueError
If any dict-valued kwarg does not provide values for exactly the set
of nodes in cfg.conf_dict["nodes"].
325 @classmethod 326 def from_directory(cls, EXPERIMENT_DIR: str, device: str = "auto", debug: bool = False, verbose: bool = False): 327 """ 328 Reconstruct a TramDagModel from an experiment directory on disk. 329 330 This method: 331 1. Loads the configuration JSON. 332 2. Wraps it in a TramDagConfig. 333 3. Builds all node models via `from_config`. 334 4. Loads the min–max scaling dictionary. 335 336 Parameters 337 ---------- 338 EXPERIMENT_DIR : str 339 Path to an experiment directory containing: 340 - ``configuration.json`` 341 - ``min_max_scaling.json``. 342 device : {"auto", "cpu", "cuda"}, optional 343 Device selection. Default is "auto". 344 debug : bool, optional 345 If True, enable debug messages. Default is False. 346 verbose : bool, optional 347 If True, enable informational messages. Default is False. 348 349 Returns 350 ------- 351 TramDagModel 352 A TramDagModel instance with models, config, and scaling loaded. 353 354 Raises 355 ------ 356 FileNotFoundError 357 If configuration or min–max files cannot be found. 358 RuntimeError 359 If the min–max file cannot be read or parsed. 360 """ 361 362 # --- load config file --- 363 config_path = os.path.join(EXPERIMENT_DIR, "configuration.json") 364 if not os.path.exists(config_path): 365 raise FileNotFoundError(f"[ERROR] Config file not found at {config_path}") 366 367 with open(config_path, "r") as f: 368 cfg_dict = json.load(f) 369 370 # Create TramConfig wrapper 371 cfg = TramDagConfig(cfg_dict, CONF_DICT_PATH=config_path) 372 373 # --- build model from config --- 374 self = cls.from_config(cfg, device=device, debug=debug, verbose=verbose, overwrite_initial_weights=False) 375 376 # --- load minmax scaling --- 377 minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json") 378 if not os.path.exists(minmax_path): 379 raise FileNotFoundError(f"[ERROR] MinMax file not found at {minmax_path}") 380 381 with open(minmax_path, "r") as f: 382 self.minmax_dict = json.load(f) 383 384 if self.verbose or self.debug: 385 print(f"[INFO] Loaded TramDagModel from {EXPERIMENT_DIR}") 386 print(f"[INFO] Config loaded from {config_path}") 387 print(f"[INFO] MinMax scaling loaded from {minmax_path}") 388 389 return self
Reconstruct a TramDagModel from an experiment directory on disk.
This method:
- Loads the configuration JSON.
- Wraps it in a TramDagConfig.
- Builds all node models via
from_config. - Loads the min–max scaling dictionary.
Parameters
EXPERIMENT_DIR : str
Path to an experiment directory containing:
- configuration.json
- min_max_scaling.json.
device : {"auto", "cpu", "cuda"}, optional
Device selection. Default is "auto".
debug : bool, optional
If True, enable debug messages. Default is False.
verbose : bool, optional
If True, enable informational messages. Default is False.
Returns
TramDagModel A TramDagModel instance with models, config, and scaling loaded.
Raises
FileNotFoundError If configuration or min–max files cannot be found. RuntimeError If the min–max file cannot be read or parsed.
429 def load_or_compute_minmax(self, td_train_data=None,use_existing=False, write=True): 430 """ 431 Load an existing Min–Max scaling dictionary from disk or compute a new one 432 from the provided training dataset. 433 434 Parameters 435 ---------- 436 use_existing : bool, optional (default=False) 437 If True, attempts to load an existing `min_max_scaling.json` file 438 from the experiment directory. Raises an error if the file is missing 439 or unreadable. 440 441 write : bool, optional (default=True) 442 If True, writes the computed Min–Max scaling dictionary to 443 `<EXPERIMENT_DIR>/min_max_scaling.json`. 444 445 td_train_data : object, optional 446 Training dataset used to compute scaling statistics. If not provided, 447 the method will ensure or construct it via `_ensure_dataset(data=..., is_val=False)`. 448 449 Behavior 450 -------- 451 - If `use_existing=True`, loads the JSON file containing previously saved 452 min–max values and stores it in `self.minmax_dict`. 453 - If `use_existing=False`, computes a new scaling dictionary using 454 `td_train_data.compute_scaling()` and stores the result in 455 `self.minmax_dict`. 456 - Optionally writes the computed dictionary to disk. 457 458 Side Effects 459 ------------- 460 - Populates `self.minmax_dict` with scaling values. 461 - Writes or loads the file `min_max_scaling.json` under 462 `<EXPERIMENT_DIR>`. 463 - Prints diagnostic output if `self.debug` or `self.verbose` is True. 464 465 Raises 466 ------ 467 FileNotFoundError 468 If `use_existing=True` but the min–max file does not exist. 469 470 RuntimeError 471 If an existing min–max file cannot be read or parsed. 472 473 Notes 474 ----- 475 The computed min–max dictionary is expected to contain scaling statistics 476 per feature, typically in the form: 477 { 478 "node": {"min": float, "max": float}, 479 ... 480 } 481 """ 482 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 483 minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json") 484 485 # laod exisitng if possible 486 if use_existing: 487 if not os.path.exists(minmax_path): 488 raise FileNotFoundError(f"MinMax file not found: {minmax_path}") 489 try: 490 with open(minmax_path, 'r') as f: 491 self.minmax_dict = json.load(f) 492 if self.debug or self.verbose: 493 print(f"[INFO] Loaded existing minmax dict from {minmax_path}") 494 return 495 except Exception as e: 496 raise RuntimeError(f"Could not load existing minmax dict: {e}") 497 498 # 499 if self.debug or self.verbose: 500 print("[INFO] Computing new minmax dict from training data...") 501 502 td_train_data=self._ensure_dataset( data=td_train_data, is_val=False) 503 504 self.minmax_dict = td_train_data.compute_scaling() 505 506 if write: 507 os.makedirs(EXPERIMENT_DIR, exist_ok=True) 508 with open(minmax_path, 'w') as f: 509 json.dump(self.minmax_dict, f, indent=4) 510 if self.debug or self.verbose: 511 print(f"[INFO] Saved new minmax dict to {minmax_path}")
Load an existing Min–Max scaling dictionary from disk or compute a new one from the provided training dataset.
Parameters
use_existing : bool, optional (default=False)
If True, attempts to load an existing min_max_scaling.json file
from the experiment directory. Raises an error if the file is missing
or unreadable.
write : bool, optional (default=True)
If True, writes the computed Min–Max scaling dictionary to
<EXPERIMENT_DIR>/min_max_scaling.json.
td_train_data : object, optional
Training dataset used to compute scaling statistics. If not provided,
the method will ensure or construct it via _ensure_dataset(data=..., is_val=False).
Behavior
- If
use_existing=True, loads the JSON file containing previously saved min–max values and stores it inself.minmax_dict. - If
use_existing=False, computes a new scaling dictionary usingtd_train_data.compute_scaling()and stores the result inself.minmax_dict. - Optionally writes the computed dictionary to disk.
Side Effects
- Populates
self.minmax_dictwith scaling values. - Writes or loads the file
min_max_scaling.jsonunder<EXPERIMENT_DIR>. - Prints diagnostic output if
self.debugorself.verboseis True.
Raises
FileNotFoundError
If use_existing=True but the min–max file does not exist.
RuntimeError If an existing min–max file cannot be read or parsed.
Notes
The computed min–max dictionary is expected to contain scaling statistics per feature, typically in the form: { "node": {"min": float, "max": float}, ... }
612 def fit(self, train_data, val_data=None, **kwargs): 613 """ 614 Train TRAM models for all nodes in the DAG. 615 616 Coordinates dataset preparation, min–max scaling, and per-node training, 617 optionally in parallel on CPU. 618 619 Parameters 620 ---------- 621 train_data : pandas.DataFrame or TramDagDataset 622 Training data. If a DataFrame is given, it is converted into a 623 TramDagDataset using `_ensure_dataset`. 624 val_data : pandas.DataFrame or TramDagDataset or None, optional 625 Validation data. If a DataFrame is given, it is converted into a 626 TramDagDataset. If None, no validation loss is computed. 627 **kwargs 628 Overrides for ``DEFAULTS_FIT``. All keys must exist in 629 ``DEFAULTS_FIT``. Common options: 630 631 epochs : int, default 100 632 Number of training epochs per node. 633 learning_rate : float, default 0.01 634 Learning rate for the default Adam optimizer. 635 train_list : list of str or None, optional 636 List of node names to train. If None, all nodes are trained. 637 train_mode : {"sequential", "parallel"}, default "sequential" 638 Training mode. "parallel" uses joblib-based CPU multiprocessing. 639 GPU forces sequential mode. 640 device : {"auto", "cpu", "cuda"}, default "auto" 641 Device selection. 642 optimizers : dict or None 643 Optional mapping ``{node_name: optimizer}``. If provided for a 644 node, that optimizer is used instead of creating a new Adam. 645 schedulers : dict or None 646 Optional mapping ``{node_name: scheduler}``. 647 use_scheduler : bool 648 If True, enable scheduler usage in the training loop. 649 num_workers : int 650 DataLoader workers in sequential mode (ignored in parallel). 651 persistent_workers : bool 652 DataLoader persistence in sequential mode (ignored in parallel). 653 prefetch_factor : int 654 DataLoader prefetch factor (ignored in parallel). 655 batch_size : int 656 Batch size for all node DataLoaders. 657 debug : bool 658 Enable debug output. 659 verbose : bool 660 Enable informational logging. 661 return_history : bool 662 If True, return a history dict. 663 664 Returns 665 ------- 666 dict or None 667 If ``return_history=True``, a dictionary mapping each node name 668 to its training history. Otherwise, returns None. 669 670 Raises 671 ------ 672 ValueError 673 If ``train_mode`` is not "sequential" or "parallel". 674 """ 675 self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_FIT', context="fit") 676 677 # --- merge defaults --- 678 settings = dict(self.DEFAULTS_FIT) 679 settings.update(kwargs) 680 681 682 self.debug = settings.get("debug", False) 683 self.verbose = settings.get("verbose", False) 684 685 # --- resolve device --- 686 device_str=self.get_device(settings) 687 self.device = torch.device(device_str) 688 689 # --- training mode --- 690 train_mode = settings.get("train_mode", "sequential").lower() 691 if train_mode not in ("sequential", "parallel"): 692 raise ValueError("train_mode must be 'sequential' or 'parallel'") 693 694 # --- DataLoader safety logic --- 695 if train_mode == "parallel": 696 # if user passed loader paralleling params, warn and override 697 for flag in ("num_workers", "persistent_workers", "prefetch_factor"): 698 if flag in kwargs: 699 print(f"[WARNING] '{flag}' is ignored in parallel mode " 700 f"(disabled to prevent nested multiprocessing).") 701 # disable unsafe loader multiprocessing options 702 settings["num_workers"] = 0 703 settings["persistent_workers"] = False 704 settings["prefetch_factor"] = None 705 else: 706 # sequential mode → respect user DataLoader settings 707 if self.debug: 708 print("[DEBUG] Sequential mode: using DataLoader kwargs as provided.") 709 710 # --- which nodes to train --- 711 train_list = settings.get("train_list") or list(self.models.keys()) 712 713 714 # --- dataset prep (receives adjusted settings) --- 715 td_train_data = self._ensure_dataset(train_data, is_val=False, **settings) 716 td_val_data = self._ensure_dataset(val_data, is_val=True, **settings) 717 718 # --- normalization --- 719 self.load_or_compute_minmax(use_existing=False, write=True, td_train_data=td_train_data) 720 721 # --- print header --- 722 if self.verbose or self.debug: 723 print(f"[INFO] Training {len(train_list)} nodes ({train_mode}) on {device_str}") 724 725 # ====================================================================== 726 # Sequential mode safe for GPU or debugging) 727 # ====================================================================== 728 if train_mode == "sequential" or "cuda" in device_str: 729 if "cuda" in device_str and train_mode == "parallel": 730 print("[WARNING] GPU device detected — forcing sequential mode.") 731 results = {} 732 for node in train_list: 733 node, history = self._fit_single_node( 734 node, self, settings, td_train_data, td_val_data, device_str 735 ) 736 results[node] = history 737 738 739 # ====================================================================== 740 # parallel mode (CPU only) 741 # ====================================================================== 742 if train_mode == "parallel": 743 744 n_jobs = min(len(train_list), os.cpu_count() // 2 or 1) 745 if self.verbose or self.debug: 746 print(f"[INFO] Using {n_jobs} CPU workers for parallel node training") 747 parallel_outputs = Parallel( 748 n_jobs=n_jobs, 749 backend="loky",#loky, multiprocessing 750 verbose=10, 751 prefer="processes" 752 )(delayed(self._fit_single_node)(node, self, settings, td_train_data, td_val_data, device_str) for node in train_list ) 753 754 results = {node: hist for node, hist in parallel_outputs} 755 756 if settings.get("return_history", False): 757 return results
Train TRAM models for all nodes in the DAG.
Coordinates dataset preparation, min–max scaling, and per-node training, optionally in parallel on CPU.
Parameters
train_data : pandas.DataFrame or TramDagDataset
Training data. If a DataFrame is given, it is converted into a
TramDagDataset using _ensure_dataset.
val_data : pandas.DataFrame or TramDagDataset or None, optional
Validation data. If a DataFrame is given, it is converted into a
TramDagDataset. If None, no validation loss is computed.
**kwargs
Overrides for DEFAULTS_FIT. All keys must exist in
DEFAULTS_FIT. Common options:
epochs : int, default 100
Number of training epochs per node.
learning_rate : float, default 0.01
Learning rate for the default Adam optimizer.
train_list : list of str or None, optional
List of node names to train. If None, all nodes are trained.
train_mode : {"sequential", "parallel"}, default "sequential"
Training mode. "parallel" uses joblib-based CPU multiprocessing.
GPU forces sequential mode.
device : {"auto", "cpu", "cuda"}, default "auto"
Device selection.
optimizers : dict or None
Optional mapping ``{node_name: optimizer}``. If provided for a
node, that optimizer is used instead of creating a new Adam.
schedulers : dict or None
Optional mapping ``{node_name: scheduler}``.
use_scheduler : bool
If True, enable scheduler usage in the training loop.
num_workers : int
DataLoader workers in sequential mode (ignored in parallel).
persistent_workers : bool
DataLoader persistence in sequential mode (ignored in parallel).
prefetch_factor : int
DataLoader prefetch factor (ignored in parallel).
batch_size : int
Batch size for all node DataLoaders.
debug : bool
Enable debug output.
verbose : bool
Enable informational logging.
return_history : bool
If True, return a history dict.
Returns
dict or None
If return_history=True, a dictionary mapping each node name
to its training history. Otherwise, returns None.
Raises
ValueError
If train_mode is not "sequential" or "parallel".
760 def loss_history(self): 761 """ 762 Load training and validation loss history for all nodes. 763 764 Looks for per-node JSON files: 765 766 - ``EXPERIMENT_DIR/{node}/train_loss_hist.json`` 767 - ``EXPERIMENT_DIR/{node}/val_loss_hist.json`` 768 769 Returns 770 ------- 771 dict 772 A dictionary mapping node names to: 773 774 .. code-block:: python 775 776 { 777 "train": list or None, 778 "validation": list or None 779 } 780 781 where each list contains NLL values per epoch, or None if not found. 782 783 Raises 784 ------ 785 ValueError 786 If the experiment directory cannot be resolved from the configuration. 787 """ 788 try: 789 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 790 except KeyError: 791 raise ValueError( 792 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 793 "History retrieval requires experiment logs." 794 ) 795 796 all_histories = {} 797 for node in self.nodes_dict.keys(): 798 node_dir = os.path.join(EXPERIMENT_DIR, node) 799 train_path = os.path.join(node_dir, "train_loss_hist.json") 800 val_path = os.path.join(node_dir, "val_loss_hist.json") 801 802 node_hist = {} 803 804 # --- load train history --- 805 if os.path.exists(train_path): 806 try: 807 with open(train_path, "r") as f: 808 node_hist["train"] = json.load(f) 809 except Exception as e: 810 print(f"[WARNING] Could not load {train_path}: {e}") 811 node_hist["train"] = None 812 else: 813 node_hist["train"] = None 814 815 # --- load val history --- 816 if os.path.exists(val_path): 817 try: 818 with open(val_path, "r") as f: 819 node_hist["validation"] = json.load(f) 820 except Exception as e: 821 print(f"[WARNING] Could not load {val_path}: {e}") 822 node_hist["validation"] = None 823 else: 824 node_hist["validation"] = None 825 826 all_histories[node] = node_hist 827 828 if self.verbose or self.debug: 829 print(f"[INFO] Loaded training/validation histories for {len(all_histories)} nodes.") 830 831 return all_histories
Load training and validation loss history for all nodes.
Looks for per-node JSON files:
EXPERIMENT_DIR/{node}/train_loss_hist.jsonEXPERIMENT_DIR/{node}/val_loss_hist.json
Returns
dict A dictionary mapping node names to:
```python
{ "train": list or None, "validation": list or None } ```
where each list contains NLL values per epoch, or None if not found.
Raises
ValueError If the experiment directory cannot be resolved from the configuration.
833 def linear_shift_history(self): 834 """ 835 Load linear shift term histories for all nodes. 836 837 Each node history is expected in a JSON file named 838 ``linear_shifts_all_epochs.json`` under the node directory. 839 840 Returns 841 ------- 842 dict 843 A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame 844 contains linear shift weights across epochs. 845 846 Raises 847 ------ 848 ValueError 849 If the experiment directory cannot be resolved from the configuration. 850 851 Notes 852 ----- 853 If a history file is missing for a node, a warning is printed and the 854 node is omitted from the returned dictionary. 855 """ 856 histories = {} 857 try: 858 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 859 except KeyError: 860 raise ValueError( 861 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 862 "Cannot load histories without experiment directory." 863 ) 864 865 for node in self.nodes_dict.keys(): 866 node_dir = os.path.join(EXPERIMENT_DIR, node) 867 history_path = os.path.join(node_dir, "linear_shifts_all_epochs.json") 868 if os.path.exists(history_path): 869 histories[node] = pd.read_json(history_path) 870 else: 871 print(f"[WARNING] No linear shift history found for node '{node}' at {history_path}") 872 return histories
Load linear shift term histories for all nodes.
Each node history is expected in a JSON file named
linear_shifts_all_epochs.json under the node directory.
Returns
dict
A mapping {node_name: pandas.DataFrame}, where each DataFrame
contains linear shift weights across epochs.
Raises
ValueError If the experiment directory cannot be resolved from the configuration.
Notes
If a history file is missing for a node, a warning is printed and the node is omitted from the returned dictionary.
874 def simple_intercept_history(self): 875 """ 876 Load simple intercept histories for all nodes. 877 878 Each node history is expected in a JSON file named 879 ``simple_intercepts_all_epochs.json`` under the node directory. 880 881 Returns 882 ------- 883 dict 884 A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame 885 contains intercept weights across epochs. 886 887 Raises 888 ------ 889 ValueError 890 If the experiment directory cannot be resolved from the configuration. 891 892 Notes 893 ----- 894 If a history file is missing for a node, a warning is printed and the 895 node is omitted from the returned dictionary. 896 """ 897 histories = {} 898 try: 899 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 900 except KeyError: 901 raise ValueError( 902 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 903 "Cannot load histories without experiment directory." 904 ) 905 906 for node in self.nodes_dict.keys(): 907 node_dir = os.path.join(EXPERIMENT_DIR, node) 908 history_path = os.path.join(node_dir, "simple_intercepts_all_epochs.json") 909 if os.path.exists(history_path): 910 histories[node] = pd.read_json(history_path) 911 else: 912 print(f"[WARNING] No simple intercept history found for node '{node}' at {history_path}") 913 return histories
Load simple intercept histories for all nodes.
Each node history is expected in a JSON file named
simple_intercepts_all_epochs.json under the node directory.
Returns
dict
A mapping {node_name: pandas.DataFrame}, where each DataFrame
contains intercept weights across epochs.
Raises
ValueError If the experiment directory cannot be resolved from the configuration.
Notes
If a history file is missing for a node, a warning is printed and the node is omitted from the returned dictionary.
915 def get_latent(self, df, verbose=False): 916 """ 917 Compute latent representations for all nodes in the DAG. 918 919 Parameters 920 ---------- 921 df : pandas.DataFrame 922 Input data frame with columns corresponding to nodes in the DAG. 923 verbose : bool, optional 924 If True, print informational messages during latent computation. 925 Default is False. 926 927 Returns 928 ------- 929 pandas.DataFrame 930 DataFrame containing the original columns plus latent variables 931 for each node (e.g. columns named ``f"{node}_U"``). 932 933 Raises 934 ------ 935 ValueError 936 If the experiment directory is missing from the configuration or 937 if ``self.minmax_dict`` has not been set. 938 """ 939 try: 940 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 941 except KeyError: 942 raise ValueError( 943 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 944 "Latent extraction requires trained model checkpoints." 945 ) 946 947 # ensure minmax_dict is available 948 if not hasattr(self, "minmax_dict"): 949 raise ValueError( 950 "[ERROR] minmax_dict not found in the TramDagModel instance. " 951 "Either call .load_or_compute_minmax(td_train_data=train_df) or .fit() first." 952 ) 953 954 all_latents_df = create_latent_df_for_full_dag( 955 configuration_dict=self.cfg.conf_dict, 956 EXPERIMENT_DIR=EXPERIMENT_DIR, 957 df=df, 958 verbose=verbose, 959 min_max_dict=self.minmax_dict, 960 ) 961 962 return all_latents_df
Compute latent representations for all nodes in the DAG.
Parameters
df : pandas.DataFrame Input data frame with columns corresponding to nodes in the DAG. verbose : bool, optional If True, print informational messages during latent computation. Default is False.
Returns
pandas.DataFrame
DataFrame containing the original columns plus latent variables
for each node (e.g. columns named f"{node}_U").
Raises
ValueError
If the experiment directory is missing from the configuration or
if self.minmax_dict has not been set.
967 def plot_loss_history(self, variable: str = None): 968 """ 969 Plot training and validation loss evolution per node. 970 971 Parameters 972 ---------- 973 variable : str or None, optional 974 If provided, plot loss history for this node only. If None, plot 975 histories for all nodes that have both train and validation logs. 976 977 Returns 978 ------- 979 None 980 981 Notes 982 ----- 983 Two subplots are produced: 984 - Full epoch history. 985 - Last 10% of epochs (or only the last epoch if fewer than 5 epochs). 986 """ 987 988 histories = self.loss_history() 989 if not histories: 990 print("[WARNING] No loss histories found.") 991 return 992 993 # Select which nodes to plot 994 if variable is not None: 995 if variable not in histories: 996 raise ValueError(f"[ERROR] Node '{variable}' not found in histories.") 997 nodes_to_plot = [variable] 998 else: 999 nodes_to_plot = list(histories.keys()) 1000 1001 # Filter out nodes with no valid history 1002 nodes_to_plot = [ 1003 n for n in nodes_to_plot 1004 if histories[n].get("train") is not None and len(histories[n]["train"]) > 0 1005 and histories[n].get("validation") is not None and len(histories[n]["validation"]) > 0 1006 ] 1007 1008 if not nodes_to_plot: 1009 print("[WARNING] No valid histories found to plot.") 1010 return 1011 1012 plt.figure(figsize=(14, 12)) 1013 1014 # --- Full history (top plot) --- 1015 plt.subplot(2, 1, 1) 1016 for node in nodes_to_plot: 1017 node_hist = histories[node] 1018 train_hist, val_hist = node_hist["train"], node_hist["validation"] 1019 1020 epochs = range(1, len(train_hist) + 1) 1021 plt.plot(epochs, train_hist, label=f"{node} - train", linestyle="--") 1022 plt.plot(epochs, val_hist, label=f"{node} - val") 1023 1024 plt.title("Training and Validation NLL - Full History") 1025 plt.xlabel("Epoch") 1026 plt.ylabel("NLL") 1027 plt.legend() 1028 plt.grid(True) 1029 1030 # --- Last 10% of epochs (bottom plot) --- 1031 plt.subplot(2, 1, 2) 1032 for node in nodes_to_plot: 1033 node_hist = histories[node] 1034 train_hist, val_hist = node_hist["train"], node_hist["validation"] 1035 1036 total_epochs = len(train_hist) 1037 start_idx = total_epochs - 1 if total_epochs < 5 else int(total_epochs * 0.9) 1038 1039 epochs = range(start_idx + 1, total_epochs + 1) 1040 plt.plot(epochs, train_hist[start_idx:], label=f"{node} - train", linestyle="--") 1041 plt.plot(epochs, val_hist[start_idx:], label=f"{node} - val") 1042 1043 plt.title("Training and Validation NLL - Last 10% of Epochs (or Last Epoch if <5)") 1044 plt.xlabel("Epoch") 1045 plt.ylabel("NLL") 1046 plt.legend() 1047 plt.grid(True) 1048 1049 plt.tight_layout() 1050 plt.show()
Plot training and validation loss evolution per node.
Parameters
variable : str or None, optional If provided, plot loss history for this node only. If None, plot histories for all nodes that have both train and validation logs.
Returns
None
Notes
Two subplots are produced:
- Full epoch history.
- Last 10% of epochs (or only the last epoch if fewer than 5 epochs).
1052 def plot_linear_shift_history(self, data_dict=None, node=None, ref_lines=None): 1053 """ 1054 Plot the evolution of linear shift terms over epochs. 1055 1056 Parameters 1057 ---------- 1058 data_dict : dict or None, optional 1059 Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing shift 1060 weights across epochs. If None, `linear_shift_history()` is called. 1061 node : str or None, optional 1062 If provided, plot only this node. Otherwise, plot all nodes 1063 present in ``data_dict``. 1064 ref_lines : dict or None, optional 1065 Optional mapping ``{node_name: list of float}``. For each specified 1066 node, horizontal reference lines are drawn at the given values. 1067 1068 Returns 1069 ------- 1070 None 1071 1072 Notes 1073 ----- 1074 The function flattens nested list-like entries in the DataFrames to scalars, 1075 converts epoch labels to numeric, and then draws one line per shift term. 1076 """ 1077 1078 if data_dict is None: 1079 data_dict = self.linear_shift_history() 1080 if data_dict is None: 1081 raise ValueError("No shift history data provided or stored in the class.") 1082 1083 nodes = [node] if node else list(data_dict.keys()) 1084 1085 for n in nodes: 1086 df = data_dict[n].copy() 1087 1088 # Flatten nested lists or list-like cells 1089 def flatten(x): 1090 if isinstance(x, list): 1091 if len(x) == 0: 1092 return np.nan 1093 if all(isinstance(i, (int, float)) for i in x): 1094 return np.mean(x) # average simple list 1095 if all(isinstance(i, list) for i in x): 1096 # nested list -> flatten inner and average 1097 flat = [v for sub in x for v in (sub if isinstance(sub, list) else [sub])] 1098 return np.mean(flat) if flat else np.nan 1099 return x[0] if len(x) == 1 else np.nan 1100 return x 1101 1102 df = df.applymap(flatten) 1103 1104 # Ensure numeric columns 1105 df = df.apply(pd.to_numeric, errors='coerce') 1106 1107 # Convert epoch labels to numeric 1108 df.columns = [ 1109 int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c 1110 for c in df.columns 1111 ] 1112 df = df.reindex(sorted(df.columns), axis=1) 1113 1114 plt.figure(figsize=(10, 6)) 1115 for idx in df.index: 1116 plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"shift_{idx}") 1117 1118 if ref_lines and n in ref_lines: 1119 for v in ref_lines[n]: 1120 plt.axhline(y=v, color="k", linestyle="--", lw=1.0) 1121 plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8) 1122 1123 plt.xlabel("Epoch") 1124 plt.ylabel("Shift Value") 1125 plt.title(f"Shift Term History — Node: {n}") 1126 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small") 1127 plt.tight_layout() 1128 plt.show()
Plot the evolution of linear shift terms over epochs.
Parameters
data_dict : dict or None, optional
Pre-loaded mapping {node_name: pandas.DataFrame} containing shift
weights across epochs. If None, linear_shift_history() is called.
node : str or None, optional
If provided, plot only this node. Otherwise, plot all nodes
present in data_dict.
ref_lines : dict or None, optional
Optional mapping {node_name: list of float}. For each specified
node, horizontal reference lines are drawn at the given values.
Returns
None
Notes
The function flattens nested list-like entries in the DataFrames to scalars, converts epoch labels to numeric, and then draws one line per shift term.
1130 def plot_simple_intercepts_history(self, data_dict=None, node=None,ref_lines=None): 1131 """ 1132 Plot the evolution of simple intercept weights over epochs. 1133 1134 Parameters 1135 ---------- 1136 data_dict : dict or None, optional 1137 Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing intercept 1138 weights across epochs. If None, `simple_intercept_history()` is called. 1139 node : str or None, optional 1140 If provided, plot only this node. Otherwise, plot all nodes present 1141 in ``data_dict``. 1142 ref_lines : dict or None, optional 1143 Optional mapping ``{node_name: list of float}``. For each specified 1144 node, horizontal reference lines are drawn at the given values. 1145 1146 Returns 1147 ------- 1148 None 1149 1150 Notes 1151 ----- 1152 Nested list-like entries in the DataFrames are reduced to scalars before 1153 plotting. One line is drawn per intercept parameter. 1154 """ 1155 if data_dict is None: 1156 data_dict = self.simple_intercept_history() 1157 if data_dict is None: 1158 raise ValueError("No intercept history data provided or stored in the class.") 1159 1160 nodes = [node] if node else list(data_dict.keys()) 1161 1162 for n in nodes: 1163 df = data_dict[n].copy() 1164 1165 def extract_scalar(x): 1166 if isinstance(x, list): 1167 while isinstance(x, list) and len(x) > 0: 1168 x = x[0] 1169 return float(x) if isinstance(x, (int, float, np.floating)) else np.nan 1170 1171 df = df.applymap(extract_scalar) 1172 1173 # Convert epoch labels → numeric 1174 df.columns = [ 1175 int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c 1176 for c in df.columns 1177 ] 1178 df = df.reindex(sorted(df.columns), axis=1) 1179 1180 plt.figure(figsize=(10, 6)) 1181 for idx in df.index: 1182 plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"theta_{idx}") 1183 1184 if ref_lines and n in ref_lines: 1185 for v in ref_lines[n]: 1186 plt.axhline(y=v, color="k", linestyle="--", lw=1.0) 1187 plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8) 1188 1189 plt.xlabel("Epoch") 1190 plt.ylabel("Intercept Weight") 1191 plt.title(f"Simple Intercept Evolution — Node: {n}") 1192 plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small") 1193 plt.tight_layout() 1194 plt.show()
Plot the evolution of simple intercept weights over epochs.
Parameters
data_dict : dict or None, optional
Pre-loaded mapping {node_name: pandas.DataFrame} containing intercept
weights across epochs. If None, simple_intercept_history() is called.
node : str or None, optional
If provided, plot only this node. Otherwise, plot all nodes present
in data_dict.
ref_lines : dict or None, optional
Optional mapping {node_name: list of float}. For each specified
node, horizontal reference lines are drawn at the given values.
Returns
None
Notes
Nested list-like entries in the DataFrames are reduced to scalars before plotting. One line is drawn per intercept parameter.
1196 def plot_latents(self, df, variable: str = None, confidence: float = 0.95, simulations: int = 1000): 1197 """ 1198 Visualize latent U distributions for one or all nodes. 1199 1200 Parameters 1201 ---------- 1202 df : pandas.DataFrame 1203 Input data frame with raw node values. 1204 variable : str or None, optional 1205 If provided, only this node's latents are plotted. If None, all 1206 nodes with latent columns are processed. 1207 confidence : float, optional 1208 Confidence level for QQ-plot bands (0 < confidence < 1). 1209 Default is 0.95. 1210 simulations : int, optional 1211 Number of Monte Carlo simulations for QQ-plot bands. Default is 1000. 1212 1213 Returns 1214 ------- 1215 None 1216 1217 Notes 1218 ----- 1219 For each node, two plots are produced: 1220 - Histogram of the latent U values. 1221 - QQ-plot with simulation-based confidence bands under a logistic reference. 1222 """ 1223 # Compute latent representations 1224 latents_df = self.get_latent(df) 1225 1226 # Select nodes 1227 nodes = [variable] if variable is not None else self.nodes_dict.keys() 1228 1229 for node in nodes: 1230 if f"{node}_U" not in latents_df.columns: 1231 print(f"[WARNING] No latent found for node {node}, skipping.") 1232 continue 1233 1234 sample = latents_df[f"{node}_U"].values 1235 1236 # --- Create plots --- 1237 fig, axs = plt.subplots(1, 2, figsize=(12, 5)) 1238 1239 # Histogram 1240 axs[0].hist(sample, bins=50, color="steelblue", alpha=0.7) 1241 axs[0].set_title(f"Latent Histogram ({node})") 1242 axs[0].set_xlabel("U") 1243 axs[0].set_ylabel("Frequency") 1244 1245 # QQ Plot with confidence bands 1246 probplot(sample, dist="logistic", plot=axs[1]) 1247 self._add_r_style_confidence_bands(axs[1], sample, dist=logistic,confidence=confidence, simulations=simulations) 1248 axs[1].set_title(f"Latent QQ Plot ({node})") 1249 1250 plt.suptitle(f"Latent Diagnostics for Node: {node}", fontsize=14) 1251 plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 1252 plt.show()
Visualize latent U distributions for one or all nodes.
Parameters
df : pandas.DataFrame Input data frame with raw node values. variable : str or None, optional If provided, only this node's latents are plotted. If None, all nodes with latent columns are processed. confidence : float, optional Confidence level for QQ-plot bands (0 < confidence < 1). Default is 0.95. simulations : int, optional Number of Monte Carlo simulations for QQ-plot bands. Default is 1000.
Returns
None
Notes
For each node, two plots are produced:
- Histogram of the latent U values.
- QQ-plot with simulation-based confidence bands under a logistic reference.
1254 def plot_hdag(self,df,variables=None, plot_n_rows=1,**kwargs): 1255 1256 """ 1257 Visualize the transformation function h() for selected DAG nodes. 1258 1259 Parameters 1260 ---------- 1261 df : pandas.DataFrame 1262 Input data containing node values or model predictions. 1263 variables : list of str or None, optional 1264 Names of nodes to visualize. If None, all nodes in ``self.models`` 1265 are considered. 1266 plot_n_rows : int, optional 1267 Maximum number of rows from ``df`` to visualize. Default is 1. 1268 **kwargs 1269 Additional keyword arguments forwarded to the underlying plotting 1270 helpers (`show_hdag_continous` / `show_hdag_ordinal`). 1271 1272 Returns 1273 ------- 1274 None 1275 1276 Notes 1277 ----- 1278 - For continuous outcomes, `show_hdag_continous` is called. 1279 - For ordinal outcomes, `show_hdag_ordinal` is called. 1280 - Nodes that are neither continuous nor ordinal are skipped with a warning. 1281 """ 1282 1283 1284 if len(df)> 1: 1285 print("[WARNING] len(df)>1, set: plot_n_rows accordingly") 1286 1287 variables_list=variables if variables is not None else list(self.models.keys()) 1288 for node in variables_list: 1289 if is_outcome_modelled_continous(node, self.nodes_dict): 1290 show_hdag_continous(df,node=node,configuration_dict=self.cfg.conf_dict,minmax_dict=self.minmax_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1291 1292 elif is_outcome_modelled_ordinal(node, self.nodes_dict): 1293 show_hdag_ordinal(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1294 # plot_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs) 1295 # save_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,**kwargs) 1296 else: 1297 print(f"[WARNING] Node {node} is wheter ordinal nor continous, not implemented yet")
Visualize the transformation function h() for selected DAG nodes.
Parameters
df : pandas.DataFrame
Input data containing node values or model predictions.
variables : list of str or None, optional
Names of nodes to visualize. If None, all nodes in self.models
are considered.
plot_n_rows : int, optional
Maximum number of rows from df to visualize. Default is 1.
**kwargs
Additional keyword arguments forwarded to the underlying plotting
helpers (show_hdag_continous / show_hdag_ordinal).
Returns
None
Notes
- For continuous outcomes,
show_hdag_continousis called. - For ordinal outcomes,
show_hdag_ordinalis called. - Nodes that are neither continuous nor ordinal are skipped with a warning.
1356 def sample( 1357 self, 1358 do_interventions: dict = None, 1359 predefined_latent_samples_df: pd.DataFrame = None, 1360 **kwargs, 1361 ): 1362 """ 1363 Sample from the joint DAG using the trained TRAM models. 1364 1365 Allows for: 1366 1367 Oberservational sampling 1368 Interventional sampling via ``do()`` operations 1369 Counterfactial sampling using predefined latent draws and do() 1370 1371 Parameters 1372 ---------- 1373 do_interventions : dict or None, optional 1374 Mapping of node names to intervened (fixed) values. For example: 1375 ``{"x1": 1.0}`` represents ``do(x1 = 1.0)``. Default is None. 1376 predefined_latent_samples_df : pandas.DataFrame or None, optional 1377 DataFrame containing columns ``"{node}_U"`` with predefined latent 1378 draws to be used instead of sampling from the prior. Default is None. 1379 **kwargs 1380 Sampling options overriding internal defaults: 1381 1382 number_of_samples : int, default 10000 1383 Total number of samples to draw. 1384 batch_size : int, default 32 1385 Batch size for internal sampling loops. 1386 delete_all_previously_sampled : bool, default True 1387 If True, delete old sampling files in node-specific sampling 1388 directories before writing new ones. 1389 verbose : bool 1390 If True, print informational messages. 1391 debug : bool 1392 If True, print debug output. 1393 device : {"auto", "cpu", "cuda"} 1394 Device selection for sampling. 1395 use_initial_weights_for_sampling : bool, default False 1396 If True, sample from initial (untrained) model parameters. 1397 1398 Returns 1399 ------- 1400 tuple 1401 A tuple ``(sampled_by_node, latents_by_node)``: 1402 1403 sampled_by_node : dict 1404 Mapping ``{node_name: torch.Tensor}`` of sampled node values. 1405 latents_by_node : dict 1406 Mapping ``{node_name: torch.Tensor}`` of latent U values used. 1407 1408 Raises 1409 ------ 1410 ValueError 1411 If the experiment directory cannot be resolved or if scaling 1412 information (``self.minmax_dict``) is missing. 1413 RuntimeError 1414 If min–max scaling has not been computed before calling `sample`. 1415 """ 1416 try: 1417 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1418 except KeyError: 1419 raise ValueError( 1420 "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. " 1421 "Sampling requires trained model checkpoints." 1422 ) 1423 1424 # ---- defaults ---- 1425 settings = { 1426 "number_of_samples": 10_000, 1427 "batch_size": 32, 1428 "delete_all_previously_sampled": True, 1429 "verbose": self.verbose if hasattr(self, "verbose") else False, 1430 "debug": self.debug if hasattr(self, "debug") else False, 1431 "device": self.device.type if hasattr(self, "device") else "auto", 1432 "use_initial_weights_for_sampling": False, 1433 1434 } 1435 1436 # self._validate_kwargs( kwargs, defaults_attr= "settings", context="sample") 1437 1438 settings.update(kwargs) 1439 1440 1441 if not hasattr(self, "minmax_dict"): 1442 raise RuntimeError( 1443 "[ERROR] minmax_dict not found. You must call .fit() or .load_or_compute_minmax() " 1444 "before sampling, so scaling info is available." 1445 ) 1446 1447 # ---- resolve device ---- 1448 device_str=self.get_device(settings) 1449 self.device = torch.device(device_str) 1450 1451 1452 if self.debug or settings["debug"]: 1453 print(f"[DEBUG] sample(): device: {self.device}") 1454 1455 # ---- perform sampling ---- 1456 sampled_by_node, latents_by_node = sample_full_dag( 1457 configuration_dict=self.cfg.conf_dict, 1458 EXPERIMENT_DIR=EXPERIMENT_DIR, 1459 device=self.device, 1460 do_interventions=do_interventions or {}, 1461 predefined_latent_samples_df=predefined_latent_samples_df, 1462 number_of_samples=settings["number_of_samples"], 1463 batch_size=settings["batch_size"], 1464 delete_all_previously_sampled=settings["delete_all_previously_sampled"], 1465 verbose=settings["verbose"], 1466 debug=settings["debug"], 1467 minmax_dict=self.minmax_dict, 1468 use_initial_weights_for_sampling=settings["use_initial_weights_for_sampling"] 1469 ) 1470 1471 return sampled_by_node, latents_by_node
Sample from the joint DAG using the trained TRAM models.
Allows for:
Oberservational sampling
Interventional sampling via do() operations
Counterfactial sampling using predefined latent draws and do()
Parameters
do_interventions : dict or None, optional
Mapping of node names to intervened (fixed) values. For example:
{"x1": 1.0} represents do(x1 = 1.0). Default is None.
predefined_latent_samples_df : pandas.DataFrame or None, optional
DataFrame containing columns "{node}_U" with predefined latent
draws to be used instead of sampling from the prior. Default is None.
**kwargs
Sampling options overriding internal defaults:
number_of_samples : int, default 10000
Total number of samples to draw.
batch_size : int, default 32
Batch size for internal sampling loops.
delete_all_previously_sampled : bool, default True
If True, delete old sampling files in node-specific sampling
directories before writing new ones.
verbose : bool
If True, print informational messages.
debug : bool
If True, print debug output.
device : {"auto", "cpu", "cuda"}
Device selection for sampling.
use_initial_weights_for_sampling : bool, default False
If True, sample from initial (untrained) model parameters.
Returns
tuple
A tuple (sampled_by_node, latents_by_node):
sampled_by_node : dict
Mapping ``{node_name: torch.Tensor}`` of sampled node values.
latents_by_node : dict
Mapping ``{node_name: torch.Tensor}`` of latent U values used.
Raises
ValueError
If the experiment directory cannot be resolved or if scaling
information (self.minmax_dict) is missing.
RuntimeError
If min–max scaling has not been computed before calling sample.
1473 def load_sampled_and_latents(self, EXPERIMENT_DIR: str = None, nodes: list = None): 1474 """ 1475 Load previously stored sampled values and latents for each node. 1476 1477 Parameters 1478 ---------- 1479 EXPERIMENT_DIR : str or None, optional 1480 Experiment directory path. If None, it is taken from 1481 ``self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]``. 1482 nodes : list of str or None, optional 1483 Nodes for which to load samples. If None, use all nodes from 1484 ``self.nodes_dict``. 1485 1486 Returns 1487 ------- 1488 tuple 1489 A tuple ``(sampled_by_node, latents_by_node)``: 1490 1491 sampled_by_node : dict 1492 Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU). 1493 latents_by_node : dict 1494 Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU). 1495 1496 Raises 1497 ------ 1498 ValueError 1499 If the experiment directory cannot be resolved or if no node list 1500 is available and ``nodes`` is None. 1501 1502 Notes 1503 ----- 1504 Nodes without both ``sampled.pt`` and ``latents.pt`` files are skipped 1505 with a warning. 1506 """ 1507 # --- resolve paths and node list --- 1508 if EXPERIMENT_DIR is None: 1509 try: 1510 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1511 except (AttributeError, KeyError): 1512 raise ValueError( 1513 "[ERROR] Could not resolve EXPERIMENT_DIR from cfg.conf_dict['PATHS']. " 1514 "Please provide EXPERIMENT_DIR explicitly." 1515 ) 1516 1517 if nodes is None: 1518 if hasattr(self, "nodes_dict"): 1519 nodes = list(self.nodes_dict.keys()) 1520 else: 1521 raise ValueError( 1522 "[ERROR] No node list found. Please provide `nodes` or initialize model with a config." 1523 ) 1524 1525 # --- load tensors --- 1526 sampled_by_node = {} 1527 latents_by_node = {} 1528 1529 for node in nodes: 1530 node_dir = os.path.join(EXPERIMENT_DIR, f"{node}") 1531 sampling_dir = os.path.join(node_dir, "sampling") 1532 1533 sampled_path = os.path.join(sampling_dir, "sampled.pt") 1534 latents_path = os.path.join(sampling_dir, "latents.pt") 1535 1536 if not os.path.exists(sampled_path) or not os.path.exists(latents_path): 1537 print(f"[WARNING] Missing files for node '{node}' — skipping.") 1538 continue 1539 1540 try: 1541 sampled = torch.load(sampled_path, map_location="cpu") 1542 latent_sample = torch.load(latents_path, map_location="cpu") 1543 except Exception as e: 1544 print(f"[ERROR] Could not load sampling files for node '{node}': {e}") 1545 continue 1546 1547 sampled_by_node[node] = sampled.detach().cpu() 1548 latents_by_node[node] = latent_sample.detach().cpu() 1549 1550 if self.verbose or self.debug: 1551 print(f"[INFO] Loaded sampled and latent tensors for {len(sampled_by_node)} nodes from {EXPERIMENT_DIR}") 1552 1553 return sampled_by_node, latents_by_node
Load previously stored sampled values and latents for each node.
Parameters
EXPERIMENT_DIR : str or None, optional
Experiment directory path. If None, it is taken from
self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"].
nodes : list of str or None, optional
Nodes for which to load samples. If None, use all nodes from
self.nodes_dict.
Returns
tuple
A tuple (sampled_by_node, latents_by_node):
sampled_by_node : dict
Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU).
latents_by_node : dict
Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU).
Raises
ValueError
If the experiment directory cannot be resolved or if no node list
is available and nodes is None.
Notes
Nodes without both sampled.pt and latents.pt files are skipped
with a warning.
1555 def plot_samples_vs_true( 1556 self, 1557 df, 1558 sampled: dict = None, 1559 variable: list = None, 1560 bins: int = 100, 1561 hist_true_color: str = "blue", 1562 hist_est_color: str = "orange", 1563 figsize: tuple = (14, 5), 1564 ): 1565 1566 1567 """ 1568 Compare sampled vs. observed distributions for selected nodes. 1569 1570 Parameters 1571 ---------- 1572 df : pandas.DataFrame 1573 Data frame containing the observed node values. 1574 sampled : dict or None, optional 1575 Optional mapping ``{node_name: array-like or torch.Tensor}`` of sampled 1576 values. If None or if a node is missing, samples are loaded from 1577 ``EXPERIMENT_DIR/{node}/sampling/sampled.pt``. 1578 variable : list of str or None, optional 1579 Subset of nodes to plot. If None, all nodes in the configuration 1580 are considered. 1581 bins : int, optional 1582 Number of histogram bins for continuous variables. Default is 100. 1583 hist_true_color : str, optional 1584 Color name for the histogram of true values. Default is "blue". 1585 hist_est_color : str, optional 1586 Color name for the histogram of sampled values. Default is "orange". 1587 figsize : tuple, optional 1588 Figure size for the matplotlib plots. Default is (14, 5). 1589 1590 Returns 1591 ------- 1592 None 1593 1594 Notes 1595 ----- 1596 - Continuous outcomes: histogram overlay + QQ-plot. 1597 - Ordinal outcomes: side-by-side bar plot of relative frequencies. 1598 - Other categorical outcomes: side-by-side bar plot with category labels. 1599 - If samples are probabilistic (2D tensor), the argmax across classes is used. 1600 """ 1601 1602 target_nodes = self.cfg.conf_dict["nodes"] 1603 experiment_dir = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1604 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 1605 1606 plot_list = variable if variable is not None else target_nodes 1607 1608 for node in plot_list: 1609 # Load sampled data 1610 if sampled is not None and node in sampled: 1611 sdata = sampled[node] 1612 if isinstance(sdata, torch.Tensor): 1613 sampled_vals = sdata.detach().cpu().numpy() 1614 else: 1615 sampled_vals = np.asarray(sdata) 1616 else: 1617 sample_path = os.path.join(experiment_dir, f"{node}/sampling/sampled.pt") 1618 if not os.path.isfile(sample_path): 1619 print(f"[WARNING] skip {node}: {sample_path} not found.") 1620 continue 1621 1622 try: 1623 sampled_vals = torch.load(sample_path, map_location=device).cpu().numpy() 1624 except Exception as e: 1625 print(f"[ERROR] Could not load {sample_path}: {e}") 1626 continue 1627 1628 # If logits/probabilities per sample, take argmax 1629 if sampled_vals.ndim == 2: 1630 print(f"[INFO] CAUTION! {node}: samples are probabilistic — each sample follows a probability " 1631 f"distribution based on the valid latent range. " 1632 f"Note that this frequency plot reflects only the distribution of the most probable " 1633 f"class per sample.") 1634 sampled_vals = np.argmax(sampled_vals, axis=1) 1635 1636 sampled_vals = sampled_vals[np.isfinite(sampled_vals)] 1637 1638 if node not in df.columns: 1639 print(f"[WARNING] skip {node}: column not found in DataFrame.") 1640 continue 1641 1642 true_vals = df[node].dropna().values 1643 true_vals = true_vals[np.isfinite(true_vals)] 1644 1645 if sampled_vals.size == 0 or true_vals.size == 0: 1646 print(f"[WARNING] skip {node}: empty array after NaN/Inf removal.") 1647 continue 1648 1649 fig, axs = plt.subplots(1, 2, figsize=figsize) 1650 1651 if is_outcome_modelled_continous(node, target_nodes): 1652 axs[0].hist(true_vals, bins=bins, density=True, alpha=0.6, 1653 color=hist_true_color, label=f"True {node}") 1654 axs[0].hist(sampled_vals, bins=bins, density=True, alpha=0.6, 1655 color=hist_est_color, label="Sampled") 1656 axs[0].set_xlabel("Value") 1657 axs[0].set_ylabel("Density") 1658 axs[0].set_title(f"Histogram overlay for {node}") 1659 axs[0].legend() 1660 axs[0].grid(True, ls="--", alpha=0.4) 1661 1662 qqplot_2samples(true_vals, sampled_vals, line="45", ax=axs[1]) 1663 axs[1].set_xlabel("True quantiles") 1664 axs[1].set_ylabel("Sampled quantiles") 1665 axs[1].set_title(f"QQ plot for {node}") 1666 axs[1].grid(True, ls="--", alpha=0.4) 1667 1668 elif is_outcome_modelled_ordinal(node, target_nodes): 1669 unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals)) 1670 unique_vals = np.sort(unique_vals) 1671 true_counts = np.array([(true_vals == val).sum() for val in unique_vals]) 1672 sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals]) 1673 1674 axs[0].bar(unique_vals - 0.2, true_counts / true_counts.sum(), 1675 width=0.4, color=hist_true_color, alpha=0.7, label="True") 1676 axs[0].bar(unique_vals + 0.2, sampled_counts / sampled_counts.sum(), 1677 width=0.4, color=hist_est_color, alpha=0.7, label="Sampled") 1678 axs[0].set_xticks(unique_vals) 1679 axs[0].set_xlabel("Ordinal Level") 1680 axs[0].set_ylabel("Relative Frequency") 1681 axs[0].set_title(f"Ordinal bar plot for {node}") 1682 axs[0].legend() 1683 axs[0].grid(True, ls="--", alpha=0.4) 1684 axs[1].axis("off") 1685 1686 else: 1687 unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals)) 1688 unique_vals = sorted(unique_vals, key=str) 1689 true_counts = np.array([(true_vals == val).sum() for val in unique_vals]) 1690 sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals]) 1691 1692 axs[0].bar(np.arange(len(unique_vals)) - 0.2, true_counts / true_counts.sum(), 1693 width=0.4, color=hist_true_color, alpha=0.7, label="True") 1694 axs[0].bar(np.arange(len(unique_vals)) + 0.2, sampled_counts / sampled_counts.sum(), 1695 width=0.4, color=hist_est_color, alpha=0.7, label="Sampled") 1696 axs[0].set_xticks(np.arange(len(unique_vals))) 1697 axs[0].set_xticklabels(unique_vals, rotation=45) 1698 axs[0].set_xlabel("Category") 1699 axs[0].set_ylabel("Relative Frequency") 1700 axs[0].set_title(f"Categorical bar plot for {node}") 1701 axs[0].legend() 1702 axs[0].grid(True, ls="--", alpha=0.4) 1703 axs[1].axis("off") 1704 1705 plt.tight_layout() 1706 plt.show()
Compare sampled vs. observed distributions for selected nodes.
Parameters
df : pandas.DataFrame
Data frame containing the observed node values.
sampled : dict or None, optional
Optional mapping {node_name: array-like or torch.Tensor} of sampled
values. If None or if a node is missing, samples are loaded from
EXPERIMENT_DIR/{node}/sampling/sampled.pt.
variable : list of str or None, optional
Subset of nodes to plot. If None, all nodes in the configuration
are considered.
bins : int, optional
Number of histogram bins for continuous variables. Default is 100.
hist_true_color : str, optional
Color name for the histogram of true values. Default is "blue".
hist_est_color : str, optional
Color name for the histogram of sampled values. Default is "orange".
figsize : tuple, optional
Figure size for the matplotlib plots. Default is (14, 5).
Returns
None
Notes
- Continuous outcomes: histogram overlay + QQ-plot.
- Ordinal outcomes: side-by-side bar plot of relative frequencies.
- Other categorical outcomes: side-by-side bar plot with category labels.
- If samples are probabilistic (2D tensor), the argmax across classes is used.
1709 def nll(self,data,variables=None): 1710 """ 1711 Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes. 1712 1713 This function evaluates trained TRAM models for each specified variable (node) 1714 on the provided dataset. It performs forward passes only—no training, no weight 1715 updates—and returns the mean NLL per node. 1716 1717 Parameters 1718 ---------- 1719 data : object 1720 Input dataset or data source compatible with `_ensure_dataset`, containing 1721 both inputs and targets for each node. 1722 variables : list[str], optional 1723 List of variable (node) names to evaluate. If None, all nodes in 1724 `self.models` are evaluated. 1725 1726 Returns 1727 ------- 1728 dict[str, float] 1729 Dictionary mapping each node name to its average NLL value. 1730 1731 Notes 1732 ----- 1733 - Each model is evaluated independently on its respective DataLoader. 1734 - The normalization values (`min_max`) for each node are retrieved from 1735 `self.minmax_dict[node]`. 1736 - The function uses `evaluate_tramdag_model()` for per-node evaluation. 1737 - Expected directory structure: 1738 `<EXPERIMENT_DIR>/<node>/` 1739 where each node directory contains the trained model. 1740 """ 1741 1742 td_data = self._ensure_dataset(data, is_val=True) 1743 variables_list = variables if variables != None else list(self.models.keys()) 1744 nll_dict = {} 1745 for node in variables_list: 1746 min_vals = torch.tensor(self.minmax_dict[node][0], dtype=torch.float32) 1747 max_vals = torch.tensor(self.minmax_dict[node][1], dtype=torch.float32) 1748 min_max = torch.stack([min_vals, max_vals], dim=0) 1749 data_loader = td_data.loaders[node] 1750 model = self.models[node] 1751 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1752 NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}") 1753 nll= evaluate_tramdag_model(node=node, 1754 target_nodes=self.nodes_dict, 1755 NODE_DIR=NODE_DIR, 1756 tram_model=model, 1757 data_loader=data_loader, 1758 min_max=min_max) 1759 nll_dict[node]=nll 1760 return nll_dict
Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes.
This function evaluates trained TRAM models for each specified variable (node) on the provided dataset. It performs forward passes only—no training, no weight updates—and returns the mean NLL per node.
Parameters
data : object
Input dataset or data source compatible with _ensure_dataset, containing
both inputs and targets for each node.
variables : list[str], optional
List of variable (node) names to evaluate. If None, all nodes in
self.models are evaluated.
Returns
dict[str, float] Dictionary mapping each node name to its average NLL value.
Notes
- Each model is evaluated independently on its respective DataLoader.
- The normalization values (
min_max) for each node are retrieved fromself.minmax_dict[node]. - The function uses
evaluate_tramdag_model()for per-node evaluation. - Expected directory structure:
<EXPERIMENT_DIR>/<node>/where each node directory contains the trained model.
1762 def get_train_val_nll(self, node: str, mode: str) -> tuple[float, float]: 1763 """ 1764 Retrieve training and validation NLL for a node and a given model state. 1765 1766 Parameters 1767 ---------- 1768 node : str 1769 Node name. 1770 mode : {"best", "last", "init"} 1771 State of interest: 1772 - "best": epoch with lowest validation NLL. 1773 - "last": final epoch. 1774 - "init": first epoch (index 0). 1775 1776 Returns 1777 ------- 1778 tuple of (float or None, float or None) 1779 A tuple ``(train_nll, val_nll)`` for the requested mode. 1780 Returns ``(None, None)`` if loss files are missing or cannot be read. 1781 1782 Notes 1783 ----- 1784 This method expects per-node JSON files: 1785 1786 - ``train_loss_hist.json`` 1787 - ``val_loss_hist.json`` 1788 1789 in the node directory. 1790 """ 1791 NODE_DIR = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node) 1792 train_path = os.path.join(NODE_DIR, "train_loss_hist.json") 1793 val_path = os.path.join(NODE_DIR, "val_loss_hist.json") 1794 1795 if not os.path.exists(train_path) or not os.path.exists(val_path): 1796 if getattr(self, "debug", False): 1797 print(f"[DEBUG] Missing loss files for node '{node}'. Returning None.") 1798 return None, None 1799 1800 try: 1801 with open(train_path, "r") as f: 1802 train_hist = json.load(f) 1803 with open(val_path, "r") as f: 1804 val_hist = json.load(f) 1805 1806 train_nlls = np.array(train_hist) 1807 val_nlls = np.array(val_hist) 1808 1809 if mode == "init": 1810 idx = 0 1811 elif mode == "last": 1812 idx = len(val_nlls) - 1 1813 elif mode == "best": 1814 idx = int(np.argmin(val_nlls)) 1815 else: 1816 raise ValueError(f"Invalid mode '{mode}' — must be one of 'best', 'last', 'init'.") 1817 1818 train_nll = float(train_nlls[idx]) 1819 val_nll = float(val_nlls[idx]) 1820 return train_nll, val_nll 1821 1822 except Exception as e: 1823 print(f"[ERROR] Failed to load NLLs for node '{node}' ({mode}): {e}") 1824 return None, None
Retrieve training and validation NLL for a node and a given model state.
Parameters
node : str Node name. mode : {"best", "last", "init"} State of interest: - "best": epoch with lowest validation NLL. - "last": final epoch. - "init": first epoch (index 0).
Returns
tuple of (float or None, float or None)
A tuple (train_nll, val_nll) for the requested mode.
Returns (None, None) if loss files are missing or cannot be read.
Notes
This method expects per-node JSON files:
train_loss_hist.jsonval_loss_hist.json
in the node directory.
1826 def get_thetas(self, node: str, state: str = "best"): 1827 """ 1828 Return transformed intercept (theta) parameters for a node and state. 1829 1830 Parameters 1831 ---------- 1832 node : str 1833 Node name. 1834 state : {"best", "last", "init"}, optional 1835 Model state for which to return parameters. Default is "best". 1836 1837 Returns 1838 ------- 1839 Any or None 1840 Transformed theta parameters for the requested node and state. 1841 The exact structure (scalar, list, or other) depends on the model. 1842 1843 Raises 1844 ------ 1845 ValueError 1846 If an invalid state is given (not in {"best", "last", "init"}). 1847 1848 Notes 1849 ----- 1850 Intercept dictionaries are cached on the instance under the attribute 1851 ``intercept_dicts``. If missing or incomplete, they are recomputed using 1852 `get_simple_intercepts_dict`. 1853 """ 1854 1855 state = state.lower() 1856 if state not in ["best", "last", "init"]: 1857 raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].") 1858 1859 dict_attr = "intercept_dicts" 1860 1861 # If no cached intercepts exist, compute them 1862 if not hasattr(self, dict_attr): 1863 if getattr(self, "debug", False): 1864 print(f"[DEBUG] '{dict_attr}' not found, computing via get_simple_intercepts_dict().") 1865 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1866 1867 all_dicts = getattr(self, dict_attr) 1868 1869 # If the requested state isn’t cached, recompute 1870 if state not in all_dicts: 1871 if getattr(self, "debug", False): 1872 print(f"[DEBUG] State '{state}' not found in cached intercepts, recomputing full dict.") 1873 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1874 all_dicts = getattr(self, dict_attr) 1875 1876 state_dict = all_dicts.get(state, {}) 1877 1878 # Return cached node intercept if present 1879 if node in state_dict: 1880 return state_dict[node] 1881 1882 # If not found, recompute full dict as fallback 1883 if getattr(self, "debug", False): 1884 print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.") 1885 setattr(self, dict_attr, self.get_simple_intercepts_dict()) 1886 all_dicts = getattr(self, dict_attr) 1887 return all_dicts.get(state, {}).get(node, None)
Return transformed intercept (theta) parameters for a node and state.
Parameters
node : str Node name. state : {"best", "last", "init"}, optional Model state for which to return parameters. Default is "best".
Returns
Any or None Transformed theta parameters for the requested node and state. The exact structure (scalar, list, or other) depends on the model.
Raises
ValueError If an invalid state is given (not in {"best", "last", "init"}).
Notes
Intercept dictionaries are cached on the instance under the attribute
intercept_dicts. If missing or incomplete, they are recomputed using
get_simple_intercepts_dict.
1889 def get_linear_shifts(self, node: str, state: str = "best"): 1890 """ 1891 Return learned linear shift terms for a node and a given state. 1892 1893 Parameters 1894 ---------- 1895 node : str 1896 Node name. 1897 state : {"best", "last", "init"}, optional 1898 Model state for which to return linear shift terms. Default is "best". 1899 1900 Returns 1901 ------- 1902 dict or Any or None 1903 Linear shift terms for the given node and state. Usually a dict 1904 mapping term names to weights. 1905 1906 Raises 1907 ------ 1908 ValueError 1909 If an invalid state is given (not in {"best", "last", "init"}). 1910 1911 Notes 1912 ----- 1913 Linear shift dictionaries are cached on the instance under the attribute 1914 ``linear_shift_dicts``. If missing or incomplete, they are recomputed using 1915 `get_linear_shifts_dict`. 1916 """ 1917 state = state.lower() 1918 if state not in ["best", "last", "init"]: 1919 raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].") 1920 1921 dict_attr = "linear_shift_dicts" 1922 1923 # If no global dicts cached, compute once 1924 if not hasattr(self, dict_attr): 1925 if getattr(self, "debug", False): 1926 print(f"[DEBUG] '{dict_attr}' not found, computing via get_linear_shifts_dict().") 1927 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1928 1929 all_dicts = getattr(self, dict_attr) 1930 1931 # If the requested state isn't cached, compute all again (covers fresh runs) 1932 if state not in all_dicts: 1933 if getattr(self, "debug", False): 1934 print(f"[DEBUG] State '{state}' not found in cached linear shifts, recomputing full dict.") 1935 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1936 all_dicts = getattr(self, dict_attr) 1937 1938 # Now fetch the dictionary for this state 1939 state_dict = all_dicts.get(state, {}) 1940 1941 # If the node is available, return its entry 1942 if node in state_dict: 1943 return state_dict[node] 1944 1945 # If missing, try recomputing (fallback) 1946 if getattr(self, "debug", False): 1947 print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.") 1948 setattr(self, dict_attr, self.get_linear_shifts_dict()) 1949 all_dicts = getattr(self, dict_attr) 1950 return all_dicts.get(state, {}).get(node, None)
Return learned linear shift terms for a node and a given state.
Parameters
node : str Node name. state : {"best", "last", "init"}, optional Model state for which to return linear shift terms. Default is "best".
Returns
dict or Any or None Linear shift terms for the given node and state. Usually a dict mapping term names to weights.
Raises
ValueError If an invalid state is given (not in {"best", "last", "init"}).
Notes
Linear shift dictionaries are cached on the instance under the attribute
linear_shift_dicts. If missing or incomplete, they are recomputed using
get_linear_shifts_dict.
1952 def get_linear_shifts_dict(self): 1953 """ 1954 Compute linear shift term dictionaries for all nodes and states. 1955 1956 For each node and each available state ("best", "last", "init"), this 1957 method loads the corresponding model checkpoint, extracts linear shift 1958 weights from the TRAM model, and stores them in a nested dictionary. 1959 1960 Returns 1961 ------- 1962 dict 1963 Nested dictionary of the form: 1964 1965 .. code-block:: python 1966 1967 { 1968 "best": {node: {...}}, 1969 "last": {node: {...}}, 1970 "init": {node: {...}}, 1971 } 1972 1973 where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``) 1974 to their weights. 1975 1976 Notes 1977 ----- 1978 - If "best" or "last" checkpoints are unavailable for a node, only 1979 the "init" entry is populated. 1980 - Empty outer states (without any nodes) are removed from the result. 1981 """ 1982 1983 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 1984 nodes_list = list(self.models.keys()) 1985 all_states = ["best", "last", "init"] 1986 all_linear_shift_dicts = {state: {} for state in all_states} 1987 1988 for node in nodes_list: 1989 NODE_DIR = os.path.join(EXPERIMENT_DIR, node) 1990 BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR) 1991 INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt") 1992 1993 state_paths = { 1994 "best": BEST_MODEL_PATH, 1995 "last": LAST_MODEL_PATH, 1996 "init": INIT_MODEL_PATH, 1997 } 1998 1999 for state, LOAD_PATH in state_paths.items(): 2000 if not os.path.exists(LOAD_PATH): 2001 if state != "init": 2002 # skip best/last if unavailable 2003 continue 2004 else: 2005 print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.") 2006 if not os.path.exists(LOAD_PATH): 2007 if getattr(self, "debug", False): 2008 print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.") 2009 continue 2010 2011 # Load parents and model 2012 _, terms_dict, _ = ordered_parents(node, self.nodes_dict) 2013 state_dict = torch.load(LOAD_PATH, map_location=self.device) 2014 tram_model = self.models[node] 2015 tram_model.load_state_dict(state_dict) 2016 2017 epoch_weights = {} 2018 if hasattr(tram_model, "nn_shift") and tram_model.nn_shift is not None: 2019 for i, shift_layer in enumerate(tram_model.nn_shift): 2020 module_name = shift_layer.__class__.__name__ 2021 if ( 2022 hasattr(shift_layer, "fc") 2023 and hasattr(shift_layer.fc, "weight") 2024 and module_name == "LinearShift" 2025 ): 2026 term_name = list(terms_dict.keys())[i] 2027 epoch_weights[f"ls({term_name})"] = ( 2028 shift_layer.fc.weight.detach().cpu().squeeze().tolist() 2029 ) 2030 elif getattr(self, "debug", False): 2031 term_name = list(terms_dict.keys())[i] 2032 print(f"[DEBUG] ls({term_name}): missing 'fc' or 'weight' in LinearShift.") 2033 else: 2034 if getattr(self, "debug", False): 2035 print(f"[DEBUG] Tram model for node '{node}' has no nn_shift or it is None.") 2036 2037 all_linear_shift_dicts[state][node] = epoch_weights 2038 2039 # Remove empty states (e.g., when best/last not found for all nodes) 2040 all_linear_shift_dicts = {k: v for k, v in all_linear_shift_dicts.items() if v} 2041 2042 return all_linear_shift_dicts
Compute linear shift term dictionaries for all nodes and states.
For each node and each available state ("best", "last", "init"), this method loads the corresponding model checkpoint, extracts linear shift weights from the TRAM model, and stores them in a nested dictionary.
Returns
dict Nested dictionary of the form:
```python
{ "best": {node: {...}}, "last": {node: {...}}, "init": {node: {...}}, } ```
where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``)
to their weights.
Notes
- If "best" or "last" checkpoints are unavailable for a node, only the "init" entry is populated.
- Empty outer states (without any nodes) are removed from the result.
2044 def get_simple_intercepts_dict(self): 2045 """ 2046 Compute transformed simple intercept dictionaries for all nodes and states. 2047 2048 For each node and each available state ("best", "last", "init"), this 2049 method loads the corresponding model checkpoint, extracts simple intercept 2050 weights, transforms them into interpretable theta parameters, and stores 2051 them in a nested dictionary. 2052 2053 Returns 2054 ------- 2055 dict 2056 Nested dictionary of the form: 2057 2058 .. code-block:: python 2059 2060 { 2061 "best": {node: [[theta_1], [theta_2], ...]}, 2062 "last": {node: [[theta_1], [theta_2], ...]}, 2063 "init": {node: [[theta_1], [theta_2], ...]}, 2064 } 2065 2066 Notes 2067 ----- 2068 - For ordinal models (``self.is_ontram == True``), `transform_intercepts_ordinal` 2069 is used. 2070 - For continuous models, `transform_intercepts_continous` is used. 2071 - Empty outer states (without any nodes) are removed from the result. 2072 """ 2073 2074 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 2075 nodes_list = list(self.models.keys()) 2076 all_states = ["best", "last", "init"] 2077 all_si_intercept_dicts = {state: {} for state in all_states} 2078 2079 debug = getattr(self, "debug", False) 2080 verbose = getattr(self, "verbose", False) 2081 is_ontram = getattr(self, "is_ontram", False) 2082 2083 for node in nodes_list: 2084 NODE_DIR = os.path.join(EXPERIMENT_DIR, node) 2085 BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR) 2086 INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt") 2087 2088 state_paths = { 2089 "best": BEST_MODEL_PATH, 2090 "last": LAST_MODEL_PATH, 2091 "init": INIT_MODEL_PATH, 2092 } 2093 2094 for state, LOAD_PATH in state_paths.items(): 2095 if not os.path.exists(LOAD_PATH): 2096 if state != "init": 2097 continue 2098 else: 2099 print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.") 2100 if not os.path.exists(LOAD_PATH): 2101 if debug: 2102 print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.") 2103 continue 2104 2105 # Load model state 2106 state_dict = torch.load(LOAD_PATH, map_location=self.device) 2107 tram_model = self.models[node] 2108 tram_model.load_state_dict(state_dict) 2109 2110 # Extract and transform simple intercept weights 2111 si_weights = None 2112 if hasattr(tram_model, "nn_int") and tram_model.nn_int is not None and isinstance(tram_model.nn_int, SimpleIntercept): 2113 if hasattr(tram_model.nn_int, "fc") and hasattr(tram_model.nn_int.fc, "weight"): 2114 weights = tram_model.nn_int.fc.weight.detach().cpu().tolist() 2115 weights_tensor = torch.Tensor(weights) 2116 2117 if debug: 2118 print(f"[DEBUG] Node '{node}' ({state}) theta tilde shape: {weights_tensor.shape}") 2119 2120 if is_ontram: 2121 si_weights = transform_intercepts_ordinal(weights_tensor.reshape(1, -1))[:, 1:-1].reshape(-1, 1) 2122 else: 2123 si_weights = transform_intercepts_continous(weights_tensor.reshape(1, -1)).reshape(-1, 1) 2124 2125 si_weights = si_weights.tolist() 2126 2127 if debug: 2128 print(f"[DEBUG] Node '{node}' ({state}) theta transformed: {si_weights}") 2129 else: 2130 if debug: 2131 print(f"[DEBUG] Node '{node}' ({state}): missing 'fc' or 'weight' in SimpleIntercept.") 2132 else: 2133 if debug: 2134 print(f"[DEBUG] Tram model for node '{node}' has no nn_int or it is None.") 2135 2136 all_si_intercept_dicts[state][node] = si_weights 2137 2138 # Clean up empty states 2139 all_si_intercept_dicts = {k: v for k, v in all_si_intercept_dicts.items() if v} 2140 return all_si_intercept_dicts
Compute transformed simple intercept dictionaries for all nodes and states.
For each node and each available state ("best", "last", "init"), this method loads the corresponding model checkpoint, extracts simple intercept weights, transforms them into interpretable theta parameters, and stores them in a nested dictionary.
Returns
dict Nested dictionary of the form:
```python
{ "best": {node: [[theta_1], [theta_2], ...]}, "last": {node: [[theta_1], [theta_2], ...]}, "init": {node: [[theta_1], [theta_2], ...]}, } ```
Notes
- For ordinal models (
self.is_ontram == True),transform_intercepts_ordinalis used. - For continuous models,
transform_intercepts_continousis used. - Empty outer states (without any nodes) are removed from the result.
2142 def summary(self, verbose=False): 2143 """ 2144 Print a multi-part textual summary of the TramDagModel. 2145 2146 The summary includes: 2147 1. Training metrics overview per node (best/last NLL, epochs). 2148 2. Node-specific details (thetas, linear shifts, optional architecture). 2149 3. Basic information about the attached training DataFrame, if present. 2150 2151 Parameters 2152 ---------- 2153 verbose : bool, optional 2154 If True, include extended per-node details such as the model 2155 architecture, parameter count, and availability of checkpoints 2156 and sampling results. Default is False. 2157 2158 Returns 2159 ------- 2160 None 2161 2162 Notes 2163 ----- 2164 This method prints to stdout and does not return structured data. 2165 It is intended for quick, human-readable inspection of the current 2166 training and model state. 2167 """ 2168 2169 # ---------- SETUP ---------- 2170 try: 2171 EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"] 2172 except KeyError: 2173 EXPERIMENT_DIR = None 2174 print("[WARNING] Missing EXPERIMENT_DIR in cfg.conf_dict['PATHS'].") 2175 2176 print("\n" + "=" * 120) 2177 print(f"{'TRAM DAG MODEL SUMMARY':^120}") 2178 print("=" * 120) 2179 2180 # ---------- METRICS OVERVIEW ---------- 2181 summary_data = [] 2182 for node in self.models.keys(): 2183 node_dir = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node) 2184 train_path = os.path.join(node_dir, "train_loss_hist.json") 2185 val_path = os.path.join(node_dir, "val_loss_hist.json") 2186 2187 if os.path.exists(train_path) and os.path.exists(val_path): 2188 best_train_nll, best_val_nll = self.get_train_val_nll(node, "best") 2189 last_train_nll, last_val_nll = self.get_train_val_nll(node, "last") 2190 n_epochs_total = len(json.load(open(train_path))) 2191 else: 2192 best_train_nll = best_val_nll = last_train_nll = last_val_nll = None 2193 n_epochs_total = 0 2194 2195 summary_data.append({ 2196 "Node": node, 2197 "Best Train NLL": best_train_nll, 2198 "Best Val NLL": best_val_nll, 2199 "Last Train NLL": last_train_nll, 2200 "Last Val NLL": last_val_nll, 2201 "Epochs": n_epochs_total, 2202 }) 2203 2204 df_summary = pd.DataFrame(summary_data) 2205 df_summary = df_summary.round(4) 2206 2207 print("\n[1] TRAINING METRICS OVERVIEW") 2208 print("-" * 120) 2209 if not df_summary.empty: 2210 print( 2211 df_summary.to_string( 2212 index=False, 2213 justify="center", 2214 col_space=14, 2215 float_format=lambda x: f"{x:7.4f}", 2216 ) 2217 ) 2218 else: 2219 print("No training history found for any node.") 2220 print("-" * 120) 2221 2222 # ---------- NODE DETAILS ---------- 2223 print("\n[2] NODE-SPECIFIC DETAILS") 2224 print("-" * 120) 2225 for node in self.models.keys(): 2226 print(f"\n{f'NODE: {node}':^120}") 2227 print("-" * 120) 2228 2229 # THETAS & SHIFTS 2230 for state in ["init", "last", "best"]: 2231 print(f"\n [{state.upper()} STATE]") 2232 2233 # ---- Thetas ---- 2234 try: 2235 thetas = getattr(self, "get_thetas", lambda n, s=None: None)(node, state) 2236 if thetas is not None: 2237 if isinstance(thetas, (list, np.ndarray, pd.Series)): 2238 thetas_flat = np.array(thetas).flatten() 2239 compact = np.round(thetas_flat, 4) 2240 arr_str = np.array2string( 2241 compact, 2242 max_line_width=110, 2243 threshold=np.inf, 2244 separator=", " 2245 ) 2246 lines = arr_str.split("\n") 2247 if len(lines) > 2: 2248 arr_str = "\n".join(lines[:2]) + " ..." 2249 print(f" Θ ({len(thetas_flat)}): {arr_str}") 2250 elif isinstance(thetas, dict): 2251 for k, v in thetas.items(): 2252 print(f" Θ[{k}]: {v}") 2253 else: 2254 print(f" Θ: {thetas}") 2255 else: 2256 print(" Θ: not available") 2257 except Exception as e: 2258 print(f" [Error loading thetas] {e}") 2259 2260 # ---- Linear Shifts ---- 2261 try: 2262 linear_shifts = getattr(self, "get_linear_shifts", lambda n, s=None: None)(node, state) 2263 if linear_shifts is not None: 2264 if isinstance(linear_shifts, dict): 2265 for k, v in linear_shifts.items(): 2266 print(f" {k}: {np.round(v, 4)}") 2267 elif isinstance(linear_shifts, (list, np.ndarray, pd.Series)): 2268 arr = np.round(linear_shifts, 4) 2269 print(f" Linear shifts ({len(arr)}): {arr}") 2270 else: 2271 print(f" Linear shifts: {linear_shifts}") 2272 else: 2273 print(" Linear shifts: not available") 2274 except Exception as e: 2275 print(f" [Error loading linear shifts] {e}") 2276 2277 # ---- Verbose info directly below node ---- 2278 if verbose: 2279 print("\n [DETAILS]") 2280 node_dir = os.path.join(EXPERIMENT_DIR, node) if EXPERIMENT_DIR else None 2281 model = self.models[node] 2282 2283 print(f" Model Architecture:") 2284 arch_str = str(model).split("\n") 2285 for line in arch_str: 2286 print(f" {line}") 2287 print(f" Parameter count: {sum(p.numel() for p in model.parameters()):,}") 2288 2289 if node_dir and os.path.exists(node_dir): 2290 ckpt_exists = any(f.endswith(('.pt', '.pth')) for f in os.listdir(node_dir)) 2291 print(f" Checkpoints found: {ckpt_exists}") 2292 2293 sampling_dir = os.path.join(node_dir, "sampling") 2294 sampling_exists = os.path.isdir(sampling_dir) and len(os.listdir(sampling_dir)) > 0 2295 print(f" Sampling results found: {sampling_exists}") 2296 2297 for label, filename in [("Train", "train_loss_hist.json"), ("Validation", "val_loss_hist.json")]: 2298 path = os.path.join(node_dir, filename) 2299 if os.path.exists(path): 2300 try: 2301 with open(path, "r") as f: 2302 hist = json.load(f) 2303 print(f" {label} history: {len(hist)} epochs") 2304 except Exception as e: 2305 print(f" {label} history: failed to load ({e})") 2306 else: 2307 print(f" {label} history: not found") 2308 else: 2309 print(" [INFO] No experiment directory defined or missing for this node.") 2310 print("-" * 120) 2311 2312 # ---------- TRAINING DATAFRAME ---------- 2313 print("\n[3] TRAINING DATAFRAME") 2314 print("-" * 120) 2315 try: 2316 self.train_df.info() 2317 except AttributeError: 2318 print("No training DataFrame attached to this TramDagModel.") 2319 print("=" * 120 + "\n")
Print a multi-part textual summary of the TramDagModel.
The summary includes:
- Training metrics overview per node (best/last NLL, epochs).
- Node-specific details (thetas, linear shifts, optional architecture).
- Basic information about the attached training DataFrame, if present.
Parameters
verbose : bool, optional If True, include extended per-node details such as the model architecture, parameter count, and availability of checkpoints and sampling results. Default is False.
Returns
None
Notes
This method prints to stdout and does not return structured data. It is intended for quick, human-readable inspection of the current training and model state.