tramdag.TramDagModel

Copyright 2025 Zurich University of Applied Sciences (ZHAW) Pascal Buehler, Beate Sick, Oliver Duerr

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

   1"""
   2Copyright 2025 Zurich University of Applied Sciences (ZHAW)
   3Pascal Buehler, Beate Sick, Oliver Duerr
   4
   5Licensed under the Apache License, Version 2.0 (the "License");
   6you may not use this file except in compliance with the License.
   7You may obtain a copy of the License at
   8
   9    http://www.apache.org/licenses/LICENSE-2.0
  10
  11Unless required by applicable law or agreed to in writing, software
  12distributed under the License is distributed on an "AS IS" BASIS,
  13WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14See the License for the specific language governing permissions and
  15limitations under the License.
  16"""
  17
  18import os
  19os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
  20
  21import pandas as pd
  22import numpy as np
  23import json
  24import matplotlib.pyplot as plt
  25
  26import torch
  27from torch.optim import Adam
  28from joblib import Parallel, delayed
  29
  30from statsmodels.graphics.gofplots import qqplot_2samples
  31from scipy.stats import logistic, probplot
  32
  33from .utils.model_helpers import train_val_loop,evaluate_tramdag_model, get_fully_specified_tram_model , model_train_val_paths ,ordered_parents
  34from .utils.sampling import create_latent_df_for_full_dag, sample_full_dag, is_outcome_modelled_ordinal,is_outcome_modelled_continous, is_outcome_modelled_ordinal, show_hdag_continous,show_hdag_ordinal
  35from .utils.continous import transform_intercepts_continous
  36from .utils.ordinal import transform_intercepts_ordinal
  37
  38from .models.tram_models import SimpleIntercept
  39
  40from .TramDagConfig import TramDagConfig
  41from .TramDagDataset import TramDagDataset
  42
  43# testserver
  44#%pip install -i https://test.pypi.org/simple --extra-index-url https://pypi.org/simple tramdag
  45
  46
  47# # Remove previous builds
  48# rm -rf build dist *.egg-info
  49
  50# # Build new package
  51# python -m build
  52
  53# # Upload to TestPyPI
  54# python -m twine upload --repository testpypi dist/*
  55
  56# documentaiton
  57# 1. download new version of tramdag in env
  58# pip install pdoc
  59# 2. generate docs
  60# pdoc tramdag -o docs
  61
  62
  63class TramDagModel:
  64    """
  65    Probabilistic DAG model built from node-wise TRAMs (transformation models).
  66
  67    This class manages:
  68    - Configuration and per-node model construction.
  69    - Data scaling (min–max).
  70    - Training (sequential or per-node parallel on CPU).
  71    - Diagnostics (loss history, intercepts, linear shifts, latents).
  72    - Sampling from the joint DAG and loading stored samples.
  73    - High-level summaries and plotting utilities.
  74    """
  75    
  76    # ---- defaults used at construction time ----
  77    DEFAULTS_CONFIG = {
  78        "set_initial_weights": False,
  79        "debug":False,
  80        "verbose": False,
  81        "device":'auto',
  82        "initial_data":None,
  83        "overwrite_initial_weights": True,
  84    }
  85
  86    # ---- defaults used at fit() time ----
  87    DEFAULTS_FIT = {
  88        "epochs": 100,
  89        "train_list": None,
  90        "callbacks": None,
  91        "learning_rate": 0.01,
  92        "device": "auto",
  93        "optimizers": None,
  94        "schedulers": None,
  95        "use_scheduler": False,
  96        "save_linear_shifts": True,
  97        "save_simple_intercepts": True,
  98        "debug":False,
  99        "verbose": True,
 100        "train_mode": "sequential",  # or "parallel"
 101        "return_history": False,
 102        "overwrite_inital_weights": True,
 103        "num_workers" : 4,
 104        "persistent_workers" : True,
 105        "prefetch_factor" : 4,
 106        "batch_size":1000,
 107        
 108    }
 109
 110    def __init__(self):
 111        """
 112        Initialize an empty TramDagModel shell.
 113
 114        Notes
 115        -----
 116        This constructor does not build any node models and does not attach a
 117        configuration. Use `TramDagModel.from_config` or `TramDagModel.from_directory`
 118        to obtain a fully configured and ready-to-use instance.
 119        """
 120        
 121        self.debug = False
 122        self.verbose = False
 123        self.device = 'auto'
 124        pass
 125
 126    @staticmethod
 127    def get_device(settings):
 128        """
 129        Resolve the target device string from a settings dictionary.
 130
 131        Parameters
 132        ----------
 133        settings : dict
 134            Dictionary containing at least a key ``"device"`` with one of
 135            {"auto", "cpu", "cuda"}. If missing, "auto" is assumed.
 136
 137        Returns
 138        -------
 139        str
 140            Device string, either "cpu" or "cuda".
 141
 142        Notes
 143        -----
 144        If ``device == "auto"``, CUDA is selected if available, otherwise CPU.
 145        """
 146        device_arg = settings.get("device", "auto")
 147        if device_arg == "auto":
 148            device_str = "cuda" if torch.cuda.is_available() else "cpu"
 149        else:
 150            device_str = device_arg
 151        return device_str
 152
 153    def _validate_kwargs(self, kwargs: dict, defaults_attr: str = "DEFAULTS_FIT", context: str = None):
 154        """
 155        Validate a kwargs dictionary against a class-level defaults dictionary.
 156
 157        Parameters
 158        ----------
 159        kwargs : dict
 160            Keyword arguments to validate.
 161        defaults_attr : str, optional
 162            Name of the attribute on this class that contains the allowed keys,
 163            e.g. ``"DEFAULTS_CONFIG"`` or ``"DEFAULTS_FIT"``. Default is "DEFAULTS_FIT".
 164        context : str or None, optional
 165            Optional label (e.g. caller name) to prepend in error messages.
 166
 167        Raises
 168        ------
 169        AttributeError
 170            If the attribute named by ``defaults_attr`` does not exist.
 171        ValueError
 172            If any key in ``kwargs`` is not present in the corresponding defaults dict.
 173        """
 174        defaults = getattr(self, defaults_attr, None)
 175        if defaults is None:
 176            raise AttributeError(f"{self.__class__.__name__} has no attribute '{defaults_attr}'")
 177
 178        unknown = set(kwargs) - set(defaults)
 179        if unknown:
 180            prefix = f"[{context}] " if context else ""
 181            raise ValueError(f"{prefix}Unknown parameter(s): {', '.join(sorted(unknown))}")
 182            
 183    ## CREATE A TRAMDADMODEL
 184    @classmethod
 185    def from_config(cls, cfg, **kwargs):
 186        """
 187        Construct a TramDagModel from a TramDagConfig object.
 188
 189        This builds one TRAM model per node in the DAG and optionally writes
 190        the initial model parameters to disk.
 191
 192        Parameters
 193        ----------
 194        cfg : TramDagConfig
 195            Configuration wrapper holding the underlying configuration dictionary,
 196            including at least:
 197            - ``conf_dict["nodes"]``: mapping of node names to node configs.
 198            - ``conf_dict["PATHS"]["EXPERIMENT_DIR"]``: experiment directory.
 199        **kwargs
 200            Node-level construction options. Each key must be present in
 201            ``DEFAULTS_CONFIG``. Values can be:
 202            - scalar: applied to all nodes.
 203            - dict: mapping ``{node_name: value}`` for per-node overrides.
 204
 205            Common keys include:
 206            device : {"auto", "cpu", "cuda"}, default "auto"
 207                Device selection (CUDA if available when "auto").
 208            debug : bool, default False
 209                If True, print debug messages.
 210            verbose : bool, default False
 211                If True, print informational messages.
 212            set_initial_weights : bool
 213                Passed to underlying TRAM model constructors.
 214            overwrite_initial_weights : bool, default True
 215                If True, overwrite any existing ``initial_model.pt`` files per node.
 216            initial_data : Any
 217                Optional object passed down to node constructors.
 218
 219        Returns
 220        -------
 221        TramDagModel
 222            Fully initialized instance with:
 223            - ``cfg``
 224            - ``nodes_dict``
 225            - ``models`` (per-node TRAMs)
 226            - ``settings`` (resolved per-node config)
 227
 228        Raises
 229        ------
 230        ValueError
 231            If any dict-valued kwarg does not provide values for exactly the set
 232            of nodes in ``cfg.conf_dict["nodes"]``.
 233        """
 234        
 235        self = cls()
 236        self.cfg = cfg
 237        self.cfg.update()  # ensure latest version from disk
 238        self.cfg._verify_completeness()
 239        
 240        
 241        try:
 242            self.cfg.save()  # persist back to disk
 243            if getattr(self, "debug", False):
 244                print("[DEBUG] Configuration updated and saved.")
 245        except Exception as e:
 246            print(f"[WARNING] Could not save configuration after update: {e}")        
 247            
 248        self.nodes_dict = self.cfg.conf_dict["nodes"] 
 249
 250        self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_CONFIG', context="from_config")
 251
 252        # update defaults with kwargs
 253        settings = dict(cls.DEFAULTS_CONFIG)
 254        settings.update(kwargs)
 255
 256        # resolve device
 257        device_arg = settings.get("device", "auto")
 258        if device_arg == "auto":
 259            device_str = "cuda" if torch.cuda.is_available() else "cpu"
 260        else:
 261            device_str = device_arg
 262        self.device = torch.device(device_str)
 263
 264        # set flags on the instance so they are accessible later
 265        self.debug = settings.get("debug", False)
 266        self.verbose = settings.get("verbose", False)
 267
 268        if  self.debug:
 269            print(f"[DEBUG] TramDagModel using device: {self.device}")
 270            
 271        # initialize settings storage
 272        self.settings = {k: {} for k in settings.keys()}
 273
 274        # validate dict-typed args
 275        for k, v in settings.items():
 276            if isinstance(v, dict):
 277                expected = set(self.nodes_dict.keys())
 278                given = set(v.keys())
 279                if expected != given:
 280                    raise ValueError(
 281                        f"[ERROR] the provided argument '{k}' keys are not same as in cfg.conf_dict['nodes'].keys().\n"
 282                        f"Expected: {expected}, but got: {given}\n"
 283                        f"Please provide values for all variables.")
 284
 285        # build one model per node
 286        self.models = {}
 287        for node in self.nodes_dict.keys():
 288            per_node_kwargs = {}
 289            for k, v in settings.items():
 290                resolved = v[node] if isinstance(v, dict) else v
 291                per_node_kwargs[k] = resolved
 292                self.settings[k][node] = resolved
 293            if self.debug:
 294                print(f"\n[INFO] Building model for node '{node}' with settings: {per_node_kwargs}")
 295            self.models[node] = get_fully_specified_tram_model(
 296                node=node,
 297                configuration_dict=self.cfg.conf_dict,
 298                **per_node_kwargs)
 299            
 300            try:
 301                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 302                NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}")
 303                os.makedirs(NODE_DIR, exist_ok=True)
 304
 305                model_path = os.path.join(NODE_DIR, "initial_model.pt")
 306                overwrite = settings.get("overwrite_initial_weights", True)
 307
 308                if overwrite or not os.path.exists(model_path):
 309                    torch.save(self.models[node].state_dict(), model_path)
 310                    if self.debug:
 311                        print(f"[DEBUG] Saved initial model state for node '{node}' to {model_path} (overwrite={overwrite})")
 312                else:
 313                    if self.debug:
 314                        print(f"[DEBUG] Skipped saving initial model for node '{node}' (already exists at {model_path})")
 315            except Exception as e:
 316                print(f"[ERROR] Could not save initial model state for node '{node}': {e}")
 317            
 318            TEMP_DIR = "temp"
 319            if os.path.isdir(TEMP_DIR) and not os.listdir(TEMP_DIR):
 320                os.rmdir(TEMP_DIR)
 321                            
 322        return self
 323
 324    @classmethod
 325    def from_directory(cls, EXPERIMENT_DIR: str, device: str = "auto", debug: bool = False, verbose: bool = False):
 326        """
 327        Reconstruct a TramDagModel from an experiment directory on disk.
 328
 329        This method:
 330        1. Loads the configuration JSON.
 331        2. Wraps it in a TramDagConfig.
 332        3. Builds all node models via `from_config`.
 333        4. Loads the min–max scaling dictionary.
 334
 335        Parameters
 336        ----------
 337        EXPERIMENT_DIR : str
 338            Path to an experiment directory containing:
 339            - ``configuration.json``
 340            - ``min_max_scaling.json``.
 341        device : {"auto", "cpu", "cuda"}, optional
 342            Device selection. Default is "auto".
 343        debug : bool, optional
 344            If True, enable debug messages. Default is False.
 345        verbose : bool, optional
 346            If True, enable informational messages. Default is False.
 347
 348        Returns
 349        -------
 350        TramDagModel
 351            A TramDagModel instance with models, config, and scaling loaded.
 352
 353        Raises
 354        ------
 355        FileNotFoundError
 356            If configuration or min–max files cannot be found.
 357        RuntimeError
 358            If the min–max file cannot be read or parsed.
 359        """
 360
 361        # --- load config file ---
 362        config_path = os.path.join(EXPERIMENT_DIR, "configuration.json")
 363        if not os.path.exists(config_path):
 364            raise FileNotFoundError(f"[ERROR] Config file not found at {config_path}")
 365
 366        with open(config_path, "r") as f:
 367            cfg_dict = json.load(f)
 368
 369        # Create TramConfig wrapper 
 370        cfg = TramDagConfig(cfg_dict, CONF_DICT_PATH=config_path)
 371
 372        # --- build model from config ---
 373        self = cls.from_config(cfg, device=device, debug=debug, verbose=verbose, overwrite_initial_weights=False)
 374
 375        # --- load minmax scaling ---
 376        minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json")
 377        if not os.path.exists(minmax_path):
 378            raise FileNotFoundError(f"[ERROR] MinMax file not found at {minmax_path}")
 379
 380        with open(minmax_path, "r") as f:
 381            self.minmax_dict = json.load(f)
 382
 383        if self.verbose or self.debug:
 384            print(f"[INFO] Loaded TramDagModel from {EXPERIMENT_DIR}")
 385            print(f"[INFO] Config loaded from {config_path}")
 386            print(f"[INFO] MinMax scaling loaded from {minmax_path}")
 387
 388        return self
 389
 390    def _ensure_dataset(self, data, is_val=False,**kwargs):
 391        """
 392        Ensure that the input data is represented as a TramDagDataset.
 393
 394        Parameters
 395        ----------
 396        data : pandas.DataFrame, TramDagDataset, or None
 397            Input data to be converted or passed through.
 398        is_val : bool, optional
 399            If True, the resulting dataset is treated as validation data
 400            (e.g. no shuffling). Default is False.
 401        **kwargs
 402            Additional keyword arguments passed through to
 403            ``TramDagDataset.from_dataframe``.
 404
 405        Returns
 406        -------
 407        TramDagDataset or None
 408            A TramDagDataset if ``data`` is a DataFrame or TramDagDataset,
 409            otherwise None if ``data`` is None.
 410
 411        Raises
 412        ------
 413        TypeError
 414            If ``data`` is not a DataFrame, TramDagDataset, or None.
 415        """
 416                
 417        if isinstance(data, pd.DataFrame):
 418            return TramDagDataset.from_dataframe(data, self.cfg, shuffle=not is_val,**kwargs)
 419        elif isinstance(data, TramDagDataset):
 420            return data
 421        elif data is None:
 422            return None
 423        else:
 424            raise TypeError(
 425                f"[ERROR] data must be pd.DataFrame, TramDagDataset, or None, got {type(data)}"
 426            )
 427
 428    def load_or_compute_minmax(self, td_train_data=None,use_existing=False, write=True):
 429        """
 430        Load an existing Min–Max scaling dictionary from disk or compute a new one 
 431        from the provided training dataset.
 432
 433        Parameters
 434        ----------
 435        use_existing : bool, optional (default=False)
 436            If True, attempts to load an existing `min_max_scaling.json` file 
 437            from the experiment directory. Raises an error if the file is missing 
 438            or unreadable.
 439
 440        write : bool, optional (default=True)
 441            If True, writes the computed Min–Max scaling dictionary to 
 442            `<EXPERIMENT_DIR>/min_max_scaling.json`.
 443
 444        td_train_data : object, optional
 445            Training dataset used to compute scaling statistics. If not provided,
 446            the method will ensure or construct it via `_ensure_dataset(data=..., is_val=False)`.
 447
 448        Behavior
 449        --------
 450        - If `use_existing=True`, loads the JSON file containing previously saved 
 451          min–max values and stores it in `self.minmax_dict`.
 452        - If `use_existing=False`, computes a new scaling dictionary using 
 453          `td_train_data.compute_scaling()` and stores the result in 
 454          `self.minmax_dict`.
 455        - Optionally writes the computed dictionary to disk.
 456
 457        Side Effects
 458        -------------
 459        - Populates `self.minmax_dict` with scaling values.
 460        - Writes or loads the file `min_max_scaling.json` under 
 461          `<EXPERIMENT_DIR>`.
 462        - Prints diagnostic output if `self.debug` or `self.verbose` is True.
 463
 464        Raises
 465        ------
 466        FileNotFoundError
 467            If `use_existing=True` but the min–max file does not exist.
 468
 469        RuntimeError
 470            If an existing min–max file cannot be read or parsed.
 471
 472        Notes
 473        -----
 474        The computed min–max dictionary is expected to contain scaling statistics 
 475        per feature, typically in the form:
 476            {
 477                "node": {"min": float, "max": float},
 478                ...
 479            }
 480        """
 481        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 482        minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json")
 483
 484        # laod exisitng if possible
 485        if use_existing:
 486            if not os.path.exists(minmax_path):
 487                raise FileNotFoundError(f"MinMax file not found: {minmax_path}")
 488            try:
 489                with open(minmax_path, 'r') as f:
 490                    self.minmax_dict = json.load(f)
 491                if self.debug or self.verbose:
 492                    print(f"[INFO] Loaded existing minmax dict from {minmax_path}")
 493                return
 494            except Exception as e:
 495                raise RuntimeError(f"Could not load existing minmax dict: {e}")
 496
 497        # 
 498        if self.debug or self.verbose:
 499            print("[INFO] Computing new minmax dict from training data...")
 500            
 501        td_train_data=self._ensure_dataset( data=td_train_data, is_val=False)    
 502            
 503        self.minmax_dict = td_train_data.compute_scaling()
 504
 505        if write:
 506            os.makedirs(EXPERIMENT_DIR, exist_ok=True)
 507            with open(minmax_path, 'w') as f:
 508                json.dump(self.minmax_dict, f, indent=4)
 509            if self.debug or self.verbose:
 510                print(f"[INFO] Saved new minmax dict to {minmax_path}")
 511
 512    ## FIT METHODS
 513    @staticmethod
 514    def _fit_single_node(node, self_ref, settings, td_train_data, td_val_data, device_str):
 515        """
 516        Train a single node model (helper for per-node training).
 517
 518        This method is designed to be called either from the main process
 519        (sequential training) or from a joblib worker (parallel CPU training).
 520
 521        Parameters
 522        ----------
 523        node : str
 524            Name of the target node to train.
 525        self_ref : TramDagModel
 526            Reference to the TramDagModel instance containing models and config.
 527        settings : dict
 528            Training settings dictionary, typically derived from ``DEFAULTS_FIT``
 529            plus any user overrides.
 530        td_train_data : TramDagDataset
 531            Training dataset with node-specific DataLoaders in ``.loaders``.
 532        td_val_data : TramDagDataset or None
 533            Validation dataset or None.
 534        device_str : str
 535            Device string, e.g. "cpu" or "cuda".
 536
 537        Returns
 538        -------
 539        tuple
 540            A tuple ``(node, history)`` where:
 541            node : str
 542                Node name.
 543            history : dict or Any
 544                Training history as returned by ``train_val_loop``.
 545        """
 546        torch.set_num_threads(1)  # prevent thread oversubscription
 547
 548        model = self_ref.models[node]
 549
 550        # Resolve per-node settings
 551        def _resolve(key):
 552            val = settings[key]
 553            return val[node] if isinstance(val, dict) else val
 554
 555        node_epochs = _resolve("epochs")
 556        node_lr = _resolve("learning_rate")
 557        node_debug = _resolve("debug")
 558        node_save_linear_shifts = _resolve("save_linear_shifts")
 559        save_simple_intercepts  = _resolve("save_simple_intercepts")
 560        node_verbose = _resolve("verbose")
 561
 562        # Optimizer & scheduler
 563        if settings["optimizers"] and node in settings["optimizers"]:
 564            optimizer = settings["optimizers"][node]
 565        else:
 566            optimizer = Adam(model.parameters(), lr=node_lr)
 567
 568        scheduler = settings["schedulers"].get(node, None) if settings["schedulers"] else None
 569
 570        # Data loaders
 571        train_loader = td_train_data.loaders[node]
 572        val_loader = td_val_data.loaders[node] if td_val_data else None
 573
 574        # Min-max scaling tensors
 575        min_vals = torch.tensor(self_ref.minmax_dict[node][0], dtype=torch.float32)
 576        max_vals = torch.tensor(self_ref.minmax_dict[node][1], dtype=torch.float32)
 577        min_max = torch.stack([min_vals, max_vals], dim=0)
 578
 579        # Node directory
 580        try:
 581            EXPERIMENT_DIR = self_ref.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 582            NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}")
 583        except Exception:
 584            NODE_DIR = os.path.join("models", node)
 585            print("[WARNING] No log directory specified in config, saving to default location.")
 586        os.makedirs(NODE_DIR, exist_ok=True)
 587
 588        if node_verbose:
 589            print(f"\n[INFO] Training node '{node}' for {node_epochs} epochs on {device_str} (pid={os.getpid()})")
 590
 591        # --- train ---
 592        history = train_val_loop(
 593            node=node,
 594            target_nodes=self_ref.nodes_dict,
 595            NODE_DIR=NODE_DIR,
 596            tram_model=model,
 597            train_loader=train_loader,
 598            val_loader=val_loader,
 599            epochs=node_epochs,
 600            optimizer=optimizer,
 601            use_scheduler=(scheduler is not None),
 602            scheduler=scheduler,
 603            save_linear_shifts=node_save_linear_shifts,
 604            save_simple_intercepts=save_simple_intercepts,
 605            verbose=node_verbose,
 606            device=torch.device(device_str),
 607            debug=node_debug,
 608            min_max=min_max)
 609        return node, history
 610
 611    def fit(self, train_data, val_data=None, **kwargs):
 612        """
 613        Train TRAM models for all nodes in the DAG.
 614
 615        Coordinates dataset preparation, min–max scaling, and per-node training,
 616        optionally in parallel on CPU.
 617
 618        Parameters
 619        ----------
 620        train_data : pandas.DataFrame or TramDagDataset
 621            Training data. If a DataFrame is given, it is converted into a
 622            TramDagDataset using `_ensure_dataset`.
 623        val_data : pandas.DataFrame or TramDagDataset or None, optional
 624            Validation data. If a DataFrame is given, it is converted into a
 625            TramDagDataset. If None, no validation loss is computed.
 626        **kwargs
 627            Overrides for ``DEFAULTS_FIT``. All keys must exist in
 628            ``DEFAULTS_FIT``. Common options:
 629
 630            epochs : int, default 100
 631                Number of training epochs per node.
 632            learning_rate : float, default 0.01
 633                Learning rate for the default Adam optimizer.
 634            train_list : list of str or None, optional
 635                List of node names to train. If None, all nodes are trained.
 636            train_mode : {"sequential", "parallel"}, default "sequential"
 637                Training mode. "parallel" uses joblib-based CPU multiprocessing.
 638                GPU forces sequential mode.
 639            device : {"auto", "cpu", "cuda"}, default "auto"
 640                Device selection.
 641            optimizers : dict or None
 642                Optional mapping ``{node_name: optimizer}``. If provided for a
 643                node, that optimizer is used instead of creating a new Adam.
 644            schedulers : dict or None
 645                Optional mapping ``{node_name: scheduler}``.
 646            use_scheduler : bool
 647                If True, enable scheduler usage in the training loop.
 648            num_workers : int
 649                DataLoader workers in sequential mode (ignored in parallel).
 650            persistent_workers : bool
 651                DataLoader persistence in sequential mode (ignored in parallel).
 652            prefetch_factor : int
 653                DataLoader prefetch factor (ignored in parallel).
 654            batch_size : int
 655                Batch size for all node DataLoaders.
 656            debug : bool
 657                Enable debug output.
 658            verbose : bool
 659                Enable informational logging.
 660            return_history : bool
 661                If True, return a history dict.
 662
 663        Returns
 664        -------
 665        dict or None
 666            If ``return_history=True``, a dictionary mapping each node name
 667            to its training history. Otherwise, returns None.
 668
 669        Raises
 670        ------
 671        ValueError
 672            If ``train_mode`` is not "sequential" or "parallel".
 673        """
 674        self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_FIT', context="fit")
 675        
 676        # --- merge defaults ---
 677        settings = dict(self.DEFAULTS_FIT)
 678        settings.update(kwargs)
 679        
 680        
 681        self.debug = settings.get("debug", False)
 682        self.verbose = settings.get("verbose", False)
 683
 684        # --- resolve device ---
 685        device_str=self.get_device(settings)
 686        self.device = torch.device(device_str)
 687
 688        # --- training mode ---
 689        train_mode = settings.get("train_mode", "sequential").lower()
 690        if train_mode not in ("sequential", "parallel"):
 691            raise ValueError("train_mode must be 'sequential' or 'parallel'")
 692
 693        # --- DataLoader safety logic ---
 694        if train_mode == "parallel":
 695            # if user passed loader paralleling params, warn and override
 696            for flag in ("num_workers", "persistent_workers", "prefetch_factor"):
 697                if flag in kwargs:
 698                    print(f"[WARNING] '{flag}' is ignored in parallel mode "
 699                        f"(disabled to prevent nested multiprocessing).")
 700            # disable unsafe loader multiprocessing options
 701            settings["num_workers"] = 0
 702            settings["persistent_workers"] = False
 703            settings["prefetch_factor"] = None
 704        else:
 705            # sequential mode → respect user DataLoader settings
 706            if self.debug:
 707                print("[DEBUG] Sequential mode: using DataLoader kwargs as provided.")
 708
 709        # --- which nodes to train ---
 710        train_list = settings.get("train_list") or list(self.models.keys())
 711
 712
 713        # --- dataset prep (receives adjusted settings) ---
 714        td_train_data = self._ensure_dataset(train_data, is_val=False, **settings)
 715        td_val_data = self._ensure_dataset(val_data, is_val=True, **settings)
 716
 717        # --- normalization ---
 718        self.load_or_compute_minmax(use_existing=False, write=True, td_train_data=td_train_data)
 719
 720        # --- print header ---
 721        if self.verbose or self.debug:
 722            print(f"[INFO] Training {len(train_list)} nodes ({train_mode}) on {device_str}")
 723
 724        # ======================================================================
 725        # Sequential mode  safe for GPU or debugging)
 726        # ======================================================================
 727        if train_mode == "sequential" or "cuda" in device_str:
 728            if "cuda" in device_str and train_mode == "parallel":
 729                print("[WARNING] GPU device detected — forcing sequential mode.")
 730            results = {}
 731            for node in train_list:
 732                node, history = self._fit_single_node(
 733                    node, self, settings, td_train_data, td_val_data, device_str
 734                )
 735                results[node] = history
 736        
 737
 738        # ======================================================================
 739        # parallel mode (CPU only)
 740        # ======================================================================
 741        if train_mode == "parallel":
 742
 743            n_jobs = min(len(train_list), os.cpu_count() // 2 or 1)
 744            if self.verbose or self.debug:
 745                print(f"[INFO] Using {n_jobs} CPU workers for parallel node training")
 746            parallel_outputs = Parallel(
 747                n_jobs=n_jobs,
 748                backend="loky",#loky, multiprocessing
 749                verbose=10,
 750                prefer="processes"
 751            )(delayed(self._fit_single_node)(node, self, settings, td_train_data, td_val_data, device_str) for node in train_list )
 752
 753            results = {node: hist for node, hist in parallel_outputs}
 754        
 755        if settings.get("return_history", False):
 756            return results
 757
 758    ## FIT-DIAGNOSTICS
 759    def loss_history(self):
 760        """
 761        Load training and validation loss history for all nodes.
 762
 763        Looks for per-node JSON files:
 764
 765        - ``EXPERIMENT_DIR/{node}/train_loss_hist.json``
 766        - ``EXPERIMENT_DIR/{node}/val_loss_hist.json``
 767
 768        Returns
 769        -------
 770        dict
 771            A dictionary mapping node names to:
 772
 773            .. code-block:: python
 774
 775                {
 776                    "train": list or None,
 777                    "validation": list or None
 778                }
 779
 780            where each list contains NLL values per epoch, or None if not found.
 781
 782        Raises
 783        ------
 784        ValueError
 785            If the experiment directory cannot be resolved from the configuration.
 786        """
 787        try:
 788            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 789        except KeyError:
 790            raise ValueError(
 791                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
 792                "History retrieval requires experiment logs."
 793            )
 794
 795        all_histories = {}
 796        for node in self.nodes_dict.keys():
 797            node_dir = os.path.join(EXPERIMENT_DIR, node)
 798            train_path = os.path.join(node_dir, "train_loss_hist.json")
 799            val_path = os.path.join(node_dir, "val_loss_hist.json")
 800
 801            node_hist = {}
 802
 803            # --- load train history ---
 804            if os.path.exists(train_path):
 805                try:
 806                    with open(train_path, "r") as f:
 807                        node_hist["train"] = json.load(f)
 808                except Exception as e:
 809                    print(f"[WARNING] Could not load {train_path}: {e}")
 810                    node_hist["train"] = None
 811            else:
 812                node_hist["train"] = None
 813
 814            # --- load val history ---
 815            if os.path.exists(val_path):
 816                try:
 817                    with open(val_path, "r") as f:
 818                        node_hist["validation"] = json.load(f)
 819                except Exception as e:
 820                    print(f"[WARNING] Could not load {val_path}: {e}")
 821                    node_hist["validation"] = None
 822            else:
 823                node_hist["validation"] = None
 824
 825            all_histories[node] = node_hist
 826
 827        if self.verbose or self.debug:
 828            print(f"[INFO] Loaded training/validation histories for {len(all_histories)} nodes.")
 829
 830        return all_histories
 831
 832    def linear_shift_history(self):
 833        """
 834        Load linear shift term histories for all nodes.
 835
 836        Each node history is expected in a JSON file named
 837        ``linear_shifts_all_epochs.json`` under the node directory.
 838
 839        Returns
 840        -------
 841        dict
 842            A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame
 843            contains linear shift weights across epochs.
 844
 845        Raises
 846        ------
 847        ValueError
 848            If the experiment directory cannot be resolved from the configuration.
 849
 850        Notes
 851        -----
 852        If a history file is missing for a node, a warning is printed and the
 853        node is omitted from the returned dictionary.
 854        """
 855        histories = {}
 856        try:
 857            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 858        except KeyError:
 859            raise ValueError(
 860                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
 861                "Cannot load histories without experiment directory."
 862            )
 863
 864        for node in self.nodes_dict.keys():
 865            node_dir = os.path.join(EXPERIMENT_DIR, node)
 866            history_path = os.path.join(node_dir, "linear_shifts_all_epochs.json")
 867            if os.path.exists(history_path):
 868                histories[node] = pd.read_json(history_path)
 869            else:
 870                print(f"[WARNING] No linear shift history found for node '{node}' at {history_path}")
 871        return histories
 872
 873    def simple_intercept_history(self):
 874        """
 875        Load simple intercept histories for all nodes.
 876
 877        Each node history is expected in a JSON file named
 878        ``simple_intercepts_all_epochs.json`` under the node directory.
 879
 880        Returns
 881        -------
 882        dict
 883            A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame
 884            contains intercept weights across epochs.
 885
 886        Raises
 887        ------
 888        ValueError
 889            If the experiment directory cannot be resolved from the configuration.
 890
 891        Notes
 892        -----
 893        If a history file is missing for a node, a warning is printed and the
 894        node is omitted from the returned dictionary.
 895        """
 896        histories = {}
 897        try:
 898            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 899        except KeyError:
 900            raise ValueError(
 901                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
 902                "Cannot load histories without experiment directory."
 903            )
 904
 905        for node in self.nodes_dict.keys():
 906            node_dir = os.path.join(EXPERIMENT_DIR, node)
 907            history_path = os.path.join(node_dir, "simple_intercepts_all_epochs.json")
 908            if os.path.exists(history_path):
 909                histories[node] = pd.read_json(history_path)
 910            else:
 911                print(f"[WARNING] No simple intercept history found for node '{node}' at {history_path}")
 912        return histories
 913
 914    def get_latent(self, df, verbose=False):
 915        """
 916        Compute latent representations for all nodes in the DAG.
 917
 918        Parameters
 919        ----------
 920        df : pandas.DataFrame
 921            Input data frame with columns corresponding to nodes in the DAG.
 922        verbose : bool, optional
 923            If True, print informational messages during latent computation.
 924            Default is False.
 925
 926        Returns
 927        -------
 928        pandas.DataFrame
 929            DataFrame containing the original columns plus latent variables
 930            for each node (e.g. columns named ``f"{node}_U"``).
 931
 932        Raises
 933        ------
 934        ValueError
 935            If the experiment directory is missing from the configuration or
 936            if ``self.minmax_dict`` has not been set.
 937        """
 938        try:
 939            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 940        except KeyError:
 941            raise ValueError(
 942                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
 943                "Latent extraction requires trained model checkpoints."
 944            )
 945
 946        # ensure minmax_dict is available
 947        if not hasattr(self, "minmax_dict"):
 948            raise ValueError(
 949                "[ERROR] minmax_dict not found in the TramDagModel instance. "
 950                "Either call .load_or_compute_minmax(td_train_data=train_df) or .fit() first."
 951            )
 952
 953        all_latents_df = create_latent_df_for_full_dag(
 954            configuration_dict=self.cfg.conf_dict,
 955            EXPERIMENT_DIR=EXPERIMENT_DIR,
 956            df=df,
 957            verbose=verbose,
 958            min_max_dict=self.minmax_dict,
 959        )
 960
 961        return all_latents_df
 962
 963    
 964    ## PLOTTING FIT-DIAGNOSTICS
 965    
 966    def plot_loss_history(self, variable: str = None):
 967        """
 968        Plot training and validation loss evolution per node.
 969
 970        Parameters
 971        ----------
 972        variable : str or None, optional
 973            If provided, plot loss history for this node only. If None, plot
 974            histories for all nodes that have both train and validation logs.
 975
 976        Returns
 977        -------
 978        None
 979
 980        Notes
 981        -----
 982        Two subplots are produced:
 983        - Full epoch history.
 984        - Last 10% of epochs (or only the last epoch if fewer than 5 epochs).
 985        """
 986
 987        histories = self.loss_history()
 988        if not histories:
 989            print("[WARNING] No loss histories found.")
 990            return
 991
 992        # Select which nodes to plot
 993        if variable is not None:
 994            if variable not in histories:
 995                raise ValueError(f"[ERROR] Node '{variable}' not found in histories.")
 996            nodes_to_plot = [variable]
 997        else:
 998            nodes_to_plot = list(histories.keys())
 999
1000        # Filter out nodes with no valid history
1001        nodes_to_plot = [
1002            n for n in nodes_to_plot
1003            if histories[n].get("train") is not None and len(histories[n]["train"]) > 0
1004            and histories[n].get("validation") is not None and len(histories[n]["validation"]) > 0
1005        ]
1006
1007        if not nodes_to_plot:
1008            print("[WARNING] No valid histories found to plot.")
1009            return
1010
1011        plt.figure(figsize=(14, 12))
1012
1013        # --- Full history (top plot) ---
1014        plt.subplot(2, 1, 1)
1015        for node in nodes_to_plot:
1016            node_hist = histories[node]
1017            train_hist, val_hist = node_hist["train"], node_hist["validation"]
1018
1019            epochs = range(1, len(train_hist) + 1)
1020            plt.plot(epochs, train_hist, label=f"{node} - train", linestyle="--")
1021            plt.plot(epochs, val_hist, label=f"{node} - val")
1022
1023        plt.title("Training and Validation NLL - Full History")
1024        plt.xlabel("Epoch")
1025        plt.ylabel("NLL")
1026        plt.legend()
1027        plt.grid(True)
1028
1029        # --- Last 10% of epochs (bottom plot) ---
1030        plt.subplot(2, 1, 2)
1031        for node in nodes_to_plot:
1032            node_hist = histories[node]
1033            train_hist, val_hist = node_hist["train"], node_hist["validation"]
1034
1035            total_epochs = len(train_hist)
1036            start_idx = total_epochs - 1 if total_epochs < 5 else int(total_epochs * 0.9)
1037
1038            epochs = range(start_idx + 1, total_epochs + 1)
1039            plt.plot(epochs, train_hist[start_idx:], label=f"{node} - train", linestyle="--")
1040            plt.plot(epochs, val_hist[start_idx:], label=f"{node} - val")
1041
1042        plt.title("Training and Validation NLL - Last 10% of Epochs (or Last Epoch if <5)")
1043        plt.xlabel("Epoch")
1044        plt.ylabel("NLL")
1045        plt.legend()
1046        plt.grid(True)
1047
1048        plt.tight_layout()
1049        plt.show()
1050
1051    def plot_linear_shift_history(self, data_dict=None, node=None, ref_lines=None):
1052        """
1053        Plot the evolution of linear shift terms over epochs.
1054
1055        Parameters
1056        ----------
1057        data_dict : dict or None, optional
1058            Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing shift
1059            weights across epochs. If None, `linear_shift_history()` is called.
1060        node : str or None, optional
1061            If provided, plot only this node. Otherwise, plot all nodes
1062            present in ``data_dict``.
1063        ref_lines : dict or None, optional
1064            Optional mapping ``{node_name: list of float}``. For each specified
1065            node, horizontal reference lines are drawn at the given values.
1066
1067        Returns
1068        -------
1069        None
1070
1071        Notes
1072        -----
1073        The function flattens nested list-like entries in the DataFrames to scalars,
1074        converts epoch labels to numeric, and then draws one line per shift term.
1075        """
1076
1077        if data_dict is None:
1078            data_dict = self.linear_shift_history()
1079            if data_dict is None:
1080                raise ValueError("No shift history data provided or stored in the class.")
1081
1082        nodes = [node] if node else list(data_dict.keys())
1083
1084        for n in nodes:
1085            df = data_dict[n].copy()
1086
1087            # Flatten nested lists or list-like cells
1088            def flatten(x):
1089                if isinstance(x, list):
1090                    if len(x) == 0:
1091                        return np.nan
1092                    if all(isinstance(i, (int, float)) for i in x):
1093                        return np.mean(x)  # average simple list
1094                    if all(isinstance(i, list) for i in x):
1095                        # nested list -> flatten inner and average
1096                        flat = [v for sub in x for v in (sub if isinstance(sub, list) else [sub])]
1097                        return np.mean(flat) if flat else np.nan
1098                    return x[0] if len(x) == 1 else np.nan
1099                return x
1100
1101            df = df.applymap(flatten)
1102
1103            # Ensure numeric columns
1104            df = df.apply(pd.to_numeric, errors='coerce')
1105
1106            # Convert epoch labels to numeric
1107            df.columns = [
1108                int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c
1109                for c in df.columns
1110            ]
1111            df = df.reindex(sorted(df.columns), axis=1)
1112
1113            plt.figure(figsize=(10, 6))
1114            for idx in df.index:
1115                plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"shift_{idx}")
1116
1117            if ref_lines and n in ref_lines:
1118                for v in ref_lines[n]:
1119                    plt.axhline(y=v, color="k", linestyle="--", lw=1.0)
1120                    plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8)
1121
1122            plt.xlabel("Epoch")
1123            plt.ylabel("Shift Value")
1124            plt.title(f"Shift Term History — Node: {n}")
1125            plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small")
1126            plt.tight_layout()
1127            plt.show()
1128
1129    def plot_simple_intercepts_history(self, data_dict=None, node=None,ref_lines=None):
1130        """
1131        Plot the evolution of simple intercept weights over epochs.
1132
1133        Parameters
1134        ----------
1135        data_dict : dict or None, optional
1136            Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing intercept
1137            weights across epochs. If None, `simple_intercept_history()` is called.
1138        node : str or None, optional
1139            If provided, plot only this node. Otherwise, plot all nodes present
1140            in ``data_dict``.
1141        ref_lines : dict or None, optional
1142            Optional mapping ``{node_name: list of float}``. For each specified
1143            node, horizontal reference lines are drawn at the given values.
1144
1145        Returns
1146        -------
1147        None
1148
1149        Notes
1150        -----
1151        Nested list-like entries in the DataFrames are reduced to scalars before
1152        plotting. One line is drawn per intercept parameter.
1153        """
1154        if data_dict is None:
1155            data_dict = self.simple_intercept_history()
1156            if data_dict is None:
1157                raise ValueError("No intercept history data provided or stored in the class.")
1158
1159        nodes = [node] if node else list(data_dict.keys())
1160
1161        for n in nodes:
1162            df = data_dict[n].copy()
1163
1164            def extract_scalar(x):
1165                if isinstance(x, list):
1166                    while isinstance(x, list) and len(x) > 0:
1167                        x = x[0]
1168                return float(x) if isinstance(x, (int, float, np.floating)) else np.nan
1169
1170            df = df.applymap(extract_scalar)
1171
1172            # Convert epoch labels → numeric
1173            df.columns = [
1174                int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c
1175                for c in df.columns
1176            ]
1177            df = df.reindex(sorted(df.columns), axis=1)
1178
1179            plt.figure(figsize=(10, 6))
1180            for idx in df.index:
1181                plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"theta_{idx}")
1182            
1183            if ref_lines and n in ref_lines:
1184                for v in ref_lines[n]:
1185                    plt.axhline(y=v, color="k", linestyle="--", lw=1.0)
1186                    plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8)
1187                
1188            plt.xlabel("Epoch")
1189            plt.ylabel("Intercept Weight")
1190            plt.title(f"Simple Intercept Evolution — Node: {n}")
1191            plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small")
1192            plt.tight_layout()
1193            plt.show()
1194
1195    def plot_latents(self, df, variable: str = None, confidence: float = 0.95, simulations: int = 1000):
1196        """
1197        Visualize latent U distributions for one or all nodes.
1198
1199        Parameters
1200        ----------
1201        df : pandas.DataFrame
1202            Input data frame with raw node values.
1203        variable : str or None, optional
1204            If provided, only this node's latents are plotted. If None, all
1205            nodes with latent columns are processed.
1206        confidence : float, optional
1207            Confidence level for QQ-plot bands (0 < confidence < 1).
1208            Default is 0.95.
1209        simulations : int, optional
1210            Number of Monte Carlo simulations for QQ-plot bands. Default is 1000.
1211
1212        Returns
1213        -------
1214        None
1215
1216        Notes
1217        -----
1218        For each node, two plots are produced:
1219        - Histogram of the latent U values.
1220        - QQ-plot with simulation-based confidence bands under a logistic reference.
1221        """
1222        # Compute latent representations
1223        latents_df = self.get_latent(df)
1224
1225        # Select nodes
1226        nodes = [variable] if variable is not None else self.nodes_dict.keys()
1227
1228        for node in nodes:
1229            if f"{node}_U" not in latents_df.columns:
1230                print(f"[WARNING] No latent found for node {node}, skipping.")
1231                continue
1232
1233            sample = latents_df[f"{node}_U"].values
1234
1235            # --- Create plots ---
1236            fig, axs = plt.subplots(1, 2, figsize=(12, 5))
1237
1238            # Histogram
1239            axs[0].hist(sample, bins=50, color="steelblue", alpha=0.7)
1240            axs[0].set_title(f"Latent Histogram ({node})")
1241            axs[0].set_xlabel("U")
1242            axs[0].set_ylabel("Frequency")
1243
1244            # QQ Plot with confidence bands
1245            probplot(sample, dist="logistic", plot=axs[1])
1246            self._add_r_style_confidence_bands(axs[1], sample, dist=logistic,confidence=confidence, simulations=simulations)
1247            axs[1].set_title(f"Latent QQ Plot ({node})")
1248
1249            plt.suptitle(f"Latent Diagnostics for Node: {node}", fontsize=14)
1250            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
1251            plt.show()
1252
1253    def plot_hdag(self,df,variables=None, plot_n_rows=1,**kwargs):
1254        
1255        """
1256        Visualize the transformation function h() for selected DAG nodes.
1257
1258        Parameters
1259        ----------
1260        df : pandas.DataFrame
1261            Input data containing node values or model predictions.
1262        variables : list of str or None, optional
1263            Names of nodes to visualize. If None, all nodes in ``self.models``
1264            are considered.
1265        plot_n_rows : int, optional
1266            Maximum number of rows from ``df`` to visualize. Default is 1.
1267        **kwargs
1268            Additional keyword arguments forwarded to the underlying plotting
1269            helpers (`show_hdag_continous` / `show_hdag_ordinal`).
1270
1271        Returns
1272        -------
1273        None
1274
1275        Notes
1276        -----
1277        - For continuous outcomes, `show_hdag_continous` is called.
1278        - For ordinal outcomes, `show_hdag_ordinal` is called.
1279        - Nodes that are neither continuous nor ordinal are skipped with a warning.
1280        """
1281                
1282
1283        if len(df)> 1:
1284            print("[WARNING] len(df)>1, set: plot_n_rows accordingly")
1285        
1286        variables_list=variables if variables is not None else list(self.models.keys())
1287        for node in variables_list:
1288            if is_outcome_modelled_continous(node, self.nodes_dict):
1289                show_hdag_continous(df,node=node,configuration_dict=self.cfg.conf_dict,minmax_dict=self.minmax_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1290            
1291            elif is_outcome_modelled_ordinal(node, self.nodes_dict):
1292                show_hdag_ordinal(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1293                # plot_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1294                # save_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,**kwargs)
1295            else:
1296                print(f"[WARNING] Node {node} is wheter ordinal nor continous, not implemented yet")
1297    
1298    @staticmethod
1299    def _add_r_style_confidence_bands(ax, sample, dist, confidence=0.95, simulations=1000):
1300        """
1301        Add simulation-based confidence bands to a QQ-plot.
1302
1303        Parameters
1304        ----------
1305        ax : matplotlib.axes.Axes
1306            Axes object on which to draw the QQ-plot and bands.
1307        sample : array-like
1308            Empirical sample used in the QQ-plot.
1309        dist : scipy.stats distribution
1310            Distribution object providing ``ppf`` and ``rvs`` methods (e.g. logistic).
1311        confidence : float, optional
1312            Confidence level (0 < confidence < 1) for the bands. Default is 0.95.
1313        simulations : int, optional
1314            Number of Monte Carlo simulations used to estimate the bands. Default is 1000.
1315
1316        Returns
1317        -------
1318        None
1319
1320        Notes
1321        -----
1322        The axes are cleared, and a new QQ-plot is drawn with:
1323        - Empirical vs. theoretical quantiles.
1324        - 45-degree reference line.
1325        - Shaded confidence band region.
1326        """
1327        
1328        n = len(sample)
1329        if n == 0:
1330            return
1331
1332        quantiles = np.linspace(0, 1, n, endpoint=False) + 0.5 / n
1333        theo_q = dist.ppf(quantiles)
1334
1335        # Simulate order statistics from the theoretical distribution
1336        sim_data = dist.rvs(size=(simulations, n))
1337        sim_order_stats = np.sort(sim_data, axis=1)
1338
1339        # Confidence bands
1340        lower = np.percentile(sim_order_stats, 100 * (1 - confidence) / 2, axis=0)
1341        upper = np.percentile(sim_order_stats, 100 * (1 + confidence) / 2, axis=0)
1342
1343        # Sort empirical sample
1344        sample_sorted = np.sort(sample)
1345
1346        # Re-draw points and CI (overwrite probplot defaults)
1347        ax.clear()
1348        ax.plot(theo_q, sample_sorted, 'o', markersize=3, alpha=0.6, label="Empirical Q-Q")
1349        ax.plot(theo_q, theo_q, 'b--', label="y = x")
1350        ax.fill_between(theo_q, lower, upper, color='gray', alpha=0.3,
1351                        label=f'{int(confidence*100)}% CI')
1352        ax.legend()
1353    
1354    ## SAMPLING METHODS
1355    def sample(
1356        self,
1357        do_interventions: dict = None,
1358        predefined_latent_samples_df: pd.DataFrame = None,
1359        **kwargs,
1360    ):
1361        """
1362        Sample from the joint DAG using the trained TRAM models.
1363
1364        Allows for:
1365        
1366        Oberservational sampling
1367        Interventional sampling via ``do()`` operations
1368        Counterfactial sampling using predefined latent draws and do()
1369        
1370        Parameters
1371        ----------
1372        do_interventions : dict or None, optional
1373            Mapping of node names to intervened (fixed) values. For example:
1374            ``{"x1": 1.0}`` represents ``do(x1 = 1.0)``. Default is None.
1375        predefined_latent_samples_df : pandas.DataFrame or None, optional
1376            DataFrame containing columns ``"{node}_U"`` with predefined latent
1377            draws to be used instead of sampling from the prior. Default is None.
1378        **kwargs
1379            Sampling options overriding internal defaults:
1380
1381            number_of_samples : int, default 10000
1382                Total number of samples to draw.
1383            batch_size : int, default 32
1384                Batch size for internal sampling loops.
1385            delete_all_previously_sampled : bool, default True
1386                If True, delete old sampling files in node-specific sampling
1387                directories before writing new ones.
1388            verbose : bool
1389                If True, print informational messages.
1390            debug : bool
1391                If True, print debug output.
1392            device : {"auto", "cpu", "cuda"}
1393                Device selection for sampling.
1394            use_initial_weights_for_sampling : bool, default False
1395                If True, sample from initial (untrained) model parameters.
1396
1397        Returns
1398        -------
1399        tuple
1400            A tuple ``(sampled_by_node, latents_by_node)``:
1401
1402            sampled_by_node : dict
1403                Mapping ``{node_name: torch.Tensor}`` of sampled node values.
1404            latents_by_node : dict
1405                Mapping ``{node_name: torch.Tensor}`` of latent U values used.
1406
1407        Raises
1408        ------
1409        ValueError
1410            If the experiment directory cannot be resolved or if scaling
1411            information (``self.minmax_dict``) is missing.
1412        RuntimeError
1413            If min–max scaling has not been computed before calling `sample`.
1414        """
1415        try:
1416            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1417        except KeyError:
1418            raise ValueError(
1419                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
1420                "Sampling requires trained model checkpoints."
1421            )
1422
1423        # ---- defaults ----
1424        settings = {
1425            "number_of_samples": 10_000,
1426            "batch_size": 32,
1427            "delete_all_previously_sampled": True,
1428            "verbose": self.verbose if hasattr(self, "verbose") else False,
1429            "debug": self.debug if hasattr(self, "debug") else False,
1430            "device": self.device.type if hasattr(self, "device") else "auto",
1431            "use_initial_weights_for_sampling": False,
1432            
1433        }
1434        
1435        # self._validate_kwargs( kwargs, defaults_attr= "settings", context="sample")
1436        
1437        settings.update(kwargs)
1438
1439        
1440        if not hasattr(self, "minmax_dict"):
1441            raise RuntimeError(
1442                "[ERROR] minmax_dict not found. You must call .fit() or .load_or_compute_minmax() "
1443                "before sampling, so scaling info is available."
1444                )
1445            
1446        # ---- resolve device ----
1447        device_str=self.get_device(settings)
1448        self.device = torch.device(device_str)
1449
1450
1451        if self.debug or settings["debug"]:
1452            print(f"[DEBUG] sample(): device: {self.device}")
1453
1454        # ---- perform sampling ----
1455        sampled_by_node, latents_by_node = sample_full_dag(
1456            configuration_dict=self.cfg.conf_dict,
1457            EXPERIMENT_DIR=EXPERIMENT_DIR,
1458            device=self.device,
1459            do_interventions=do_interventions or {},
1460            predefined_latent_samples_df=predefined_latent_samples_df,
1461            number_of_samples=settings["number_of_samples"],
1462            batch_size=settings["batch_size"],
1463            delete_all_previously_sampled=settings["delete_all_previously_sampled"],
1464            verbose=settings["verbose"],
1465            debug=settings["debug"],
1466            minmax_dict=self.minmax_dict,
1467            use_initial_weights_for_sampling=settings["use_initial_weights_for_sampling"]
1468        )
1469
1470        return sampled_by_node, latents_by_node
1471
1472    def load_sampled_and_latents(self, EXPERIMENT_DIR: str = None, nodes: list = None):
1473        """
1474        Load previously stored sampled values and latents for each node.
1475
1476        Parameters
1477        ----------
1478        EXPERIMENT_DIR : str or None, optional
1479            Experiment directory path. If None, it is taken from
1480            ``self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]``.
1481        nodes : list of str or None, optional
1482            Nodes for which to load samples. If None, use all nodes from
1483            ``self.nodes_dict``.
1484
1485        Returns
1486        -------
1487        tuple
1488            A tuple ``(sampled_by_node, latents_by_node)``:
1489
1490            sampled_by_node : dict
1491                Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU).
1492            latents_by_node : dict
1493                Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU).
1494
1495        Raises
1496        ------
1497        ValueError
1498            If the experiment directory cannot be resolved or if no node list
1499            is available and ``nodes`` is None.
1500
1501        Notes
1502        -----
1503        Nodes without both ``sampled.pt`` and ``latents.pt`` files are skipped
1504        with a warning.
1505        """
1506        # --- resolve paths and node list ---
1507        if EXPERIMENT_DIR is None:
1508            try:
1509                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1510            except (AttributeError, KeyError):
1511                raise ValueError(
1512                    "[ERROR] Could not resolve EXPERIMENT_DIR from cfg.conf_dict['PATHS']. "
1513                    "Please provide EXPERIMENT_DIR explicitly."
1514                )
1515
1516        if nodes is None:
1517            if hasattr(self, "nodes_dict"):
1518                nodes = list(self.nodes_dict.keys())
1519            else:
1520                raise ValueError(
1521                    "[ERROR] No node list found. Please provide `nodes` or initialize model with a config."
1522                )
1523
1524        # --- load tensors ---
1525        sampled_by_node = {}
1526        latents_by_node = {}
1527
1528        for node in nodes:
1529            node_dir = os.path.join(EXPERIMENT_DIR, f"{node}")
1530            sampling_dir = os.path.join(node_dir, "sampling")
1531
1532            sampled_path = os.path.join(sampling_dir, "sampled.pt")
1533            latents_path = os.path.join(sampling_dir, "latents.pt")
1534
1535            if not os.path.exists(sampled_path) or not os.path.exists(latents_path):
1536                print(f"[WARNING] Missing files for node '{node}' — skipping.")
1537                continue
1538
1539            try:
1540                sampled = torch.load(sampled_path, map_location="cpu")
1541                latent_sample = torch.load(latents_path, map_location="cpu")
1542            except Exception as e:
1543                print(f"[ERROR] Could not load sampling files for node '{node}': {e}")
1544                continue
1545
1546            sampled_by_node[node] = sampled.detach().cpu()
1547            latents_by_node[node] = latent_sample.detach().cpu()
1548
1549        if self.verbose or self.debug:
1550            print(f"[INFO] Loaded sampled and latent tensors for {len(sampled_by_node)} nodes from {EXPERIMENT_DIR}")
1551
1552        return sampled_by_node, latents_by_node
1553
1554    def plot_samples_vs_true(
1555        self,
1556        df,
1557        sampled: dict = None,
1558        variable: list = None,
1559        bins: int = 100,
1560        hist_true_color: str = "blue",
1561        hist_est_color: str = "orange",
1562        figsize: tuple = (14, 5),
1563    ):
1564        
1565        
1566        """
1567        Compare sampled vs. observed distributions for selected nodes.
1568
1569        Parameters
1570        ----------
1571        df : pandas.DataFrame
1572            Data frame containing the observed node values.
1573        sampled : dict or None, optional
1574            Optional mapping ``{node_name: array-like or torch.Tensor}`` of sampled
1575            values. If None or if a node is missing, samples are loaded from
1576            ``EXPERIMENT_DIR/{node}/sampling/sampled.pt``.
1577        variable : list of str or None, optional
1578            Subset of nodes to plot. If None, all nodes in the configuration
1579            are considered.
1580        bins : int, optional
1581            Number of histogram bins for continuous variables. Default is 100.
1582        hist_true_color : str, optional
1583            Color name for the histogram of true values. Default is "blue".
1584        hist_est_color : str, optional
1585            Color name for the histogram of sampled values. Default is "orange".
1586        figsize : tuple, optional
1587            Figure size for the matplotlib plots. Default is (14, 5).
1588
1589        Returns
1590        -------
1591        None
1592
1593        Notes
1594        -----
1595        - Continuous outcomes: histogram overlay + QQ-plot.
1596        - Ordinal outcomes: side-by-side bar plot of relative frequencies.
1597        - Other categorical outcomes: side-by-side bar plot with category labels.
1598        - If samples are probabilistic (2D tensor), the argmax across classes is used.
1599        """
1600        
1601        target_nodes = self.cfg.conf_dict["nodes"]
1602        experiment_dir = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1603        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1604
1605        plot_list = variable if variable is not None else target_nodes
1606
1607        for node in plot_list:
1608            # Load sampled data
1609            if sampled is not None and node in sampled:
1610                sdata = sampled[node]
1611                if isinstance(sdata, torch.Tensor):
1612                    sampled_vals = sdata.detach().cpu().numpy()
1613                else:
1614                    sampled_vals = np.asarray(sdata)
1615            else:
1616                sample_path = os.path.join(experiment_dir, f"{node}/sampling/sampled.pt")
1617                if not os.path.isfile(sample_path):
1618                    print(f"[WARNING] skip {node}: {sample_path} not found.")
1619                    continue
1620
1621                try:
1622                    sampled_vals = torch.load(sample_path, map_location=device).cpu().numpy()
1623                except Exception as e:
1624                    print(f"[ERROR] Could not load {sample_path}: {e}")
1625                    continue
1626
1627            # If logits/probabilities per sample, take argmax
1628            if sampled_vals.ndim == 2:
1629                    print(f"[INFO] CAUTION! {node}: samples are probabilistic — each sample follows a probability "
1630                    f"distribution based on the valid latent range. "
1631                    f"Note that this frequency plot reflects only the distribution of the most probable "
1632                    f"class per sample.")
1633                    sampled_vals = np.argmax(sampled_vals, axis=1)
1634
1635            sampled_vals = sampled_vals[np.isfinite(sampled_vals)]
1636
1637            if node not in df.columns:
1638                print(f"[WARNING] skip {node}: column not found in DataFrame.")
1639                continue
1640
1641            true_vals = df[node].dropna().values
1642            true_vals = true_vals[np.isfinite(true_vals)]
1643
1644            if sampled_vals.size == 0 or true_vals.size == 0:
1645                print(f"[WARNING] skip {node}: empty array after NaN/Inf removal.")
1646                continue
1647
1648            fig, axs = plt.subplots(1, 2, figsize=figsize)
1649
1650            if is_outcome_modelled_continous(node, target_nodes):
1651                axs[0].hist(true_vals, bins=bins, density=True, alpha=0.6,
1652                            color=hist_true_color, label=f"True {node}")
1653                axs[0].hist(sampled_vals, bins=bins, density=True, alpha=0.6,
1654                            color=hist_est_color, label="Sampled")
1655                axs[0].set_xlabel("Value")
1656                axs[0].set_ylabel("Density")
1657                axs[0].set_title(f"Histogram overlay for {node}")
1658                axs[0].legend()
1659                axs[0].grid(True, ls="--", alpha=0.4)
1660
1661                qqplot_2samples(true_vals, sampled_vals, line="45", ax=axs[1])
1662                axs[1].set_xlabel("True quantiles")
1663                axs[1].set_ylabel("Sampled quantiles")
1664                axs[1].set_title(f"QQ plot for {node}")
1665                axs[1].grid(True, ls="--", alpha=0.4)
1666
1667            elif is_outcome_modelled_ordinal(node, target_nodes):
1668                unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals))
1669                unique_vals = np.sort(unique_vals)
1670                true_counts = np.array([(true_vals == val).sum() for val in unique_vals])
1671                sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals])
1672
1673                axs[0].bar(unique_vals - 0.2, true_counts / true_counts.sum(),
1674                        width=0.4, color=hist_true_color, alpha=0.7, label="True")
1675                axs[0].bar(unique_vals + 0.2, sampled_counts / sampled_counts.sum(),
1676                        width=0.4, color=hist_est_color, alpha=0.7, label="Sampled")
1677                axs[0].set_xticks(unique_vals)
1678                axs[0].set_xlabel("Ordinal Level")
1679                axs[0].set_ylabel("Relative Frequency")
1680                axs[0].set_title(f"Ordinal bar plot for {node}")
1681                axs[0].legend()
1682                axs[0].grid(True, ls="--", alpha=0.4)
1683                axs[1].axis("off")
1684
1685            else:
1686                unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals))
1687                unique_vals = sorted(unique_vals, key=str)
1688                true_counts = np.array([(true_vals == val).sum() for val in unique_vals])
1689                sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals])
1690
1691                axs[0].bar(np.arange(len(unique_vals)) - 0.2, true_counts / true_counts.sum(),
1692                        width=0.4, color=hist_true_color, alpha=0.7, label="True")
1693                axs[0].bar(np.arange(len(unique_vals)) + 0.2, sampled_counts / sampled_counts.sum(),
1694                        width=0.4, color=hist_est_color, alpha=0.7, label="Sampled")
1695                axs[0].set_xticks(np.arange(len(unique_vals)))
1696                axs[0].set_xticklabels(unique_vals, rotation=45)
1697                axs[0].set_xlabel("Category")
1698                axs[0].set_ylabel("Relative Frequency")
1699                axs[0].set_title(f"Categorical bar plot for {node}")
1700                axs[0].legend()
1701                axs[0].grid(True, ls="--", alpha=0.4)
1702                axs[1].axis("off")
1703
1704            plt.tight_layout()
1705            plt.show()
1706
1707    ## SUMMARY METHODS
1708    def nll(self,data,variables=None):
1709        """
1710        Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes.
1711
1712        This function evaluates trained TRAM models for each specified variable (node) 
1713        on the provided dataset. It performs forward passes only—no training, no weight 
1714        updates—and returns the mean NLL per node.
1715
1716        Parameters
1717        ----------
1718        data : object
1719            Input dataset or data source compatible with `_ensure_dataset`, containing 
1720            both inputs and targets for each node.
1721        variables : list[str], optional
1722            List of variable (node) names to evaluate. If None, all nodes in 
1723            `self.models` are evaluated.
1724
1725        Returns
1726        -------
1727        dict[str, float]
1728            Dictionary mapping each node name to its average NLL value.
1729
1730        Notes
1731        -----
1732        - Each model is evaluated independently on its respective DataLoader.
1733        - The normalization values (`min_max`) for each node are retrieved from 
1734          `self.minmax_dict[node]`.
1735        - The function uses `evaluate_tramdag_model()` for per-node evaluation.
1736        - Expected directory structure:
1737              `<EXPERIMENT_DIR>/<node>/`
1738          where each node directory contains the trained model.
1739        """
1740
1741        td_data = self._ensure_dataset(data, is_val=True)  
1742        variables_list = variables if variables != None else list(self.models.keys())
1743        nll_dict = {}
1744        for node in variables_list:  
1745                min_vals = torch.tensor(self.minmax_dict[node][0], dtype=torch.float32)
1746                max_vals = torch.tensor(self.minmax_dict[node][1], dtype=torch.float32)
1747                min_max = torch.stack([min_vals, max_vals], dim=0)
1748                data_loader = td_data.loaders[node]
1749                model = self.models[node]
1750                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1751                NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}")
1752                nll= evaluate_tramdag_model(node=node,
1753                                            target_nodes=self.nodes_dict,
1754                                            NODE_DIR=NODE_DIR,
1755                                            tram_model=model,
1756                                            data_loader=data_loader,
1757                                            min_max=min_max)
1758                nll_dict[node]=nll
1759        return nll_dict
1760    
1761    def get_train_val_nll(self, node: str, mode: str) -> tuple[float, float]:
1762        """
1763        Retrieve training and validation NLL for a node and a given model state.
1764
1765        Parameters
1766        ----------
1767        node : str
1768            Node name.
1769        mode : {"best", "last", "init"}
1770            State of interest:
1771            - "best": epoch with lowest validation NLL.
1772            - "last": final epoch.
1773            - "init": first epoch (index 0).
1774
1775        Returns
1776        -------
1777        tuple of (float or None, float or None)
1778            A tuple ``(train_nll, val_nll)`` for the requested mode.
1779            Returns ``(None, None)`` if loss files are missing or cannot be read.
1780
1781        Notes
1782        -----
1783        This method expects per-node JSON files:
1784
1785        - ``train_loss_hist.json``
1786        - ``val_loss_hist.json``
1787
1788        in the node directory.
1789        """
1790        NODE_DIR = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node)
1791        train_path = os.path.join(NODE_DIR, "train_loss_hist.json")
1792        val_path = os.path.join(NODE_DIR, "val_loss_hist.json")
1793
1794        if not os.path.exists(train_path) or not os.path.exists(val_path):
1795            if getattr(self, "debug", False):
1796                print(f"[DEBUG] Missing loss files for node '{node}'. Returning None.")
1797            return None, None
1798
1799        try:
1800            with open(train_path, "r") as f:
1801                train_hist = json.load(f)
1802            with open(val_path, "r") as f:
1803                val_hist = json.load(f)
1804
1805            train_nlls = np.array(train_hist)
1806            val_nlls = np.array(val_hist)
1807
1808            if mode == "init":
1809                idx = 0
1810            elif mode == "last":
1811                idx = len(val_nlls) - 1
1812            elif mode == "best":
1813                idx = int(np.argmin(val_nlls))
1814            else:
1815                raise ValueError(f"Invalid mode '{mode}' — must be one of 'best', 'last', 'init'.")
1816
1817            train_nll = float(train_nlls[idx])
1818            val_nll = float(val_nlls[idx])
1819            return train_nll, val_nll
1820
1821        except Exception as e:
1822            print(f"[ERROR] Failed to load NLLs for node '{node}' ({mode}): {e}")
1823            return None, None
1824
1825    def get_thetas(self, node: str, state: str = "best"):
1826        """
1827        Return transformed intercept (theta) parameters for a node and state.
1828
1829        Parameters
1830        ----------
1831        node : str
1832            Node name.
1833        state : {"best", "last", "init"}, optional
1834            Model state for which to return parameters. Default is "best".
1835
1836        Returns
1837        -------
1838        Any or None
1839            Transformed theta parameters for the requested node and state.
1840            The exact structure (scalar, list, or other) depends on the model.
1841
1842        Raises
1843        ------
1844        ValueError
1845            If an invalid state is given (not in {"best", "last", "init"}).
1846
1847        Notes
1848        -----
1849        Intercept dictionaries are cached on the instance under the attribute
1850        ``intercept_dicts``. If missing or incomplete, they are recomputed using
1851        `get_simple_intercepts_dict`.
1852        """
1853
1854        state = state.lower()
1855        if state not in ["best", "last", "init"]:
1856            raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].")
1857
1858        dict_attr = "intercept_dicts"
1859
1860        # If no cached intercepts exist, compute them
1861        if not hasattr(self, dict_attr):
1862            if getattr(self, "debug", False):
1863                print(f"[DEBUG] '{dict_attr}' not found, computing via get_simple_intercepts_dict().")
1864            setattr(self, dict_attr, self.get_simple_intercepts_dict())
1865
1866        all_dicts = getattr(self, dict_attr)
1867
1868        # If the requested state isn’t cached, recompute
1869        if state not in all_dicts:
1870            if getattr(self, "debug", False):
1871                print(f"[DEBUG] State '{state}' not found in cached intercepts, recomputing full dict.")
1872            setattr(self, dict_attr, self.get_simple_intercepts_dict())
1873            all_dicts = getattr(self, dict_attr)
1874
1875        state_dict = all_dicts.get(state, {})
1876
1877        # Return cached node intercept if present
1878        if node in state_dict:
1879            return state_dict[node]
1880
1881        # If not found, recompute full dict as fallback
1882        if getattr(self, "debug", False):
1883            print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.")
1884        setattr(self, dict_attr, self.get_simple_intercepts_dict())
1885        all_dicts = getattr(self, dict_attr)
1886        return all_dicts.get(state, {}).get(node, None)
1887        
1888    def get_linear_shifts(self, node: str, state: str = "best"):
1889        """
1890        Return learned linear shift terms for a node and a given state.
1891
1892        Parameters
1893        ----------
1894        node : str
1895            Node name.
1896        state : {"best", "last", "init"}, optional
1897            Model state for which to return linear shift terms. Default is "best".
1898
1899        Returns
1900        -------
1901        dict or Any or None
1902            Linear shift terms for the given node and state. Usually a dict
1903            mapping term names to weights.
1904
1905        Raises
1906        ------
1907        ValueError
1908            If an invalid state is given (not in {"best", "last", "init"}).
1909
1910        Notes
1911        -----
1912        Linear shift dictionaries are cached on the instance under the attribute
1913        ``linear_shift_dicts``. If missing or incomplete, they are recomputed using
1914        `get_linear_shifts_dict`.
1915        """
1916        state = state.lower()
1917        if state not in ["best", "last", "init"]:
1918            raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].")
1919
1920        dict_attr = "linear_shift_dicts"
1921
1922        # If no global dicts cached, compute once
1923        if not hasattr(self, dict_attr):
1924            if getattr(self, "debug", False):
1925                print(f"[DEBUG] '{dict_attr}' not found, computing via get_linear_shifts_dict().")
1926            setattr(self, dict_attr, self.get_linear_shifts_dict())
1927
1928        all_dicts = getattr(self, dict_attr)
1929
1930        # If the requested state isn't cached, compute all again (covers fresh runs)
1931        if state not in all_dicts:
1932            if getattr(self, "debug", False):
1933                print(f"[DEBUG] State '{state}' not found in cached linear shifts, recomputing full dict.")
1934            setattr(self, dict_attr, self.get_linear_shifts_dict())
1935            all_dicts = getattr(self, dict_attr)
1936
1937        # Now fetch the dictionary for this state
1938        state_dict = all_dicts.get(state, {})
1939
1940        # If the node is available, return its entry
1941        if node in state_dict:
1942            return state_dict[node]
1943
1944        # If missing, try recomputing (fallback)
1945        if getattr(self, "debug", False):
1946            print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.")
1947        setattr(self, dict_attr, self.get_linear_shifts_dict())
1948        all_dicts = getattr(self, dict_attr)
1949        return all_dicts.get(state, {}).get(node, None)
1950
1951    def get_linear_shifts_dict(self):
1952        """
1953        Compute linear shift term dictionaries for all nodes and states.
1954
1955        For each node and each available state ("best", "last", "init"), this
1956        method loads the corresponding model checkpoint, extracts linear shift
1957        weights from the TRAM model, and stores them in a nested dictionary.
1958
1959        Returns
1960        -------
1961        dict
1962            Nested dictionary of the form:
1963
1964            .. code-block:: python
1965
1966                {
1967                    "best": {node: {...}},
1968                    "last": {node: {...}},
1969                    "init": {node: {...}},
1970                }
1971
1972            where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``)
1973            to their weights.
1974
1975        Notes
1976        -----
1977        - If "best" or "last" checkpoints are unavailable for a node, only
1978        the "init" entry is populated.
1979        - Empty outer states (without any nodes) are removed from the result.
1980        """
1981
1982        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1983        nodes_list = list(self.models.keys())
1984        all_states = ["best", "last", "init"]
1985        all_linear_shift_dicts = {state: {} for state in all_states}
1986
1987        for node in nodes_list:
1988            NODE_DIR = os.path.join(EXPERIMENT_DIR, node)
1989            BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR)
1990            INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt")
1991
1992            state_paths = {
1993                "best": BEST_MODEL_PATH,
1994                "last": LAST_MODEL_PATH,
1995                "init": INIT_MODEL_PATH,
1996            }
1997
1998            for state, LOAD_PATH in state_paths.items():
1999                if not os.path.exists(LOAD_PATH):
2000                    if state != "init":
2001                        # skip best/last if unavailable
2002                        continue
2003                    else:
2004                        print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.")
2005                        if not os.path.exists(LOAD_PATH):
2006                            if getattr(self, "debug", False):
2007                                print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.")
2008                            continue
2009
2010                # Load parents and model
2011                _, terms_dict, _ = ordered_parents(node, self.nodes_dict)
2012                state_dict = torch.load(LOAD_PATH, map_location=self.device)
2013                tram_model = self.models[node]
2014                tram_model.load_state_dict(state_dict)
2015
2016                epoch_weights = {}
2017                if hasattr(tram_model, "nn_shift") and tram_model.nn_shift is not None:
2018                    for i, shift_layer in enumerate(tram_model.nn_shift):
2019                        module_name = shift_layer.__class__.__name__
2020                        if (
2021                            hasattr(shift_layer, "fc")
2022                            and hasattr(shift_layer.fc, "weight")
2023                            and module_name == "LinearShift"
2024                        ):
2025                            term_name = list(terms_dict.keys())[i]
2026                            epoch_weights[f"ls({term_name})"] = (
2027                                shift_layer.fc.weight.detach().cpu().squeeze().tolist()
2028                            )
2029                        elif getattr(self, "debug", False):
2030                            term_name = list(terms_dict.keys())[i]
2031                            print(f"[DEBUG] ls({term_name}): missing 'fc' or 'weight' in LinearShift.")
2032                else:
2033                    if getattr(self, "debug", False):
2034                        print(f"[DEBUG] Tram model for node '{node}' has no nn_shift or it is None.")
2035
2036                all_linear_shift_dicts[state][node] = epoch_weights
2037
2038        # Remove empty states (e.g., when best/last not found for all nodes)
2039        all_linear_shift_dicts = {k: v for k, v in all_linear_shift_dicts.items() if v}
2040
2041        return all_linear_shift_dicts
2042
2043    def get_simple_intercepts_dict(self):
2044        """
2045        Compute transformed simple intercept dictionaries for all nodes and states.
2046
2047        For each node and each available state ("best", "last", "init"), this
2048        method loads the corresponding model checkpoint, extracts simple intercept
2049        weights, transforms them into interpretable theta parameters, and stores
2050        them in a nested dictionary.
2051
2052        Returns
2053        -------
2054        dict
2055            Nested dictionary of the form:
2056
2057            .. code-block:: python
2058
2059                {
2060                    "best": {node: [[theta_1], [theta_2], ...]},
2061                    "last": {node: [[theta_1], [theta_2], ...]},
2062                    "init": {node: [[theta_1], [theta_2], ...]},
2063                }
2064
2065        Notes
2066        -----
2067        - For ordinal models (``self.is_ontram == True``), `transform_intercepts_ordinal`
2068        is used.
2069        - For continuous models, `transform_intercepts_continous` is used.
2070        - Empty outer states (without any nodes) are removed from the result.
2071        """
2072
2073        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
2074        nodes_list = list(self.models.keys())
2075        all_states = ["best", "last", "init"]
2076        all_si_intercept_dicts = {state: {} for state in all_states}
2077
2078        debug = getattr(self, "debug", False)
2079        verbose = getattr(self, "verbose", False)
2080        is_ontram = getattr(self, "is_ontram", False)
2081
2082        for node in nodes_list:
2083            NODE_DIR = os.path.join(EXPERIMENT_DIR, node)
2084            BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR)
2085            INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt")
2086
2087            state_paths = {
2088                "best": BEST_MODEL_PATH,
2089                "last": LAST_MODEL_PATH,
2090                "init": INIT_MODEL_PATH,
2091            }
2092
2093            for state, LOAD_PATH in state_paths.items():
2094                if not os.path.exists(LOAD_PATH):
2095                    if state != "init":
2096                        continue
2097                    else:
2098                        print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.")
2099                        if not os.path.exists(LOAD_PATH):
2100                            if debug:
2101                                print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.")
2102                            continue
2103
2104                # Load model state
2105                state_dict = torch.load(LOAD_PATH, map_location=self.device)
2106                tram_model = self.models[node]
2107                tram_model.load_state_dict(state_dict)
2108
2109                # Extract and transform simple intercept weights
2110                si_weights = None
2111                if hasattr(tram_model, "nn_int") and tram_model.nn_int is not None and isinstance(tram_model.nn_int, SimpleIntercept):
2112                    if hasattr(tram_model.nn_int, "fc") and hasattr(tram_model.nn_int.fc, "weight"):
2113                        weights = tram_model.nn_int.fc.weight.detach().cpu().tolist()
2114                        weights_tensor = torch.Tensor(weights)
2115
2116                        if debug:
2117                            print(f"[DEBUG] Node '{node}' ({state}) theta tilde shape: {weights_tensor.shape}")
2118
2119                        if is_ontram:
2120                            si_weights = transform_intercepts_ordinal(weights_tensor.reshape(1, -1))[:, 1:-1].reshape(-1, 1)
2121                        else:
2122                            si_weights = transform_intercepts_continous(weights_tensor.reshape(1, -1)).reshape(-1, 1)
2123
2124                        si_weights = si_weights.tolist()
2125
2126                        if debug:
2127                            print(f"[DEBUG] Node '{node}' ({state}) theta transformed: {si_weights}")
2128                    else:
2129                        if debug:
2130                            print(f"[DEBUG] Node '{node}' ({state}): missing 'fc' or 'weight' in SimpleIntercept.")
2131                else:
2132                    if debug:
2133                        print(f"[DEBUG] Tram model for node '{node}' has no nn_int or it is None.")
2134
2135                all_si_intercept_dicts[state][node] = si_weights
2136
2137        # Clean up empty states
2138        all_si_intercept_dicts = {k: v for k, v in all_si_intercept_dicts.items() if v}
2139        return all_si_intercept_dicts
2140       
2141    def summary(self, verbose=False):
2142        """
2143        Print a multi-part textual summary of the TramDagModel.
2144
2145        The summary includes:
2146        1. Training metrics overview per node (best/last NLL, epochs).
2147        2. Node-specific details (thetas, linear shifts, optional architecture).
2148        3. Basic information about the attached training DataFrame, if present.
2149
2150        Parameters
2151        ----------
2152        verbose : bool, optional
2153            If True, include extended per-node details such as the model
2154            architecture, parameter count, and availability of checkpoints
2155            and sampling results. Default is False.
2156
2157        Returns
2158        -------
2159        None
2160
2161        Notes
2162        -----
2163        This method prints to stdout and does not return structured data.
2164        It is intended for quick, human-readable inspection of the current
2165        training and model state.
2166        """
2167
2168        # ---------- SETUP ----------
2169        try:
2170            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
2171        except KeyError:
2172            EXPERIMENT_DIR = None
2173            print("[WARNING] Missing EXPERIMENT_DIR in cfg.conf_dict['PATHS'].")
2174
2175        print("\n" + "=" * 120)
2176        print(f"{'TRAM DAG MODEL SUMMARY':^120}")
2177        print("=" * 120)
2178
2179        # ---------- METRICS OVERVIEW ----------
2180        summary_data = []
2181        for node in self.models.keys():
2182            node_dir = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node)
2183            train_path = os.path.join(node_dir, "train_loss_hist.json")
2184            val_path = os.path.join(node_dir, "val_loss_hist.json")
2185
2186            if os.path.exists(train_path) and os.path.exists(val_path):
2187                best_train_nll, best_val_nll = self.get_train_val_nll(node, "best")
2188                last_train_nll, last_val_nll = self.get_train_val_nll(node, "last")
2189                n_epochs_total = len(json.load(open(train_path)))
2190            else:
2191                best_train_nll = best_val_nll = last_train_nll = last_val_nll = None
2192                n_epochs_total = 0
2193
2194            summary_data.append({
2195                "Node": node,
2196                "Best Train NLL": best_train_nll,
2197                "Best Val NLL": best_val_nll,
2198                "Last Train NLL": last_train_nll,
2199                "Last Val NLL": last_val_nll,
2200                "Epochs": n_epochs_total,
2201            })
2202
2203        df_summary = pd.DataFrame(summary_data)
2204        df_summary = df_summary.round(4)
2205
2206        print("\n[1] TRAINING METRICS OVERVIEW")
2207        print("-" * 120)
2208        if not df_summary.empty:
2209            print(
2210                df_summary.to_string(
2211                    index=False,
2212                    justify="center",
2213                    col_space=14,
2214                    float_format=lambda x: f"{x:7.4f}",
2215                )
2216            )
2217        else:
2218            print("No training history found for any node.")
2219        print("-" * 120)
2220
2221        # ---------- NODE DETAILS ----------
2222        print("\n[2] NODE-SPECIFIC DETAILS")
2223        print("-" * 120)
2224        for node in self.models.keys():
2225            print(f"\n{f'NODE: {node}':^120}")
2226            print("-" * 120)
2227
2228            # THETAS & SHIFTS
2229            for state in ["init", "last", "best"]:
2230                print(f"\n  [{state.upper()} STATE]")
2231
2232                # ---- Thetas ----
2233                try:
2234                    thetas = getattr(self, "get_thetas", lambda n, s=None: None)(node, state)
2235                    if thetas is not None:
2236                        if isinstance(thetas, (list, np.ndarray, pd.Series)):
2237                            thetas_flat = np.array(thetas).flatten()
2238                            compact = np.round(thetas_flat, 4)
2239                            arr_str = np.array2string(
2240                                compact,
2241                                max_line_width=110,
2242                                threshold=np.inf,
2243                                separator=", "
2244                            )
2245                            lines = arr_str.split("\n")
2246                            if len(lines) > 2:
2247                                arr_str = "\n".join(lines[:2]) + " ..."
2248                            print(f"    Θ ({len(thetas_flat)}): {arr_str}")
2249                        elif isinstance(thetas, dict):
2250                            for k, v in thetas.items():
2251                                print(f"     Θ[{k}]: {v}")
2252                        else:
2253                            print(f"    Θ: {thetas}")
2254                    else:
2255                        print("    Θ: not available")
2256                except Exception as e:
2257                    print(f"    [Error loading thetas] {e}")
2258
2259                # ---- Linear Shifts ----
2260                try:
2261                    linear_shifts = getattr(self, "get_linear_shifts", lambda n, s=None: None)(node, state)
2262                    if linear_shifts is not None:
2263                        if isinstance(linear_shifts, dict):
2264                            for k, v in linear_shifts.items():
2265                                print(f"     {k}: {np.round(v, 4)}")
2266                        elif isinstance(linear_shifts, (list, np.ndarray, pd.Series)):
2267                            arr = np.round(linear_shifts, 4)
2268                            print(f"    Linear shifts ({len(arr)}): {arr}")
2269                        else:
2270                            print(f"    Linear shifts: {linear_shifts}")
2271                    else:
2272                        print("    Linear shifts: not available")
2273                except Exception as e:
2274                    print(f"    [Error loading linear shifts] {e}")
2275
2276            # ---- Verbose info directly below node ----
2277            if verbose:
2278                print("\n  [DETAILS]")
2279                node_dir = os.path.join(EXPERIMENT_DIR, node) if EXPERIMENT_DIR else None
2280                model = self.models[node]
2281
2282                print(f"    Model Architecture:")
2283                arch_str = str(model).split("\n")
2284                for line in arch_str:
2285                    print(f"      {line}")
2286                print(f"    Parameter count: {sum(p.numel() for p in model.parameters()):,}")
2287
2288                if node_dir and os.path.exists(node_dir):
2289                    ckpt_exists = any(f.endswith(('.pt', '.pth')) for f in os.listdir(node_dir))
2290                    print(f"    Checkpoints found: {ckpt_exists}")
2291
2292                    sampling_dir = os.path.join(node_dir, "sampling")
2293                    sampling_exists = os.path.isdir(sampling_dir) and len(os.listdir(sampling_dir)) > 0
2294                    print(f"    Sampling results found: {sampling_exists}")
2295
2296                    for label, filename in [("Train", "train_loss_hist.json"), ("Validation", "val_loss_hist.json")]:
2297                        path = os.path.join(node_dir, filename)
2298                        if os.path.exists(path):
2299                            try:
2300                                with open(path, "r") as f:
2301                                    hist = json.load(f)
2302                                print(f"    {label} history: {len(hist)} epochs")
2303                            except Exception as e:
2304                                print(f"    {label} history: failed to load ({e})")
2305                        else:
2306                            print(f"    {label} history: not found")
2307                else:
2308                    print("    [INFO] No experiment directory defined or missing for this node.")
2309            print("-" * 120)
2310
2311        # ---------- TRAINING DATAFRAME ----------
2312        print("\n[3] TRAINING DATAFRAME")
2313        print("-" * 120)
2314        try:
2315            self.train_df.info()
2316        except AttributeError:
2317            print("No training DataFrame attached to this TramDagModel.")
2318        print("=" * 120 + "\n")
class TramDagModel:
  64class TramDagModel:
  65    """
  66    Probabilistic DAG model built from node-wise TRAMs (transformation models).
  67
  68    This class manages:
  69    - Configuration and per-node model construction.
  70    - Data scaling (min–max).
  71    - Training (sequential or per-node parallel on CPU).
  72    - Diagnostics (loss history, intercepts, linear shifts, latents).
  73    - Sampling from the joint DAG and loading stored samples.
  74    - High-level summaries and plotting utilities.
  75    """
  76    
  77    # ---- defaults used at construction time ----
  78    DEFAULTS_CONFIG = {
  79        "set_initial_weights": False,
  80        "debug":False,
  81        "verbose": False,
  82        "device":'auto',
  83        "initial_data":None,
  84        "overwrite_initial_weights": True,
  85    }
  86
  87    # ---- defaults used at fit() time ----
  88    DEFAULTS_FIT = {
  89        "epochs": 100,
  90        "train_list": None,
  91        "callbacks": None,
  92        "learning_rate": 0.01,
  93        "device": "auto",
  94        "optimizers": None,
  95        "schedulers": None,
  96        "use_scheduler": False,
  97        "save_linear_shifts": True,
  98        "save_simple_intercepts": True,
  99        "debug":False,
 100        "verbose": True,
 101        "train_mode": "sequential",  # or "parallel"
 102        "return_history": False,
 103        "overwrite_inital_weights": True,
 104        "num_workers" : 4,
 105        "persistent_workers" : True,
 106        "prefetch_factor" : 4,
 107        "batch_size":1000,
 108        
 109    }
 110
 111    def __init__(self):
 112        """
 113        Initialize an empty TramDagModel shell.
 114
 115        Notes
 116        -----
 117        This constructor does not build any node models and does not attach a
 118        configuration. Use `TramDagModel.from_config` or `TramDagModel.from_directory`
 119        to obtain a fully configured and ready-to-use instance.
 120        """
 121        
 122        self.debug = False
 123        self.verbose = False
 124        self.device = 'auto'
 125        pass
 126
 127    @staticmethod
 128    def get_device(settings):
 129        """
 130        Resolve the target device string from a settings dictionary.
 131
 132        Parameters
 133        ----------
 134        settings : dict
 135            Dictionary containing at least a key ``"device"`` with one of
 136            {"auto", "cpu", "cuda"}. If missing, "auto" is assumed.
 137
 138        Returns
 139        -------
 140        str
 141            Device string, either "cpu" or "cuda".
 142
 143        Notes
 144        -----
 145        If ``device == "auto"``, CUDA is selected if available, otherwise CPU.
 146        """
 147        device_arg = settings.get("device", "auto")
 148        if device_arg == "auto":
 149            device_str = "cuda" if torch.cuda.is_available() else "cpu"
 150        else:
 151            device_str = device_arg
 152        return device_str
 153
 154    def _validate_kwargs(self, kwargs: dict, defaults_attr: str = "DEFAULTS_FIT", context: str = None):
 155        """
 156        Validate a kwargs dictionary against a class-level defaults dictionary.
 157
 158        Parameters
 159        ----------
 160        kwargs : dict
 161            Keyword arguments to validate.
 162        defaults_attr : str, optional
 163            Name of the attribute on this class that contains the allowed keys,
 164            e.g. ``"DEFAULTS_CONFIG"`` or ``"DEFAULTS_FIT"``. Default is "DEFAULTS_FIT".
 165        context : str or None, optional
 166            Optional label (e.g. caller name) to prepend in error messages.
 167
 168        Raises
 169        ------
 170        AttributeError
 171            If the attribute named by ``defaults_attr`` does not exist.
 172        ValueError
 173            If any key in ``kwargs`` is not present in the corresponding defaults dict.
 174        """
 175        defaults = getattr(self, defaults_attr, None)
 176        if defaults is None:
 177            raise AttributeError(f"{self.__class__.__name__} has no attribute '{defaults_attr}'")
 178
 179        unknown = set(kwargs) - set(defaults)
 180        if unknown:
 181            prefix = f"[{context}] " if context else ""
 182            raise ValueError(f"{prefix}Unknown parameter(s): {', '.join(sorted(unknown))}")
 183            
 184    ## CREATE A TRAMDADMODEL
 185    @classmethod
 186    def from_config(cls, cfg, **kwargs):
 187        """
 188        Construct a TramDagModel from a TramDagConfig object.
 189
 190        This builds one TRAM model per node in the DAG and optionally writes
 191        the initial model parameters to disk.
 192
 193        Parameters
 194        ----------
 195        cfg : TramDagConfig
 196            Configuration wrapper holding the underlying configuration dictionary,
 197            including at least:
 198            - ``conf_dict["nodes"]``: mapping of node names to node configs.
 199            - ``conf_dict["PATHS"]["EXPERIMENT_DIR"]``: experiment directory.
 200        **kwargs
 201            Node-level construction options. Each key must be present in
 202            ``DEFAULTS_CONFIG``. Values can be:
 203            - scalar: applied to all nodes.
 204            - dict: mapping ``{node_name: value}`` for per-node overrides.
 205
 206            Common keys include:
 207            device : {"auto", "cpu", "cuda"}, default "auto"
 208                Device selection (CUDA if available when "auto").
 209            debug : bool, default False
 210                If True, print debug messages.
 211            verbose : bool, default False
 212                If True, print informational messages.
 213            set_initial_weights : bool
 214                Passed to underlying TRAM model constructors.
 215            overwrite_initial_weights : bool, default True
 216                If True, overwrite any existing ``initial_model.pt`` files per node.
 217            initial_data : Any
 218                Optional object passed down to node constructors.
 219
 220        Returns
 221        -------
 222        TramDagModel
 223            Fully initialized instance with:
 224            - ``cfg``
 225            - ``nodes_dict``
 226            - ``models`` (per-node TRAMs)
 227            - ``settings`` (resolved per-node config)
 228
 229        Raises
 230        ------
 231        ValueError
 232            If any dict-valued kwarg does not provide values for exactly the set
 233            of nodes in ``cfg.conf_dict["nodes"]``.
 234        """
 235        
 236        self = cls()
 237        self.cfg = cfg
 238        self.cfg.update()  # ensure latest version from disk
 239        self.cfg._verify_completeness()
 240        
 241        
 242        try:
 243            self.cfg.save()  # persist back to disk
 244            if getattr(self, "debug", False):
 245                print("[DEBUG] Configuration updated and saved.")
 246        except Exception as e:
 247            print(f"[WARNING] Could not save configuration after update: {e}")        
 248            
 249        self.nodes_dict = self.cfg.conf_dict["nodes"] 
 250
 251        self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_CONFIG', context="from_config")
 252
 253        # update defaults with kwargs
 254        settings = dict(cls.DEFAULTS_CONFIG)
 255        settings.update(kwargs)
 256
 257        # resolve device
 258        device_arg = settings.get("device", "auto")
 259        if device_arg == "auto":
 260            device_str = "cuda" if torch.cuda.is_available() else "cpu"
 261        else:
 262            device_str = device_arg
 263        self.device = torch.device(device_str)
 264
 265        # set flags on the instance so they are accessible later
 266        self.debug = settings.get("debug", False)
 267        self.verbose = settings.get("verbose", False)
 268
 269        if  self.debug:
 270            print(f"[DEBUG] TramDagModel using device: {self.device}")
 271            
 272        # initialize settings storage
 273        self.settings = {k: {} for k in settings.keys()}
 274
 275        # validate dict-typed args
 276        for k, v in settings.items():
 277            if isinstance(v, dict):
 278                expected = set(self.nodes_dict.keys())
 279                given = set(v.keys())
 280                if expected != given:
 281                    raise ValueError(
 282                        f"[ERROR] the provided argument '{k}' keys are not same as in cfg.conf_dict['nodes'].keys().\n"
 283                        f"Expected: {expected}, but got: {given}\n"
 284                        f"Please provide values for all variables.")
 285
 286        # build one model per node
 287        self.models = {}
 288        for node in self.nodes_dict.keys():
 289            per_node_kwargs = {}
 290            for k, v in settings.items():
 291                resolved = v[node] if isinstance(v, dict) else v
 292                per_node_kwargs[k] = resolved
 293                self.settings[k][node] = resolved
 294            if self.debug:
 295                print(f"\n[INFO] Building model for node '{node}' with settings: {per_node_kwargs}")
 296            self.models[node] = get_fully_specified_tram_model(
 297                node=node,
 298                configuration_dict=self.cfg.conf_dict,
 299                **per_node_kwargs)
 300            
 301            try:
 302                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 303                NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}")
 304                os.makedirs(NODE_DIR, exist_ok=True)
 305
 306                model_path = os.path.join(NODE_DIR, "initial_model.pt")
 307                overwrite = settings.get("overwrite_initial_weights", True)
 308
 309                if overwrite or not os.path.exists(model_path):
 310                    torch.save(self.models[node].state_dict(), model_path)
 311                    if self.debug:
 312                        print(f"[DEBUG] Saved initial model state for node '{node}' to {model_path} (overwrite={overwrite})")
 313                else:
 314                    if self.debug:
 315                        print(f"[DEBUG] Skipped saving initial model for node '{node}' (already exists at {model_path})")
 316            except Exception as e:
 317                print(f"[ERROR] Could not save initial model state for node '{node}': {e}")
 318            
 319            TEMP_DIR = "temp"
 320            if os.path.isdir(TEMP_DIR) and not os.listdir(TEMP_DIR):
 321                os.rmdir(TEMP_DIR)
 322                            
 323        return self
 324
 325    @classmethod
 326    def from_directory(cls, EXPERIMENT_DIR: str, device: str = "auto", debug: bool = False, verbose: bool = False):
 327        """
 328        Reconstruct a TramDagModel from an experiment directory on disk.
 329
 330        This method:
 331        1. Loads the configuration JSON.
 332        2. Wraps it in a TramDagConfig.
 333        3. Builds all node models via `from_config`.
 334        4. Loads the min–max scaling dictionary.
 335
 336        Parameters
 337        ----------
 338        EXPERIMENT_DIR : str
 339            Path to an experiment directory containing:
 340            - ``configuration.json``
 341            - ``min_max_scaling.json``.
 342        device : {"auto", "cpu", "cuda"}, optional
 343            Device selection. Default is "auto".
 344        debug : bool, optional
 345            If True, enable debug messages. Default is False.
 346        verbose : bool, optional
 347            If True, enable informational messages. Default is False.
 348
 349        Returns
 350        -------
 351        TramDagModel
 352            A TramDagModel instance with models, config, and scaling loaded.
 353
 354        Raises
 355        ------
 356        FileNotFoundError
 357            If configuration or min–max files cannot be found.
 358        RuntimeError
 359            If the min–max file cannot be read or parsed.
 360        """
 361
 362        # --- load config file ---
 363        config_path = os.path.join(EXPERIMENT_DIR, "configuration.json")
 364        if not os.path.exists(config_path):
 365            raise FileNotFoundError(f"[ERROR] Config file not found at {config_path}")
 366
 367        with open(config_path, "r") as f:
 368            cfg_dict = json.load(f)
 369
 370        # Create TramConfig wrapper 
 371        cfg = TramDagConfig(cfg_dict, CONF_DICT_PATH=config_path)
 372
 373        # --- build model from config ---
 374        self = cls.from_config(cfg, device=device, debug=debug, verbose=verbose, overwrite_initial_weights=False)
 375
 376        # --- load minmax scaling ---
 377        minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json")
 378        if not os.path.exists(minmax_path):
 379            raise FileNotFoundError(f"[ERROR] MinMax file not found at {minmax_path}")
 380
 381        with open(minmax_path, "r") as f:
 382            self.minmax_dict = json.load(f)
 383
 384        if self.verbose or self.debug:
 385            print(f"[INFO] Loaded TramDagModel from {EXPERIMENT_DIR}")
 386            print(f"[INFO] Config loaded from {config_path}")
 387            print(f"[INFO] MinMax scaling loaded from {minmax_path}")
 388
 389        return self
 390
 391    def _ensure_dataset(self, data, is_val=False,**kwargs):
 392        """
 393        Ensure that the input data is represented as a TramDagDataset.
 394
 395        Parameters
 396        ----------
 397        data : pandas.DataFrame, TramDagDataset, or None
 398            Input data to be converted or passed through.
 399        is_val : bool, optional
 400            If True, the resulting dataset is treated as validation data
 401            (e.g. no shuffling). Default is False.
 402        **kwargs
 403            Additional keyword arguments passed through to
 404            ``TramDagDataset.from_dataframe``.
 405
 406        Returns
 407        -------
 408        TramDagDataset or None
 409            A TramDagDataset if ``data`` is a DataFrame or TramDagDataset,
 410            otherwise None if ``data`` is None.
 411
 412        Raises
 413        ------
 414        TypeError
 415            If ``data`` is not a DataFrame, TramDagDataset, or None.
 416        """
 417                
 418        if isinstance(data, pd.DataFrame):
 419            return TramDagDataset.from_dataframe(data, self.cfg, shuffle=not is_val,**kwargs)
 420        elif isinstance(data, TramDagDataset):
 421            return data
 422        elif data is None:
 423            return None
 424        else:
 425            raise TypeError(
 426                f"[ERROR] data must be pd.DataFrame, TramDagDataset, or None, got {type(data)}"
 427            )
 428
 429    def load_or_compute_minmax(self, td_train_data=None,use_existing=False, write=True):
 430        """
 431        Load an existing Min–Max scaling dictionary from disk or compute a new one 
 432        from the provided training dataset.
 433
 434        Parameters
 435        ----------
 436        use_existing : bool, optional (default=False)
 437            If True, attempts to load an existing `min_max_scaling.json` file 
 438            from the experiment directory. Raises an error if the file is missing 
 439            or unreadable.
 440
 441        write : bool, optional (default=True)
 442            If True, writes the computed Min–Max scaling dictionary to 
 443            `<EXPERIMENT_DIR>/min_max_scaling.json`.
 444
 445        td_train_data : object, optional
 446            Training dataset used to compute scaling statistics. If not provided,
 447            the method will ensure or construct it via `_ensure_dataset(data=..., is_val=False)`.
 448
 449        Behavior
 450        --------
 451        - If `use_existing=True`, loads the JSON file containing previously saved 
 452          min–max values and stores it in `self.minmax_dict`.
 453        - If `use_existing=False`, computes a new scaling dictionary using 
 454          `td_train_data.compute_scaling()` and stores the result in 
 455          `self.minmax_dict`.
 456        - Optionally writes the computed dictionary to disk.
 457
 458        Side Effects
 459        -------------
 460        - Populates `self.minmax_dict` with scaling values.
 461        - Writes or loads the file `min_max_scaling.json` under 
 462          `<EXPERIMENT_DIR>`.
 463        - Prints diagnostic output if `self.debug` or `self.verbose` is True.
 464
 465        Raises
 466        ------
 467        FileNotFoundError
 468            If `use_existing=True` but the min–max file does not exist.
 469
 470        RuntimeError
 471            If an existing min–max file cannot be read or parsed.
 472
 473        Notes
 474        -----
 475        The computed min–max dictionary is expected to contain scaling statistics 
 476        per feature, typically in the form:
 477            {
 478                "node": {"min": float, "max": float},
 479                ...
 480            }
 481        """
 482        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 483        minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json")
 484
 485        # laod exisitng if possible
 486        if use_existing:
 487            if not os.path.exists(minmax_path):
 488                raise FileNotFoundError(f"MinMax file not found: {minmax_path}")
 489            try:
 490                with open(minmax_path, 'r') as f:
 491                    self.minmax_dict = json.load(f)
 492                if self.debug or self.verbose:
 493                    print(f"[INFO] Loaded existing minmax dict from {minmax_path}")
 494                return
 495            except Exception as e:
 496                raise RuntimeError(f"Could not load existing minmax dict: {e}")
 497
 498        # 
 499        if self.debug or self.verbose:
 500            print("[INFO] Computing new minmax dict from training data...")
 501            
 502        td_train_data=self._ensure_dataset( data=td_train_data, is_val=False)    
 503            
 504        self.minmax_dict = td_train_data.compute_scaling()
 505
 506        if write:
 507            os.makedirs(EXPERIMENT_DIR, exist_ok=True)
 508            with open(minmax_path, 'w') as f:
 509                json.dump(self.minmax_dict, f, indent=4)
 510            if self.debug or self.verbose:
 511                print(f"[INFO] Saved new minmax dict to {minmax_path}")
 512
 513    ## FIT METHODS
 514    @staticmethod
 515    def _fit_single_node(node, self_ref, settings, td_train_data, td_val_data, device_str):
 516        """
 517        Train a single node model (helper for per-node training).
 518
 519        This method is designed to be called either from the main process
 520        (sequential training) or from a joblib worker (parallel CPU training).
 521
 522        Parameters
 523        ----------
 524        node : str
 525            Name of the target node to train.
 526        self_ref : TramDagModel
 527            Reference to the TramDagModel instance containing models and config.
 528        settings : dict
 529            Training settings dictionary, typically derived from ``DEFAULTS_FIT``
 530            plus any user overrides.
 531        td_train_data : TramDagDataset
 532            Training dataset with node-specific DataLoaders in ``.loaders``.
 533        td_val_data : TramDagDataset or None
 534            Validation dataset or None.
 535        device_str : str
 536            Device string, e.g. "cpu" or "cuda".
 537
 538        Returns
 539        -------
 540        tuple
 541            A tuple ``(node, history)`` where:
 542            node : str
 543                Node name.
 544            history : dict or Any
 545                Training history as returned by ``train_val_loop``.
 546        """
 547        torch.set_num_threads(1)  # prevent thread oversubscription
 548
 549        model = self_ref.models[node]
 550
 551        # Resolve per-node settings
 552        def _resolve(key):
 553            val = settings[key]
 554            return val[node] if isinstance(val, dict) else val
 555
 556        node_epochs = _resolve("epochs")
 557        node_lr = _resolve("learning_rate")
 558        node_debug = _resolve("debug")
 559        node_save_linear_shifts = _resolve("save_linear_shifts")
 560        save_simple_intercepts  = _resolve("save_simple_intercepts")
 561        node_verbose = _resolve("verbose")
 562
 563        # Optimizer & scheduler
 564        if settings["optimizers"] and node in settings["optimizers"]:
 565            optimizer = settings["optimizers"][node]
 566        else:
 567            optimizer = Adam(model.parameters(), lr=node_lr)
 568
 569        scheduler = settings["schedulers"].get(node, None) if settings["schedulers"] else None
 570
 571        # Data loaders
 572        train_loader = td_train_data.loaders[node]
 573        val_loader = td_val_data.loaders[node] if td_val_data else None
 574
 575        # Min-max scaling tensors
 576        min_vals = torch.tensor(self_ref.minmax_dict[node][0], dtype=torch.float32)
 577        max_vals = torch.tensor(self_ref.minmax_dict[node][1], dtype=torch.float32)
 578        min_max = torch.stack([min_vals, max_vals], dim=0)
 579
 580        # Node directory
 581        try:
 582            EXPERIMENT_DIR = self_ref.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 583            NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}")
 584        except Exception:
 585            NODE_DIR = os.path.join("models", node)
 586            print("[WARNING] No log directory specified in config, saving to default location.")
 587        os.makedirs(NODE_DIR, exist_ok=True)
 588
 589        if node_verbose:
 590            print(f"\n[INFO] Training node '{node}' for {node_epochs} epochs on {device_str} (pid={os.getpid()})")
 591
 592        # --- train ---
 593        history = train_val_loop(
 594            node=node,
 595            target_nodes=self_ref.nodes_dict,
 596            NODE_DIR=NODE_DIR,
 597            tram_model=model,
 598            train_loader=train_loader,
 599            val_loader=val_loader,
 600            epochs=node_epochs,
 601            optimizer=optimizer,
 602            use_scheduler=(scheduler is not None),
 603            scheduler=scheduler,
 604            save_linear_shifts=node_save_linear_shifts,
 605            save_simple_intercepts=save_simple_intercepts,
 606            verbose=node_verbose,
 607            device=torch.device(device_str),
 608            debug=node_debug,
 609            min_max=min_max)
 610        return node, history
 611
 612    def fit(self, train_data, val_data=None, **kwargs):
 613        """
 614        Train TRAM models for all nodes in the DAG.
 615
 616        Coordinates dataset preparation, min–max scaling, and per-node training,
 617        optionally in parallel on CPU.
 618
 619        Parameters
 620        ----------
 621        train_data : pandas.DataFrame or TramDagDataset
 622            Training data. If a DataFrame is given, it is converted into a
 623            TramDagDataset using `_ensure_dataset`.
 624        val_data : pandas.DataFrame or TramDagDataset or None, optional
 625            Validation data. If a DataFrame is given, it is converted into a
 626            TramDagDataset. If None, no validation loss is computed.
 627        **kwargs
 628            Overrides for ``DEFAULTS_FIT``. All keys must exist in
 629            ``DEFAULTS_FIT``. Common options:
 630
 631            epochs : int, default 100
 632                Number of training epochs per node.
 633            learning_rate : float, default 0.01
 634                Learning rate for the default Adam optimizer.
 635            train_list : list of str or None, optional
 636                List of node names to train. If None, all nodes are trained.
 637            train_mode : {"sequential", "parallel"}, default "sequential"
 638                Training mode. "parallel" uses joblib-based CPU multiprocessing.
 639                GPU forces sequential mode.
 640            device : {"auto", "cpu", "cuda"}, default "auto"
 641                Device selection.
 642            optimizers : dict or None
 643                Optional mapping ``{node_name: optimizer}``. If provided for a
 644                node, that optimizer is used instead of creating a new Adam.
 645            schedulers : dict or None
 646                Optional mapping ``{node_name: scheduler}``.
 647            use_scheduler : bool
 648                If True, enable scheduler usage in the training loop.
 649            num_workers : int
 650                DataLoader workers in sequential mode (ignored in parallel).
 651            persistent_workers : bool
 652                DataLoader persistence in sequential mode (ignored in parallel).
 653            prefetch_factor : int
 654                DataLoader prefetch factor (ignored in parallel).
 655            batch_size : int
 656                Batch size for all node DataLoaders.
 657            debug : bool
 658                Enable debug output.
 659            verbose : bool
 660                Enable informational logging.
 661            return_history : bool
 662                If True, return a history dict.
 663
 664        Returns
 665        -------
 666        dict or None
 667            If ``return_history=True``, a dictionary mapping each node name
 668            to its training history. Otherwise, returns None.
 669
 670        Raises
 671        ------
 672        ValueError
 673            If ``train_mode`` is not "sequential" or "parallel".
 674        """
 675        self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_FIT', context="fit")
 676        
 677        # --- merge defaults ---
 678        settings = dict(self.DEFAULTS_FIT)
 679        settings.update(kwargs)
 680        
 681        
 682        self.debug = settings.get("debug", False)
 683        self.verbose = settings.get("verbose", False)
 684
 685        # --- resolve device ---
 686        device_str=self.get_device(settings)
 687        self.device = torch.device(device_str)
 688
 689        # --- training mode ---
 690        train_mode = settings.get("train_mode", "sequential").lower()
 691        if train_mode not in ("sequential", "parallel"):
 692            raise ValueError("train_mode must be 'sequential' or 'parallel'")
 693
 694        # --- DataLoader safety logic ---
 695        if train_mode == "parallel":
 696            # if user passed loader paralleling params, warn and override
 697            for flag in ("num_workers", "persistent_workers", "prefetch_factor"):
 698                if flag in kwargs:
 699                    print(f"[WARNING] '{flag}' is ignored in parallel mode "
 700                        f"(disabled to prevent nested multiprocessing).")
 701            # disable unsafe loader multiprocessing options
 702            settings["num_workers"] = 0
 703            settings["persistent_workers"] = False
 704            settings["prefetch_factor"] = None
 705        else:
 706            # sequential mode → respect user DataLoader settings
 707            if self.debug:
 708                print("[DEBUG] Sequential mode: using DataLoader kwargs as provided.")
 709
 710        # --- which nodes to train ---
 711        train_list = settings.get("train_list") or list(self.models.keys())
 712
 713
 714        # --- dataset prep (receives adjusted settings) ---
 715        td_train_data = self._ensure_dataset(train_data, is_val=False, **settings)
 716        td_val_data = self._ensure_dataset(val_data, is_val=True, **settings)
 717
 718        # --- normalization ---
 719        self.load_or_compute_minmax(use_existing=False, write=True, td_train_data=td_train_data)
 720
 721        # --- print header ---
 722        if self.verbose or self.debug:
 723            print(f"[INFO] Training {len(train_list)} nodes ({train_mode}) on {device_str}")
 724
 725        # ======================================================================
 726        # Sequential mode  safe for GPU or debugging)
 727        # ======================================================================
 728        if train_mode == "sequential" or "cuda" in device_str:
 729            if "cuda" in device_str and train_mode == "parallel":
 730                print("[WARNING] GPU device detected — forcing sequential mode.")
 731            results = {}
 732            for node in train_list:
 733                node, history = self._fit_single_node(
 734                    node, self, settings, td_train_data, td_val_data, device_str
 735                )
 736                results[node] = history
 737        
 738
 739        # ======================================================================
 740        # parallel mode (CPU only)
 741        # ======================================================================
 742        if train_mode == "parallel":
 743
 744            n_jobs = min(len(train_list), os.cpu_count() // 2 or 1)
 745            if self.verbose or self.debug:
 746                print(f"[INFO] Using {n_jobs} CPU workers for parallel node training")
 747            parallel_outputs = Parallel(
 748                n_jobs=n_jobs,
 749                backend="loky",#loky, multiprocessing
 750                verbose=10,
 751                prefer="processes"
 752            )(delayed(self._fit_single_node)(node, self, settings, td_train_data, td_val_data, device_str) for node in train_list )
 753
 754            results = {node: hist for node, hist in parallel_outputs}
 755        
 756        if settings.get("return_history", False):
 757            return results
 758
 759    ## FIT-DIAGNOSTICS
 760    def loss_history(self):
 761        """
 762        Load training and validation loss history for all nodes.
 763
 764        Looks for per-node JSON files:
 765
 766        - ``EXPERIMENT_DIR/{node}/train_loss_hist.json``
 767        - ``EXPERIMENT_DIR/{node}/val_loss_hist.json``
 768
 769        Returns
 770        -------
 771        dict
 772            A dictionary mapping node names to:
 773
 774            .. code-block:: python
 775
 776                {
 777                    "train": list or None,
 778                    "validation": list or None
 779                }
 780
 781            where each list contains NLL values per epoch, or None if not found.
 782
 783        Raises
 784        ------
 785        ValueError
 786            If the experiment directory cannot be resolved from the configuration.
 787        """
 788        try:
 789            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 790        except KeyError:
 791            raise ValueError(
 792                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
 793                "History retrieval requires experiment logs."
 794            )
 795
 796        all_histories = {}
 797        for node in self.nodes_dict.keys():
 798            node_dir = os.path.join(EXPERIMENT_DIR, node)
 799            train_path = os.path.join(node_dir, "train_loss_hist.json")
 800            val_path = os.path.join(node_dir, "val_loss_hist.json")
 801
 802            node_hist = {}
 803
 804            # --- load train history ---
 805            if os.path.exists(train_path):
 806                try:
 807                    with open(train_path, "r") as f:
 808                        node_hist["train"] = json.load(f)
 809                except Exception as e:
 810                    print(f"[WARNING] Could not load {train_path}: {e}")
 811                    node_hist["train"] = None
 812            else:
 813                node_hist["train"] = None
 814
 815            # --- load val history ---
 816            if os.path.exists(val_path):
 817                try:
 818                    with open(val_path, "r") as f:
 819                        node_hist["validation"] = json.load(f)
 820                except Exception as e:
 821                    print(f"[WARNING] Could not load {val_path}: {e}")
 822                    node_hist["validation"] = None
 823            else:
 824                node_hist["validation"] = None
 825
 826            all_histories[node] = node_hist
 827
 828        if self.verbose or self.debug:
 829            print(f"[INFO] Loaded training/validation histories for {len(all_histories)} nodes.")
 830
 831        return all_histories
 832
 833    def linear_shift_history(self):
 834        """
 835        Load linear shift term histories for all nodes.
 836
 837        Each node history is expected in a JSON file named
 838        ``linear_shifts_all_epochs.json`` under the node directory.
 839
 840        Returns
 841        -------
 842        dict
 843            A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame
 844            contains linear shift weights across epochs.
 845
 846        Raises
 847        ------
 848        ValueError
 849            If the experiment directory cannot be resolved from the configuration.
 850
 851        Notes
 852        -----
 853        If a history file is missing for a node, a warning is printed and the
 854        node is omitted from the returned dictionary.
 855        """
 856        histories = {}
 857        try:
 858            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 859        except KeyError:
 860            raise ValueError(
 861                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
 862                "Cannot load histories without experiment directory."
 863            )
 864
 865        for node in self.nodes_dict.keys():
 866            node_dir = os.path.join(EXPERIMENT_DIR, node)
 867            history_path = os.path.join(node_dir, "linear_shifts_all_epochs.json")
 868            if os.path.exists(history_path):
 869                histories[node] = pd.read_json(history_path)
 870            else:
 871                print(f"[WARNING] No linear shift history found for node '{node}' at {history_path}")
 872        return histories
 873
 874    def simple_intercept_history(self):
 875        """
 876        Load simple intercept histories for all nodes.
 877
 878        Each node history is expected in a JSON file named
 879        ``simple_intercepts_all_epochs.json`` under the node directory.
 880
 881        Returns
 882        -------
 883        dict
 884            A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame
 885            contains intercept weights across epochs.
 886
 887        Raises
 888        ------
 889        ValueError
 890            If the experiment directory cannot be resolved from the configuration.
 891
 892        Notes
 893        -----
 894        If a history file is missing for a node, a warning is printed and the
 895        node is omitted from the returned dictionary.
 896        """
 897        histories = {}
 898        try:
 899            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 900        except KeyError:
 901            raise ValueError(
 902                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
 903                "Cannot load histories without experiment directory."
 904            )
 905
 906        for node in self.nodes_dict.keys():
 907            node_dir = os.path.join(EXPERIMENT_DIR, node)
 908            history_path = os.path.join(node_dir, "simple_intercepts_all_epochs.json")
 909            if os.path.exists(history_path):
 910                histories[node] = pd.read_json(history_path)
 911            else:
 912                print(f"[WARNING] No simple intercept history found for node '{node}' at {history_path}")
 913        return histories
 914
 915    def get_latent(self, df, verbose=False):
 916        """
 917        Compute latent representations for all nodes in the DAG.
 918
 919        Parameters
 920        ----------
 921        df : pandas.DataFrame
 922            Input data frame with columns corresponding to nodes in the DAG.
 923        verbose : bool, optional
 924            If True, print informational messages during latent computation.
 925            Default is False.
 926
 927        Returns
 928        -------
 929        pandas.DataFrame
 930            DataFrame containing the original columns plus latent variables
 931            for each node (e.g. columns named ``f"{node}_U"``).
 932
 933        Raises
 934        ------
 935        ValueError
 936            If the experiment directory is missing from the configuration or
 937            if ``self.minmax_dict`` has not been set.
 938        """
 939        try:
 940            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 941        except KeyError:
 942            raise ValueError(
 943                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
 944                "Latent extraction requires trained model checkpoints."
 945            )
 946
 947        # ensure minmax_dict is available
 948        if not hasattr(self, "minmax_dict"):
 949            raise ValueError(
 950                "[ERROR] minmax_dict not found in the TramDagModel instance. "
 951                "Either call .load_or_compute_minmax(td_train_data=train_df) or .fit() first."
 952            )
 953
 954        all_latents_df = create_latent_df_for_full_dag(
 955            configuration_dict=self.cfg.conf_dict,
 956            EXPERIMENT_DIR=EXPERIMENT_DIR,
 957            df=df,
 958            verbose=verbose,
 959            min_max_dict=self.minmax_dict,
 960        )
 961
 962        return all_latents_df
 963
 964    
 965    ## PLOTTING FIT-DIAGNOSTICS
 966    
 967    def plot_loss_history(self, variable: str = None):
 968        """
 969        Plot training and validation loss evolution per node.
 970
 971        Parameters
 972        ----------
 973        variable : str or None, optional
 974            If provided, plot loss history for this node only. If None, plot
 975            histories for all nodes that have both train and validation logs.
 976
 977        Returns
 978        -------
 979        None
 980
 981        Notes
 982        -----
 983        Two subplots are produced:
 984        - Full epoch history.
 985        - Last 10% of epochs (or only the last epoch if fewer than 5 epochs).
 986        """
 987
 988        histories = self.loss_history()
 989        if not histories:
 990            print("[WARNING] No loss histories found.")
 991            return
 992
 993        # Select which nodes to plot
 994        if variable is not None:
 995            if variable not in histories:
 996                raise ValueError(f"[ERROR] Node '{variable}' not found in histories.")
 997            nodes_to_plot = [variable]
 998        else:
 999            nodes_to_plot = list(histories.keys())
1000
1001        # Filter out nodes with no valid history
1002        nodes_to_plot = [
1003            n for n in nodes_to_plot
1004            if histories[n].get("train") is not None and len(histories[n]["train"]) > 0
1005            and histories[n].get("validation") is not None and len(histories[n]["validation"]) > 0
1006        ]
1007
1008        if not nodes_to_plot:
1009            print("[WARNING] No valid histories found to plot.")
1010            return
1011
1012        plt.figure(figsize=(14, 12))
1013
1014        # --- Full history (top plot) ---
1015        plt.subplot(2, 1, 1)
1016        for node in nodes_to_plot:
1017            node_hist = histories[node]
1018            train_hist, val_hist = node_hist["train"], node_hist["validation"]
1019
1020            epochs = range(1, len(train_hist) + 1)
1021            plt.plot(epochs, train_hist, label=f"{node} - train", linestyle="--")
1022            plt.plot(epochs, val_hist, label=f"{node} - val")
1023
1024        plt.title("Training and Validation NLL - Full History")
1025        plt.xlabel("Epoch")
1026        plt.ylabel("NLL")
1027        plt.legend()
1028        plt.grid(True)
1029
1030        # --- Last 10% of epochs (bottom plot) ---
1031        plt.subplot(2, 1, 2)
1032        for node in nodes_to_plot:
1033            node_hist = histories[node]
1034            train_hist, val_hist = node_hist["train"], node_hist["validation"]
1035
1036            total_epochs = len(train_hist)
1037            start_idx = total_epochs - 1 if total_epochs < 5 else int(total_epochs * 0.9)
1038
1039            epochs = range(start_idx + 1, total_epochs + 1)
1040            plt.plot(epochs, train_hist[start_idx:], label=f"{node} - train", linestyle="--")
1041            plt.plot(epochs, val_hist[start_idx:], label=f"{node} - val")
1042
1043        plt.title("Training and Validation NLL - Last 10% of Epochs (or Last Epoch if <5)")
1044        plt.xlabel("Epoch")
1045        plt.ylabel("NLL")
1046        plt.legend()
1047        plt.grid(True)
1048
1049        plt.tight_layout()
1050        plt.show()
1051
1052    def plot_linear_shift_history(self, data_dict=None, node=None, ref_lines=None):
1053        """
1054        Plot the evolution of linear shift terms over epochs.
1055
1056        Parameters
1057        ----------
1058        data_dict : dict or None, optional
1059            Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing shift
1060            weights across epochs. If None, `linear_shift_history()` is called.
1061        node : str or None, optional
1062            If provided, plot only this node. Otherwise, plot all nodes
1063            present in ``data_dict``.
1064        ref_lines : dict or None, optional
1065            Optional mapping ``{node_name: list of float}``. For each specified
1066            node, horizontal reference lines are drawn at the given values.
1067
1068        Returns
1069        -------
1070        None
1071
1072        Notes
1073        -----
1074        The function flattens nested list-like entries in the DataFrames to scalars,
1075        converts epoch labels to numeric, and then draws one line per shift term.
1076        """
1077
1078        if data_dict is None:
1079            data_dict = self.linear_shift_history()
1080            if data_dict is None:
1081                raise ValueError("No shift history data provided or stored in the class.")
1082
1083        nodes = [node] if node else list(data_dict.keys())
1084
1085        for n in nodes:
1086            df = data_dict[n].copy()
1087
1088            # Flatten nested lists or list-like cells
1089            def flatten(x):
1090                if isinstance(x, list):
1091                    if len(x) == 0:
1092                        return np.nan
1093                    if all(isinstance(i, (int, float)) for i in x):
1094                        return np.mean(x)  # average simple list
1095                    if all(isinstance(i, list) for i in x):
1096                        # nested list -> flatten inner and average
1097                        flat = [v for sub in x for v in (sub if isinstance(sub, list) else [sub])]
1098                        return np.mean(flat) if flat else np.nan
1099                    return x[0] if len(x) == 1 else np.nan
1100                return x
1101
1102            df = df.applymap(flatten)
1103
1104            # Ensure numeric columns
1105            df = df.apply(pd.to_numeric, errors='coerce')
1106
1107            # Convert epoch labels to numeric
1108            df.columns = [
1109                int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c
1110                for c in df.columns
1111            ]
1112            df = df.reindex(sorted(df.columns), axis=1)
1113
1114            plt.figure(figsize=(10, 6))
1115            for idx in df.index:
1116                plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"shift_{idx}")
1117
1118            if ref_lines and n in ref_lines:
1119                for v in ref_lines[n]:
1120                    plt.axhline(y=v, color="k", linestyle="--", lw=1.0)
1121                    plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8)
1122
1123            plt.xlabel("Epoch")
1124            plt.ylabel("Shift Value")
1125            plt.title(f"Shift Term History — Node: {n}")
1126            plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small")
1127            plt.tight_layout()
1128            plt.show()
1129
1130    def plot_simple_intercepts_history(self, data_dict=None, node=None,ref_lines=None):
1131        """
1132        Plot the evolution of simple intercept weights over epochs.
1133
1134        Parameters
1135        ----------
1136        data_dict : dict or None, optional
1137            Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing intercept
1138            weights across epochs. If None, `simple_intercept_history()` is called.
1139        node : str or None, optional
1140            If provided, plot only this node. Otherwise, plot all nodes present
1141            in ``data_dict``.
1142        ref_lines : dict or None, optional
1143            Optional mapping ``{node_name: list of float}``. For each specified
1144            node, horizontal reference lines are drawn at the given values.
1145
1146        Returns
1147        -------
1148        None
1149
1150        Notes
1151        -----
1152        Nested list-like entries in the DataFrames are reduced to scalars before
1153        plotting. One line is drawn per intercept parameter.
1154        """
1155        if data_dict is None:
1156            data_dict = self.simple_intercept_history()
1157            if data_dict is None:
1158                raise ValueError("No intercept history data provided or stored in the class.")
1159
1160        nodes = [node] if node else list(data_dict.keys())
1161
1162        for n in nodes:
1163            df = data_dict[n].copy()
1164
1165            def extract_scalar(x):
1166                if isinstance(x, list):
1167                    while isinstance(x, list) and len(x) > 0:
1168                        x = x[0]
1169                return float(x) if isinstance(x, (int, float, np.floating)) else np.nan
1170
1171            df = df.applymap(extract_scalar)
1172
1173            # Convert epoch labels → numeric
1174            df.columns = [
1175                int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c
1176                for c in df.columns
1177            ]
1178            df = df.reindex(sorted(df.columns), axis=1)
1179
1180            plt.figure(figsize=(10, 6))
1181            for idx in df.index:
1182                plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"theta_{idx}")
1183            
1184            if ref_lines and n in ref_lines:
1185                for v in ref_lines[n]:
1186                    plt.axhline(y=v, color="k", linestyle="--", lw=1.0)
1187                    plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8)
1188                
1189            plt.xlabel("Epoch")
1190            plt.ylabel("Intercept Weight")
1191            plt.title(f"Simple Intercept Evolution — Node: {n}")
1192            plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small")
1193            plt.tight_layout()
1194            plt.show()
1195
1196    def plot_latents(self, df, variable: str = None, confidence: float = 0.95, simulations: int = 1000):
1197        """
1198        Visualize latent U distributions for one or all nodes.
1199
1200        Parameters
1201        ----------
1202        df : pandas.DataFrame
1203            Input data frame with raw node values.
1204        variable : str or None, optional
1205            If provided, only this node's latents are plotted. If None, all
1206            nodes with latent columns are processed.
1207        confidence : float, optional
1208            Confidence level for QQ-plot bands (0 < confidence < 1).
1209            Default is 0.95.
1210        simulations : int, optional
1211            Number of Monte Carlo simulations for QQ-plot bands. Default is 1000.
1212
1213        Returns
1214        -------
1215        None
1216
1217        Notes
1218        -----
1219        For each node, two plots are produced:
1220        - Histogram of the latent U values.
1221        - QQ-plot with simulation-based confidence bands under a logistic reference.
1222        """
1223        # Compute latent representations
1224        latents_df = self.get_latent(df)
1225
1226        # Select nodes
1227        nodes = [variable] if variable is not None else self.nodes_dict.keys()
1228
1229        for node in nodes:
1230            if f"{node}_U" not in latents_df.columns:
1231                print(f"[WARNING] No latent found for node {node}, skipping.")
1232                continue
1233
1234            sample = latents_df[f"{node}_U"].values
1235
1236            # --- Create plots ---
1237            fig, axs = plt.subplots(1, 2, figsize=(12, 5))
1238
1239            # Histogram
1240            axs[0].hist(sample, bins=50, color="steelblue", alpha=0.7)
1241            axs[0].set_title(f"Latent Histogram ({node})")
1242            axs[0].set_xlabel("U")
1243            axs[0].set_ylabel("Frequency")
1244
1245            # QQ Plot with confidence bands
1246            probplot(sample, dist="logistic", plot=axs[1])
1247            self._add_r_style_confidence_bands(axs[1], sample, dist=logistic,confidence=confidence, simulations=simulations)
1248            axs[1].set_title(f"Latent QQ Plot ({node})")
1249
1250            plt.suptitle(f"Latent Diagnostics for Node: {node}", fontsize=14)
1251            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
1252            plt.show()
1253
1254    def plot_hdag(self,df,variables=None, plot_n_rows=1,**kwargs):
1255        
1256        """
1257        Visualize the transformation function h() for selected DAG nodes.
1258
1259        Parameters
1260        ----------
1261        df : pandas.DataFrame
1262            Input data containing node values or model predictions.
1263        variables : list of str or None, optional
1264            Names of nodes to visualize. If None, all nodes in ``self.models``
1265            are considered.
1266        plot_n_rows : int, optional
1267            Maximum number of rows from ``df`` to visualize. Default is 1.
1268        **kwargs
1269            Additional keyword arguments forwarded to the underlying plotting
1270            helpers (`show_hdag_continous` / `show_hdag_ordinal`).
1271
1272        Returns
1273        -------
1274        None
1275
1276        Notes
1277        -----
1278        - For continuous outcomes, `show_hdag_continous` is called.
1279        - For ordinal outcomes, `show_hdag_ordinal` is called.
1280        - Nodes that are neither continuous nor ordinal are skipped with a warning.
1281        """
1282                
1283
1284        if len(df)> 1:
1285            print("[WARNING] len(df)>1, set: plot_n_rows accordingly")
1286        
1287        variables_list=variables if variables is not None else list(self.models.keys())
1288        for node in variables_list:
1289            if is_outcome_modelled_continous(node, self.nodes_dict):
1290                show_hdag_continous(df,node=node,configuration_dict=self.cfg.conf_dict,minmax_dict=self.minmax_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1291            
1292            elif is_outcome_modelled_ordinal(node, self.nodes_dict):
1293                show_hdag_ordinal(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1294                # plot_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1295                # save_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,**kwargs)
1296            else:
1297                print(f"[WARNING] Node {node} is wheter ordinal nor continous, not implemented yet")
1298    
1299    @staticmethod
1300    def _add_r_style_confidence_bands(ax, sample, dist, confidence=0.95, simulations=1000):
1301        """
1302        Add simulation-based confidence bands to a QQ-plot.
1303
1304        Parameters
1305        ----------
1306        ax : matplotlib.axes.Axes
1307            Axes object on which to draw the QQ-plot and bands.
1308        sample : array-like
1309            Empirical sample used in the QQ-plot.
1310        dist : scipy.stats distribution
1311            Distribution object providing ``ppf`` and ``rvs`` methods (e.g. logistic).
1312        confidence : float, optional
1313            Confidence level (0 < confidence < 1) for the bands. Default is 0.95.
1314        simulations : int, optional
1315            Number of Monte Carlo simulations used to estimate the bands. Default is 1000.
1316
1317        Returns
1318        -------
1319        None
1320
1321        Notes
1322        -----
1323        The axes are cleared, and a new QQ-plot is drawn with:
1324        - Empirical vs. theoretical quantiles.
1325        - 45-degree reference line.
1326        - Shaded confidence band region.
1327        """
1328        
1329        n = len(sample)
1330        if n == 0:
1331            return
1332
1333        quantiles = np.linspace(0, 1, n, endpoint=False) + 0.5 / n
1334        theo_q = dist.ppf(quantiles)
1335
1336        # Simulate order statistics from the theoretical distribution
1337        sim_data = dist.rvs(size=(simulations, n))
1338        sim_order_stats = np.sort(sim_data, axis=1)
1339
1340        # Confidence bands
1341        lower = np.percentile(sim_order_stats, 100 * (1 - confidence) / 2, axis=0)
1342        upper = np.percentile(sim_order_stats, 100 * (1 + confidence) / 2, axis=0)
1343
1344        # Sort empirical sample
1345        sample_sorted = np.sort(sample)
1346
1347        # Re-draw points and CI (overwrite probplot defaults)
1348        ax.clear()
1349        ax.plot(theo_q, sample_sorted, 'o', markersize=3, alpha=0.6, label="Empirical Q-Q")
1350        ax.plot(theo_q, theo_q, 'b--', label="y = x")
1351        ax.fill_between(theo_q, lower, upper, color='gray', alpha=0.3,
1352                        label=f'{int(confidence*100)}% CI')
1353        ax.legend()
1354    
1355    ## SAMPLING METHODS
1356    def sample(
1357        self,
1358        do_interventions: dict = None,
1359        predefined_latent_samples_df: pd.DataFrame = None,
1360        **kwargs,
1361    ):
1362        """
1363        Sample from the joint DAG using the trained TRAM models.
1364
1365        Allows for:
1366        
1367        Oberservational sampling
1368        Interventional sampling via ``do()`` operations
1369        Counterfactial sampling using predefined latent draws and do()
1370        
1371        Parameters
1372        ----------
1373        do_interventions : dict or None, optional
1374            Mapping of node names to intervened (fixed) values. For example:
1375            ``{"x1": 1.0}`` represents ``do(x1 = 1.0)``. Default is None.
1376        predefined_latent_samples_df : pandas.DataFrame or None, optional
1377            DataFrame containing columns ``"{node}_U"`` with predefined latent
1378            draws to be used instead of sampling from the prior. Default is None.
1379        **kwargs
1380            Sampling options overriding internal defaults:
1381
1382            number_of_samples : int, default 10000
1383                Total number of samples to draw.
1384            batch_size : int, default 32
1385                Batch size for internal sampling loops.
1386            delete_all_previously_sampled : bool, default True
1387                If True, delete old sampling files in node-specific sampling
1388                directories before writing new ones.
1389            verbose : bool
1390                If True, print informational messages.
1391            debug : bool
1392                If True, print debug output.
1393            device : {"auto", "cpu", "cuda"}
1394                Device selection for sampling.
1395            use_initial_weights_for_sampling : bool, default False
1396                If True, sample from initial (untrained) model parameters.
1397
1398        Returns
1399        -------
1400        tuple
1401            A tuple ``(sampled_by_node, latents_by_node)``:
1402
1403            sampled_by_node : dict
1404                Mapping ``{node_name: torch.Tensor}`` of sampled node values.
1405            latents_by_node : dict
1406                Mapping ``{node_name: torch.Tensor}`` of latent U values used.
1407
1408        Raises
1409        ------
1410        ValueError
1411            If the experiment directory cannot be resolved or if scaling
1412            information (``self.minmax_dict``) is missing.
1413        RuntimeError
1414            If min–max scaling has not been computed before calling `sample`.
1415        """
1416        try:
1417            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1418        except KeyError:
1419            raise ValueError(
1420                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
1421                "Sampling requires trained model checkpoints."
1422            )
1423
1424        # ---- defaults ----
1425        settings = {
1426            "number_of_samples": 10_000,
1427            "batch_size": 32,
1428            "delete_all_previously_sampled": True,
1429            "verbose": self.verbose if hasattr(self, "verbose") else False,
1430            "debug": self.debug if hasattr(self, "debug") else False,
1431            "device": self.device.type if hasattr(self, "device") else "auto",
1432            "use_initial_weights_for_sampling": False,
1433            
1434        }
1435        
1436        # self._validate_kwargs( kwargs, defaults_attr= "settings", context="sample")
1437        
1438        settings.update(kwargs)
1439
1440        
1441        if not hasattr(self, "minmax_dict"):
1442            raise RuntimeError(
1443                "[ERROR] minmax_dict not found. You must call .fit() or .load_or_compute_minmax() "
1444                "before sampling, so scaling info is available."
1445                )
1446            
1447        # ---- resolve device ----
1448        device_str=self.get_device(settings)
1449        self.device = torch.device(device_str)
1450
1451
1452        if self.debug or settings["debug"]:
1453            print(f"[DEBUG] sample(): device: {self.device}")
1454
1455        # ---- perform sampling ----
1456        sampled_by_node, latents_by_node = sample_full_dag(
1457            configuration_dict=self.cfg.conf_dict,
1458            EXPERIMENT_DIR=EXPERIMENT_DIR,
1459            device=self.device,
1460            do_interventions=do_interventions or {},
1461            predefined_latent_samples_df=predefined_latent_samples_df,
1462            number_of_samples=settings["number_of_samples"],
1463            batch_size=settings["batch_size"],
1464            delete_all_previously_sampled=settings["delete_all_previously_sampled"],
1465            verbose=settings["verbose"],
1466            debug=settings["debug"],
1467            minmax_dict=self.minmax_dict,
1468            use_initial_weights_for_sampling=settings["use_initial_weights_for_sampling"]
1469        )
1470
1471        return sampled_by_node, latents_by_node
1472
1473    def load_sampled_and_latents(self, EXPERIMENT_DIR: str = None, nodes: list = None):
1474        """
1475        Load previously stored sampled values and latents for each node.
1476
1477        Parameters
1478        ----------
1479        EXPERIMENT_DIR : str or None, optional
1480            Experiment directory path. If None, it is taken from
1481            ``self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]``.
1482        nodes : list of str or None, optional
1483            Nodes for which to load samples. If None, use all nodes from
1484            ``self.nodes_dict``.
1485
1486        Returns
1487        -------
1488        tuple
1489            A tuple ``(sampled_by_node, latents_by_node)``:
1490
1491            sampled_by_node : dict
1492                Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU).
1493            latents_by_node : dict
1494                Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU).
1495
1496        Raises
1497        ------
1498        ValueError
1499            If the experiment directory cannot be resolved or if no node list
1500            is available and ``nodes`` is None.
1501
1502        Notes
1503        -----
1504        Nodes without both ``sampled.pt`` and ``latents.pt`` files are skipped
1505        with a warning.
1506        """
1507        # --- resolve paths and node list ---
1508        if EXPERIMENT_DIR is None:
1509            try:
1510                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1511            except (AttributeError, KeyError):
1512                raise ValueError(
1513                    "[ERROR] Could not resolve EXPERIMENT_DIR from cfg.conf_dict['PATHS']. "
1514                    "Please provide EXPERIMENT_DIR explicitly."
1515                )
1516
1517        if nodes is None:
1518            if hasattr(self, "nodes_dict"):
1519                nodes = list(self.nodes_dict.keys())
1520            else:
1521                raise ValueError(
1522                    "[ERROR] No node list found. Please provide `nodes` or initialize model with a config."
1523                )
1524
1525        # --- load tensors ---
1526        sampled_by_node = {}
1527        latents_by_node = {}
1528
1529        for node in nodes:
1530            node_dir = os.path.join(EXPERIMENT_DIR, f"{node}")
1531            sampling_dir = os.path.join(node_dir, "sampling")
1532
1533            sampled_path = os.path.join(sampling_dir, "sampled.pt")
1534            latents_path = os.path.join(sampling_dir, "latents.pt")
1535
1536            if not os.path.exists(sampled_path) or not os.path.exists(latents_path):
1537                print(f"[WARNING] Missing files for node '{node}' — skipping.")
1538                continue
1539
1540            try:
1541                sampled = torch.load(sampled_path, map_location="cpu")
1542                latent_sample = torch.load(latents_path, map_location="cpu")
1543            except Exception as e:
1544                print(f"[ERROR] Could not load sampling files for node '{node}': {e}")
1545                continue
1546
1547            sampled_by_node[node] = sampled.detach().cpu()
1548            latents_by_node[node] = latent_sample.detach().cpu()
1549
1550        if self.verbose or self.debug:
1551            print(f"[INFO] Loaded sampled and latent tensors for {len(sampled_by_node)} nodes from {EXPERIMENT_DIR}")
1552
1553        return sampled_by_node, latents_by_node
1554
1555    def plot_samples_vs_true(
1556        self,
1557        df,
1558        sampled: dict = None,
1559        variable: list = None,
1560        bins: int = 100,
1561        hist_true_color: str = "blue",
1562        hist_est_color: str = "orange",
1563        figsize: tuple = (14, 5),
1564    ):
1565        
1566        
1567        """
1568        Compare sampled vs. observed distributions for selected nodes.
1569
1570        Parameters
1571        ----------
1572        df : pandas.DataFrame
1573            Data frame containing the observed node values.
1574        sampled : dict or None, optional
1575            Optional mapping ``{node_name: array-like or torch.Tensor}`` of sampled
1576            values. If None or if a node is missing, samples are loaded from
1577            ``EXPERIMENT_DIR/{node}/sampling/sampled.pt``.
1578        variable : list of str or None, optional
1579            Subset of nodes to plot. If None, all nodes in the configuration
1580            are considered.
1581        bins : int, optional
1582            Number of histogram bins for continuous variables. Default is 100.
1583        hist_true_color : str, optional
1584            Color name for the histogram of true values. Default is "blue".
1585        hist_est_color : str, optional
1586            Color name for the histogram of sampled values. Default is "orange".
1587        figsize : tuple, optional
1588            Figure size for the matplotlib plots. Default is (14, 5).
1589
1590        Returns
1591        -------
1592        None
1593
1594        Notes
1595        -----
1596        - Continuous outcomes: histogram overlay + QQ-plot.
1597        - Ordinal outcomes: side-by-side bar plot of relative frequencies.
1598        - Other categorical outcomes: side-by-side bar plot with category labels.
1599        - If samples are probabilistic (2D tensor), the argmax across classes is used.
1600        """
1601        
1602        target_nodes = self.cfg.conf_dict["nodes"]
1603        experiment_dir = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1604        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1605
1606        plot_list = variable if variable is not None else target_nodes
1607
1608        for node in plot_list:
1609            # Load sampled data
1610            if sampled is not None and node in sampled:
1611                sdata = sampled[node]
1612                if isinstance(sdata, torch.Tensor):
1613                    sampled_vals = sdata.detach().cpu().numpy()
1614                else:
1615                    sampled_vals = np.asarray(sdata)
1616            else:
1617                sample_path = os.path.join(experiment_dir, f"{node}/sampling/sampled.pt")
1618                if not os.path.isfile(sample_path):
1619                    print(f"[WARNING] skip {node}: {sample_path} not found.")
1620                    continue
1621
1622                try:
1623                    sampled_vals = torch.load(sample_path, map_location=device).cpu().numpy()
1624                except Exception as e:
1625                    print(f"[ERROR] Could not load {sample_path}: {e}")
1626                    continue
1627
1628            # If logits/probabilities per sample, take argmax
1629            if sampled_vals.ndim == 2:
1630                    print(f"[INFO] CAUTION! {node}: samples are probabilistic — each sample follows a probability "
1631                    f"distribution based on the valid latent range. "
1632                    f"Note that this frequency plot reflects only the distribution of the most probable "
1633                    f"class per sample.")
1634                    sampled_vals = np.argmax(sampled_vals, axis=1)
1635
1636            sampled_vals = sampled_vals[np.isfinite(sampled_vals)]
1637
1638            if node not in df.columns:
1639                print(f"[WARNING] skip {node}: column not found in DataFrame.")
1640                continue
1641
1642            true_vals = df[node].dropna().values
1643            true_vals = true_vals[np.isfinite(true_vals)]
1644
1645            if sampled_vals.size == 0 or true_vals.size == 0:
1646                print(f"[WARNING] skip {node}: empty array after NaN/Inf removal.")
1647                continue
1648
1649            fig, axs = plt.subplots(1, 2, figsize=figsize)
1650
1651            if is_outcome_modelled_continous(node, target_nodes):
1652                axs[0].hist(true_vals, bins=bins, density=True, alpha=0.6,
1653                            color=hist_true_color, label=f"True {node}")
1654                axs[0].hist(sampled_vals, bins=bins, density=True, alpha=0.6,
1655                            color=hist_est_color, label="Sampled")
1656                axs[0].set_xlabel("Value")
1657                axs[0].set_ylabel("Density")
1658                axs[0].set_title(f"Histogram overlay for {node}")
1659                axs[0].legend()
1660                axs[0].grid(True, ls="--", alpha=0.4)
1661
1662                qqplot_2samples(true_vals, sampled_vals, line="45", ax=axs[1])
1663                axs[1].set_xlabel("True quantiles")
1664                axs[1].set_ylabel("Sampled quantiles")
1665                axs[1].set_title(f"QQ plot for {node}")
1666                axs[1].grid(True, ls="--", alpha=0.4)
1667
1668            elif is_outcome_modelled_ordinal(node, target_nodes):
1669                unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals))
1670                unique_vals = np.sort(unique_vals)
1671                true_counts = np.array([(true_vals == val).sum() for val in unique_vals])
1672                sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals])
1673
1674                axs[0].bar(unique_vals - 0.2, true_counts / true_counts.sum(),
1675                        width=0.4, color=hist_true_color, alpha=0.7, label="True")
1676                axs[0].bar(unique_vals + 0.2, sampled_counts / sampled_counts.sum(),
1677                        width=0.4, color=hist_est_color, alpha=0.7, label="Sampled")
1678                axs[0].set_xticks(unique_vals)
1679                axs[0].set_xlabel("Ordinal Level")
1680                axs[0].set_ylabel("Relative Frequency")
1681                axs[0].set_title(f"Ordinal bar plot for {node}")
1682                axs[0].legend()
1683                axs[0].grid(True, ls="--", alpha=0.4)
1684                axs[1].axis("off")
1685
1686            else:
1687                unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals))
1688                unique_vals = sorted(unique_vals, key=str)
1689                true_counts = np.array([(true_vals == val).sum() for val in unique_vals])
1690                sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals])
1691
1692                axs[0].bar(np.arange(len(unique_vals)) - 0.2, true_counts / true_counts.sum(),
1693                        width=0.4, color=hist_true_color, alpha=0.7, label="True")
1694                axs[0].bar(np.arange(len(unique_vals)) + 0.2, sampled_counts / sampled_counts.sum(),
1695                        width=0.4, color=hist_est_color, alpha=0.7, label="Sampled")
1696                axs[0].set_xticks(np.arange(len(unique_vals)))
1697                axs[0].set_xticklabels(unique_vals, rotation=45)
1698                axs[0].set_xlabel("Category")
1699                axs[0].set_ylabel("Relative Frequency")
1700                axs[0].set_title(f"Categorical bar plot for {node}")
1701                axs[0].legend()
1702                axs[0].grid(True, ls="--", alpha=0.4)
1703                axs[1].axis("off")
1704
1705            plt.tight_layout()
1706            plt.show()
1707
1708    ## SUMMARY METHODS
1709    def nll(self,data,variables=None):
1710        """
1711        Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes.
1712
1713        This function evaluates trained TRAM models for each specified variable (node) 
1714        on the provided dataset. It performs forward passes only—no training, no weight 
1715        updates—and returns the mean NLL per node.
1716
1717        Parameters
1718        ----------
1719        data : object
1720            Input dataset or data source compatible with `_ensure_dataset`, containing 
1721            both inputs and targets for each node.
1722        variables : list[str], optional
1723            List of variable (node) names to evaluate. If None, all nodes in 
1724            `self.models` are evaluated.
1725
1726        Returns
1727        -------
1728        dict[str, float]
1729            Dictionary mapping each node name to its average NLL value.
1730
1731        Notes
1732        -----
1733        - Each model is evaluated independently on its respective DataLoader.
1734        - The normalization values (`min_max`) for each node are retrieved from 
1735          `self.minmax_dict[node]`.
1736        - The function uses `evaluate_tramdag_model()` for per-node evaluation.
1737        - Expected directory structure:
1738              `<EXPERIMENT_DIR>/<node>/`
1739          where each node directory contains the trained model.
1740        """
1741
1742        td_data = self._ensure_dataset(data, is_val=True)  
1743        variables_list = variables if variables != None else list(self.models.keys())
1744        nll_dict = {}
1745        for node in variables_list:  
1746                min_vals = torch.tensor(self.minmax_dict[node][0], dtype=torch.float32)
1747                max_vals = torch.tensor(self.minmax_dict[node][1], dtype=torch.float32)
1748                min_max = torch.stack([min_vals, max_vals], dim=0)
1749                data_loader = td_data.loaders[node]
1750                model = self.models[node]
1751                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1752                NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}")
1753                nll= evaluate_tramdag_model(node=node,
1754                                            target_nodes=self.nodes_dict,
1755                                            NODE_DIR=NODE_DIR,
1756                                            tram_model=model,
1757                                            data_loader=data_loader,
1758                                            min_max=min_max)
1759                nll_dict[node]=nll
1760        return nll_dict
1761    
1762    def get_train_val_nll(self, node: str, mode: str) -> tuple[float, float]:
1763        """
1764        Retrieve training and validation NLL for a node and a given model state.
1765
1766        Parameters
1767        ----------
1768        node : str
1769            Node name.
1770        mode : {"best", "last", "init"}
1771            State of interest:
1772            - "best": epoch with lowest validation NLL.
1773            - "last": final epoch.
1774            - "init": first epoch (index 0).
1775
1776        Returns
1777        -------
1778        tuple of (float or None, float or None)
1779            A tuple ``(train_nll, val_nll)`` for the requested mode.
1780            Returns ``(None, None)`` if loss files are missing or cannot be read.
1781
1782        Notes
1783        -----
1784        This method expects per-node JSON files:
1785
1786        - ``train_loss_hist.json``
1787        - ``val_loss_hist.json``
1788
1789        in the node directory.
1790        """
1791        NODE_DIR = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node)
1792        train_path = os.path.join(NODE_DIR, "train_loss_hist.json")
1793        val_path = os.path.join(NODE_DIR, "val_loss_hist.json")
1794
1795        if not os.path.exists(train_path) or not os.path.exists(val_path):
1796            if getattr(self, "debug", False):
1797                print(f"[DEBUG] Missing loss files for node '{node}'. Returning None.")
1798            return None, None
1799
1800        try:
1801            with open(train_path, "r") as f:
1802                train_hist = json.load(f)
1803            with open(val_path, "r") as f:
1804                val_hist = json.load(f)
1805
1806            train_nlls = np.array(train_hist)
1807            val_nlls = np.array(val_hist)
1808
1809            if mode == "init":
1810                idx = 0
1811            elif mode == "last":
1812                idx = len(val_nlls) - 1
1813            elif mode == "best":
1814                idx = int(np.argmin(val_nlls))
1815            else:
1816                raise ValueError(f"Invalid mode '{mode}' — must be one of 'best', 'last', 'init'.")
1817
1818            train_nll = float(train_nlls[idx])
1819            val_nll = float(val_nlls[idx])
1820            return train_nll, val_nll
1821
1822        except Exception as e:
1823            print(f"[ERROR] Failed to load NLLs for node '{node}' ({mode}): {e}")
1824            return None, None
1825
1826    def get_thetas(self, node: str, state: str = "best"):
1827        """
1828        Return transformed intercept (theta) parameters for a node and state.
1829
1830        Parameters
1831        ----------
1832        node : str
1833            Node name.
1834        state : {"best", "last", "init"}, optional
1835            Model state for which to return parameters. Default is "best".
1836
1837        Returns
1838        -------
1839        Any or None
1840            Transformed theta parameters for the requested node and state.
1841            The exact structure (scalar, list, or other) depends on the model.
1842
1843        Raises
1844        ------
1845        ValueError
1846            If an invalid state is given (not in {"best", "last", "init"}).
1847
1848        Notes
1849        -----
1850        Intercept dictionaries are cached on the instance under the attribute
1851        ``intercept_dicts``. If missing or incomplete, they are recomputed using
1852        `get_simple_intercepts_dict`.
1853        """
1854
1855        state = state.lower()
1856        if state not in ["best", "last", "init"]:
1857            raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].")
1858
1859        dict_attr = "intercept_dicts"
1860
1861        # If no cached intercepts exist, compute them
1862        if not hasattr(self, dict_attr):
1863            if getattr(self, "debug", False):
1864                print(f"[DEBUG] '{dict_attr}' not found, computing via get_simple_intercepts_dict().")
1865            setattr(self, dict_attr, self.get_simple_intercepts_dict())
1866
1867        all_dicts = getattr(self, dict_attr)
1868
1869        # If the requested state isn’t cached, recompute
1870        if state not in all_dicts:
1871            if getattr(self, "debug", False):
1872                print(f"[DEBUG] State '{state}' not found in cached intercepts, recomputing full dict.")
1873            setattr(self, dict_attr, self.get_simple_intercepts_dict())
1874            all_dicts = getattr(self, dict_attr)
1875
1876        state_dict = all_dicts.get(state, {})
1877
1878        # Return cached node intercept if present
1879        if node in state_dict:
1880            return state_dict[node]
1881
1882        # If not found, recompute full dict as fallback
1883        if getattr(self, "debug", False):
1884            print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.")
1885        setattr(self, dict_attr, self.get_simple_intercepts_dict())
1886        all_dicts = getattr(self, dict_attr)
1887        return all_dicts.get(state, {}).get(node, None)
1888        
1889    def get_linear_shifts(self, node: str, state: str = "best"):
1890        """
1891        Return learned linear shift terms for a node and a given state.
1892
1893        Parameters
1894        ----------
1895        node : str
1896            Node name.
1897        state : {"best", "last", "init"}, optional
1898            Model state for which to return linear shift terms. Default is "best".
1899
1900        Returns
1901        -------
1902        dict or Any or None
1903            Linear shift terms for the given node and state. Usually a dict
1904            mapping term names to weights.
1905
1906        Raises
1907        ------
1908        ValueError
1909            If an invalid state is given (not in {"best", "last", "init"}).
1910
1911        Notes
1912        -----
1913        Linear shift dictionaries are cached on the instance under the attribute
1914        ``linear_shift_dicts``. If missing or incomplete, they are recomputed using
1915        `get_linear_shifts_dict`.
1916        """
1917        state = state.lower()
1918        if state not in ["best", "last", "init"]:
1919            raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].")
1920
1921        dict_attr = "linear_shift_dicts"
1922
1923        # If no global dicts cached, compute once
1924        if not hasattr(self, dict_attr):
1925            if getattr(self, "debug", False):
1926                print(f"[DEBUG] '{dict_attr}' not found, computing via get_linear_shifts_dict().")
1927            setattr(self, dict_attr, self.get_linear_shifts_dict())
1928
1929        all_dicts = getattr(self, dict_attr)
1930
1931        # If the requested state isn't cached, compute all again (covers fresh runs)
1932        if state not in all_dicts:
1933            if getattr(self, "debug", False):
1934                print(f"[DEBUG] State '{state}' not found in cached linear shifts, recomputing full dict.")
1935            setattr(self, dict_attr, self.get_linear_shifts_dict())
1936            all_dicts = getattr(self, dict_attr)
1937
1938        # Now fetch the dictionary for this state
1939        state_dict = all_dicts.get(state, {})
1940
1941        # If the node is available, return its entry
1942        if node in state_dict:
1943            return state_dict[node]
1944
1945        # If missing, try recomputing (fallback)
1946        if getattr(self, "debug", False):
1947            print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.")
1948        setattr(self, dict_attr, self.get_linear_shifts_dict())
1949        all_dicts = getattr(self, dict_attr)
1950        return all_dicts.get(state, {}).get(node, None)
1951
1952    def get_linear_shifts_dict(self):
1953        """
1954        Compute linear shift term dictionaries for all nodes and states.
1955
1956        For each node and each available state ("best", "last", "init"), this
1957        method loads the corresponding model checkpoint, extracts linear shift
1958        weights from the TRAM model, and stores them in a nested dictionary.
1959
1960        Returns
1961        -------
1962        dict
1963            Nested dictionary of the form:
1964
1965            .. code-block:: python
1966
1967                {
1968                    "best": {node: {...}},
1969                    "last": {node: {...}},
1970                    "init": {node: {...}},
1971                }
1972
1973            where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``)
1974            to their weights.
1975
1976        Notes
1977        -----
1978        - If "best" or "last" checkpoints are unavailable for a node, only
1979        the "init" entry is populated.
1980        - Empty outer states (without any nodes) are removed from the result.
1981        """
1982
1983        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1984        nodes_list = list(self.models.keys())
1985        all_states = ["best", "last", "init"]
1986        all_linear_shift_dicts = {state: {} for state in all_states}
1987
1988        for node in nodes_list:
1989            NODE_DIR = os.path.join(EXPERIMENT_DIR, node)
1990            BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR)
1991            INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt")
1992
1993            state_paths = {
1994                "best": BEST_MODEL_PATH,
1995                "last": LAST_MODEL_PATH,
1996                "init": INIT_MODEL_PATH,
1997            }
1998
1999            for state, LOAD_PATH in state_paths.items():
2000                if not os.path.exists(LOAD_PATH):
2001                    if state != "init":
2002                        # skip best/last if unavailable
2003                        continue
2004                    else:
2005                        print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.")
2006                        if not os.path.exists(LOAD_PATH):
2007                            if getattr(self, "debug", False):
2008                                print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.")
2009                            continue
2010
2011                # Load parents and model
2012                _, terms_dict, _ = ordered_parents(node, self.nodes_dict)
2013                state_dict = torch.load(LOAD_PATH, map_location=self.device)
2014                tram_model = self.models[node]
2015                tram_model.load_state_dict(state_dict)
2016
2017                epoch_weights = {}
2018                if hasattr(tram_model, "nn_shift") and tram_model.nn_shift is not None:
2019                    for i, shift_layer in enumerate(tram_model.nn_shift):
2020                        module_name = shift_layer.__class__.__name__
2021                        if (
2022                            hasattr(shift_layer, "fc")
2023                            and hasattr(shift_layer.fc, "weight")
2024                            and module_name == "LinearShift"
2025                        ):
2026                            term_name = list(terms_dict.keys())[i]
2027                            epoch_weights[f"ls({term_name})"] = (
2028                                shift_layer.fc.weight.detach().cpu().squeeze().tolist()
2029                            )
2030                        elif getattr(self, "debug", False):
2031                            term_name = list(terms_dict.keys())[i]
2032                            print(f"[DEBUG] ls({term_name}): missing 'fc' or 'weight' in LinearShift.")
2033                else:
2034                    if getattr(self, "debug", False):
2035                        print(f"[DEBUG] Tram model for node '{node}' has no nn_shift or it is None.")
2036
2037                all_linear_shift_dicts[state][node] = epoch_weights
2038
2039        # Remove empty states (e.g., when best/last not found for all nodes)
2040        all_linear_shift_dicts = {k: v for k, v in all_linear_shift_dicts.items() if v}
2041
2042        return all_linear_shift_dicts
2043
2044    def get_simple_intercepts_dict(self):
2045        """
2046        Compute transformed simple intercept dictionaries for all nodes and states.
2047
2048        For each node and each available state ("best", "last", "init"), this
2049        method loads the corresponding model checkpoint, extracts simple intercept
2050        weights, transforms them into interpretable theta parameters, and stores
2051        them in a nested dictionary.
2052
2053        Returns
2054        -------
2055        dict
2056            Nested dictionary of the form:
2057
2058            .. code-block:: python
2059
2060                {
2061                    "best": {node: [[theta_1], [theta_2], ...]},
2062                    "last": {node: [[theta_1], [theta_2], ...]},
2063                    "init": {node: [[theta_1], [theta_2], ...]},
2064                }
2065
2066        Notes
2067        -----
2068        - For ordinal models (``self.is_ontram == True``), `transform_intercepts_ordinal`
2069        is used.
2070        - For continuous models, `transform_intercepts_continous` is used.
2071        - Empty outer states (without any nodes) are removed from the result.
2072        """
2073
2074        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
2075        nodes_list = list(self.models.keys())
2076        all_states = ["best", "last", "init"]
2077        all_si_intercept_dicts = {state: {} for state in all_states}
2078
2079        debug = getattr(self, "debug", False)
2080        verbose = getattr(self, "verbose", False)
2081        is_ontram = getattr(self, "is_ontram", False)
2082
2083        for node in nodes_list:
2084            NODE_DIR = os.path.join(EXPERIMENT_DIR, node)
2085            BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR)
2086            INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt")
2087
2088            state_paths = {
2089                "best": BEST_MODEL_PATH,
2090                "last": LAST_MODEL_PATH,
2091                "init": INIT_MODEL_PATH,
2092            }
2093
2094            for state, LOAD_PATH in state_paths.items():
2095                if not os.path.exists(LOAD_PATH):
2096                    if state != "init":
2097                        continue
2098                    else:
2099                        print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.")
2100                        if not os.path.exists(LOAD_PATH):
2101                            if debug:
2102                                print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.")
2103                            continue
2104
2105                # Load model state
2106                state_dict = torch.load(LOAD_PATH, map_location=self.device)
2107                tram_model = self.models[node]
2108                tram_model.load_state_dict(state_dict)
2109
2110                # Extract and transform simple intercept weights
2111                si_weights = None
2112                if hasattr(tram_model, "nn_int") and tram_model.nn_int is not None and isinstance(tram_model.nn_int, SimpleIntercept):
2113                    if hasattr(tram_model.nn_int, "fc") and hasattr(tram_model.nn_int.fc, "weight"):
2114                        weights = tram_model.nn_int.fc.weight.detach().cpu().tolist()
2115                        weights_tensor = torch.Tensor(weights)
2116
2117                        if debug:
2118                            print(f"[DEBUG] Node '{node}' ({state}) theta tilde shape: {weights_tensor.shape}")
2119
2120                        if is_ontram:
2121                            si_weights = transform_intercepts_ordinal(weights_tensor.reshape(1, -1))[:, 1:-1].reshape(-1, 1)
2122                        else:
2123                            si_weights = transform_intercepts_continous(weights_tensor.reshape(1, -1)).reshape(-1, 1)
2124
2125                        si_weights = si_weights.tolist()
2126
2127                        if debug:
2128                            print(f"[DEBUG] Node '{node}' ({state}) theta transformed: {si_weights}")
2129                    else:
2130                        if debug:
2131                            print(f"[DEBUG] Node '{node}' ({state}): missing 'fc' or 'weight' in SimpleIntercept.")
2132                else:
2133                    if debug:
2134                        print(f"[DEBUG] Tram model for node '{node}' has no nn_int or it is None.")
2135
2136                all_si_intercept_dicts[state][node] = si_weights
2137
2138        # Clean up empty states
2139        all_si_intercept_dicts = {k: v for k, v in all_si_intercept_dicts.items() if v}
2140        return all_si_intercept_dicts
2141       
2142    def summary(self, verbose=False):
2143        """
2144        Print a multi-part textual summary of the TramDagModel.
2145
2146        The summary includes:
2147        1. Training metrics overview per node (best/last NLL, epochs).
2148        2. Node-specific details (thetas, linear shifts, optional architecture).
2149        3. Basic information about the attached training DataFrame, if present.
2150
2151        Parameters
2152        ----------
2153        verbose : bool, optional
2154            If True, include extended per-node details such as the model
2155            architecture, parameter count, and availability of checkpoints
2156            and sampling results. Default is False.
2157
2158        Returns
2159        -------
2160        None
2161
2162        Notes
2163        -----
2164        This method prints to stdout and does not return structured data.
2165        It is intended for quick, human-readable inspection of the current
2166        training and model state.
2167        """
2168
2169        # ---------- SETUP ----------
2170        try:
2171            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
2172        except KeyError:
2173            EXPERIMENT_DIR = None
2174            print("[WARNING] Missing EXPERIMENT_DIR in cfg.conf_dict['PATHS'].")
2175
2176        print("\n" + "=" * 120)
2177        print(f"{'TRAM DAG MODEL SUMMARY':^120}")
2178        print("=" * 120)
2179
2180        # ---------- METRICS OVERVIEW ----------
2181        summary_data = []
2182        for node in self.models.keys():
2183            node_dir = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node)
2184            train_path = os.path.join(node_dir, "train_loss_hist.json")
2185            val_path = os.path.join(node_dir, "val_loss_hist.json")
2186
2187            if os.path.exists(train_path) and os.path.exists(val_path):
2188                best_train_nll, best_val_nll = self.get_train_val_nll(node, "best")
2189                last_train_nll, last_val_nll = self.get_train_val_nll(node, "last")
2190                n_epochs_total = len(json.load(open(train_path)))
2191            else:
2192                best_train_nll = best_val_nll = last_train_nll = last_val_nll = None
2193                n_epochs_total = 0
2194
2195            summary_data.append({
2196                "Node": node,
2197                "Best Train NLL": best_train_nll,
2198                "Best Val NLL": best_val_nll,
2199                "Last Train NLL": last_train_nll,
2200                "Last Val NLL": last_val_nll,
2201                "Epochs": n_epochs_total,
2202            })
2203
2204        df_summary = pd.DataFrame(summary_data)
2205        df_summary = df_summary.round(4)
2206
2207        print("\n[1] TRAINING METRICS OVERVIEW")
2208        print("-" * 120)
2209        if not df_summary.empty:
2210            print(
2211                df_summary.to_string(
2212                    index=False,
2213                    justify="center",
2214                    col_space=14,
2215                    float_format=lambda x: f"{x:7.4f}",
2216                )
2217            )
2218        else:
2219            print("No training history found for any node.")
2220        print("-" * 120)
2221
2222        # ---------- NODE DETAILS ----------
2223        print("\n[2] NODE-SPECIFIC DETAILS")
2224        print("-" * 120)
2225        for node in self.models.keys():
2226            print(f"\n{f'NODE: {node}':^120}")
2227            print("-" * 120)
2228
2229            # THETAS & SHIFTS
2230            for state in ["init", "last", "best"]:
2231                print(f"\n  [{state.upper()} STATE]")
2232
2233                # ---- Thetas ----
2234                try:
2235                    thetas = getattr(self, "get_thetas", lambda n, s=None: None)(node, state)
2236                    if thetas is not None:
2237                        if isinstance(thetas, (list, np.ndarray, pd.Series)):
2238                            thetas_flat = np.array(thetas).flatten()
2239                            compact = np.round(thetas_flat, 4)
2240                            arr_str = np.array2string(
2241                                compact,
2242                                max_line_width=110,
2243                                threshold=np.inf,
2244                                separator=", "
2245                            )
2246                            lines = arr_str.split("\n")
2247                            if len(lines) > 2:
2248                                arr_str = "\n".join(lines[:2]) + " ..."
2249                            print(f"    Θ ({len(thetas_flat)}): {arr_str}")
2250                        elif isinstance(thetas, dict):
2251                            for k, v in thetas.items():
2252                                print(f"     Θ[{k}]: {v}")
2253                        else:
2254                            print(f"    Θ: {thetas}")
2255                    else:
2256                        print("    Θ: not available")
2257                except Exception as e:
2258                    print(f"    [Error loading thetas] {e}")
2259
2260                # ---- Linear Shifts ----
2261                try:
2262                    linear_shifts = getattr(self, "get_linear_shifts", lambda n, s=None: None)(node, state)
2263                    if linear_shifts is not None:
2264                        if isinstance(linear_shifts, dict):
2265                            for k, v in linear_shifts.items():
2266                                print(f"     {k}: {np.round(v, 4)}")
2267                        elif isinstance(linear_shifts, (list, np.ndarray, pd.Series)):
2268                            arr = np.round(linear_shifts, 4)
2269                            print(f"    Linear shifts ({len(arr)}): {arr}")
2270                        else:
2271                            print(f"    Linear shifts: {linear_shifts}")
2272                    else:
2273                        print("    Linear shifts: not available")
2274                except Exception as e:
2275                    print(f"    [Error loading linear shifts] {e}")
2276
2277            # ---- Verbose info directly below node ----
2278            if verbose:
2279                print("\n  [DETAILS]")
2280                node_dir = os.path.join(EXPERIMENT_DIR, node) if EXPERIMENT_DIR else None
2281                model = self.models[node]
2282
2283                print(f"    Model Architecture:")
2284                arch_str = str(model).split("\n")
2285                for line in arch_str:
2286                    print(f"      {line}")
2287                print(f"    Parameter count: {sum(p.numel() for p in model.parameters()):,}")
2288
2289                if node_dir and os.path.exists(node_dir):
2290                    ckpt_exists = any(f.endswith(('.pt', '.pth')) for f in os.listdir(node_dir))
2291                    print(f"    Checkpoints found: {ckpt_exists}")
2292
2293                    sampling_dir = os.path.join(node_dir, "sampling")
2294                    sampling_exists = os.path.isdir(sampling_dir) and len(os.listdir(sampling_dir)) > 0
2295                    print(f"    Sampling results found: {sampling_exists}")
2296
2297                    for label, filename in [("Train", "train_loss_hist.json"), ("Validation", "val_loss_hist.json")]:
2298                        path = os.path.join(node_dir, filename)
2299                        if os.path.exists(path):
2300                            try:
2301                                with open(path, "r") as f:
2302                                    hist = json.load(f)
2303                                print(f"    {label} history: {len(hist)} epochs")
2304                            except Exception as e:
2305                                print(f"    {label} history: failed to load ({e})")
2306                        else:
2307                            print(f"    {label} history: not found")
2308                else:
2309                    print("    [INFO] No experiment directory defined or missing for this node.")
2310            print("-" * 120)
2311
2312        # ---------- TRAINING DATAFRAME ----------
2313        print("\n[3] TRAINING DATAFRAME")
2314        print("-" * 120)
2315        try:
2316            self.train_df.info()
2317        except AttributeError:
2318            print("No training DataFrame attached to this TramDagModel.")
2319        print("=" * 120 + "\n")

Probabilistic DAG model built from node-wise TRAMs (transformation models).

This class manages:

  • Configuration and per-node model construction.
  • Data scaling (min–max).
  • Training (sequential or per-node parallel on CPU).
  • Diagnostics (loss history, intercepts, linear shifts, latents).
  • Sampling from the joint DAG and loading stored samples.
  • High-level summaries and plotting utilities.
TramDagModel()
111    def __init__(self):
112        """
113        Initialize an empty TramDagModel shell.
114
115        Notes
116        -----
117        This constructor does not build any node models and does not attach a
118        configuration. Use `TramDagModel.from_config` or `TramDagModel.from_directory`
119        to obtain a fully configured and ready-to-use instance.
120        """
121        
122        self.debug = False
123        self.verbose = False
124        self.device = 'auto'
125        pass

Initialize an empty TramDagModel shell.

Notes

This constructor does not build any node models and does not attach a configuration. Use TramDagModel.from_config or TramDagModel.from_directory to obtain a fully configured and ready-to-use instance.

DEFAULTS_CONFIG = {'set_initial_weights': False, 'debug': False, 'verbose': False, 'device': 'auto', 'initial_data': None, 'overwrite_initial_weights': True}
DEFAULTS_FIT = {'epochs': 100, 'train_list': None, 'callbacks': None, 'learning_rate': 0.01, 'device': 'auto', 'optimizers': None, 'schedulers': None, 'use_scheduler': False, 'save_linear_shifts': True, 'save_simple_intercepts': True, 'debug': False, 'verbose': True, 'train_mode': 'sequential', 'return_history': False, 'overwrite_inital_weights': True, 'num_workers': 4, 'persistent_workers': True, 'prefetch_factor': 4, 'batch_size': 1000}
debug
verbose
device
@staticmethod
def get_device(settings):
127    @staticmethod
128    def get_device(settings):
129        """
130        Resolve the target device string from a settings dictionary.
131
132        Parameters
133        ----------
134        settings : dict
135            Dictionary containing at least a key ``"device"`` with one of
136            {"auto", "cpu", "cuda"}. If missing, "auto" is assumed.
137
138        Returns
139        -------
140        str
141            Device string, either "cpu" or "cuda".
142
143        Notes
144        -----
145        If ``device == "auto"``, CUDA is selected if available, otherwise CPU.
146        """
147        device_arg = settings.get("device", "auto")
148        if device_arg == "auto":
149            device_str = "cuda" if torch.cuda.is_available() else "cpu"
150        else:
151            device_str = device_arg
152        return device_str

Resolve the target device string from a settings dictionary.

Parameters

settings : dict Dictionary containing at least a key "device" with one of {"auto", "cpu", "cuda"}. If missing, "auto" is assumed.

Returns

str Device string, either "cpu" or "cuda".

Notes

If device == "auto", CUDA is selected if available, otherwise CPU.

@classmethod
def from_config(cls, cfg, **kwargs):
185    @classmethod
186    def from_config(cls, cfg, **kwargs):
187        """
188        Construct a TramDagModel from a TramDagConfig object.
189
190        This builds one TRAM model per node in the DAG and optionally writes
191        the initial model parameters to disk.
192
193        Parameters
194        ----------
195        cfg : TramDagConfig
196            Configuration wrapper holding the underlying configuration dictionary,
197            including at least:
198            - ``conf_dict["nodes"]``: mapping of node names to node configs.
199            - ``conf_dict["PATHS"]["EXPERIMENT_DIR"]``: experiment directory.
200        **kwargs
201            Node-level construction options. Each key must be present in
202            ``DEFAULTS_CONFIG``. Values can be:
203            - scalar: applied to all nodes.
204            - dict: mapping ``{node_name: value}`` for per-node overrides.
205
206            Common keys include:
207            device : {"auto", "cpu", "cuda"}, default "auto"
208                Device selection (CUDA if available when "auto").
209            debug : bool, default False
210                If True, print debug messages.
211            verbose : bool, default False
212                If True, print informational messages.
213            set_initial_weights : bool
214                Passed to underlying TRAM model constructors.
215            overwrite_initial_weights : bool, default True
216                If True, overwrite any existing ``initial_model.pt`` files per node.
217            initial_data : Any
218                Optional object passed down to node constructors.
219
220        Returns
221        -------
222        TramDagModel
223            Fully initialized instance with:
224            - ``cfg``
225            - ``nodes_dict``
226            - ``models`` (per-node TRAMs)
227            - ``settings`` (resolved per-node config)
228
229        Raises
230        ------
231        ValueError
232            If any dict-valued kwarg does not provide values for exactly the set
233            of nodes in ``cfg.conf_dict["nodes"]``.
234        """
235        
236        self = cls()
237        self.cfg = cfg
238        self.cfg.update()  # ensure latest version from disk
239        self.cfg._verify_completeness()
240        
241        
242        try:
243            self.cfg.save()  # persist back to disk
244            if getattr(self, "debug", False):
245                print("[DEBUG] Configuration updated and saved.")
246        except Exception as e:
247            print(f"[WARNING] Could not save configuration after update: {e}")        
248            
249        self.nodes_dict = self.cfg.conf_dict["nodes"] 
250
251        self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_CONFIG', context="from_config")
252
253        # update defaults with kwargs
254        settings = dict(cls.DEFAULTS_CONFIG)
255        settings.update(kwargs)
256
257        # resolve device
258        device_arg = settings.get("device", "auto")
259        if device_arg == "auto":
260            device_str = "cuda" if torch.cuda.is_available() else "cpu"
261        else:
262            device_str = device_arg
263        self.device = torch.device(device_str)
264
265        # set flags on the instance so they are accessible later
266        self.debug = settings.get("debug", False)
267        self.verbose = settings.get("verbose", False)
268
269        if  self.debug:
270            print(f"[DEBUG] TramDagModel using device: {self.device}")
271            
272        # initialize settings storage
273        self.settings = {k: {} for k in settings.keys()}
274
275        # validate dict-typed args
276        for k, v in settings.items():
277            if isinstance(v, dict):
278                expected = set(self.nodes_dict.keys())
279                given = set(v.keys())
280                if expected != given:
281                    raise ValueError(
282                        f"[ERROR] the provided argument '{k}' keys are not same as in cfg.conf_dict['nodes'].keys().\n"
283                        f"Expected: {expected}, but got: {given}\n"
284                        f"Please provide values for all variables.")
285
286        # build one model per node
287        self.models = {}
288        for node in self.nodes_dict.keys():
289            per_node_kwargs = {}
290            for k, v in settings.items():
291                resolved = v[node] if isinstance(v, dict) else v
292                per_node_kwargs[k] = resolved
293                self.settings[k][node] = resolved
294            if self.debug:
295                print(f"\n[INFO] Building model for node '{node}' with settings: {per_node_kwargs}")
296            self.models[node] = get_fully_specified_tram_model(
297                node=node,
298                configuration_dict=self.cfg.conf_dict,
299                **per_node_kwargs)
300            
301            try:
302                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
303                NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}")
304                os.makedirs(NODE_DIR, exist_ok=True)
305
306                model_path = os.path.join(NODE_DIR, "initial_model.pt")
307                overwrite = settings.get("overwrite_initial_weights", True)
308
309                if overwrite or not os.path.exists(model_path):
310                    torch.save(self.models[node].state_dict(), model_path)
311                    if self.debug:
312                        print(f"[DEBUG] Saved initial model state for node '{node}' to {model_path} (overwrite={overwrite})")
313                else:
314                    if self.debug:
315                        print(f"[DEBUG] Skipped saving initial model for node '{node}' (already exists at {model_path})")
316            except Exception as e:
317                print(f"[ERROR] Could not save initial model state for node '{node}': {e}")
318            
319            TEMP_DIR = "temp"
320            if os.path.isdir(TEMP_DIR) and not os.listdir(TEMP_DIR):
321                os.rmdir(TEMP_DIR)
322                            
323        return self

Construct a TramDagModel from a TramDagConfig object.

This builds one TRAM model per node in the DAG and optionally writes the initial model parameters to disk.

Parameters

cfg : TramDagConfig Configuration wrapper holding the underlying configuration dictionary, including at least: - conf_dict["nodes"]: mapping of node names to node configs. - conf_dict["PATHS"]["EXPERIMENT_DIR"]: experiment directory. **kwargs Node-level construction options. Each key must be present in DEFAULTS_CONFIG. Values can be: - scalar: applied to all nodes. - dict: mapping {node_name: value} for per-node overrides.

Common keys include:
device : {"auto", "cpu", "cuda"}, default "auto"
    Device selection (CUDA if available when "auto").
debug : bool, default False
    If True, print debug messages.
verbose : bool, default False
    If True, print informational messages.
set_initial_weights : bool
    Passed to underlying TRAM model constructors.
overwrite_initial_weights : bool, default True
    If True, overwrite any existing ``initial_model.pt`` files per node.
initial_data : Any
    Optional object passed down to node constructors.

Returns

TramDagModel Fully initialized instance with: - cfg - nodes_dict - models (per-node TRAMs) - settings (resolved per-node config)

Raises

ValueError If any dict-valued kwarg does not provide values for exactly the set of nodes in cfg.conf_dict["nodes"].

@classmethod
def from_directory( cls, EXPERIMENT_DIR: str, device: str = 'auto', debug: bool = False, verbose: bool = False):
325    @classmethod
326    def from_directory(cls, EXPERIMENT_DIR: str, device: str = "auto", debug: bool = False, verbose: bool = False):
327        """
328        Reconstruct a TramDagModel from an experiment directory on disk.
329
330        This method:
331        1. Loads the configuration JSON.
332        2. Wraps it in a TramDagConfig.
333        3. Builds all node models via `from_config`.
334        4. Loads the min–max scaling dictionary.
335
336        Parameters
337        ----------
338        EXPERIMENT_DIR : str
339            Path to an experiment directory containing:
340            - ``configuration.json``
341            - ``min_max_scaling.json``.
342        device : {"auto", "cpu", "cuda"}, optional
343            Device selection. Default is "auto".
344        debug : bool, optional
345            If True, enable debug messages. Default is False.
346        verbose : bool, optional
347            If True, enable informational messages. Default is False.
348
349        Returns
350        -------
351        TramDagModel
352            A TramDagModel instance with models, config, and scaling loaded.
353
354        Raises
355        ------
356        FileNotFoundError
357            If configuration or min–max files cannot be found.
358        RuntimeError
359            If the min–max file cannot be read or parsed.
360        """
361
362        # --- load config file ---
363        config_path = os.path.join(EXPERIMENT_DIR, "configuration.json")
364        if not os.path.exists(config_path):
365            raise FileNotFoundError(f"[ERROR] Config file not found at {config_path}")
366
367        with open(config_path, "r") as f:
368            cfg_dict = json.load(f)
369
370        # Create TramConfig wrapper 
371        cfg = TramDagConfig(cfg_dict, CONF_DICT_PATH=config_path)
372
373        # --- build model from config ---
374        self = cls.from_config(cfg, device=device, debug=debug, verbose=verbose, overwrite_initial_weights=False)
375
376        # --- load minmax scaling ---
377        minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json")
378        if not os.path.exists(minmax_path):
379            raise FileNotFoundError(f"[ERROR] MinMax file not found at {minmax_path}")
380
381        with open(minmax_path, "r") as f:
382            self.minmax_dict = json.load(f)
383
384        if self.verbose or self.debug:
385            print(f"[INFO] Loaded TramDagModel from {EXPERIMENT_DIR}")
386            print(f"[INFO] Config loaded from {config_path}")
387            print(f"[INFO] MinMax scaling loaded from {minmax_path}")
388
389        return self

Reconstruct a TramDagModel from an experiment directory on disk.

This method:

  1. Loads the configuration JSON.
  2. Wraps it in a TramDagConfig.
  3. Builds all node models via from_config.
  4. Loads the min–max scaling dictionary.

Parameters

EXPERIMENT_DIR : str Path to an experiment directory containing: - configuration.json - min_max_scaling.json. device : {"auto", "cpu", "cuda"}, optional Device selection. Default is "auto". debug : bool, optional If True, enable debug messages. Default is False. verbose : bool, optional If True, enable informational messages. Default is False.

Returns

TramDagModel A TramDagModel instance with models, config, and scaling loaded.

Raises

FileNotFoundError If configuration or min–max files cannot be found. RuntimeError If the min–max file cannot be read or parsed.

def load_or_compute_minmax(self, td_train_data=None, use_existing=False, write=True):
429    def load_or_compute_minmax(self, td_train_data=None,use_existing=False, write=True):
430        """
431        Load an existing Min–Max scaling dictionary from disk or compute a new one 
432        from the provided training dataset.
433
434        Parameters
435        ----------
436        use_existing : bool, optional (default=False)
437            If True, attempts to load an existing `min_max_scaling.json` file 
438            from the experiment directory. Raises an error if the file is missing 
439            or unreadable.
440
441        write : bool, optional (default=True)
442            If True, writes the computed Min–Max scaling dictionary to 
443            `<EXPERIMENT_DIR>/min_max_scaling.json`.
444
445        td_train_data : object, optional
446            Training dataset used to compute scaling statistics. If not provided,
447            the method will ensure or construct it via `_ensure_dataset(data=..., is_val=False)`.
448
449        Behavior
450        --------
451        - If `use_existing=True`, loads the JSON file containing previously saved 
452          min–max values and stores it in `self.minmax_dict`.
453        - If `use_existing=False`, computes a new scaling dictionary using 
454          `td_train_data.compute_scaling()` and stores the result in 
455          `self.minmax_dict`.
456        - Optionally writes the computed dictionary to disk.
457
458        Side Effects
459        -------------
460        - Populates `self.minmax_dict` with scaling values.
461        - Writes or loads the file `min_max_scaling.json` under 
462          `<EXPERIMENT_DIR>`.
463        - Prints diagnostic output if `self.debug` or `self.verbose` is True.
464
465        Raises
466        ------
467        FileNotFoundError
468            If `use_existing=True` but the min–max file does not exist.
469
470        RuntimeError
471            If an existing min–max file cannot be read or parsed.
472
473        Notes
474        -----
475        The computed min–max dictionary is expected to contain scaling statistics 
476        per feature, typically in the form:
477            {
478                "node": {"min": float, "max": float},
479                ...
480            }
481        """
482        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
483        minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json")
484
485        # laod exisitng if possible
486        if use_existing:
487            if not os.path.exists(minmax_path):
488                raise FileNotFoundError(f"MinMax file not found: {minmax_path}")
489            try:
490                with open(minmax_path, 'r') as f:
491                    self.minmax_dict = json.load(f)
492                if self.debug or self.verbose:
493                    print(f"[INFO] Loaded existing minmax dict from {minmax_path}")
494                return
495            except Exception as e:
496                raise RuntimeError(f"Could not load existing minmax dict: {e}")
497
498        # 
499        if self.debug or self.verbose:
500            print("[INFO] Computing new minmax dict from training data...")
501            
502        td_train_data=self._ensure_dataset( data=td_train_data, is_val=False)    
503            
504        self.minmax_dict = td_train_data.compute_scaling()
505
506        if write:
507            os.makedirs(EXPERIMENT_DIR, exist_ok=True)
508            with open(minmax_path, 'w') as f:
509                json.dump(self.minmax_dict, f, indent=4)
510            if self.debug or self.verbose:
511                print(f"[INFO] Saved new minmax dict to {minmax_path}")

Load an existing Min–Max scaling dictionary from disk or compute a new one from the provided training dataset.

Parameters

use_existing : bool, optional (default=False) If True, attempts to load an existing min_max_scaling.json file from the experiment directory. Raises an error if the file is missing or unreadable.

write : bool, optional (default=True) If True, writes the computed Min–Max scaling dictionary to <EXPERIMENT_DIR>/min_max_scaling.json.

td_train_data : object, optional Training dataset used to compute scaling statistics. If not provided, the method will ensure or construct it via _ensure_dataset(data=..., is_val=False).

Behavior

  • If use_existing=True, loads the JSON file containing previously saved min–max values and stores it in self.minmax_dict.
  • If use_existing=False, computes a new scaling dictionary using td_train_data.compute_scaling() and stores the result in self.minmax_dict.
  • Optionally writes the computed dictionary to disk.

Side Effects

  • Populates self.minmax_dict with scaling values.
  • Writes or loads the file min_max_scaling.json under <EXPERIMENT_DIR>.
  • Prints diagnostic output if self.debug or self.verbose is True.

Raises

FileNotFoundError If use_existing=True but the min–max file does not exist.

RuntimeError If an existing min–max file cannot be read or parsed.

Notes

The computed min–max dictionary is expected to contain scaling statistics per feature, typically in the form: { "node": {"min": float, "max": float}, ... }

def fit(self, train_data, val_data=None, **kwargs):
612    def fit(self, train_data, val_data=None, **kwargs):
613        """
614        Train TRAM models for all nodes in the DAG.
615
616        Coordinates dataset preparation, min–max scaling, and per-node training,
617        optionally in parallel on CPU.
618
619        Parameters
620        ----------
621        train_data : pandas.DataFrame or TramDagDataset
622            Training data. If a DataFrame is given, it is converted into a
623            TramDagDataset using `_ensure_dataset`.
624        val_data : pandas.DataFrame or TramDagDataset or None, optional
625            Validation data. If a DataFrame is given, it is converted into a
626            TramDagDataset. If None, no validation loss is computed.
627        **kwargs
628            Overrides for ``DEFAULTS_FIT``. All keys must exist in
629            ``DEFAULTS_FIT``. Common options:
630
631            epochs : int, default 100
632                Number of training epochs per node.
633            learning_rate : float, default 0.01
634                Learning rate for the default Adam optimizer.
635            train_list : list of str or None, optional
636                List of node names to train. If None, all nodes are trained.
637            train_mode : {"sequential", "parallel"}, default "sequential"
638                Training mode. "parallel" uses joblib-based CPU multiprocessing.
639                GPU forces sequential mode.
640            device : {"auto", "cpu", "cuda"}, default "auto"
641                Device selection.
642            optimizers : dict or None
643                Optional mapping ``{node_name: optimizer}``. If provided for a
644                node, that optimizer is used instead of creating a new Adam.
645            schedulers : dict or None
646                Optional mapping ``{node_name: scheduler}``.
647            use_scheduler : bool
648                If True, enable scheduler usage in the training loop.
649            num_workers : int
650                DataLoader workers in sequential mode (ignored in parallel).
651            persistent_workers : bool
652                DataLoader persistence in sequential mode (ignored in parallel).
653            prefetch_factor : int
654                DataLoader prefetch factor (ignored in parallel).
655            batch_size : int
656                Batch size for all node DataLoaders.
657            debug : bool
658                Enable debug output.
659            verbose : bool
660                Enable informational logging.
661            return_history : bool
662                If True, return a history dict.
663
664        Returns
665        -------
666        dict or None
667            If ``return_history=True``, a dictionary mapping each node name
668            to its training history. Otherwise, returns None.
669
670        Raises
671        ------
672        ValueError
673            If ``train_mode`` is not "sequential" or "parallel".
674        """
675        self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_FIT', context="fit")
676        
677        # --- merge defaults ---
678        settings = dict(self.DEFAULTS_FIT)
679        settings.update(kwargs)
680        
681        
682        self.debug = settings.get("debug", False)
683        self.verbose = settings.get("verbose", False)
684
685        # --- resolve device ---
686        device_str=self.get_device(settings)
687        self.device = torch.device(device_str)
688
689        # --- training mode ---
690        train_mode = settings.get("train_mode", "sequential").lower()
691        if train_mode not in ("sequential", "parallel"):
692            raise ValueError("train_mode must be 'sequential' or 'parallel'")
693
694        # --- DataLoader safety logic ---
695        if train_mode == "parallel":
696            # if user passed loader paralleling params, warn and override
697            for flag in ("num_workers", "persistent_workers", "prefetch_factor"):
698                if flag in kwargs:
699                    print(f"[WARNING] '{flag}' is ignored in parallel mode "
700                        f"(disabled to prevent nested multiprocessing).")
701            # disable unsafe loader multiprocessing options
702            settings["num_workers"] = 0
703            settings["persistent_workers"] = False
704            settings["prefetch_factor"] = None
705        else:
706            # sequential mode → respect user DataLoader settings
707            if self.debug:
708                print("[DEBUG] Sequential mode: using DataLoader kwargs as provided.")
709
710        # --- which nodes to train ---
711        train_list = settings.get("train_list") or list(self.models.keys())
712
713
714        # --- dataset prep (receives adjusted settings) ---
715        td_train_data = self._ensure_dataset(train_data, is_val=False, **settings)
716        td_val_data = self._ensure_dataset(val_data, is_val=True, **settings)
717
718        # --- normalization ---
719        self.load_or_compute_minmax(use_existing=False, write=True, td_train_data=td_train_data)
720
721        # --- print header ---
722        if self.verbose or self.debug:
723            print(f"[INFO] Training {len(train_list)} nodes ({train_mode}) on {device_str}")
724
725        # ======================================================================
726        # Sequential mode  safe for GPU or debugging)
727        # ======================================================================
728        if train_mode == "sequential" or "cuda" in device_str:
729            if "cuda" in device_str and train_mode == "parallel":
730                print("[WARNING] GPU device detected — forcing sequential mode.")
731            results = {}
732            for node in train_list:
733                node, history = self._fit_single_node(
734                    node, self, settings, td_train_data, td_val_data, device_str
735                )
736                results[node] = history
737        
738
739        # ======================================================================
740        # parallel mode (CPU only)
741        # ======================================================================
742        if train_mode == "parallel":
743
744            n_jobs = min(len(train_list), os.cpu_count() // 2 or 1)
745            if self.verbose or self.debug:
746                print(f"[INFO] Using {n_jobs} CPU workers for parallel node training")
747            parallel_outputs = Parallel(
748                n_jobs=n_jobs,
749                backend="loky",#loky, multiprocessing
750                verbose=10,
751                prefer="processes"
752            )(delayed(self._fit_single_node)(node, self, settings, td_train_data, td_val_data, device_str) for node in train_list )
753
754            results = {node: hist for node, hist in parallel_outputs}
755        
756        if settings.get("return_history", False):
757            return results

Train TRAM models for all nodes in the DAG.

Coordinates dataset preparation, min–max scaling, and per-node training, optionally in parallel on CPU.

Parameters

train_data : pandas.DataFrame or TramDagDataset Training data. If a DataFrame is given, it is converted into a TramDagDataset using _ensure_dataset. val_data : pandas.DataFrame or TramDagDataset or None, optional Validation data. If a DataFrame is given, it is converted into a TramDagDataset. If None, no validation loss is computed. **kwargs Overrides for DEFAULTS_FIT. All keys must exist in DEFAULTS_FIT. Common options:

epochs : int, default 100
    Number of training epochs per node.
learning_rate : float, default 0.01
    Learning rate for the default Adam optimizer.
train_list : list of str or None, optional
    List of node names to train. If None, all nodes are trained.
train_mode : {"sequential", "parallel"}, default "sequential"
    Training mode. "parallel" uses joblib-based CPU multiprocessing.
    GPU forces sequential mode.
device : {"auto", "cpu", "cuda"}, default "auto"
    Device selection.
optimizers : dict or None
    Optional mapping ``{node_name: optimizer}``. If provided for a
    node, that optimizer is used instead of creating a new Adam.
schedulers : dict or None
    Optional mapping ``{node_name: scheduler}``.
use_scheduler : bool
    If True, enable scheduler usage in the training loop.
num_workers : int
    DataLoader workers in sequential mode (ignored in parallel).
persistent_workers : bool
    DataLoader persistence in sequential mode (ignored in parallel).
prefetch_factor : int
    DataLoader prefetch factor (ignored in parallel).
batch_size : int
    Batch size for all node DataLoaders.
debug : bool
    Enable debug output.
verbose : bool
    Enable informational logging.
return_history : bool
    If True, return a history dict.

Returns

dict or None If return_history=True, a dictionary mapping each node name to its training history. Otherwise, returns None.

Raises

ValueError If train_mode is not "sequential" or "parallel".

def loss_history(self):
760    def loss_history(self):
761        """
762        Load training and validation loss history for all nodes.
763
764        Looks for per-node JSON files:
765
766        - ``EXPERIMENT_DIR/{node}/train_loss_hist.json``
767        - ``EXPERIMENT_DIR/{node}/val_loss_hist.json``
768
769        Returns
770        -------
771        dict
772            A dictionary mapping node names to:
773
774            .. code-block:: python
775
776                {
777                    "train": list or None,
778                    "validation": list or None
779                }
780
781            where each list contains NLL values per epoch, or None if not found.
782
783        Raises
784        ------
785        ValueError
786            If the experiment directory cannot be resolved from the configuration.
787        """
788        try:
789            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
790        except KeyError:
791            raise ValueError(
792                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
793                "History retrieval requires experiment logs."
794            )
795
796        all_histories = {}
797        for node in self.nodes_dict.keys():
798            node_dir = os.path.join(EXPERIMENT_DIR, node)
799            train_path = os.path.join(node_dir, "train_loss_hist.json")
800            val_path = os.path.join(node_dir, "val_loss_hist.json")
801
802            node_hist = {}
803
804            # --- load train history ---
805            if os.path.exists(train_path):
806                try:
807                    with open(train_path, "r") as f:
808                        node_hist["train"] = json.load(f)
809                except Exception as e:
810                    print(f"[WARNING] Could not load {train_path}: {e}")
811                    node_hist["train"] = None
812            else:
813                node_hist["train"] = None
814
815            # --- load val history ---
816            if os.path.exists(val_path):
817                try:
818                    with open(val_path, "r") as f:
819                        node_hist["validation"] = json.load(f)
820                except Exception as e:
821                    print(f"[WARNING] Could not load {val_path}: {e}")
822                    node_hist["validation"] = None
823            else:
824                node_hist["validation"] = None
825
826            all_histories[node] = node_hist
827
828        if self.verbose or self.debug:
829            print(f"[INFO] Loaded training/validation histories for {len(all_histories)} nodes.")
830
831        return all_histories

Load training and validation loss history for all nodes.

Looks for per-node JSON files:

  • EXPERIMENT_DIR/{node}/train_loss_hist.json
  • EXPERIMENT_DIR/{node}/val_loss_hist.json

Returns

dict A dictionary mapping node names to:

```python

{ "train": list or None, "validation": list or None } ```

where each list contains NLL values per epoch, or None if not found.

Raises

ValueError If the experiment directory cannot be resolved from the configuration.

def linear_shift_history(self):
833    def linear_shift_history(self):
834        """
835        Load linear shift term histories for all nodes.
836
837        Each node history is expected in a JSON file named
838        ``linear_shifts_all_epochs.json`` under the node directory.
839
840        Returns
841        -------
842        dict
843            A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame
844            contains linear shift weights across epochs.
845
846        Raises
847        ------
848        ValueError
849            If the experiment directory cannot be resolved from the configuration.
850
851        Notes
852        -----
853        If a history file is missing for a node, a warning is printed and the
854        node is omitted from the returned dictionary.
855        """
856        histories = {}
857        try:
858            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
859        except KeyError:
860            raise ValueError(
861                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
862                "Cannot load histories without experiment directory."
863            )
864
865        for node in self.nodes_dict.keys():
866            node_dir = os.path.join(EXPERIMENT_DIR, node)
867            history_path = os.path.join(node_dir, "linear_shifts_all_epochs.json")
868            if os.path.exists(history_path):
869                histories[node] = pd.read_json(history_path)
870            else:
871                print(f"[WARNING] No linear shift history found for node '{node}' at {history_path}")
872        return histories

Load linear shift term histories for all nodes.

Each node history is expected in a JSON file named linear_shifts_all_epochs.json under the node directory.

Returns

dict A mapping {node_name: pandas.DataFrame}, where each DataFrame contains linear shift weights across epochs.

Raises

ValueError If the experiment directory cannot be resolved from the configuration.

Notes

If a history file is missing for a node, a warning is printed and the node is omitted from the returned dictionary.

def simple_intercept_history(self):
874    def simple_intercept_history(self):
875        """
876        Load simple intercept histories for all nodes.
877
878        Each node history is expected in a JSON file named
879        ``simple_intercepts_all_epochs.json`` under the node directory.
880
881        Returns
882        -------
883        dict
884            A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame
885            contains intercept weights across epochs.
886
887        Raises
888        ------
889        ValueError
890            If the experiment directory cannot be resolved from the configuration.
891
892        Notes
893        -----
894        If a history file is missing for a node, a warning is printed and the
895        node is omitted from the returned dictionary.
896        """
897        histories = {}
898        try:
899            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
900        except KeyError:
901            raise ValueError(
902                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
903                "Cannot load histories without experiment directory."
904            )
905
906        for node in self.nodes_dict.keys():
907            node_dir = os.path.join(EXPERIMENT_DIR, node)
908            history_path = os.path.join(node_dir, "simple_intercepts_all_epochs.json")
909            if os.path.exists(history_path):
910                histories[node] = pd.read_json(history_path)
911            else:
912                print(f"[WARNING] No simple intercept history found for node '{node}' at {history_path}")
913        return histories

Load simple intercept histories for all nodes.

Each node history is expected in a JSON file named simple_intercepts_all_epochs.json under the node directory.

Returns

dict A mapping {node_name: pandas.DataFrame}, where each DataFrame contains intercept weights across epochs.

Raises

ValueError If the experiment directory cannot be resolved from the configuration.

Notes

If a history file is missing for a node, a warning is printed and the node is omitted from the returned dictionary.

def get_latent(self, df, verbose=False):
915    def get_latent(self, df, verbose=False):
916        """
917        Compute latent representations for all nodes in the DAG.
918
919        Parameters
920        ----------
921        df : pandas.DataFrame
922            Input data frame with columns corresponding to nodes in the DAG.
923        verbose : bool, optional
924            If True, print informational messages during latent computation.
925            Default is False.
926
927        Returns
928        -------
929        pandas.DataFrame
930            DataFrame containing the original columns plus latent variables
931            for each node (e.g. columns named ``f"{node}_U"``).
932
933        Raises
934        ------
935        ValueError
936            If the experiment directory is missing from the configuration or
937            if ``self.minmax_dict`` has not been set.
938        """
939        try:
940            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
941        except KeyError:
942            raise ValueError(
943                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
944                "Latent extraction requires trained model checkpoints."
945            )
946
947        # ensure minmax_dict is available
948        if not hasattr(self, "minmax_dict"):
949            raise ValueError(
950                "[ERROR] minmax_dict not found in the TramDagModel instance. "
951                "Either call .load_or_compute_minmax(td_train_data=train_df) or .fit() first."
952            )
953
954        all_latents_df = create_latent_df_for_full_dag(
955            configuration_dict=self.cfg.conf_dict,
956            EXPERIMENT_DIR=EXPERIMENT_DIR,
957            df=df,
958            verbose=verbose,
959            min_max_dict=self.minmax_dict,
960        )
961
962        return all_latents_df

Compute latent representations for all nodes in the DAG.

Parameters

df : pandas.DataFrame Input data frame with columns corresponding to nodes in the DAG. verbose : bool, optional If True, print informational messages during latent computation. Default is False.

Returns

pandas.DataFrame DataFrame containing the original columns plus latent variables for each node (e.g. columns named f"{node}_U").

Raises

ValueError If the experiment directory is missing from the configuration or if self.minmax_dict has not been set.

def plot_loss_history(self, variable: str = None):
 967    def plot_loss_history(self, variable: str = None):
 968        """
 969        Plot training and validation loss evolution per node.
 970
 971        Parameters
 972        ----------
 973        variable : str or None, optional
 974            If provided, plot loss history for this node only. If None, plot
 975            histories for all nodes that have both train and validation logs.
 976
 977        Returns
 978        -------
 979        None
 980
 981        Notes
 982        -----
 983        Two subplots are produced:
 984        - Full epoch history.
 985        - Last 10% of epochs (or only the last epoch if fewer than 5 epochs).
 986        """
 987
 988        histories = self.loss_history()
 989        if not histories:
 990            print("[WARNING] No loss histories found.")
 991            return
 992
 993        # Select which nodes to plot
 994        if variable is not None:
 995            if variable not in histories:
 996                raise ValueError(f"[ERROR] Node '{variable}' not found in histories.")
 997            nodes_to_plot = [variable]
 998        else:
 999            nodes_to_plot = list(histories.keys())
1000
1001        # Filter out nodes with no valid history
1002        nodes_to_plot = [
1003            n for n in nodes_to_plot
1004            if histories[n].get("train") is not None and len(histories[n]["train"]) > 0
1005            and histories[n].get("validation") is not None and len(histories[n]["validation"]) > 0
1006        ]
1007
1008        if not nodes_to_plot:
1009            print("[WARNING] No valid histories found to plot.")
1010            return
1011
1012        plt.figure(figsize=(14, 12))
1013
1014        # --- Full history (top plot) ---
1015        plt.subplot(2, 1, 1)
1016        for node in nodes_to_plot:
1017            node_hist = histories[node]
1018            train_hist, val_hist = node_hist["train"], node_hist["validation"]
1019
1020            epochs = range(1, len(train_hist) + 1)
1021            plt.plot(epochs, train_hist, label=f"{node} - train", linestyle="--")
1022            plt.plot(epochs, val_hist, label=f"{node} - val")
1023
1024        plt.title("Training and Validation NLL - Full History")
1025        plt.xlabel("Epoch")
1026        plt.ylabel("NLL")
1027        plt.legend()
1028        plt.grid(True)
1029
1030        # --- Last 10% of epochs (bottom plot) ---
1031        plt.subplot(2, 1, 2)
1032        for node in nodes_to_plot:
1033            node_hist = histories[node]
1034            train_hist, val_hist = node_hist["train"], node_hist["validation"]
1035
1036            total_epochs = len(train_hist)
1037            start_idx = total_epochs - 1 if total_epochs < 5 else int(total_epochs * 0.9)
1038
1039            epochs = range(start_idx + 1, total_epochs + 1)
1040            plt.plot(epochs, train_hist[start_idx:], label=f"{node} - train", linestyle="--")
1041            plt.plot(epochs, val_hist[start_idx:], label=f"{node} - val")
1042
1043        plt.title("Training and Validation NLL - Last 10% of Epochs (or Last Epoch if <5)")
1044        plt.xlabel("Epoch")
1045        plt.ylabel("NLL")
1046        plt.legend()
1047        plt.grid(True)
1048
1049        plt.tight_layout()
1050        plt.show()

Plot training and validation loss evolution per node.

Parameters

variable : str or None, optional If provided, plot loss history for this node only. If None, plot histories for all nodes that have both train and validation logs.

Returns

None

Notes

Two subplots are produced:

  • Full epoch history.
  • Last 10% of epochs (or only the last epoch if fewer than 5 epochs).
def plot_linear_shift_history(self, data_dict=None, node=None, ref_lines=None):
1052    def plot_linear_shift_history(self, data_dict=None, node=None, ref_lines=None):
1053        """
1054        Plot the evolution of linear shift terms over epochs.
1055
1056        Parameters
1057        ----------
1058        data_dict : dict or None, optional
1059            Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing shift
1060            weights across epochs. If None, `linear_shift_history()` is called.
1061        node : str or None, optional
1062            If provided, plot only this node. Otherwise, plot all nodes
1063            present in ``data_dict``.
1064        ref_lines : dict or None, optional
1065            Optional mapping ``{node_name: list of float}``. For each specified
1066            node, horizontal reference lines are drawn at the given values.
1067
1068        Returns
1069        -------
1070        None
1071
1072        Notes
1073        -----
1074        The function flattens nested list-like entries in the DataFrames to scalars,
1075        converts epoch labels to numeric, and then draws one line per shift term.
1076        """
1077
1078        if data_dict is None:
1079            data_dict = self.linear_shift_history()
1080            if data_dict is None:
1081                raise ValueError("No shift history data provided or stored in the class.")
1082
1083        nodes = [node] if node else list(data_dict.keys())
1084
1085        for n in nodes:
1086            df = data_dict[n].copy()
1087
1088            # Flatten nested lists or list-like cells
1089            def flatten(x):
1090                if isinstance(x, list):
1091                    if len(x) == 0:
1092                        return np.nan
1093                    if all(isinstance(i, (int, float)) for i in x):
1094                        return np.mean(x)  # average simple list
1095                    if all(isinstance(i, list) for i in x):
1096                        # nested list -> flatten inner and average
1097                        flat = [v for sub in x for v in (sub if isinstance(sub, list) else [sub])]
1098                        return np.mean(flat) if flat else np.nan
1099                    return x[0] if len(x) == 1 else np.nan
1100                return x
1101
1102            df = df.applymap(flatten)
1103
1104            # Ensure numeric columns
1105            df = df.apply(pd.to_numeric, errors='coerce')
1106
1107            # Convert epoch labels to numeric
1108            df.columns = [
1109                int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c
1110                for c in df.columns
1111            ]
1112            df = df.reindex(sorted(df.columns), axis=1)
1113
1114            plt.figure(figsize=(10, 6))
1115            for idx in df.index:
1116                plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"shift_{idx}")
1117
1118            if ref_lines and n in ref_lines:
1119                for v in ref_lines[n]:
1120                    plt.axhline(y=v, color="k", linestyle="--", lw=1.0)
1121                    plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8)
1122
1123            plt.xlabel("Epoch")
1124            plt.ylabel("Shift Value")
1125            plt.title(f"Shift Term History — Node: {n}")
1126            plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small")
1127            plt.tight_layout()
1128            plt.show()

Plot the evolution of linear shift terms over epochs.

Parameters

data_dict : dict or None, optional Pre-loaded mapping {node_name: pandas.DataFrame} containing shift weights across epochs. If None, linear_shift_history() is called. node : str or None, optional If provided, plot only this node. Otherwise, plot all nodes present in data_dict. ref_lines : dict or None, optional Optional mapping {node_name: list of float}. For each specified node, horizontal reference lines are drawn at the given values.

Returns

None

Notes

The function flattens nested list-like entries in the DataFrames to scalars, converts epoch labels to numeric, and then draws one line per shift term.

def plot_simple_intercepts_history(self, data_dict=None, node=None, ref_lines=None):
1130    def plot_simple_intercepts_history(self, data_dict=None, node=None,ref_lines=None):
1131        """
1132        Plot the evolution of simple intercept weights over epochs.
1133
1134        Parameters
1135        ----------
1136        data_dict : dict or None, optional
1137            Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing intercept
1138            weights across epochs. If None, `simple_intercept_history()` is called.
1139        node : str or None, optional
1140            If provided, plot only this node. Otherwise, plot all nodes present
1141            in ``data_dict``.
1142        ref_lines : dict or None, optional
1143            Optional mapping ``{node_name: list of float}``. For each specified
1144            node, horizontal reference lines are drawn at the given values.
1145
1146        Returns
1147        -------
1148        None
1149
1150        Notes
1151        -----
1152        Nested list-like entries in the DataFrames are reduced to scalars before
1153        plotting. One line is drawn per intercept parameter.
1154        """
1155        if data_dict is None:
1156            data_dict = self.simple_intercept_history()
1157            if data_dict is None:
1158                raise ValueError("No intercept history data provided or stored in the class.")
1159
1160        nodes = [node] if node else list(data_dict.keys())
1161
1162        for n in nodes:
1163            df = data_dict[n].copy()
1164
1165            def extract_scalar(x):
1166                if isinstance(x, list):
1167                    while isinstance(x, list) and len(x) > 0:
1168                        x = x[0]
1169                return float(x) if isinstance(x, (int, float, np.floating)) else np.nan
1170
1171            df = df.applymap(extract_scalar)
1172
1173            # Convert epoch labels → numeric
1174            df.columns = [
1175                int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c
1176                for c in df.columns
1177            ]
1178            df = df.reindex(sorted(df.columns), axis=1)
1179
1180            plt.figure(figsize=(10, 6))
1181            for idx in df.index:
1182                plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"theta_{idx}")
1183            
1184            if ref_lines and n in ref_lines:
1185                for v in ref_lines[n]:
1186                    plt.axhline(y=v, color="k", linestyle="--", lw=1.0)
1187                    plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8)
1188                
1189            plt.xlabel("Epoch")
1190            plt.ylabel("Intercept Weight")
1191            plt.title(f"Simple Intercept Evolution — Node: {n}")
1192            plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small")
1193            plt.tight_layout()
1194            plt.show()

Plot the evolution of simple intercept weights over epochs.

Parameters

data_dict : dict or None, optional Pre-loaded mapping {node_name: pandas.DataFrame} containing intercept weights across epochs. If None, simple_intercept_history() is called. node : str or None, optional If provided, plot only this node. Otherwise, plot all nodes present in data_dict. ref_lines : dict or None, optional Optional mapping {node_name: list of float}. For each specified node, horizontal reference lines are drawn at the given values.

Returns

None

Notes

Nested list-like entries in the DataFrames are reduced to scalars before plotting. One line is drawn per intercept parameter.

def plot_latents( self, df, variable: str = None, confidence: float = 0.95, simulations: int = 1000):
1196    def plot_latents(self, df, variable: str = None, confidence: float = 0.95, simulations: int = 1000):
1197        """
1198        Visualize latent U distributions for one or all nodes.
1199
1200        Parameters
1201        ----------
1202        df : pandas.DataFrame
1203            Input data frame with raw node values.
1204        variable : str or None, optional
1205            If provided, only this node's latents are plotted. If None, all
1206            nodes with latent columns are processed.
1207        confidence : float, optional
1208            Confidence level for QQ-plot bands (0 < confidence < 1).
1209            Default is 0.95.
1210        simulations : int, optional
1211            Number of Monte Carlo simulations for QQ-plot bands. Default is 1000.
1212
1213        Returns
1214        -------
1215        None
1216
1217        Notes
1218        -----
1219        For each node, two plots are produced:
1220        - Histogram of the latent U values.
1221        - QQ-plot with simulation-based confidence bands under a logistic reference.
1222        """
1223        # Compute latent representations
1224        latents_df = self.get_latent(df)
1225
1226        # Select nodes
1227        nodes = [variable] if variable is not None else self.nodes_dict.keys()
1228
1229        for node in nodes:
1230            if f"{node}_U" not in latents_df.columns:
1231                print(f"[WARNING] No latent found for node {node}, skipping.")
1232                continue
1233
1234            sample = latents_df[f"{node}_U"].values
1235
1236            # --- Create plots ---
1237            fig, axs = plt.subplots(1, 2, figsize=(12, 5))
1238
1239            # Histogram
1240            axs[0].hist(sample, bins=50, color="steelblue", alpha=0.7)
1241            axs[0].set_title(f"Latent Histogram ({node})")
1242            axs[0].set_xlabel("U")
1243            axs[0].set_ylabel("Frequency")
1244
1245            # QQ Plot with confidence bands
1246            probplot(sample, dist="logistic", plot=axs[1])
1247            self._add_r_style_confidence_bands(axs[1], sample, dist=logistic,confidence=confidence, simulations=simulations)
1248            axs[1].set_title(f"Latent QQ Plot ({node})")
1249
1250            plt.suptitle(f"Latent Diagnostics for Node: {node}", fontsize=14)
1251            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
1252            plt.show()

Visualize latent U distributions for one or all nodes.

Parameters

df : pandas.DataFrame Input data frame with raw node values. variable : str or None, optional If provided, only this node's latents are plotted. If None, all nodes with latent columns are processed. confidence : float, optional Confidence level for QQ-plot bands (0 < confidence < 1). Default is 0.95. simulations : int, optional Number of Monte Carlo simulations for QQ-plot bands. Default is 1000.

Returns

None

Notes

For each node, two plots are produced:

  • Histogram of the latent U values.
  • QQ-plot with simulation-based confidence bands under a logistic reference.
def plot_hdag(self, df, variables=None, plot_n_rows=1, **kwargs):
1254    def plot_hdag(self,df,variables=None, plot_n_rows=1,**kwargs):
1255        
1256        """
1257        Visualize the transformation function h() for selected DAG nodes.
1258
1259        Parameters
1260        ----------
1261        df : pandas.DataFrame
1262            Input data containing node values or model predictions.
1263        variables : list of str or None, optional
1264            Names of nodes to visualize. If None, all nodes in ``self.models``
1265            are considered.
1266        plot_n_rows : int, optional
1267            Maximum number of rows from ``df`` to visualize. Default is 1.
1268        **kwargs
1269            Additional keyword arguments forwarded to the underlying plotting
1270            helpers (`show_hdag_continous` / `show_hdag_ordinal`).
1271
1272        Returns
1273        -------
1274        None
1275
1276        Notes
1277        -----
1278        - For continuous outcomes, `show_hdag_continous` is called.
1279        - For ordinal outcomes, `show_hdag_ordinal` is called.
1280        - Nodes that are neither continuous nor ordinal are skipped with a warning.
1281        """
1282                
1283
1284        if len(df)> 1:
1285            print("[WARNING] len(df)>1, set: plot_n_rows accordingly")
1286        
1287        variables_list=variables if variables is not None else list(self.models.keys())
1288        for node in variables_list:
1289            if is_outcome_modelled_continous(node, self.nodes_dict):
1290                show_hdag_continous(df,node=node,configuration_dict=self.cfg.conf_dict,minmax_dict=self.minmax_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1291            
1292            elif is_outcome_modelled_ordinal(node, self.nodes_dict):
1293                show_hdag_ordinal(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1294                # plot_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1295                # save_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,**kwargs)
1296            else:
1297                print(f"[WARNING] Node {node} is wheter ordinal nor continous, not implemented yet")

Visualize the transformation function h() for selected DAG nodes.

Parameters

df : pandas.DataFrame Input data containing node values or model predictions. variables : list of str or None, optional Names of nodes to visualize. If None, all nodes in self.models are considered. plot_n_rows : int, optional Maximum number of rows from df to visualize. Default is 1. **kwargs Additional keyword arguments forwarded to the underlying plotting helpers (show_hdag_continous / show_hdag_ordinal).

Returns

None

Notes

  • For continuous outcomes, show_hdag_continous is called.
  • For ordinal outcomes, show_hdag_ordinal is called.
  • Nodes that are neither continuous nor ordinal are skipped with a warning.
def sample( self, do_interventions: dict = None, predefined_latent_samples_df: pandas.core.frame.DataFrame = None, **kwargs):
1356    def sample(
1357        self,
1358        do_interventions: dict = None,
1359        predefined_latent_samples_df: pd.DataFrame = None,
1360        **kwargs,
1361    ):
1362        """
1363        Sample from the joint DAG using the trained TRAM models.
1364
1365        Allows for:
1366        
1367        Oberservational sampling
1368        Interventional sampling via ``do()`` operations
1369        Counterfactial sampling using predefined latent draws and do()
1370        
1371        Parameters
1372        ----------
1373        do_interventions : dict or None, optional
1374            Mapping of node names to intervened (fixed) values. For example:
1375            ``{"x1": 1.0}`` represents ``do(x1 = 1.0)``. Default is None.
1376        predefined_latent_samples_df : pandas.DataFrame or None, optional
1377            DataFrame containing columns ``"{node}_U"`` with predefined latent
1378            draws to be used instead of sampling from the prior. Default is None.
1379        **kwargs
1380            Sampling options overriding internal defaults:
1381
1382            number_of_samples : int, default 10000
1383                Total number of samples to draw.
1384            batch_size : int, default 32
1385                Batch size for internal sampling loops.
1386            delete_all_previously_sampled : bool, default True
1387                If True, delete old sampling files in node-specific sampling
1388                directories before writing new ones.
1389            verbose : bool
1390                If True, print informational messages.
1391            debug : bool
1392                If True, print debug output.
1393            device : {"auto", "cpu", "cuda"}
1394                Device selection for sampling.
1395            use_initial_weights_for_sampling : bool, default False
1396                If True, sample from initial (untrained) model parameters.
1397
1398        Returns
1399        -------
1400        tuple
1401            A tuple ``(sampled_by_node, latents_by_node)``:
1402
1403            sampled_by_node : dict
1404                Mapping ``{node_name: torch.Tensor}`` of sampled node values.
1405            latents_by_node : dict
1406                Mapping ``{node_name: torch.Tensor}`` of latent U values used.
1407
1408        Raises
1409        ------
1410        ValueError
1411            If the experiment directory cannot be resolved or if scaling
1412            information (``self.minmax_dict``) is missing.
1413        RuntimeError
1414            If min–max scaling has not been computed before calling `sample`.
1415        """
1416        try:
1417            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1418        except KeyError:
1419            raise ValueError(
1420                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
1421                "Sampling requires trained model checkpoints."
1422            )
1423
1424        # ---- defaults ----
1425        settings = {
1426            "number_of_samples": 10_000,
1427            "batch_size": 32,
1428            "delete_all_previously_sampled": True,
1429            "verbose": self.verbose if hasattr(self, "verbose") else False,
1430            "debug": self.debug if hasattr(self, "debug") else False,
1431            "device": self.device.type if hasattr(self, "device") else "auto",
1432            "use_initial_weights_for_sampling": False,
1433            
1434        }
1435        
1436        # self._validate_kwargs( kwargs, defaults_attr= "settings", context="sample")
1437        
1438        settings.update(kwargs)
1439
1440        
1441        if not hasattr(self, "minmax_dict"):
1442            raise RuntimeError(
1443                "[ERROR] minmax_dict not found. You must call .fit() or .load_or_compute_minmax() "
1444                "before sampling, so scaling info is available."
1445                )
1446            
1447        # ---- resolve device ----
1448        device_str=self.get_device(settings)
1449        self.device = torch.device(device_str)
1450
1451
1452        if self.debug or settings["debug"]:
1453            print(f"[DEBUG] sample(): device: {self.device}")
1454
1455        # ---- perform sampling ----
1456        sampled_by_node, latents_by_node = sample_full_dag(
1457            configuration_dict=self.cfg.conf_dict,
1458            EXPERIMENT_DIR=EXPERIMENT_DIR,
1459            device=self.device,
1460            do_interventions=do_interventions or {},
1461            predefined_latent_samples_df=predefined_latent_samples_df,
1462            number_of_samples=settings["number_of_samples"],
1463            batch_size=settings["batch_size"],
1464            delete_all_previously_sampled=settings["delete_all_previously_sampled"],
1465            verbose=settings["verbose"],
1466            debug=settings["debug"],
1467            minmax_dict=self.minmax_dict,
1468            use_initial_weights_for_sampling=settings["use_initial_weights_for_sampling"]
1469        )
1470
1471        return sampled_by_node, latents_by_node

Sample from the joint DAG using the trained TRAM models.

Allows for:

Oberservational sampling Interventional sampling via do() operations Counterfactial sampling using predefined latent draws and do()

Parameters

do_interventions : dict or None, optional Mapping of node names to intervened (fixed) values. For example: {"x1": 1.0} represents do(x1 = 1.0). Default is None. predefined_latent_samples_df : pandas.DataFrame or None, optional DataFrame containing columns "{node}_U" with predefined latent draws to be used instead of sampling from the prior. Default is None. **kwargs Sampling options overriding internal defaults:

number_of_samples : int, default 10000
    Total number of samples to draw.
batch_size : int, default 32
    Batch size for internal sampling loops.
delete_all_previously_sampled : bool, default True
    If True, delete old sampling files in node-specific sampling
    directories before writing new ones.
verbose : bool
    If True, print informational messages.
debug : bool
    If True, print debug output.
device : {"auto", "cpu", "cuda"}
    Device selection for sampling.
use_initial_weights_for_sampling : bool, default False
    If True, sample from initial (untrained) model parameters.

Returns

tuple A tuple (sampled_by_node, latents_by_node):

sampled_by_node : dict
    Mapping ``{node_name: torch.Tensor}`` of sampled node values.
latents_by_node : dict
    Mapping ``{node_name: torch.Tensor}`` of latent U values used.

Raises

ValueError If the experiment directory cannot be resolved or if scaling information (self.minmax_dict) is missing. RuntimeError If min–max scaling has not been computed before calling sample.

def load_sampled_and_latents(self, EXPERIMENT_DIR: str = None, nodes: list = None):
1473    def load_sampled_and_latents(self, EXPERIMENT_DIR: str = None, nodes: list = None):
1474        """
1475        Load previously stored sampled values and latents for each node.
1476
1477        Parameters
1478        ----------
1479        EXPERIMENT_DIR : str or None, optional
1480            Experiment directory path. If None, it is taken from
1481            ``self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]``.
1482        nodes : list of str or None, optional
1483            Nodes for which to load samples. If None, use all nodes from
1484            ``self.nodes_dict``.
1485
1486        Returns
1487        -------
1488        tuple
1489            A tuple ``(sampled_by_node, latents_by_node)``:
1490
1491            sampled_by_node : dict
1492                Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU).
1493            latents_by_node : dict
1494                Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU).
1495
1496        Raises
1497        ------
1498        ValueError
1499            If the experiment directory cannot be resolved or if no node list
1500            is available and ``nodes`` is None.
1501
1502        Notes
1503        -----
1504        Nodes without both ``sampled.pt`` and ``latents.pt`` files are skipped
1505        with a warning.
1506        """
1507        # --- resolve paths and node list ---
1508        if EXPERIMENT_DIR is None:
1509            try:
1510                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1511            except (AttributeError, KeyError):
1512                raise ValueError(
1513                    "[ERROR] Could not resolve EXPERIMENT_DIR from cfg.conf_dict['PATHS']. "
1514                    "Please provide EXPERIMENT_DIR explicitly."
1515                )
1516
1517        if nodes is None:
1518            if hasattr(self, "nodes_dict"):
1519                nodes = list(self.nodes_dict.keys())
1520            else:
1521                raise ValueError(
1522                    "[ERROR] No node list found. Please provide `nodes` or initialize model with a config."
1523                )
1524
1525        # --- load tensors ---
1526        sampled_by_node = {}
1527        latents_by_node = {}
1528
1529        for node in nodes:
1530            node_dir = os.path.join(EXPERIMENT_DIR, f"{node}")
1531            sampling_dir = os.path.join(node_dir, "sampling")
1532
1533            sampled_path = os.path.join(sampling_dir, "sampled.pt")
1534            latents_path = os.path.join(sampling_dir, "latents.pt")
1535
1536            if not os.path.exists(sampled_path) or not os.path.exists(latents_path):
1537                print(f"[WARNING] Missing files for node '{node}' — skipping.")
1538                continue
1539
1540            try:
1541                sampled = torch.load(sampled_path, map_location="cpu")
1542                latent_sample = torch.load(latents_path, map_location="cpu")
1543            except Exception as e:
1544                print(f"[ERROR] Could not load sampling files for node '{node}': {e}")
1545                continue
1546
1547            sampled_by_node[node] = sampled.detach().cpu()
1548            latents_by_node[node] = latent_sample.detach().cpu()
1549
1550        if self.verbose or self.debug:
1551            print(f"[INFO] Loaded sampled and latent tensors for {len(sampled_by_node)} nodes from {EXPERIMENT_DIR}")
1552
1553        return sampled_by_node, latents_by_node

Load previously stored sampled values and latents for each node.

Parameters

EXPERIMENT_DIR : str or None, optional Experiment directory path. If None, it is taken from self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]. nodes : list of str or None, optional Nodes for which to load samples. If None, use all nodes from self.nodes_dict.

Returns

tuple A tuple (sampled_by_node, latents_by_node):

sampled_by_node : dict
    Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU).
latents_by_node : dict
    Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU).

Raises

ValueError If the experiment directory cannot be resolved or if no node list is available and nodes is None.

Notes

Nodes without both sampled.pt and latents.pt files are skipped with a warning.

def plot_samples_vs_true( self, df, sampled: dict = None, variable: list = None, bins: int = 100, hist_true_color: str = 'blue', hist_est_color: str = 'orange', figsize: tuple = (14, 5)):
1555    def plot_samples_vs_true(
1556        self,
1557        df,
1558        sampled: dict = None,
1559        variable: list = None,
1560        bins: int = 100,
1561        hist_true_color: str = "blue",
1562        hist_est_color: str = "orange",
1563        figsize: tuple = (14, 5),
1564    ):
1565        
1566        
1567        """
1568        Compare sampled vs. observed distributions for selected nodes.
1569
1570        Parameters
1571        ----------
1572        df : pandas.DataFrame
1573            Data frame containing the observed node values.
1574        sampled : dict or None, optional
1575            Optional mapping ``{node_name: array-like or torch.Tensor}`` of sampled
1576            values. If None or if a node is missing, samples are loaded from
1577            ``EXPERIMENT_DIR/{node}/sampling/sampled.pt``.
1578        variable : list of str or None, optional
1579            Subset of nodes to plot. If None, all nodes in the configuration
1580            are considered.
1581        bins : int, optional
1582            Number of histogram bins for continuous variables. Default is 100.
1583        hist_true_color : str, optional
1584            Color name for the histogram of true values. Default is "blue".
1585        hist_est_color : str, optional
1586            Color name for the histogram of sampled values. Default is "orange".
1587        figsize : tuple, optional
1588            Figure size for the matplotlib plots. Default is (14, 5).
1589
1590        Returns
1591        -------
1592        None
1593
1594        Notes
1595        -----
1596        - Continuous outcomes: histogram overlay + QQ-plot.
1597        - Ordinal outcomes: side-by-side bar plot of relative frequencies.
1598        - Other categorical outcomes: side-by-side bar plot with category labels.
1599        - If samples are probabilistic (2D tensor), the argmax across classes is used.
1600        """
1601        
1602        target_nodes = self.cfg.conf_dict["nodes"]
1603        experiment_dir = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1604        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1605
1606        plot_list = variable if variable is not None else target_nodes
1607
1608        for node in plot_list:
1609            # Load sampled data
1610            if sampled is not None and node in sampled:
1611                sdata = sampled[node]
1612                if isinstance(sdata, torch.Tensor):
1613                    sampled_vals = sdata.detach().cpu().numpy()
1614                else:
1615                    sampled_vals = np.asarray(sdata)
1616            else:
1617                sample_path = os.path.join(experiment_dir, f"{node}/sampling/sampled.pt")
1618                if not os.path.isfile(sample_path):
1619                    print(f"[WARNING] skip {node}: {sample_path} not found.")
1620                    continue
1621
1622                try:
1623                    sampled_vals = torch.load(sample_path, map_location=device).cpu().numpy()
1624                except Exception as e:
1625                    print(f"[ERROR] Could not load {sample_path}: {e}")
1626                    continue
1627
1628            # If logits/probabilities per sample, take argmax
1629            if sampled_vals.ndim == 2:
1630                    print(f"[INFO] CAUTION! {node}: samples are probabilistic — each sample follows a probability "
1631                    f"distribution based on the valid latent range. "
1632                    f"Note that this frequency plot reflects only the distribution of the most probable "
1633                    f"class per sample.")
1634                    sampled_vals = np.argmax(sampled_vals, axis=1)
1635
1636            sampled_vals = sampled_vals[np.isfinite(sampled_vals)]
1637
1638            if node not in df.columns:
1639                print(f"[WARNING] skip {node}: column not found in DataFrame.")
1640                continue
1641
1642            true_vals = df[node].dropna().values
1643            true_vals = true_vals[np.isfinite(true_vals)]
1644
1645            if sampled_vals.size == 0 or true_vals.size == 0:
1646                print(f"[WARNING] skip {node}: empty array after NaN/Inf removal.")
1647                continue
1648
1649            fig, axs = plt.subplots(1, 2, figsize=figsize)
1650
1651            if is_outcome_modelled_continous(node, target_nodes):
1652                axs[0].hist(true_vals, bins=bins, density=True, alpha=0.6,
1653                            color=hist_true_color, label=f"True {node}")
1654                axs[0].hist(sampled_vals, bins=bins, density=True, alpha=0.6,
1655                            color=hist_est_color, label="Sampled")
1656                axs[0].set_xlabel("Value")
1657                axs[0].set_ylabel("Density")
1658                axs[0].set_title(f"Histogram overlay for {node}")
1659                axs[0].legend()
1660                axs[0].grid(True, ls="--", alpha=0.4)
1661
1662                qqplot_2samples(true_vals, sampled_vals, line="45", ax=axs[1])
1663                axs[1].set_xlabel("True quantiles")
1664                axs[1].set_ylabel("Sampled quantiles")
1665                axs[1].set_title(f"QQ plot for {node}")
1666                axs[1].grid(True, ls="--", alpha=0.4)
1667
1668            elif is_outcome_modelled_ordinal(node, target_nodes):
1669                unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals))
1670                unique_vals = np.sort(unique_vals)
1671                true_counts = np.array([(true_vals == val).sum() for val in unique_vals])
1672                sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals])
1673
1674                axs[0].bar(unique_vals - 0.2, true_counts / true_counts.sum(),
1675                        width=0.4, color=hist_true_color, alpha=0.7, label="True")
1676                axs[0].bar(unique_vals + 0.2, sampled_counts / sampled_counts.sum(),
1677                        width=0.4, color=hist_est_color, alpha=0.7, label="Sampled")
1678                axs[0].set_xticks(unique_vals)
1679                axs[0].set_xlabel("Ordinal Level")
1680                axs[0].set_ylabel("Relative Frequency")
1681                axs[0].set_title(f"Ordinal bar plot for {node}")
1682                axs[0].legend()
1683                axs[0].grid(True, ls="--", alpha=0.4)
1684                axs[1].axis("off")
1685
1686            else:
1687                unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals))
1688                unique_vals = sorted(unique_vals, key=str)
1689                true_counts = np.array([(true_vals == val).sum() for val in unique_vals])
1690                sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals])
1691
1692                axs[0].bar(np.arange(len(unique_vals)) - 0.2, true_counts / true_counts.sum(),
1693                        width=0.4, color=hist_true_color, alpha=0.7, label="True")
1694                axs[0].bar(np.arange(len(unique_vals)) + 0.2, sampled_counts / sampled_counts.sum(),
1695                        width=0.4, color=hist_est_color, alpha=0.7, label="Sampled")
1696                axs[0].set_xticks(np.arange(len(unique_vals)))
1697                axs[0].set_xticklabels(unique_vals, rotation=45)
1698                axs[0].set_xlabel("Category")
1699                axs[0].set_ylabel("Relative Frequency")
1700                axs[0].set_title(f"Categorical bar plot for {node}")
1701                axs[0].legend()
1702                axs[0].grid(True, ls="--", alpha=0.4)
1703                axs[1].axis("off")
1704
1705            plt.tight_layout()
1706            plt.show()

Compare sampled vs. observed distributions for selected nodes.

Parameters

df : pandas.DataFrame Data frame containing the observed node values. sampled : dict or None, optional Optional mapping {node_name: array-like or torch.Tensor} of sampled values. If None or if a node is missing, samples are loaded from EXPERIMENT_DIR/{node}/sampling/sampled.pt. variable : list of str or None, optional Subset of nodes to plot. If None, all nodes in the configuration are considered. bins : int, optional Number of histogram bins for continuous variables. Default is 100. hist_true_color : str, optional Color name for the histogram of true values. Default is "blue". hist_est_color : str, optional Color name for the histogram of sampled values. Default is "orange". figsize : tuple, optional Figure size for the matplotlib plots. Default is (14, 5).

Returns

None

Notes

  • Continuous outcomes: histogram overlay + QQ-plot.
  • Ordinal outcomes: side-by-side bar plot of relative frequencies.
  • Other categorical outcomes: side-by-side bar plot with category labels.
  • If samples are probabilistic (2D tensor), the argmax across classes is used.
def nll(self, data, variables=None):
1709    def nll(self,data,variables=None):
1710        """
1711        Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes.
1712
1713        This function evaluates trained TRAM models for each specified variable (node) 
1714        on the provided dataset. It performs forward passes only—no training, no weight 
1715        updates—and returns the mean NLL per node.
1716
1717        Parameters
1718        ----------
1719        data : object
1720            Input dataset or data source compatible with `_ensure_dataset`, containing 
1721            both inputs and targets for each node.
1722        variables : list[str], optional
1723            List of variable (node) names to evaluate. If None, all nodes in 
1724            `self.models` are evaluated.
1725
1726        Returns
1727        -------
1728        dict[str, float]
1729            Dictionary mapping each node name to its average NLL value.
1730
1731        Notes
1732        -----
1733        - Each model is evaluated independently on its respective DataLoader.
1734        - The normalization values (`min_max`) for each node are retrieved from 
1735          `self.minmax_dict[node]`.
1736        - The function uses `evaluate_tramdag_model()` for per-node evaluation.
1737        - Expected directory structure:
1738              `<EXPERIMENT_DIR>/<node>/`
1739          where each node directory contains the trained model.
1740        """
1741
1742        td_data = self._ensure_dataset(data, is_val=True)  
1743        variables_list = variables if variables != None else list(self.models.keys())
1744        nll_dict = {}
1745        for node in variables_list:  
1746                min_vals = torch.tensor(self.minmax_dict[node][0], dtype=torch.float32)
1747                max_vals = torch.tensor(self.minmax_dict[node][1], dtype=torch.float32)
1748                min_max = torch.stack([min_vals, max_vals], dim=0)
1749                data_loader = td_data.loaders[node]
1750                model = self.models[node]
1751                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1752                NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}")
1753                nll= evaluate_tramdag_model(node=node,
1754                                            target_nodes=self.nodes_dict,
1755                                            NODE_DIR=NODE_DIR,
1756                                            tram_model=model,
1757                                            data_loader=data_loader,
1758                                            min_max=min_max)
1759                nll_dict[node]=nll
1760        return nll_dict

Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes.

This function evaluates trained TRAM models for each specified variable (node) on the provided dataset. It performs forward passes only—no training, no weight updates—and returns the mean NLL per node.

Parameters

data : object Input dataset or data source compatible with _ensure_dataset, containing both inputs and targets for each node. variables : list[str], optional List of variable (node) names to evaluate. If None, all nodes in self.models are evaluated.

Returns

dict[str, float] Dictionary mapping each node name to its average NLL value.

Notes

  • Each model is evaluated independently on its respective DataLoader.
  • The normalization values (min_max) for each node are retrieved from self.minmax_dict[node].
  • The function uses evaluate_tramdag_model() for per-node evaluation.
  • Expected directory structure: <EXPERIMENT_DIR>/<node>/ where each node directory contains the trained model.
def get_train_val_nll(self, node: str, mode: str) -> tuple[float, float]:
1762    def get_train_val_nll(self, node: str, mode: str) -> tuple[float, float]:
1763        """
1764        Retrieve training and validation NLL for a node and a given model state.
1765
1766        Parameters
1767        ----------
1768        node : str
1769            Node name.
1770        mode : {"best", "last", "init"}
1771            State of interest:
1772            - "best": epoch with lowest validation NLL.
1773            - "last": final epoch.
1774            - "init": first epoch (index 0).
1775
1776        Returns
1777        -------
1778        tuple of (float or None, float or None)
1779            A tuple ``(train_nll, val_nll)`` for the requested mode.
1780            Returns ``(None, None)`` if loss files are missing or cannot be read.
1781
1782        Notes
1783        -----
1784        This method expects per-node JSON files:
1785
1786        - ``train_loss_hist.json``
1787        - ``val_loss_hist.json``
1788
1789        in the node directory.
1790        """
1791        NODE_DIR = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node)
1792        train_path = os.path.join(NODE_DIR, "train_loss_hist.json")
1793        val_path = os.path.join(NODE_DIR, "val_loss_hist.json")
1794
1795        if not os.path.exists(train_path) or not os.path.exists(val_path):
1796            if getattr(self, "debug", False):
1797                print(f"[DEBUG] Missing loss files for node '{node}'. Returning None.")
1798            return None, None
1799
1800        try:
1801            with open(train_path, "r") as f:
1802                train_hist = json.load(f)
1803            with open(val_path, "r") as f:
1804                val_hist = json.load(f)
1805
1806            train_nlls = np.array(train_hist)
1807            val_nlls = np.array(val_hist)
1808
1809            if mode == "init":
1810                idx = 0
1811            elif mode == "last":
1812                idx = len(val_nlls) - 1
1813            elif mode == "best":
1814                idx = int(np.argmin(val_nlls))
1815            else:
1816                raise ValueError(f"Invalid mode '{mode}' — must be one of 'best', 'last', 'init'.")
1817
1818            train_nll = float(train_nlls[idx])
1819            val_nll = float(val_nlls[idx])
1820            return train_nll, val_nll
1821
1822        except Exception as e:
1823            print(f"[ERROR] Failed to load NLLs for node '{node}' ({mode}): {e}")
1824            return None, None

Retrieve training and validation NLL for a node and a given model state.

Parameters

node : str Node name. mode : {"best", "last", "init"} State of interest: - "best": epoch with lowest validation NLL. - "last": final epoch. - "init": first epoch (index 0).

Returns

tuple of (float or None, float or None) A tuple (train_nll, val_nll) for the requested mode. Returns (None, None) if loss files are missing or cannot be read.

Notes

This method expects per-node JSON files:

  • train_loss_hist.json
  • val_loss_hist.json

in the node directory.

def get_thetas(self, node: str, state: str = 'best'):
1826    def get_thetas(self, node: str, state: str = "best"):
1827        """
1828        Return transformed intercept (theta) parameters for a node and state.
1829
1830        Parameters
1831        ----------
1832        node : str
1833            Node name.
1834        state : {"best", "last", "init"}, optional
1835            Model state for which to return parameters. Default is "best".
1836
1837        Returns
1838        -------
1839        Any or None
1840            Transformed theta parameters for the requested node and state.
1841            The exact structure (scalar, list, or other) depends on the model.
1842
1843        Raises
1844        ------
1845        ValueError
1846            If an invalid state is given (not in {"best", "last", "init"}).
1847
1848        Notes
1849        -----
1850        Intercept dictionaries are cached on the instance under the attribute
1851        ``intercept_dicts``. If missing or incomplete, they are recomputed using
1852        `get_simple_intercepts_dict`.
1853        """
1854
1855        state = state.lower()
1856        if state not in ["best", "last", "init"]:
1857            raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].")
1858
1859        dict_attr = "intercept_dicts"
1860
1861        # If no cached intercepts exist, compute them
1862        if not hasattr(self, dict_attr):
1863            if getattr(self, "debug", False):
1864                print(f"[DEBUG] '{dict_attr}' not found, computing via get_simple_intercepts_dict().")
1865            setattr(self, dict_attr, self.get_simple_intercepts_dict())
1866
1867        all_dicts = getattr(self, dict_attr)
1868
1869        # If the requested state isn’t cached, recompute
1870        if state not in all_dicts:
1871            if getattr(self, "debug", False):
1872                print(f"[DEBUG] State '{state}' not found in cached intercepts, recomputing full dict.")
1873            setattr(self, dict_attr, self.get_simple_intercepts_dict())
1874            all_dicts = getattr(self, dict_attr)
1875
1876        state_dict = all_dicts.get(state, {})
1877
1878        # Return cached node intercept if present
1879        if node in state_dict:
1880            return state_dict[node]
1881
1882        # If not found, recompute full dict as fallback
1883        if getattr(self, "debug", False):
1884            print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.")
1885        setattr(self, dict_attr, self.get_simple_intercepts_dict())
1886        all_dicts = getattr(self, dict_attr)
1887        return all_dicts.get(state, {}).get(node, None)

Return transformed intercept (theta) parameters for a node and state.

Parameters

node : str Node name. state : {"best", "last", "init"}, optional Model state for which to return parameters. Default is "best".

Returns

Any or None Transformed theta parameters for the requested node and state. The exact structure (scalar, list, or other) depends on the model.

Raises

ValueError If an invalid state is given (not in {"best", "last", "init"}).

Notes

Intercept dictionaries are cached on the instance under the attribute intercept_dicts. If missing or incomplete, they are recomputed using get_simple_intercepts_dict.

def get_linear_shifts(self, node: str, state: str = 'best'):
1889    def get_linear_shifts(self, node: str, state: str = "best"):
1890        """
1891        Return learned linear shift terms for a node and a given state.
1892
1893        Parameters
1894        ----------
1895        node : str
1896            Node name.
1897        state : {"best", "last", "init"}, optional
1898            Model state for which to return linear shift terms. Default is "best".
1899
1900        Returns
1901        -------
1902        dict or Any or None
1903            Linear shift terms for the given node and state. Usually a dict
1904            mapping term names to weights.
1905
1906        Raises
1907        ------
1908        ValueError
1909            If an invalid state is given (not in {"best", "last", "init"}).
1910
1911        Notes
1912        -----
1913        Linear shift dictionaries are cached on the instance under the attribute
1914        ``linear_shift_dicts``. If missing or incomplete, they are recomputed using
1915        `get_linear_shifts_dict`.
1916        """
1917        state = state.lower()
1918        if state not in ["best", "last", "init"]:
1919            raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].")
1920
1921        dict_attr = "linear_shift_dicts"
1922
1923        # If no global dicts cached, compute once
1924        if not hasattr(self, dict_attr):
1925            if getattr(self, "debug", False):
1926                print(f"[DEBUG] '{dict_attr}' not found, computing via get_linear_shifts_dict().")
1927            setattr(self, dict_attr, self.get_linear_shifts_dict())
1928
1929        all_dicts = getattr(self, dict_attr)
1930
1931        # If the requested state isn't cached, compute all again (covers fresh runs)
1932        if state not in all_dicts:
1933            if getattr(self, "debug", False):
1934                print(f"[DEBUG] State '{state}' not found in cached linear shifts, recomputing full dict.")
1935            setattr(self, dict_attr, self.get_linear_shifts_dict())
1936            all_dicts = getattr(self, dict_attr)
1937
1938        # Now fetch the dictionary for this state
1939        state_dict = all_dicts.get(state, {})
1940
1941        # If the node is available, return its entry
1942        if node in state_dict:
1943            return state_dict[node]
1944
1945        # If missing, try recomputing (fallback)
1946        if getattr(self, "debug", False):
1947            print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.")
1948        setattr(self, dict_attr, self.get_linear_shifts_dict())
1949        all_dicts = getattr(self, dict_attr)
1950        return all_dicts.get(state, {}).get(node, None)

Return learned linear shift terms for a node and a given state.

Parameters

node : str Node name. state : {"best", "last", "init"}, optional Model state for which to return linear shift terms. Default is "best".

Returns

dict or Any or None Linear shift terms for the given node and state. Usually a dict mapping term names to weights.

Raises

ValueError If an invalid state is given (not in {"best", "last", "init"}).

Notes

Linear shift dictionaries are cached on the instance under the attribute linear_shift_dicts. If missing or incomplete, they are recomputed using get_linear_shifts_dict.

def get_linear_shifts_dict(self):
1952    def get_linear_shifts_dict(self):
1953        """
1954        Compute linear shift term dictionaries for all nodes and states.
1955
1956        For each node and each available state ("best", "last", "init"), this
1957        method loads the corresponding model checkpoint, extracts linear shift
1958        weights from the TRAM model, and stores them in a nested dictionary.
1959
1960        Returns
1961        -------
1962        dict
1963            Nested dictionary of the form:
1964
1965            .. code-block:: python
1966
1967                {
1968                    "best": {node: {...}},
1969                    "last": {node: {...}},
1970                    "init": {node: {...}},
1971                }
1972
1973            where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``)
1974            to their weights.
1975
1976        Notes
1977        -----
1978        - If "best" or "last" checkpoints are unavailable for a node, only
1979        the "init" entry is populated.
1980        - Empty outer states (without any nodes) are removed from the result.
1981        """
1982
1983        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1984        nodes_list = list(self.models.keys())
1985        all_states = ["best", "last", "init"]
1986        all_linear_shift_dicts = {state: {} for state in all_states}
1987
1988        for node in nodes_list:
1989            NODE_DIR = os.path.join(EXPERIMENT_DIR, node)
1990            BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR)
1991            INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt")
1992
1993            state_paths = {
1994                "best": BEST_MODEL_PATH,
1995                "last": LAST_MODEL_PATH,
1996                "init": INIT_MODEL_PATH,
1997            }
1998
1999            for state, LOAD_PATH in state_paths.items():
2000                if not os.path.exists(LOAD_PATH):
2001                    if state != "init":
2002                        # skip best/last if unavailable
2003                        continue
2004                    else:
2005                        print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.")
2006                        if not os.path.exists(LOAD_PATH):
2007                            if getattr(self, "debug", False):
2008                                print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.")
2009                            continue
2010
2011                # Load parents and model
2012                _, terms_dict, _ = ordered_parents(node, self.nodes_dict)
2013                state_dict = torch.load(LOAD_PATH, map_location=self.device)
2014                tram_model = self.models[node]
2015                tram_model.load_state_dict(state_dict)
2016
2017                epoch_weights = {}
2018                if hasattr(tram_model, "nn_shift") and tram_model.nn_shift is not None:
2019                    for i, shift_layer in enumerate(tram_model.nn_shift):
2020                        module_name = shift_layer.__class__.__name__
2021                        if (
2022                            hasattr(shift_layer, "fc")
2023                            and hasattr(shift_layer.fc, "weight")
2024                            and module_name == "LinearShift"
2025                        ):
2026                            term_name = list(terms_dict.keys())[i]
2027                            epoch_weights[f"ls({term_name})"] = (
2028                                shift_layer.fc.weight.detach().cpu().squeeze().tolist()
2029                            )
2030                        elif getattr(self, "debug", False):
2031                            term_name = list(terms_dict.keys())[i]
2032                            print(f"[DEBUG] ls({term_name}): missing 'fc' or 'weight' in LinearShift.")
2033                else:
2034                    if getattr(self, "debug", False):
2035                        print(f"[DEBUG] Tram model for node '{node}' has no nn_shift or it is None.")
2036
2037                all_linear_shift_dicts[state][node] = epoch_weights
2038
2039        # Remove empty states (e.g., when best/last not found for all nodes)
2040        all_linear_shift_dicts = {k: v for k, v in all_linear_shift_dicts.items() if v}
2041
2042        return all_linear_shift_dicts

Compute linear shift term dictionaries for all nodes and states.

For each node and each available state ("best", "last", "init"), this method loads the corresponding model checkpoint, extracts linear shift weights from the TRAM model, and stores them in a nested dictionary.

Returns

dict Nested dictionary of the form:

```python

{ "best": {node: {...}}, "last": {node: {...}}, "init": {node: {...}}, } ```

where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``)
to their weights.

Notes

  • If "best" or "last" checkpoints are unavailable for a node, only the "init" entry is populated.
  • Empty outer states (without any nodes) are removed from the result.
def get_simple_intercepts_dict(self):
2044    def get_simple_intercepts_dict(self):
2045        """
2046        Compute transformed simple intercept dictionaries for all nodes and states.
2047
2048        For each node and each available state ("best", "last", "init"), this
2049        method loads the corresponding model checkpoint, extracts simple intercept
2050        weights, transforms them into interpretable theta parameters, and stores
2051        them in a nested dictionary.
2052
2053        Returns
2054        -------
2055        dict
2056            Nested dictionary of the form:
2057
2058            .. code-block:: python
2059
2060                {
2061                    "best": {node: [[theta_1], [theta_2], ...]},
2062                    "last": {node: [[theta_1], [theta_2], ...]},
2063                    "init": {node: [[theta_1], [theta_2], ...]},
2064                }
2065
2066        Notes
2067        -----
2068        - For ordinal models (``self.is_ontram == True``), `transform_intercepts_ordinal`
2069        is used.
2070        - For continuous models, `transform_intercepts_continous` is used.
2071        - Empty outer states (without any nodes) are removed from the result.
2072        """
2073
2074        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
2075        nodes_list = list(self.models.keys())
2076        all_states = ["best", "last", "init"]
2077        all_si_intercept_dicts = {state: {} for state in all_states}
2078
2079        debug = getattr(self, "debug", False)
2080        verbose = getattr(self, "verbose", False)
2081        is_ontram = getattr(self, "is_ontram", False)
2082
2083        for node in nodes_list:
2084            NODE_DIR = os.path.join(EXPERIMENT_DIR, node)
2085            BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR)
2086            INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt")
2087
2088            state_paths = {
2089                "best": BEST_MODEL_PATH,
2090                "last": LAST_MODEL_PATH,
2091                "init": INIT_MODEL_PATH,
2092            }
2093
2094            for state, LOAD_PATH in state_paths.items():
2095                if not os.path.exists(LOAD_PATH):
2096                    if state != "init":
2097                        continue
2098                    else:
2099                        print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.")
2100                        if not os.path.exists(LOAD_PATH):
2101                            if debug:
2102                                print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.")
2103                            continue
2104
2105                # Load model state
2106                state_dict = torch.load(LOAD_PATH, map_location=self.device)
2107                tram_model = self.models[node]
2108                tram_model.load_state_dict(state_dict)
2109
2110                # Extract and transform simple intercept weights
2111                si_weights = None
2112                if hasattr(tram_model, "nn_int") and tram_model.nn_int is not None and isinstance(tram_model.nn_int, SimpleIntercept):
2113                    if hasattr(tram_model.nn_int, "fc") and hasattr(tram_model.nn_int.fc, "weight"):
2114                        weights = tram_model.nn_int.fc.weight.detach().cpu().tolist()
2115                        weights_tensor = torch.Tensor(weights)
2116
2117                        if debug:
2118                            print(f"[DEBUG] Node '{node}' ({state}) theta tilde shape: {weights_tensor.shape}")
2119
2120                        if is_ontram:
2121                            si_weights = transform_intercepts_ordinal(weights_tensor.reshape(1, -1))[:, 1:-1].reshape(-1, 1)
2122                        else:
2123                            si_weights = transform_intercepts_continous(weights_tensor.reshape(1, -1)).reshape(-1, 1)
2124
2125                        si_weights = si_weights.tolist()
2126
2127                        if debug:
2128                            print(f"[DEBUG] Node '{node}' ({state}) theta transformed: {si_weights}")
2129                    else:
2130                        if debug:
2131                            print(f"[DEBUG] Node '{node}' ({state}): missing 'fc' or 'weight' in SimpleIntercept.")
2132                else:
2133                    if debug:
2134                        print(f"[DEBUG] Tram model for node '{node}' has no nn_int or it is None.")
2135
2136                all_si_intercept_dicts[state][node] = si_weights
2137
2138        # Clean up empty states
2139        all_si_intercept_dicts = {k: v for k, v in all_si_intercept_dicts.items() if v}
2140        return all_si_intercept_dicts

Compute transformed simple intercept dictionaries for all nodes and states.

For each node and each available state ("best", "last", "init"), this method loads the corresponding model checkpoint, extracts simple intercept weights, transforms them into interpretable theta parameters, and stores them in a nested dictionary.

Returns

dict Nested dictionary of the form:

```python

{ "best": {node: [[theta_1], [theta_2], ...]}, "last": {node: [[theta_1], [theta_2], ...]}, "init": {node: [[theta_1], [theta_2], ...]}, } ```

Notes

  • For ordinal models (self.is_ontram == True), transform_intercepts_ordinal is used.
  • For continuous models, transform_intercepts_continous is used.
  • Empty outer states (without any nodes) are removed from the result.
def summary(self, verbose=False):
2142    def summary(self, verbose=False):
2143        """
2144        Print a multi-part textual summary of the TramDagModel.
2145
2146        The summary includes:
2147        1. Training metrics overview per node (best/last NLL, epochs).
2148        2. Node-specific details (thetas, linear shifts, optional architecture).
2149        3. Basic information about the attached training DataFrame, if present.
2150
2151        Parameters
2152        ----------
2153        verbose : bool, optional
2154            If True, include extended per-node details such as the model
2155            architecture, parameter count, and availability of checkpoints
2156            and sampling results. Default is False.
2157
2158        Returns
2159        -------
2160        None
2161
2162        Notes
2163        -----
2164        This method prints to stdout and does not return structured data.
2165        It is intended for quick, human-readable inspection of the current
2166        training and model state.
2167        """
2168
2169        # ---------- SETUP ----------
2170        try:
2171            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
2172        except KeyError:
2173            EXPERIMENT_DIR = None
2174            print("[WARNING] Missing EXPERIMENT_DIR in cfg.conf_dict['PATHS'].")
2175
2176        print("\n" + "=" * 120)
2177        print(f"{'TRAM DAG MODEL SUMMARY':^120}")
2178        print("=" * 120)
2179
2180        # ---------- METRICS OVERVIEW ----------
2181        summary_data = []
2182        for node in self.models.keys():
2183            node_dir = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node)
2184            train_path = os.path.join(node_dir, "train_loss_hist.json")
2185            val_path = os.path.join(node_dir, "val_loss_hist.json")
2186
2187            if os.path.exists(train_path) and os.path.exists(val_path):
2188                best_train_nll, best_val_nll = self.get_train_val_nll(node, "best")
2189                last_train_nll, last_val_nll = self.get_train_val_nll(node, "last")
2190                n_epochs_total = len(json.load(open(train_path)))
2191            else:
2192                best_train_nll = best_val_nll = last_train_nll = last_val_nll = None
2193                n_epochs_total = 0
2194
2195            summary_data.append({
2196                "Node": node,
2197                "Best Train NLL": best_train_nll,
2198                "Best Val NLL": best_val_nll,
2199                "Last Train NLL": last_train_nll,
2200                "Last Val NLL": last_val_nll,
2201                "Epochs": n_epochs_total,
2202            })
2203
2204        df_summary = pd.DataFrame(summary_data)
2205        df_summary = df_summary.round(4)
2206
2207        print("\n[1] TRAINING METRICS OVERVIEW")
2208        print("-" * 120)
2209        if not df_summary.empty:
2210            print(
2211                df_summary.to_string(
2212                    index=False,
2213                    justify="center",
2214                    col_space=14,
2215                    float_format=lambda x: f"{x:7.4f}",
2216                )
2217            )
2218        else:
2219            print("No training history found for any node.")
2220        print("-" * 120)
2221
2222        # ---------- NODE DETAILS ----------
2223        print("\n[2] NODE-SPECIFIC DETAILS")
2224        print("-" * 120)
2225        for node in self.models.keys():
2226            print(f"\n{f'NODE: {node}':^120}")
2227            print("-" * 120)
2228
2229            # THETAS & SHIFTS
2230            for state in ["init", "last", "best"]:
2231                print(f"\n  [{state.upper()} STATE]")
2232
2233                # ---- Thetas ----
2234                try:
2235                    thetas = getattr(self, "get_thetas", lambda n, s=None: None)(node, state)
2236                    if thetas is not None:
2237                        if isinstance(thetas, (list, np.ndarray, pd.Series)):
2238                            thetas_flat = np.array(thetas).flatten()
2239                            compact = np.round(thetas_flat, 4)
2240                            arr_str = np.array2string(
2241                                compact,
2242                                max_line_width=110,
2243                                threshold=np.inf,
2244                                separator=", "
2245                            )
2246                            lines = arr_str.split("\n")
2247                            if len(lines) > 2:
2248                                arr_str = "\n".join(lines[:2]) + " ..."
2249                            print(f"    Θ ({len(thetas_flat)}): {arr_str}")
2250                        elif isinstance(thetas, dict):
2251                            for k, v in thetas.items():
2252                                print(f"     Θ[{k}]: {v}")
2253                        else:
2254                            print(f"    Θ: {thetas}")
2255                    else:
2256                        print("    Θ: not available")
2257                except Exception as e:
2258                    print(f"    [Error loading thetas] {e}")
2259
2260                # ---- Linear Shifts ----
2261                try:
2262                    linear_shifts = getattr(self, "get_linear_shifts", lambda n, s=None: None)(node, state)
2263                    if linear_shifts is not None:
2264                        if isinstance(linear_shifts, dict):
2265                            for k, v in linear_shifts.items():
2266                                print(f"     {k}: {np.round(v, 4)}")
2267                        elif isinstance(linear_shifts, (list, np.ndarray, pd.Series)):
2268                            arr = np.round(linear_shifts, 4)
2269                            print(f"    Linear shifts ({len(arr)}): {arr}")
2270                        else:
2271                            print(f"    Linear shifts: {linear_shifts}")
2272                    else:
2273                        print("    Linear shifts: not available")
2274                except Exception as e:
2275                    print(f"    [Error loading linear shifts] {e}")
2276
2277            # ---- Verbose info directly below node ----
2278            if verbose:
2279                print("\n  [DETAILS]")
2280                node_dir = os.path.join(EXPERIMENT_DIR, node) if EXPERIMENT_DIR else None
2281                model = self.models[node]
2282
2283                print(f"    Model Architecture:")
2284                arch_str = str(model).split("\n")
2285                for line in arch_str:
2286                    print(f"      {line}")
2287                print(f"    Parameter count: {sum(p.numel() for p in model.parameters()):,}")
2288
2289                if node_dir and os.path.exists(node_dir):
2290                    ckpt_exists = any(f.endswith(('.pt', '.pth')) for f in os.listdir(node_dir))
2291                    print(f"    Checkpoints found: {ckpt_exists}")
2292
2293                    sampling_dir = os.path.join(node_dir, "sampling")
2294                    sampling_exists = os.path.isdir(sampling_dir) and len(os.listdir(sampling_dir)) > 0
2295                    print(f"    Sampling results found: {sampling_exists}")
2296
2297                    for label, filename in [("Train", "train_loss_hist.json"), ("Validation", "val_loss_hist.json")]:
2298                        path = os.path.join(node_dir, filename)
2299                        if os.path.exists(path):
2300                            try:
2301                                with open(path, "r") as f:
2302                                    hist = json.load(f)
2303                                print(f"    {label} history: {len(hist)} epochs")
2304                            except Exception as e:
2305                                print(f"    {label} history: failed to load ({e})")
2306                        else:
2307                            print(f"    {label} history: not found")
2308                else:
2309                    print("    [INFO] No experiment directory defined or missing for this node.")
2310            print("-" * 120)
2311
2312        # ---------- TRAINING DATAFRAME ----------
2313        print("\n[3] TRAINING DATAFRAME")
2314        print("-" * 120)
2315        try:
2316            self.train_df.info()
2317        except AttributeError:
2318            print("No training DataFrame attached to this TramDagModel.")
2319        print("=" * 120 + "\n")

Print a multi-part textual summary of the TramDagModel.

The summary includes:

  1. Training metrics overview per node (best/last NLL, epochs).
  2. Node-specific details (thetas, linear shifts, optional architecture).
  3. Basic information about the attached training DataFrame, if present.

Parameters

verbose : bool, optional If True, include extended per-node details such as the model architecture, parameter count, and availability of checkpoints and sampling results. Default is False.

Returns

None

Notes

This method prints to stdout and does not return structured data. It is intended for quick, human-readable inspection of the current training and model state.