tramdag

 1# tramdag/__init__.py
 2from .TramDagConfig import TramDagConfig
 3from .TramDagDataset import TramDagDataset
 4from .TramDagModel import TramDagModel
 5from importlib.metadata import version, PackageNotFoundError
 6
 7
 8__all__ = ["TramDagConfig", "TramDagDataset", "TramDagModel"]
 9
10try:
11    __version__ = version("tramdag")
12except PackageNotFoundError:
13    __version__ = "0.0.0"
class TramDagConfig:
 30class TramDagConfig:
 31    """
 32    Configuration manager for TRAM-DAG experiments.
 33
 34    This class encapsulates:
 35
 36    - The experiment configuration dictionary (`conf_dict`).
 37    - Its backing file path (`CONF_DICT_PATH`).
 38    - Utilities to load, validate, modify, and persist configuration.
 39    - DAG visualization and interactive editing helpers.
 40
 41    Typical usage
 42    -------------
 43    - Load existing configuration from disk via `TramDagConfig.load_json`.
 44    - Or create/reuse experiment setup via `TramDagConfig().setup_configuration`.
 45    - Update sections such as `data_type`, adjacency matrix, and neural network
 46    model names using the provided methods.
 47    """
 48        
 49    def __init__(self, conf_dict: dict = None, CONF_DICT_PATH: str = None,  _verify: bool = False,**kwargs):
 50        """
 51        Initialize a TramDagConfig instance.
 52
 53        Parameters
 54        ----------
 55        conf_dict : dict or None, optional
 56            Configuration dictionary. If None, an empty dict is used and can
 57            be populated later. Default is None.
 58        CONF_DICT_PATH : str or None, optional
 59            Path to the configuration file on disk. Default is None.
 60        _verify : bool, optional
 61            If True, run `_verify_completeness()` after initialization.
 62            Default is False.
 63        **kwargs
 64            Additional attributes to be set on the instance. Keys "conf_dict"
 65            and "CONF_DICT_PATH" are forbidden and raise a ValueError.
 66
 67        Raises
 68        ------
 69        ValueError
 70            If any key in `kwargs` is "conf_dict" or "CONF_DICT_PATH".
 71
 72        Notes
 73        -----
 74        By default, `debug` and `verbose` are set to False. They can be
 75        overridden via `kwargs`.
 76        """
 77
 78        self.debug = False
 79        self.verbose = False
 80        
 81        for key, value in kwargs.items():
 82            if key in ['conf_dict', 'CONF_DICT_PATH']:
 83                raise ValueError(f"Cannot override '{key}' via kwargs.")
 84            setattr(self, key, value)
 85        
 86        self.conf_dict = conf_dict or {}
 87        self.CONF_DICT_PATH = CONF_DICT_PATH
 88        
 89        # verification 
 90        if _verify:
 91            self._verify_completeness()
 92
 93    @classmethod
 94    def load_json(cls, CONF_DICT_PATH: str,debug: bool = False):
 95        """
 96        Load a configuration from a JSON file and construct a TramDagConfig.
 97
 98        Parameters
 99        ----------
100        CONF_DICT_PATH : str
101            Path to the configuration JSON file.
102        debug : bool, optional
103            If True, initialize the instance with `debug=True`. Default is False.
104
105        Returns
106        -------
107        TramDagConfig
108            Newly created configuration instance with `conf_dict` loaded from
109            `CONF_DICT_PATH` and `_verify_completeness()` executed.
110
111        Raises
112        ------
113        FileNotFoundError
114            If the configuration file cannot be found (propagated by
115            `load_configuration_dict`).
116        """
117
118        conf = load_configuration_dict(CONF_DICT_PATH)
119        return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=True)
120
121    def update(self):
122        """
123        Reload the latest configuration from disk into this instance.
124
125        Parameters
126        ----------
127        None
128
129        Returns
130        -------
131        None
132
133        Raises
134        ------
135        ValueError
136            If `CONF_DICT_PATH` is not set on the instance.
137
138        Notes
139        -----
140        The current in-memory `conf_dict` is overwritten by the contents
141        loaded from `CONF_DICT_PATH`.
142        """
143
144        if not hasattr(self, "CONF_DICT_PATH") or self.CONF_DICT_PATH is None:
145            raise ValueError("CONF_DICT_PATH not set — cannot update configuration.")
146        
147        
148        self.conf_dict = load_configuration_dict(self.CONF_DICT_PATH)
149
150
151    def save(self, CONF_DICT_PATH: str = None):
152        """
153        Persist the current configuration dictionary to disk.
154
155        Parameters
156        ----------
157        CONF_DICT_PATH : str or None, optional
158            Target path for the configuration file. If None, uses
159            `self.CONF_DICT_PATH`. Default is None.
160
161        Returns
162        -------
163        None
164
165        Raises
166        ------
167        ValueError
168            If neither the function argument nor `self.CONF_DICT_PATH`
169            provides a valid path.
170
171        Notes
172        -----
173        The resulting file is written via `write_configuration_dict`.
174        """
175        path = CONF_DICT_PATH or self.CONF_DICT_PATH
176        if path is None:
177            raise ValueError("No CONF_DICT_PATH provided to save config.")
178        write_configuration_dict(self.conf_dict, path)
179
180    def _verify_completeness(self):
181        """
182        Check that the configuration is structurally complete and consistent.
183
184        The following checks are performed:
185
186        1. Top-level mandatory keys:
187        - "experiment_name"
188        - "PATHS"
189        - "nodes"
190        - "data_type"
191        - "adj_matrix"
192        - "model_names"
193
194        2. Per-node mandatory keys:
195        - "data_type"
196        - "node_type"
197        - "parents"
198        - "parents_datatype"
199        - "transformation_terms_in_h()"
200        - "transformation_term_nn_models_in_h()"
201
202        3. Ordinal / categorical levels:
203        - All ordinal variables must have a corresponding "levels" entry
204            under `conf_dict["nodes"][var]`.
205
206        4. Experiment name:
207        - Must be non-empty.
208
209        5. Adjacency matrix:
210        - Must be valid under `validate_adj_matrix`.
211
212        Parameters
213        ----------
214        None
215
216        Returns
217        -------
218        None
219
220        Notes
221        -----
222        - Missing or invalid components are reported via printed warnings.
223        - Detailed debug messages are printed when `self.debug=True`.
224        """
225        mandatory_keys = ["experiment_name","PATHS", "nodes", "data_type", "adj_matrix","nodes","model_names"]
226        optional_keys = ["date_of_creation", "seed"]
227
228        # ---- 1. Check mandatory keys exist
229        missing = [k for k in mandatory_keys if k not in self.conf_dict]
230        if missing:
231            print(f"[WARNING] Missing mandatory keys in configuration: {missing}"
232                "\n Please add them to the configuration dict and reload.")
233            
234        # --- 2. Check  if mandatory keys in nodesdict are present
235        mandatory_keys_nodes = ['data_type', 'node_type','parents','parents_datatype','transformation_terms_in_h()','transformation_term_nn_models_in_h()']
236        optional_keys_nodes = ["levels"]
237        for node, node_dict in self.conf_dict.get("nodes", {}).items():
238            # check missing mandatory keys
239            missing_node_keys = [k for k in mandatory_keys_nodes if k not in node_dict]
240            if missing_node_keys:
241                print(f"[WARNING] Node '{node}' is missing mandatory keys: {missing_node_keys}")
242                
243
244        
245        if self._verify_levels_dict():
246            if self.debug:
247                print("[DEBUG] levels are present for all ordinal variables in configuration dict.")
248            pass
249        else:
250            print("[WARNING]  levels are missing for some ordinal variables in configuration dict. THIS will FAIL in model training later!\n"
251                " Please provide levels manually to config and reload or compute levels from data using the method compute_levels().\n"
252                " e.g. cfg.compute_levels(train_df) # computes levels from training data and writes to cfg")
253
254        if self._verify_experiment_name():
255            if self.debug:
256                print("[DEBUG] experiment_name is valid in configuration dict.")
257            pass
258        
259        if self._verify_adj_matrix():
260            if self.debug:
261                print("[DEBUG] adj_matrix is valid in configuration dict.")
262            pass
263
264    def _verify_levels_dict(self):
265        """
266        Verify that all ordinal variables have levels specified in the config.
267
268        Parameters
269        ----------
270        None
271
272        Returns
273        -------
274        bool
275            True if all variables declared as ordinal in ``conf_dict["data_type"]``
276            have a "levels" entry in ``conf_dict["nodes"][var]``.
277            False otherwise.
278
279        Notes
280        -----
281        This method does not modify the configuration; it only checks presence
282        of level information.
283        """
284        data_type = self.conf_dict.get('data_type', {})
285        nodes = self.conf_dict.get('nodes', {})
286        for var, dtype in data_type.items():
287            if 'ordinal' in dtype:
288                if var not in nodes or 'levels' not in nodes[var]:
289                    return False
290        return True
291
292    def _verify_experiment_name(self):
293        """
294        Check whether the experiment name in the configuration is valid.
295
296        Parameters
297        ----------
298        None
299
300        Returns
301        -------
302        bool
303            True if ``conf_dict["experiment_name"]`` exists and is non-empty
304            after stripping whitespace. False otherwise.
305        """
306        experiment_name = self.conf_dict.get("experiment_name")
307        if experiment_name is None or str(experiment_name).strip() == "":
308            return False
309        return True
310        
311    def _verify_adj_matrix(self):
312        """
313        Validate the adjacency matrix stored in the configuration.
314
315        Parameters
316        ----------
317        None
318
319        Returns
320        -------
321        bool
322            True if the adjacency matrix passes `validate_adj_matrix`, False otherwise.
323
324        Notes
325        -----
326        If the adjacency matrix is stored as a list, it is converted to a
327        NumPy array before validation.
328        """
329        adj_matrix = self.conf_dict['adj_matrix']
330        if isinstance(adj_matrix, list):
331            adj_matrix = np.array(self.conf_dict['adj_matrix'])
332        if validate_adj_matrix(adj_matrix):
333            return True
334        else:
335            return False
336
337    def compute_levels(self, df: pd.DataFrame, write: bool = True):
338        """
339        Infer and update ordinal/categorical levels from data.
340
341        For each variable in the configuration's `data_type` section, this
342        method uses the provided DataFrame to construct a levels dictionary
343        and injects the corresponding "levels" entry into `conf_dict["nodes"]`.
344
345        Parameters
346        ----------
347        df : pandas.DataFrame
348            DataFrame used to infer levels for configured variables.
349        write : bool, optional
350            If True and `CONF_DICT_PATH` is set, the updated configuration is
351            written back to disk. Default is True.
352
353        Returns
354        -------
355        None
356
357        Raises
358        ------
359        Exception
360            If saving the configuration fails when `write=True`.
361
362        Notes
363        -----
364        - Variables present in `levels_dict` but not in `conf_dict["nodes"]`
365        trigger a warning and are skipped.
366        - If `self.verbose` or `self.debug` is True, a success message is printed
367        when the configuration is saved.
368        """
369        self.update()
370        levels_dict = create_levels_dict(df, self.conf_dict['data_type'])
371        
372        # update nodes dict with levels
373        for var, levels in levels_dict.items():
374            if var in self.conf_dict['nodes']:
375                self.conf_dict['nodes'][var]['levels'] = levels
376            else:
377                print(f"[WARNING] Variable '{var}' not found in nodes dict. Cannot add levels.")
378        
379        if write and self.CONF_DICT_PATH is not None:
380            try:
381                self.save(self.CONF_DICT_PATH)
382                if self.verbose or self.debug:
383                    print(f'[INFO] Configuration with updated levels saved to {self.CONF_DICT_PATH}')
384            except Exception as e:
385                print(f'[ERROR] Failed to save configuration: {e}')
386
387    def plot_dag(self, seed: int = 42, causal_order: bool = False):
388        """
389        Visualize the DAG defined by the configuration.
390
391        Nodes are categorized and colored as:
392        - Source nodes (no incoming edges): green.
393        - Sink nodes (no outgoing edges): red.
394        - Intermediate nodes: light blue.
395
396        Parameters
397        ----------
398        seed : int, optional
399            Random seed for layout stability in the spring layout fallback.
400            Default is 42.
401        causal_order : bool, optional
402            If True, attempt to use Graphviz 'dot' layout via
403            `networkx.nx_agraph.graphviz_layout` to preserve causal ordering.
404            If False or if Graphviz is unavailable, use `spring_layout`.
405            Default is False.
406
407        Returns
408        -------
409        None
410
411        Raises
412        ------
413        ValueError
414            If `adj_matrix` or `data_type` is missing or inconsistent with each other,
415            or if the adjacency matrix fails validation.
416
417        Notes
418        -----
419        Edge labels are colored by prefix:
420        - "ci": blue
421        - "ls": red
422        - "cs": green
423        - other: black
424        """
425        adj_matrix = self.conf_dict.get("adj_matrix")
426        data_type  = self.conf_dict.get("data_type")
427
428        if adj_matrix is None or data_type is None:
429            raise ValueError("Configuration must include 'adj_matrix' and 'data_type'.")
430
431        if isinstance(adj_matrix, list):
432            adj_matrix = np.array(adj_matrix)
433
434        if not validate_adj_matrix(adj_matrix):
435            raise ValueError("Invalid adjacency matrix.")
436        if len(data_type) != adj_matrix.shape[0]:
437            raise ValueError("data_type must match adjacency matrix size.")
438
439        node_labels = list(data_type.keys())
440        G, edge_labels = create_nx_graph(adj_matrix, node_labels)
441
442        sources       = {n for n in G.nodes if G.in_degree(n) == 0}
443        sinks         = {n for n in G.nodes if G.out_degree(n) == 0}
444        intermediates = set(G.nodes) - sources - sinks
445
446        node_colors = [
447            "green" if n in sources
448            else "red" if n in sinks
449            else "lightblue"
450            for n in G.nodes
451        ]
452
453        if causal_order:
454            try:
455                pos = nx.nx_agraph.graphviz_layout(G, prog="dot")
456            except (ImportError, nx.NetworkXException):
457                pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100)
458        else:
459            pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100)
460
461        plt.figure(figsize=(8, 6))
462        nx.draw(
463            G, pos,
464            with_labels=True,
465            node_color=node_colors,
466            edge_color="gray",
467            node_size=2500,
468            arrowsize=20
469        )
470
471        for (u, v), lbl in edge_labels.items():
472            color = (
473                "blue"  if lbl.startswith("ci")
474                else "red"   if lbl.startswith("ls")
475                else "green" if lbl.startswith("cs")
476                else "black"
477            )
478            nx.draw_networkx_edge_labels(
479                G, pos,
480                edge_labels={(u, v): lbl},
481                font_color=color,
482                font_size=12
483            )
484
485        legend_items = [
486            Patch(facecolor="green",     edgecolor="black", label="Source"),
487            Patch(facecolor="red",       edgecolor="black", label="Sink"),
488            Patch(facecolor="lightblue", edgecolor="black", label="Intermediate")
489        ]
490        plt.legend(handles=legend_items, loc="upper right", frameon=True)
491
492        plt.title(f"TRAM DAG")
493        plt.axis("off")
494        plt.tight_layout()
495        plt.show()
496
497  
498    def setup_configuration(self, experiment_name=None, EXPERIMENT_DIR=None, debug=False, _verify=False):
499        """
500        Create or reuse a configuration for an experiment.
501
502        This method behaves differently depending on how it is called:
503
504        1. Class call (e.g. `TramDagConfig.setup_configuration(...)`):
505        - Creates or loads a configuration at the resolved path.
506        - Returns a new `TramDagConfig` instance.
507
508        2. Instance call (e.g. `cfg.setup_configuration(...)`):
509        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` in place.
510        - Optionally verifies completeness.
511        - Returns None.
512
513        Parameters
514        ----------
515        experiment_name : str or None, optional
516            Name of the experiment. If None, defaults to "experiment_1".
517        EXPERIMENT_DIR : str or None, optional
518            Directory for the experiment. If None, defaults to
519            `<cwd>/<experiment_name>`.
520        debug : bool, optional
521            If True, initialize / update with `debug=True`. Default is False.
522        _verify : bool, optional
523            If True, call `_verify_completeness()` after loading. Default is False.
524
525        Returns
526        -------
527        TramDagConfig or None
528            - A new instance when called on the class.
529            - None when called on an existing instance.
530
531        Notes
532        -----
533        - A configuration file named "configuration.json" is created if it does
534        not exist yet.
535        - Underlying creation uses `create_and_write_new_configuration_dict`
536        and `load_configuration_dict`.
537        """
538        is_class_call = isinstance(self, type)
539        cls = self if is_class_call else self.__class__
540
541        if experiment_name is None:
542            experiment_name = "experiment_1"
543        if EXPERIMENT_DIR is None:
544            EXPERIMENT_DIR = os.path.join(os.getcwd(), experiment_name)
545
546        CONF_DICT_PATH = os.path.join(EXPERIMENT_DIR, "configuration.json")
547        DATA_PATH = EXPERIMENT_DIR
548
549        os.makedirs(EXPERIMENT_DIR, exist_ok=True)
550
551        if os.path.exists(CONF_DICT_PATH):
552            print(f"Configuration already exists: {CONF_DICT_PATH}")
553        else:
554            _ = create_and_write_new_configuration_dict(
555                experiment_name, CONF_DICT_PATH, EXPERIMENT_DIR, DATA_PATH, None
556            )
557            print(f"Created new configuration file at {CONF_DICT_PATH}")
558
559        conf = load_configuration_dict(CONF_DICT_PATH)
560
561        if is_class_call:
562            return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=_verify)
563        else:
564            self.conf_dict = conf
565            self.CONF_DICT_PATH = CONF_DICT_PATH
566            if _verify:
567                self._verify_completeness()
568                
569    def set_data_type(self, data_type: dict, CONF_DICT_PATH: str = None) -> None:
570        """
571        Update or write the `data_type` section of a configuration file.
572
573        Supports both class-level and instance-level usage:
574
575        - Class call:
576        - Requires `CONF_DICT_PATH` argument.
577        - Reads the file if it exists, or starts from an empty dict.
578        - Writes updated configuration to `CONF_DICT_PATH`.
579
580        - Instance call:
581        - Uses `self.CONF_DICT_PATH` if available, otherwise defaults to
582            `<cwd>/configuration.json` if no path is provided.
583        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` after writing.
584
585        Parameters
586        ----------
587        data_type : dict
588            Mapping `{variable_name: type_spec}`, where `type_spec` encodes
589            modeling types (e.g. continuous, ordinal, etc.).
590        CONF_DICT_PATH : str or None, optional
591            Path to the configuration file. Must be provided for class calls.
592            For instance calls, defaults as described above.
593
594        Returns
595        -------
596        None
597
598        Raises
599        ------
600        ValueError
601            If `CONF_DICT_PATH` is missing when called on the class, or if
602            validation of data types fails.
603
604        Notes
605        -----
606        - Variable names are validated via `validate_variable_names`.
607        - Data type values are validated via `validate_data_types`.
608        - A textual summary of modeling settings is printed via
609        `print_data_type_modeling_setting`, if possible.
610        """
611        is_class_call = isinstance(self, type)
612        cls = self if is_class_call else self.__class__
613
614        # resolve path
615        if CONF_DICT_PATH is None:
616            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
617                CONF_DICT_PATH = self.CONF_DICT_PATH
618            elif not is_class_call:
619                CONF_DICT_PATH = os.path.join(os.getcwd(), "configuration.json")
620            else:
621                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
622
623        try:
624            # load existing or create empty configuration
625            configuration_dict = (
626                load_configuration_dict(CONF_DICT_PATH)
627                if os.path.exists(CONF_DICT_PATH)
628                else {}
629            )
630
631            validate_variable_names(data_type.keys())
632            if not validate_data_types(data_type):
633                raise ValueError("Invalid data types in the provided dictionary.")
634
635            configuration_dict["data_type"] = data_type
636            write_configuration_dict(configuration_dict, CONF_DICT_PATH)
637
638            # safe printing
639            try:
640                print_data_type_modeling_setting(data_type or {})
641            except Exception as e:
642                print(f"[WARNING] Could not print data type modeling settings: {e}")
643
644            if not is_class_call:
645                self.conf_dict = configuration_dict
646                self.CONF_DICT_PATH = CONF_DICT_PATH
647
648
649        except Exception as e:
650            print(f"Failed to update configuration: {e}")
651        else:
652            print(f"Configuration updated successfully at {CONF_DICT_PATH}.")
653            
654    def set_meta_adj_matrix(self, CONF_DICT_PATH: str = None, seed: int = 5):
655        """
656        Launch the interactive editor to set or modify the adjacency matrix.
657
658        This method:
659
660        1. Resolves the configuration path either from the argument or, for
661        instances, from `self.CONF_DICT_PATH`.
662        2. Invokes `interactive_adj_matrix` to edit the adjacency matrix.
663        3. For instances, reloads the updated configuration into `self.conf_dict`.
664
665        Parameters
666        ----------
667        CONF_DICT_PATH : str or None, optional
668            Path to the configuration file. Must be provided when called
669            on the class. For instance calls, defaults to `self.CONF_DICT_PATH`.
670        seed : int, optional
671            Random seed for any layout or stochastic behavior in the interactive
672            editor. Default is 5.
673
674        Returns
675        -------
676        None
677
678        Raises
679        ------
680        ValueError
681            If `CONF_DICT_PATH` is not provided and cannot be inferred
682            (e.g. in a class call without path).
683
684        Notes
685        -----
686        `self.update()` is called at the start to ensure the in-memory config
687        is in sync with the file before launching the editor.
688        """
689        self.update()
690        is_class_call = isinstance(self, type)
691        # resolve path
692        if CONF_DICT_PATH is None:
693            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
694                CONF_DICT_PATH = self.CONF_DICT_PATH
695            else:
696                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
697
698        # launch interactive editor
699        
700        interactive_adj_matrix(CONF_DICT_PATH, seed=seed)
701
702        # reload config if instance
703        if not is_class_call:
704            self.conf_dict = load_configuration_dict(CONF_DICT_PATH)
705            self.CONF_DICT_PATH = CONF_DICT_PATH
706
707
708    def set_tramdag_nn_models(self, CONF_DICT_PATH: str = None):
709        """
710        Launch the interactive editor to set TRAM-DAG neural network model names.
711
712        Depending on call context:
713
714        - Class call:
715        - Requires `CONF_DICT_PATH` argument.
716        - Returns nothing and does not modify a specific instance.
717
718        - Instance call:
719        - Resolves `CONF_DICT_PATH` from the argument or `self.CONF_DICT_PATH`.
720        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` if the editor
721            returns an updated configuration.
722
723        Parameters
724        ----------
725        CONF_DICT_PATH : str or None, optional
726            Path to the configuration file. Must be provided when called on
727            the class. For instance calls, defaults to `self.CONF_DICT_PATH`.
728
729        Returns
730        -------
731        None
732
733        Raises
734        ------
735        ValueError
736            If `CONF_DICT_PATH` is not provided and cannot be inferred
737            (e.g. in a class call without path).
738
739        Notes
740        -----
741        The interactive editor is invoked via `interactive_nn_names_matrix`.
742        If it returns `None`, the instance configuration is left unchanged.
743        """
744        is_class_call = isinstance(self, type)
745        if CONF_DICT_PATH is None:
746            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
747                CONF_DICT_PATH = self.CONF_DICT_PATH
748            else:
749                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
750
751        updated_conf = interactive_nn_names_matrix(CONF_DICT_PATH)
752        if updated_conf is not None and not is_class_call:
753            self.conf_dict = updated_conf
754            self.CONF_DICT_PATH = CONF_DICT_PATH

Configuration manager for TRAM-DAG experiments.

This class encapsulates:

  • The experiment configuration dictionary (conf_dict).
  • Its backing file path (CONF_DICT_PATH).
  • Utilities to load, validate, modify, and persist configuration.
  • DAG visualization and interactive editing helpers.

Typical usage

  • Load existing configuration from disk via TramDagConfig.load_json.
  • Or create/reuse experiment setup via TramDagConfig().setup_configuration.
  • Update sections such as data_type, adjacency matrix, and neural network model names using the provided methods.
TramDagConfig( conf_dict: dict = None, CONF_DICT_PATH: str = None, _verify: bool = False, **kwargs)
49    def __init__(self, conf_dict: dict = None, CONF_DICT_PATH: str = None,  _verify: bool = False,**kwargs):
50        """
51        Initialize a TramDagConfig instance.
52
53        Parameters
54        ----------
55        conf_dict : dict or None, optional
56            Configuration dictionary. If None, an empty dict is used and can
57            be populated later. Default is None.
58        CONF_DICT_PATH : str or None, optional
59            Path to the configuration file on disk. Default is None.
60        _verify : bool, optional
61            If True, run `_verify_completeness()` after initialization.
62            Default is False.
63        **kwargs
64            Additional attributes to be set on the instance. Keys "conf_dict"
65            and "CONF_DICT_PATH" are forbidden and raise a ValueError.
66
67        Raises
68        ------
69        ValueError
70            If any key in `kwargs` is "conf_dict" or "CONF_DICT_PATH".
71
72        Notes
73        -----
74        By default, `debug` and `verbose` are set to False. They can be
75        overridden via `kwargs`.
76        """
77
78        self.debug = False
79        self.verbose = False
80        
81        for key, value in kwargs.items():
82            if key in ['conf_dict', 'CONF_DICT_PATH']:
83                raise ValueError(f"Cannot override '{key}' via kwargs.")
84            setattr(self, key, value)
85        
86        self.conf_dict = conf_dict or {}
87        self.CONF_DICT_PATH = CONF_DICT_PATH
88        
89        # verification 
90        if _verify:
91            self._verify_completeness()

Initialize a TramDagConfig instance.

Parameters

conf_dict : dict or None, optional Configuration dictionary. If None, an empty dict is used and can be populated later. Default is None. CONF_DICT_PATH : str or None, optional Path to the configuration file on disk. Default is None. _verify : bool, optional If True, run _verify_completeness() after initialization. Default is False. **kwargs Additional attributes to be set on the instance. Keys "conf_dict" and "CONF_DICT_PATH" are forbidden and raise a ValueError.

Raises

ValueError If any key in kwargs is "conf_dict" or "CONF_DICT_PATH".

Notes

By default, debug and verbose are set to False. They can be overridden via kwargs.

debug
verbose
conf_dict
CONF_DICT_PATH
@classmethod
def load_json(cls, CONF_DICT_PATH: str, debug: bool = False):
 93    @classmethod
 94    def load_json(cls, CONF_DICT_PATH: str,debug: bool = False):
 95        """
 96        Load a configuration from a JSON file and construct a TramDagConfig.
 97
 98        Parameters
 99        ----------
100        CONF_DICT_PATH : str
101            Path to the configuration JSON file.
102        debug : bool, optional
103            If True, initialize the instance with `debug=True`. Default is False.
104
105        Returns
106        -------
107        TramDagConfig
108            Newly created configuration instance with `conf_dict` loaded from
109            `CONF_DICT_PATH` and `_verify_completeness()` executed.
110
111        Raises
112        ------
113        FileNotFoundError
114            If the configuration file cannot be found (propagated by
115            `load_configuration_dict`).
116        """
117
118        conf = load_configuration_dict(CONF_DICT_PATH)
119        return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=True)

Load a configuration from a JSON file and construct a TramDagConfig.

Parameters

CONF_DICT_PATH : str Path to the configuration JSON file. debug : bool, optional If True, initialize the instance with debug=True. Default is False.

Returns

TramDagConfig Newly created configuration instance with conf_dict loaded from CONF_DICT_PATH and _verify_completeness() executed.

Raises

FileNotFoundError If the configuration file cannot be found (propagated by load_configuration_dict).

def update(self):
121    def update(self):
122        """
123        Reload the latest configuration from disk into this instance.
124
125        Parameters
126        ----------
127        None
128
129        Returns
130        -------
131        None
132
133        Raises
134        ------
135        ValueError
136            If `CONF_DICT_PATH` is not set on the instance.
137
138        Notes
139        -----
140        The current in-memory `conf_dict` is overwritten by the contents
141        loaded from `CONF_DICT_PATH`.
142        """
143
144        if not hasattr(self, "CONF_DICT_PATH") or self.CONF_DICT_PATH is None:
145            raise ValueError("CONF_DICT_PATH not set — cannot update configuration.")
146        
147        
148        self.conf_dict = load_configuration_dict(self.CONF_DICT_PATH)

Reload the latest configuration from disk into this instance.

Parameters

None

Returns

None

Raises

ValueError If CONF_DICT_PATH is not set on the instance.

Notes

The current in-memory conf_dict is overwritten by the contents loaded from CONF_DICT_PATH.

def save(self, CONF_DICT_PATH: str = None):
151    def save(self, CONF_DICT_PATH: str = None):
152        """
153        Persist the current configuration dictionary to disk.
154
155        Parameters
156        ----------
157        CONF_DICT_PATH : str or None, optional
158            Target path for the configuration file. If None, uses
159            `self.CONF_DICT_PATH`. Default is None.
160
161        Returns
162        -------
163        None
164
165        Raises
166        ------
167        ValueError
168            If neither the function argument nor `self.CONF_DICT_PATH`
169            provides a valid path.
170
171        Notes
172        -----
173        The resulting file is written via `write_configuration_dict`.
174        """
175        path = CONF_DICT_PATH or self.CONF_DICT_PATH
176        if path is None:
177            raise ValueError("No CONF_DICT_PATH provided to save config.")
178        write_configuration_dict(self.conf_dict, path)

Persist the current configuration dictionary to disk.

Parameters

CONF_DICT_PATH : str or None, optional Target path for the configuration file. If None, uses self.CONF_DICT_PATH. Default is None.

Returns

None

Raises

ValueError If neither the function argument nor self.CONF_DICT_PATH provides a valid path.

Notes

The resulting file is written via write_configuration_dict.

def compute_levels(self, df: pandas.core.frame.DataFrame, write: bool = True):
337    def compute_levels(self, df: pd.DataFrame, write: bool = True):
338        """
339        Infer and update ordinal/categorical levels from data.
340
341        For each variable in the configuration's `data_type` section, this
342        method uses the provided DataFrame to construct a levels dictionary
343        and injects the corresponding "levels" entry into `conf_dict["nodes"]`.
344
345        Parameters
346        ----------
347        df : pandas.DataFrame
348            DataFrame used to infer levels for configured variables.
349        write : bool, optional
350            If True and `CONF_DICT_PATH` is set, the updated configuration is
351            written back to disk. Default is True.
352
353        Returns
354        -------
355        None
356
357        Raises
358        ------
359        Exception
360            If saving the configuration fails when `write=True`.
361
362        Notes
363        -----
364        - Variables present in `levels_dict` but not in `conf_dict["nodes"]`
365        trigger a warning and are skipped.
366        - If `self.verbose` or `self.debug` is True, a success message is printed
367        when the configuration is saved.
368        """
369        self.update()
370        levels_dict = create_levels_dict(df, self.conf_dict['data_type'])
371        
372        # update nodes dict with levels
373        for var, levels in levels_dict.items():
374            if var in self.conf_dict['nodes']:
375                self.conf_dict['nodes'][var]['levels'] = levels
376            else:
377                print(f"[WARNING] Variable '{var}' not found in nodes dict. Cannot add levels.")
378        
379        if write and self.CONF_DICT_PATH is not None:
380            try:
381                self.save(self.CONF_DICT_PATH)
382                if self.verbose or self.debug:
383                    print(f'[INFO] Configuration with updated levels saved to {self.CONF_DICT_PATH}')
384            except Exception as e:
385                print(f'[ERROR] Failed to save configuration: {e}')

Infer and update ordinal/categorical levels from data.

For each variable in the configuration's data_type section, this method uses the provided DataFrame to construct a levels dictionary and injects the corresponding "levels" entry into conf_dict["nodes"].

Parameters

df : pandas.DataFrame DataFrame used to infer levels for configured variables. write : bool, optional If True and CONF_DICT_PATH is set, the updated configuration is written back to disk. Default is True.

Returns

None

Raises

Exception If saving the configuration fails when write=True.

Notes

  • Variables present in levels_dict but not in conf_dict["nodes"] trigger a warning and are skipped.
  • If self.verbose or self.debug is True, a success message is printed when the configuration is saved.
def plot_dag(self, seed: int = 42, causal_order: bool = False):
387    def plot_dag(self, seed: int = 42, causal_order: bool = False):
388        """
389        Visualize the DAG defined by the configuration.
390
391        Nodes are categorized and colored as:
392        - Source nodes (no incoming edges): green.
393        - Sink nodes (no outgoing edges): red.
394        - Intermediate nodes: light blue.
395
396        Parameters
397        ----------
398        seed : int, optional
399            Random seed for layout stability in the spring layout fallback.
400            Default is 42.
401        causal_order : bool, optional
402            If True, attempt to use Graphviz 'dot' layout via
403            `networkx.nx_agraph.graphviz_layout` to preserve causal ordering.
404            If False or if Graphviz is unavailable, use `spring_layout`.
405            Default is False.
406
407        Returns
408        -------
409        None
410
411        Raises
412        ------
413        ValueError
414            If `adj_matrix` or `data_type` is missing or inconsistent with each other,
415            or if the adjacency matrix fails validation.
416
417        Notes
418        -----
419        Edge labels are colored by prefix:
420        - "ci": blue
421        - "ls": red
422        - "cs": green
423        - other: black
424        """
425        adj_matrix = self.conf_dict.get("adj_matrix")
426        data_type  = self.conf_dict.get("data_type")
427
428        if adj_matrix is None or data_type is None:
429            raise ValueError("Configuration must include 'adj_matrix' and 'data_type'.")
430
431        if isinstance(adj_matrix, list):
432            adj_matrix = np.array(adj_matrix)
433
434        if not validate_adj_matrix(adj_matrix):
435            raise ValueError("Invalid adjacency matrix.")
436        if len(data_type) != adj_matrix.shape[0]:
437            raise ValueError("data_type must match adjacency matrix size.")
438
439        node_labels = list(data_type.keys())
440        G, edge_labels = create_nx_graph(adj_matrix, node_labels)
441
442        sources       = {n for n in G.nodes if G.in_degree(n) == 0}
443        sinks         = {n for n in G.nodes if G.out_degree(n) == 0}
444        intermediates = set(G.nodes) - sources - sinks
445
446        node_colors = [
447            "green" if n in sources
448            else "red" if n in sinks
449            else "lightblue"
450            for n in G.nodes
451        ]
452
453        if causal_order:
454            try:
455                pos = nx.nx_agraph.graphviz_layout(G, prog="dot")
456            except (ImportError, nx.NetworkXException):
457                pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100)
458        else:
459            pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100)
460
461        plt.figure(figsize=(8, 6))
462        nx.draw(
463            G, pos,
464            with_labels=True,
465            node_color=node_colors,
466            edge_color="gray",
467            node_size=2500,
468            arrowsize=20
469        )
470
471        for (u, v), lbl in edge_labels.items():
472            color = (
473                "blue"  if lbl.startswith("ci")
474                else "red"   if lbl.startswith("ls")
475                else "green" if lbl.startswith("cs")
476                else "black"
477            )
478            nx.draw_networkx_edge_labels(
479                G, pos,
480                edge_labels={(u, v): lbl},
481                font_color=color,
482                font_size=12
483            )
484
485        legend_items = [
486            Patch(facecolor="green",     edgecolor="black", label="Source"),
487            Patch(facecolor="red",       edgecolor="black", label="Sink"),
488            Patch(facecolor="lightblue", edgecolor="black", label="Intermediate")
489        ]
490        plt.legend(handles=legend_items, loc="upper right", frameon=True)
491
492        plt.title(f"TRAM DAG")
493        plt.axis("off")
494        plt.tight_layout()
495        plt.show()

Visualize the DAG defined by the configuration.

Nodes are categorized and colored as:

  • Source nodes (no incoming edges): green.
  • Sink nodes (no outgoing edges): red.
  • Intermediate nodes: light blue.

Parameters

seed : int, optional Random seed for layout stability in the spring layout fallback. Default is 42. causal_order : bool, optional If True, attempt to use Graphviz 'dot' layout via networkx.nx_agraph.graphviz_layout to preserve causal ordering. If False or if Graphviz is unavailable, use spring_layout. Default is False.

Returns

None

Raises

ValueError If adj_matrix or data_type is missing or inconsistent with each other, or if the adjacency matrix fails validation.

Notes

Edge labels are colored by prefix:

  • "ci": blue
  • "ls": red
  • "cs": green
  • other: black
def setup_configuration( self, experiment_name=None, EXPERIMENT_DIR=None, debug=False, _verify=False):
498    def setup_configuration(self, experiment_name=None, EXPERIMENT_DIR=None, debug=False, _verify=False):
499        """
500        Create or reuse a configuration for an experiment.
501
502        This method behaves differently depending on how it is called:
503
504        1. Class call (e.g. `TramDagConfig.setup_configuration(...)`):
505        - Creates or loads a configuration at the resolved path.
506        - Returns a new `TramDagConfig` instance.
507
508        2. Instance call (e.g. `cfg.setup_configuration(...)`):
509        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` in place.
510        - Optionally verifies completeness.
511        - Returns None.
512
513        Parameters
514        ----------
515        experiment_name : str or None, optional
516            Name of the experiment. If None, defaults to "experiment_1".
517        EXPERIMENT_DIR : str or None, optional
518            Directory for the experiment. If None, defaults to
519            `<cwd>/<experiment_name>`.
520        debug : bool, optional
521            If True, initialize / update with `debug=True`. Default is False.
522        _verify : bool, optional
523            If True, call `_verify_completeness()` after loading. Default is False.
524
525        Returns
526        -------
527        TramDagConfig or None
528            - A new instance when called on the class.
529            - None when called on an existing instance.
530
531        Notes
532        -----
533        - A configuration file named "configuration.json" is created if it does
534        not exist yet.
535        - Underlying creation uses `create_and_write_new_configuration_dict`
536        and `load_configuration_dict`.
537        """
538        is_class_call = isinstance(self, type)
539        cls = self if is_class_call else self.__class__
540
541        if experiment_name is None:
542            experiment_name = "experiment_1"
543        if EXPERIMENT_DIR is None:
544            EXPERIMENT_DIR = os.path.join(os.getcwd(), experiment_name)
545
546        CONF_DICT_PATH = os.path.join(EXPERIMENT_DIR, "configuration.json")
547        DATA_PATH = EXPERIMENT_DIR
548
549        os.makedirs(EXPERIMENT_DIR, exist_ok=True)
550
551        if os.path.exists(CONF_DICT_PATH):
552            print(f"Configuration already exists: {CONF_DICT_PATH}")
553        else:
554            _ = create_and_write_new_configuration_dict(
555                experiment_name, CONF_DICT_PATH, EXPERIMENT_DIR, DATA_PATH, None
556            )
557            print(f"Created new configuration file at {CONF_DICT_PATH}")
558
559        conf = load_configuration_dict(CONF_DICT_PATH)
560
561        if is_class_call:
562            return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=_verify)
563        else:
564            self.conf_dict = conf
565            self.CONF_DICT_PATH = CONF_DICT_PATH
566            if _verify:
567                self._verify_completeness()

Create or reuse a configuration for an experiment.

This method behaves differently depending on how it is called:

  1. Class call (e.g. TramDagConfig.setup_configuration(...)):
  • Creates or loads a configuration at the resolved path.
  • Returns a new TramDagConfig instance.
  1. Instance call (e.g. cfg.setup_configuration(...)):
  • Updates self.conf_dict and self.CONF_DICT_PATH in place.
  • Optionally verifies completeness.
  • Returns None.

Parameters

experiment_name : str or None, optional Name of the experiment. If None, defaults to "experiment_1". EXPERIMENT_DIR : str or None, optional Directory for the experiment. If None, defaults to <cwd>/<experiment_name>. debug : bool, optional If True, initialize / update with debug=True. Default is False. _verify : bool, optional If True, call _verify_completeness() after loading. Default is False.

Returns

TramDagConfig or None - A new instance when called on the class. - None when called on an existing instance.

Notes

  • A configuration file named "configuration.json" is created if it does not exist yet.
  • Underlying creation uses create_and_write_new_configuration_dict and load_configuration_dict.
def set_data_type(self, data_type: dict, CONF_DICT_PATH: str = None) -> None:
569    def set_data_type(self, data_type: dict, CONF_DICT_PATH: str = None) -> None:
570        """
571        Update or write the `data_type` section of a configuration file.
572
573        Supports both class-level and instance-level usage:
574
575        - Class call:
576        - Requires `CONF_DICT_PATH` argument.
577        - Reads the file if it exists, or starts from an empty dict.
578        - Writes updated configuration to `CONF_DICT_PATH`.
579
580        - Instance call:
581        - Uses `self.CONF_DICT_PATH` if available, otherwise defaults to
582            `<cwd>/configuration.json` if no path is provided.
583        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` after writing.
584
585        Parameters
586        ----------
587        data_type : dict
588            Mapping `{variable_name: type_spec}`, where `type_spec` encodes
589            modeling types (e.g. continuous, ordinal, etc.).
590        CONF_DICT_PATH : str or None, optional
591            Path to the configuration file. Must be provided for class calls.
592            For instance calls, defaults as described above.
593
594        Returns
595        -------
596        None
597
598        Raises
599        ------
600        ValueError
601            If `CONF_DICT_PATH` is missing when called on the class, or if
602            validation of data types fails.
603
604        Notes
605        -----
606        - Variable names are validated via `validate_variable_names`.
607        - Data type values are validated via `validate_data_types`.
608        - A textual summary of modeling settings is printed via
609        `print_data_type_modeling_setting`, if possible.
610        """
611        is_class_call = isinstance(self, type)
612        cls = self if is_class_call else self.__class__
613
614        # resolve path
615        if CONF_DICT_PATH is None:
616            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
617                CONF_DICT_PATH = self.CONF_DICT_PATH
618            elif not is_class_call:
619                CONF_DICT_PATH = os.path.join(os.getcwd(), "configuration.json")
620            else:
621                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
622
623        try:
624            # load existing or create empty configuration
625            configuration_dict = (
626                load_configuration_dict(CONF_DICT_PATH)
627                if os.path.exists(CONF_DICT_PATH)
628                else {}
629            )
630
631            validate_variable_names(data_type.keys())
632            if not validate_data_types(data_type):
633                raise ValueError("Invalid data types in the provided dictionary.")
634
635            configuration_dict["data_type"] = data_type
636            write_configuration_dict(configuration_dict, CONF_DICT_PATH)
637
638            # safe printing
639            try:
640                print_data_type_modeling_setting(data_type or {})
641            except Exception as e:
642                print(f"[WARNING] Could not print data type modeling settings: {e}")
643
644            if not is_class_call:
645                self.conf_dict = configuration_dict
646                self.CONF_DICT_PATH = CONF_DICT_PATH
647
648
649        except Exception as e:
650            print(f"Failed to update configuration: {e}")
651        else:
652            print(f"Configuration updated successfully at {CONF_DICT_PATH}.")

Update or write the data_type section of a configuration file.

Supports both class-level and instance-level usage:

  • Class call:
  • Requires CONF_DICT_PATH argument.
  • Reads the file if it exists, or starts from an empty dict.
  • Writes updated configuration to CONF_DICT_PATH.

  • Instance call:

  • Uses self.CONF_DICT_PATH if available, otherwise defaults to <cwd>/configuration.json if no path is provided.
  • Updates self.conf_dict and self.CONF_DICT_PATH after writing.

Parameters

data_type : dict Mapping {variable_name: type_spec}, where type_spec encodes modeling types (e.g. continuous, ordinal, etc.). CONF_DICT_PATH : str or None, optional Path to the configuration file. Must be provided for class calls. For instance calls, defaults as described above.

Returns

None

Raises

ValueError If CONF_DICT_PATH is missing when called on the class, or if validation of data types fails.

Notes

  • Variable names are validated via validate_variable_names.
  • Data type values are validated via validate_data_types.
  • A textual summary of modeling settings is printed via print_data_type_modeling_setting, if possible.
def set_meta_adj_matrix(self, CONF_DICT_PATH: str = None, seed: int = 5):
654    def set_meta_adj_matrix(self, CONF_DICT_PATH: str = None, seed: int = 5):
655        """
656        Launch the interactive editor to set or modify the adjacency matrix.
657
658        This method:
659
660        1. Resolves the configuration path either from the argument or, for
661        instances, from `self.CONF_DICT_PATH`.
662        2. Invokes `interactive_adj_matrix` to edit the adjacency matrix.
663        3. For instances, reloads the updated configuration into `self.conf_dict`.
664
665        Parameters
666        ----------
667        CONF_DICT_PATH : str or None, optional
668            Path to the configuration file. Must be provided when called
669            on the class. For instance calls, defaults to `self.CONF_DICT_PATH`.
670        seed : int, optional
671            Random seed for any layout or stochastic behavior in the interactive
672            editor. Default is 5.
673
674        Returns
675        -------
676        None
677
678        Raises
679        ------
680        ValueError
681            If `CONF_DICT_PATH` is not provided and cannot be inferred
682            (e.g. in a class call without path).
683
684        Notes
685        -----
686        `self.update()` is called at the start to ensure the in-memory config
687        is in sync with the file before launching the editor.
688        """
689        self.update()
690        is_class_call = isinstance(self, type)
691        # resolve path
692        if CONF_DICT_PATH is None:
693            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
694                CONF_DICT_PATH = self.CONF_DICT_PATH
695            else:
696                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
697
698        # launch interactive editor
699        
700        interactive_adj_matrix(CONF_DICT_PATH, seed=seed)
701
702        # reload config if instance
703        if not is_class_call:
704            self.conf_dict = load_configuration_dict(CONF_DICT_PATH)
705            self.CONF_DICT_PATH = CONF_DICT_PATH

Launch the interactive editor to set or modify the adjacency matrix.

This method:

  1. Resolves the configuration path either from the argument or, for instances, from self.CONF_DICT_PATH.
  2. Invokes interactive_adj_matrix to edit the adjacency matrix.
  3. For instances, reloads the updated configuration into self.conf_dict.

Parameters

CONF_DICT_PATH : str or None, optional Path to the configuration file. Must be provided when called on the class. For instance calls, defaults to self.CONF_DICT_PATH. seed : int, optional Random seed for any layout or stochastic behavior in the interactive editor. Default is 5.

Returns

None

Raises

ValueError If CONF_DICT_PATH is not provided and cannot be inferred (e.g. in a class call without path).

Notes

self.update() is called at the start to ensure the in-memory config is in sync with the file before launching the editor.

def set_tramdag_nn_models(self, CONF_DICT_PATH: str = None):
708    def set_tramdag_nn_models(self, CONF_DICT_PATH: str = None):
709        """
710        Launch the interactive editor to set TRAM-DAG neural network model names.
711
712        Depending on call context:
713
714        - Class call:
715        - Requires `CONF_DICT_PATH` argument.
716        - Returns nothing and does not modify a specific instance.
717
718        - Instance call:
719        - Resolves `CONF_DICT_PATH` from the argument or `self.CONF_DICT_PATH`.
720        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` if the editor
721            returns an updated configuration.
722
723        Parameters
724        ----------
725        CONF_DICT_PATH : str or None, optional
726            Path to the configuration file. Must be provided when called on
727            the class. For instance calls, defaults to `self.CONF_DICT_PATH`.
728
729        Returns
730        -------
731        None
732
733        Raises
734        ------
735        ValueError
736            If `CONF_DICT_PATH` is not provided and cannot be inferred
737            (e.g. in a class call without path).
738
739        Notes
740        -----
741        The interactive editor is invoked via `interactive_nn_names_matrix`.
742        If it returns `None`, the instance configuration is left unchanged.
743        """
744        is_class_call = isinstance(self, type)
745        if CONF_DICT_PATH is None:
746            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
747                CONF_DICT_PATH = self.CONF_DICT_PATH
748            else:
749                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
750
751        updated_conf = interactive_nn_names_matrix(CONF_DICT_PATH)
752        if updated_conf is not None and not is_class_call:
753            self.conf_dict = updated_conf
754            self.CONF_DICT_PATH = CONF_DICT_PATH

Launch the interactive editor to set TRAM-DAG neural network model names.

Depending on call context:

  • Class call:
  • Requires CONF_DICT_PATH argument.
  • Returns nothing and does not modify a specific instance.

  • Instance call:

  • Resolves CONF_DICT_PATH from the argument or self.CONF_DICT_PATH.
  • Updates self.conf_dict and self.CONF_DICT_PATH if the editor returns an updated configuration.

Parameters

CONF_DICT_PATH : str or None, optional Path to the configuration file. Must be provided when called on the class. For instance calls, defaults to self.CONF_DICT_PATH.

Returns

None

Raises

ValueError If CONF_DICT_PATH is not provided and cannot be inferred (e.g. in a class call without path).

Notes

The interactive editor is invoked via interactive_nn_names_matrix. If it returns None, the instance configuration is left unchanged.

class TramDagDataset(typing.Generic[+_T_co]):
 29class TramDagDataset(Dataset):
 30    
 31    """
 32    TramDagDataset
 33    ==============
 34
 35    The `TramDagDataset` class handles structured data preparation for TRAM-DAG
 36    models. It wraps a pandas DataFrame together with its configuration and provides
 37    utilities for scaling, transformation, and efficient DataLoader construction
 38    for each node in a DAG-based configuration.
 39
 40    ---------------------------------------------------------------------
 41    Core Responsibilities
 42    ---------------------------------------------------------------------
 43    - Validate and store configuration metadata (`TramDagConfig`).
 44    - Manage per-node settings for DataLoader creation (batch size, shuffling, workers).
 45    - Compute scaling information (quantile-based min/max).
 46    - Optionally precompute and cache dataset representations.
 47    - Expose PyTorch Dataset and DataLoader interfaces for model training.
 48
 49    ---------------------------------------------------------------------
 50    Key Attributes
 51    ---------------------------------------------------------------------
 52    - **df** : pandas.DataFrame  
 53      The dataset content used for building loaders and computing scaling.
 54
 55    - **cfg** : TramDagConfig  
 56      Configuration object defining nodes and variable metadata.
 57
 58    - **nodes_dict** : dict  
 59      Mapping of variable names to node specifications from the configuration.
 60
 61    - **loaders** : dict  
 62      Mapping of node names to `torch.utils.data.DataLoader` instances or `GenericDataset` objects.
 63
 64    - **DEFAULTS** : dict  
 65      Default DataLoader and dataset-related settings (e.g., batch_size, shuffle, num_workers, etc.).
 66
 67    ---------------------------------------------------------------------
 68    Main Methods
 69    ---------------------------------------------------------------------
 70    - **from_dataframe(df, cfg, **kwargs)**  
 71      Construct the dataset directly from a pandas DataFrame.
 72
 73    - **compute_scaling(df=None, write=True)**  
 74      Compute per-variable min/max scaling values from data.
 75
 76    - **summary()**  
 77      Print dataset overview including shape, dtypes, statistics, and node settings.
 78
 79    ---------------------------------------------------------------------
 80    Notes
 81    ---------------------------------------------------------------------
 82    - Intended for training data; `compute_scaling()` should use only training subsets.
 83    - Compatible with both CPU and GPU DataLoader options.
 84    - Strict validation of keyword arguments against `DEFAULTS` prevents silent misconfiguration.
 85
 86    ---------------------------------------------------------------------
 87    Example
 88    ---------------------------------------------------------------------
 89    >>> cfg = TramDagConfig.from_json("config.json")
 90    >>> dataset = TramDagDataset.from_dataframe(train_df, cfg, batch_size=1024, debug=True)
 91    >>> dataset.summary()
 92    >>> minmax = dataset.compute_scaling(train_df)
 93    >>> loader = dataset.loaders["variable_x"]
 94    >>> next(iter(loader))
 95    """
 96
 97    DEFAULTS = {
 98        "batch_size": 32_000,
 99        "shuffle": True,
100        "num_workers": 4,
101        "pin_memory": True,
102        "return_intercept_shift": True,
103        "debug": False,
104        "transform": None,
105        "use_dataloader": True,
106        "use_precomputed": False, 
107        # DataLoader extras
108        "sampler": None,
109        "batch_sampler": None,
110        "collate_fn": None,
111        "drop_last": False,
112        "timeout": 0,
113        "worker_init_fn": None,
114        "multiprocessing_context": None,
115        "generator": None,
116        "prefetch_factor": 2,
117        "persistent_workers": True,
118        "pin_memory_device": "",
119    }
120
121    def __init__(self):
122        """
123        Initialize an empty TramDagDataset shell.
124
125        Notes
126        -----
127        This constructor does not attach data or configuration. Use
128        `TramDagDataset.from_dataframe` to obtain a ready-to-use instance.
129        """
130        pass
131
132    @classmethod
133    def from_dataframe(cls, df, cfg, **kwargs):
134        """
135        Create a TramDagDataset instance directly from a pandas DataFrame.
136
137        This classmethod:
138
139        1. Validates keyword arguments against `DEFAULTS`.
140        2. Merges user overrides with defaults into a resolved settings dict.
141        3. Stores the configuration and verifies its completeness.
142        4. Applies settings to the instance.
143        5. Builds per-node datasets and DataLoaders.
144
145        Parameters
146        ----------
147        df : pandas.DataFrame
148            Input DataFrame containing the dataset.
149        cfg : TramDagConfig
150            Configuration object defining nodes and variable metadata.
151        **kwargs
152            Optional overrides for `DEFAULTS`. All keys must exist in
153            `TramDagDataset.DEFAULTS`. Common keys include:
154
155            batch_size : int
156                Batch size for DataLoaders.
157            shuffle : bool
158                Whether to shuffle samples per epoch.
159            num_workers : int
160                Number of DataLoader workers.
161            pin_memory : bool
162                Whether to pin memory for faster host-to-device transfers.
163            return_intercept_shift : bool
164                Whether datasets should return intercept/shift information.
165            debug : bool
166                Enable debug printing.
167            transform : callable or dict or None
168                Optional transform(s) applied to samples.
169            use_dataloader : bool
170                If True, construct DataLoaders; else store raw Dataset objects.
171            use_precomputed : bool
172                If True, precompute dataset representation to disk and reload it.
173
174        Returns
175        -------
176        TramDagDataset
177            Initialized dataset instance.
178
179        Raises
180        ------
181        TypeError
182            If `df` is not a pandas DataFrame.
183        ValueError
184            If unknown keyword arguments are provided (when validation is enabled).
185
186        Notes
187        -----
188        If `shuffle=True` and the inferred variable name of `df` suggests
189        validation/test data (e.g. "val", "test"), a warning is printed.
190        """
191        self = cls()
192        if not isinstance(df, pd.DataFrame):
193            raise TypeError(f"[ERROR] df must be a pandas DataFrame, but got {type(df)}")
194
195        # validate kwargs
196        #self._validate_kwargs(kwargs, context="from_dataframe")
197        
198        # merge defaults with overrides
199        settings = dict(cls.DEFAULTS)
200        settings.update(kwargs)
201
202        # store config and verify
203        self.cfg = cfg
204        self.cfg._verify_completeness()
205
206        # ouptu all setttings if debug
207        if settings.get("debug", False):
208            print("[DEBUG] TramDagDataset.from_dataframe() settings (after defaults + overrides):")
209            for k, v in settings.items():
210                print(f"    {k}: {v}")
211
212        # infer variable name automatically
213        callers_locals = inspect.currentframe().f_back.f_locals
214        inferred = None
215        for var_name, var_val in callers_locals.items():
216            if var_val is df:
217                inferred = var_name
218                break
219        df_name = inferred or "dataframe"
220
221        if settings["shuffle"]:
222            if any(x in df_name.lower() for x in ["val", "validation", "test"]):
223                print(f"[WARNING] DataFrame '{df_name}' looks like a validation/test set → shuffle=True. Are you sure?")
224
225        # call again to ensure Warning messages if ordinal vars have missing levels
226        self.df = df.copy()
227        self._apply_settings(settings)
228        self._build_dataloaders()
229        return self
230
231    def compute_scaling(self, df: pd.DataFrame = None, write: bool = True):
232        """
233        Compute variable-wise scaling parameters from data.
234
235        Per variable, this method computes approximate minimum and maximum
236        values using the 5th and 95th percentiles. This is typically used
237        to derive robust normalization/clipping ranges from training data.
238
239        Parameters
240        ----------
241        df : pandas.DataFrame or None, optional
242            DataFrame used to compute scaling. If None, `self.df` is used.
243        write : bool, optional
244            Unused placeholder for interface compatibility with other components.
245            Kept for potential future extensions. Default is True.
246
247        Returns
248        -------
249        dict
250            Mapping `{column_name: [min_value, max_value]}`, where values
251            are derived from the 0.05 and 0.95 quantiles.
252
253        Notes
254        -----
255        If `self.debug` is True, the method emits debug messages about the
256        data source. Only training data should be used to avoid leakage.
257        """
258        if self.debug:
259            print("[DEBUG] Make sure to provide only training data to compute_scaling!")     
260        if df is None:
261            df = self.df
262            if self.debug:
263                print("[DEBUG] No DataFrame provided, using internal df.")
264        quantiles = df.quantile([0.05, 0.95])
265        min_vals = quantiles.loc[0.05]
266        max_vals = quantiles.loc[0.95]
267        minmax_dict = pd.concat([min_vals, max_vals], axis=1).T.to_dict('list')
268        return minmax_dict
269
270    def summary(self):
271        """
272        Print a structured overview of the dataset and configuration.
273
274        The summary includes:
275
276        1. DataFrame information:
277        - Shape
278        - Columns
279        - Head (first rows)
280        - Dtypes
281        - Descriptive statistics
282
283        2. Configuration overview:
284        - Number of nodes
285        - Loader mode (DataLoader vs. raw Dataset)
286        - Precomputation status
287
288        3. Node settings:
289        - Batch size
290        - Shuffle flag
291        - num_workers
292        - pin_memory
293        - return_intercept_shift
294        - debug
295        - transform
296
297        4. DataLoader overview:
298        - Type and length of each loader.
299
300        Parameters
301        ----------
302        None
303
304        Returns
305        -------
306        None
307
308        Notes
309        -----
310        Intended for quick inspection and debugging. Uses `print` statements
311        and does not return structured metadata.
312        """
313        
314        print("\n[TramDagDataset Summary]")
315        print("=" * 60)
316
317        print("\n[DataFrame]")
318        print("Shape:", self.df.shape)
319        print("Columns:", list(self.df.columns))
320        print("\nHead:")
321        print(self.df.head())
322
323        print("\nDtypes:")
324        print(self.df.dtypes)
325
326        print("\nDescribe:")
327        print(self.df.describe(include="all"))
328
329        print("\n[Configuration]")
330        print(f"Nodes: {len(self.nodes_dict)}")
331        print(f"Loader mode: {'DataLoader' if self.use_dataloader else 'Direct dataset'}")
332        print(f"Precomputed: {getattr(self, 'use_precomputed', False)}")
333
334        print("\n[Node Settings]")
335        for node in self.nodes_dict.keys():
336            batch_size = self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size
337            shuffle_flag = self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle
338            num_workers = self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers
339            pin_memory = self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory
340            rshift = self.return_intercept_shift[node] if isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift
341            debug_flag = self.debug[node] if isinstance(self.debug, dict) else self.debug
342            transform = self.transform[node] if isinstance(self.transform, dict) else self.transform
343
344            print(
345                f" Node '{node}': "
346                f"batch_size={batch_size}, "
347                f"shuffle={shuffle_flag}, "
348                f"num_workers={num_workers}, "
349                f"pin_memory={pin_memory}, "
350                f"return_intercept_shift={rshift}, "
351                f"debug={debug_flag}, "
352                f"transform={transform}"
353            )
354
355        if hasattr(self, "loaders"):
356            print("\n[DataLoaders]")
357            for node, loader in self.loaders.items():
358                try:
359                    length = len(loader)
360                except Exception:
361                    length = "?"
362                print(f"  {node}: {type(loader).__name__}, len={length}")
363
364        print("=" * 60 + "\n")
365
366    def _validate_kwargs(self, kwargs: dict, defaults_attr: str = "DEFAULTS", context: str = None):
367        """
368        Validate keyword arguments against a defaults dictionary.
369
370        Parameters
371        ----------
372        kwargs : dict
373            Keyword arguments to validate.
374        defaults_attr : str, optional
375            Name of the attribute on this class containing the default keys
376            (e.g. "DEFAULTS"). Default is "DEFAULTS".
377        context : str or None, optional
378            Optional context label (e.g. caller name) included in error messages.
379
380        Raises
381        ------
382        AttributeError
383            If the attribute named by `defaults_attr` does not exist.
384        ValueError
385            If any keys in `kwargs` are not present in the defaults dictionary.
386        """
387        defaults = getattr(self, defaults_attr, None)
388        if defaults is None:
389            raise AttributeError(f"{self.__class__.__name__} has no attribute '{defaults_attr}'")
390
391        unknown = set(kwargs) - set(defaults)
392        if unknown:
393            prefix = f"[{context}] " if context else ""
394            raise ValueError(f"{prefix}Unknown parameter(s): {', '.join(sorted(unknown))}")
395
396    def _apply_settings(self, settings: dict):
397        """
398        Apply resolved settings to the dataset instance.
399
400        This method:
401
402        1. Stores all key–value pairs from `settings` as attributes on `self`.
403        2. Extracts `nodes_dict` from the configuration.
404        3. Validates that dict-valued core attributes (batch_size, shuffle, etc.)
405        have keys matching the node set.
406
407        Parameters
408        ----------
409        settings : dict
410            Resolved settings dictionary, usually built from `DEFAULTS` plus
411            user overrides.
412
413        Returns
414        -------
415        None
416
417        Raises
418        ------
419        ValueError
420            If any dict-valued core attribute has keys that do not match
421            `cfg.conf_dict["nodes"].keys()`.
422        """
423        for k, v in settings.items():
424            setattr(self, k, v)
425
426        self.nodes_dict = self.cfg.conf_dict["nodes"]
427
428        # validate only the most important ones
429        for name in ["batch_size", "shuffle", "num_workers", "pin_memory",
430                     "return_intercept_shift", "debug", "transform"]:
431            self._check_keys(name, getattr(self, name))
432
433    def _build_dataloaders(self):
434        """Build node-specific dataloaders or raw datasets depending on settings."""
435        self.loaders = {}
436        for node in self.nodes_dict:
437            ds = GenericDataset(
438                self.df,
439                target_col=node,
440                target_nodes=self.nodes_dict,
441                transform=self.transform if not isinstance(self.transform, dict) else self.transform[node],
442                return_intercept_shift=self.return_intercept_shift if not isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift[node],
443                debug=self.debug if not isinstance(self.debug, dict) else self.debug[node],
444            )
445
446        ########## QUICK PATCH 
447            if hasattr(self, "use_precomputed") and self.use_precomputed:
448                os.makedirs("temp", exist_ok=True) 
449                pth = os.path.join("temp", "precomputed.pt")
450
451                if hasattr(ds, "save_precomputed") and callable(getattr(ds, "save_precomputed")):
452                    ds.save_precomputed(pth)
453                    ds = GenericDatasetPrecomputed(pth)
454                else:
455                    print("[WARNING] Dataset has no 'save_precomputed()' method — skipping precomputation.")
456
457
458            if self.use_dataloader:
459                # resolve per-node overrides
460                kwargs = {
461                    "batch_size": self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size,
462                    "shuffle": self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle,
463                    "num_workers": self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers,
464                    "pin_memory": self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory,
465                    "sampler": self.sampler,
466                    "batch_sampler": self.batch_sampler,
467                    "collate_fn": self.collate_fn,
468                    "drop_last": self.drop_last,
469                    "timeout": self.timeout,
470                    "worker_init_fn": self.worker_init_fn,
471                    "multiprocessing_context": self.multiprocessing_context,
472                    "generator": self.generator,
473                    "prefetch_factor": self.prefetch_factor,
474                    "persistent_workers": self.persistent_workers,
475                    "pin_memory_device": self.pin_memory_device,
476                }
477                self.loaders[node] = DataLoader(ds, **kwargs)
478            else:
479                self.loaders[node] = ds
480                
481        if hasattr(self, "use_precomputed") and self.use_precomputed:
482            if os.path.exists(pth):
483                try:
484                    os.remove(pth)
485                    if self.debug:
486                        print(f"[INFO] Removed existing precomputed file: {pth}")
487                except Exception as e:
488                    print(f"[WARNING] Could not remove {pth}: {e}")
489
490    def _check_keys(self, attr_name, attr_value):
491        """
492        Check that dict-valued attributes use node names as keys.
493
494        Parameters
495        ----------
496        attr_name : str
497            Name of the attribute being checked (for error messages).
498        attr_value : Any
499            Attribute value. If it is a dict, its keys are validated.
500
501        Returns
502        -------
503        None
504
505        Raises
506        ------
507        ValueError
508            If `attr_value` is a dict and its keys do not exactly match
509            `cfg.conf_dict["nodes"].keys()`.
510
511        Notes
512        -----
513        This check prevents partial or mismatched per-node settings such as
514        batch sizes or shuffle flags.
515        """
516        if isinstance(attr_value, dict):
517            expected_keys = set(self.nodes_dict.keys())
518            given_keys = set(attr_value.keys())
519            if expected_keys != given_keys:
520                raise ValueError(
521                    f"[ERROR] the provided attribute '{attr_name}' keys are not same as in cfg.conf_dict['nodes'].keys().\n"
522                    f"Expected: {expected_keys}, but got: {given_keys}\n"
523                    f"Please provide values for all variables."
524                )
525
526    def __getitem__(self, idx):
527        return self.df.iloc[idx].to_dict()
528
529    def __len__(self):
530        return len(self.df)

TramDagDataset

The TramDagDataset class handles structured data preparation for TRAM-DAG models. It wraps a pandas DataFrame together with its configuration and provides utilities for scaling, transformation, and efficient DataLoader construction for each node in a DAG-based configuration.


Core Responsibilities

  • Validate and store configuration metadata (TramDagConfig).
  • Manage per-node settings for DataLoader creation (batch size, shuffling, workers).
  • Compute scaling information (quantile-based min/max).
  • Optionally precompute and cache dataset representations.
  • Expose PyTorch Dataset and DataLoader interfaces for model training.

Key Attributes

  • df : pandas.DataFrame
    The dataset content used for building loaders and computing scaling.

  • cfg : TramDagConfig
    Configuration object defining nodes and variable metadata.

  • nodes_dict : dict
    Mapping of variable names to node specifications from the configuration.

  • loaders : dict
    Mapping of node names to torch.utils.data.DataLoader instances or GenericDataset objects.

  • DEFAULTS : dict
    Default DataLoader and dataset-related settings (e.g., batch_size, shuffle, num_workers, etc.).


Main Methods

  • from_dataframe(df, cfg, **kwargs)
    Construct the dataset directly from a pandas DataFrame.

  • compute_scaling(df=None, write=True)
    Compute per-variable min/max scaling values from data.

  • summary()
    Print dataset overview including shape, dtypes, statistics, and node settings.


Notes

  • Intended for training data; compute_scaling() should use only training subsets.
  • Compatible with both CPU and GPU DataLoader options.
  • Strict validation of keyword arguments against DEFAULTS prevents silent misconfiguration.

Example

>>> cfg = TramDagConfig.from_json("config.json")
>>> dataset = TramDagDataset.from_dataframe(train_df, cfg, batch_size=1024, debug=True)
>>> dataset.summary()
>>> minmax = dataset.compute_scaling(train_df)
>>> loader = dataset.loaders["variable_x"]
>>> next(iter(loader))
TramDagDataset()
121    def __init__(self):
122        """
123        Initialize an empty TramDagDataset shell.
124
125        Notes
126        -----
127        This constructor does not attach data or configuration. Use
128        `TramDagDataset.from_dataframe` to obtain a ready-to-use instance.
129        """
130        pass

Initialize an empty TramDagDataset shell.

Notes

This constructor does not attach data or configuration. Use TramDagDataset.from_dataframe to obtain a ready-to-use instance.

DEFAULTS = {'batch_size': 32000, 'shuffle': True, 'num_workers': 4, 'pin_memory': True, 'return_intercept_shift': True, 'debug': False, 'transform': None, 'use_dataloader': True, 'use_precomputed': False, 'sampler': None, 'batch_sampler': None, 'collate_fn': None, 'drop_last': False, 'timeout': 0, 'worker_init_fn': None, 'multiprocessing_context': None, 'generator': None, 'prefetch_factor': 2, 'persistent_workers': True, 'pin_memory_device': ''}
@classmethod
def from_dataframe(cls, df, cfg, **kwargs):
132    @classmethod
133    def from_dataframe(cls, df, cfg, **kwargs):
134        """
135        Create a TramDagDataset instance directly from a pandas DataFrame.
136
137        This classmethod:
138
139        1. Validates keyword arguments against `DEFAULTS`.
140        2. Merges user overrides with defaults into a resolved settings dict.
141        3. Stores the configuration and verifies its completeness.
142        4. Applies settings to the instance.
143        5. Builds per-node datasets and DataLoaders.
144
145        Parameters
146        ----------
147        df : pandas.DataFrame
148            Input DataFrame containing the dataset.
149        cfg : TramDagConfig
150            Configuration object defining nodes and variable metadata.
151        **kwargs
152            Optional overrides for `DEFAULTS`. All keys must exist in
153            `TramDagDataset.DEFAULTS`. Common keys include:
154
155            batch_size : int
156                Batch size for DataLoaders.
157            shuffle : bool
158                Whether to shuffle samples per epoch.
159            num_workers : int
160                Number of DataLoader workers.
161            pin_memory : bool
162                Whether to pin memory for faster host-to-device transfers.
163            return_intercept_shift : bool
164                Whether datasets should return intercept/shift information.
165            debug : bool
166                Enable debug printing.
167            transform : callable or dict or None
168                Optional transform(s) applied to samples.
169            use_dataloader : bool
170                If True, construct DataLoaders; else store raw Dataset objects.
171            use_precomputed : bool
172                If True, precompute dataset representation to disk and reload it.
173
174        Returns
175        -------
176        TramDagDataset
177            Initialized dataset instance.
178
179        Raises
180        ------
181        TypeError
182            If `df` is not a pandas DataFrame.
183        ValueError
184            If unknown keyword arguments are provided (when validation is enabled).
185
186        Notes
187        -----
188        If `shuffle=True` and the inferred variable name of `df` suggests
189        validation/test data (e.g. "val", "test"), a warning is printed.
190        """
191        self = cls()
192        if not isinstance(df, pd.DataFrame):
193            raise TypeError(f"[ERROR] df must be a pandas DataFrame, but got {type(df)}")
194
195        # validate kwargs
196        #self._validate_kwargs(kwargs, context="from_dataframe")
197        
198        # merge defaults with overrides
199        settings = dict(cls.DEFAULTS)
200        settings.update(kwargs)
201
202        # store config and verify
203        self.cfg = cfg
204        self.cfg._verify_completeness()
205
206        # ouptu all setttings if debug
207        if settings.get("debug", False):
208            print("[DEBUG] TramDagDataset.from_dataframe() settings (after defaults + overrides):")
209            for k, v in settings.items():
210                print(f"    {k}: {v}")
211
212        # infer variable name automatically
213        callers_locals = inspect.currentframe().f_back.f_locals
214        inferred = None
215        for var_name, var_val in callers_locals.items():
216            if var_val is df:
217                inferred = var_name
218                break
219        df_name = inferred or "dataframe"
220
221        if settings["shuffle"]:
222            if any(x in df_name.lower() for x in ["val", "validation", "test"]):
223                print(f"[WARNING] DataFrame '{df_name}' looks like a validation/test set → shuffle=True. Are you sure?")
224
225        # call again to ensure Warning messages if ordinal vars have missing levels
226        self.df = df.copy()
227        self._apply_settings(settings)
228        self._build_dataloaders()
229        return self

Create a TramDagDataset instance directly from a pandas DataFrame.

This classmethod:

  1. Validates keyword arguments against DEFAULTS.
  2. Merges user overrides with defaults into a resolved settings dict.
  3. Stores the configuration and verifies its completeness.
  4. Applies settings to the instance.
  5. Builds per-node datasets and DataLoaders.

Parameters

df : pandas.DataFrame Input DataFrame containing the dataset. cfg : TramDagConfig Configuration object defining nodes and variable metadata. **kwargs Optional overrides for DEFAULTS. All keys must exist in TramDagDataset.DEFAULTS. Common keys include:

batch_size : int
    Batch size for DataLoaders.
shuffle : bool
    Whether to shuffle samples per epoch.
num_workers : int
    Number of DataLoader workers.
pin_memory : bool
    Whether to pin memory for faster host-to-device transfers.
return_intercept_shift : bool
    Whether datasets should return intercept/shift information.
debug : bool
    Enable debug printing.
transform : callable or dict or None
    Optional transform(s) applied to samples.
use_dataloader : bool
    If True, construct DataLoaders; else store raw Dataset objects.
use_precomputed : bool
    If True, precompute dataset representation to disk and reload it.

Returns

TramDagDataset Initialized dataset instance.

Raises

TypeError If df is not a pandas DataFrame. ValueError If unknown keyword arguments are provided (when validation is enabled).

Notes

If shuffle=True and the inferred variable name of df suggests validation/test data (e.g. "val", "test"), a warning is printed.

def compute_scaling(self, df: pandas.core.frame.DataFrame = None, write: bool = True):
231    def compute_scaling(self, df: pd.DataFrame = None, write: bool = True):
232        """
233        Compute variable-wise scaling parameters from data.
234
235        Per variable, this method computes approximate minimum and maximum
236        values using the 5th and 95th percentiles. This is typically used
237        to derive robust normalization/clipping ranges from training data.
238
239        Parameters
240        ----------
241        df : pandas.DataFrame or None, optional
242            DataFrame used to compute scaling. If None, `self.df` is used.
243        write : bool, optional
244            Unused placeholder for interface compatibility with other components.
245            Kept for potential future extensions. Default is True.
246
247        Returns
248        -------
249        dict
250            Mapping `{column_name: [min_value, max_value]}`, where values
251            are derived from the 0.05 and 0.95 quantiles.
252
253        Notes
254        -----
255        If `self.debug` is True, the method emits debug messages about the
256        data source. Only training data should be used to avoid leakage.
257        """
258        if self.debug:
259            print("[DEBUG] Make sure to provide only training data to compute_scaling!")     
260        if df is None:
261            df = self.df
262            if self.debug:
263                print("[DEBUG] No DataFrame provided, using internal df.")
264        quantiles = df.quantile([0.05, 0.95])
265        min_vals = quantiles.loc[0.05]
266        max_vals = quantiles.loc[0.95]
267        minmax_dict = pd.concat([min_vals, max_vals], axis=1).T.to_dict('list')
268        return minmax_dict

Compute variable-wise scaling parameters from data.

Per variable, this method computes approximate minimum and maximum values using the 5th and 95th percentiles. This is typically used to derive robust normalization/clipping ranges from training data.

Parameters

df : pandas.DataFrame or None, optional DataFrame used to compute scaling. If None, self.df is used. write : bool, optional Unused placeholder for interface compatibility with other components. Kept for potential future extensions. Default is True.

Returns

dict Mapping {column_name: [min_value, max_value]}, where values are derived from the 0.05 and 0.95 quantiles.

Notes

If self.debug is True, the method emits debug messages about the data source. Only training data should be used to avoid leakage.

def summary(self):
270    def summary(self):
271        """
272        Print a structured overview of the dataset and configuration.
273
274        The summary includes:
275
276        1. DataFrame information:
277        - Shape
278        - Columns
279        - Head (first rows)
280        - Dtypes
281        - Descriptive statistics
282
283        2. Configuration overview:
284        - Number of nodes
285        - Loader mode (DataLoader vs. raw Dataset)
286        - Precomputation status
287
288        3. Node settings:
289        - Batch size
290        - Shuffle flag
291        - num_workers
292        - pin_memory
293        - return_intercept_shift
294        - debug
295        - transform
296
297        4. DataLoader overview:
298        - Type and length of each loader.
299
300        Parameters
301        ----------
302        None
303
304        Returns
305        -------
306        None
307
308        Notes
309        -----
310        Intended for quick inspection and debugging. Uses `print` statements
311        and does not return structured metadata.
312        """
313        
314        print("\n[TramDagDataset Summary]")
315        print("=" * 60)
316
317        print("\n[DataFrame]")
318        print("Shape:", self.df.shape)
319        print("Columns:", list(self.df.columns))
320        print("\nHead:")
321        print(self.df.head())
322
323        print("\nDtypes:")
324        print(self.df.dtypes)
325
326        print("\nDescribe:")
327        print(self.df.describe(include="all"))
328
329        print("\n[Configuration]")
330        print(f"Nodes: {len(self.nodes_dict)}")
331        print(f"Loader mode: {'DataLoader' if self.use_dataloader else 'Direct dataset'}")
332        print(f"Precomputed: {getattr(self, 'use_precomputed', False)}")
333
334        print("\n[Node Settings]")
335        for node in self.nodes_dict.keys():
336            batch_size = self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size
337            shuffle_flag = self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle
338            num_workers = self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers
339            pin_memory = self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory
340            rshift = self.return_intercept_shift[node] if isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift
341            debug_flag = self.debug[node] if isinstance(self.debug, dict) else self.debug
342            transform = self.transform[node] if isinstance(self.transform, dict) else self.transform
343
344            print(
345                f" Node '{node}': "
346                f"batch_size={batch_size}, "
347                f"shuffle={shuffle_flag}, "
348                f"num_workers={num_workers}, "
349                f"pin_memory={pin_memory}, "
350                f"return_intercept_shift={rshift}, "
351                f"debug={debug_flag}, "
352                f"transform={transform}"
353            )
354
355        if hasattr(self, "loaders"):
356            print("\n[DataLoaders]")
357            for node, loader in self.loaders.items():
358                try:
359                    length = len(loader)
360                except Exception:
361                    length = "?"
362                print(f"  {node}: {type(loader).__name__}, len={length}")
363
364        print("=" * 60 + "\n")

Print a structured overview of the dataset and configuration.

The summary includes:

  1. DataFrame information:
  • Shape
  • Columns
  • Head (first rows)
  • Dtypes
  • Descriptive statistics
  1. Configuration overview:
  • Number of nodes
  • Loader mode (DataLoader vs. raw Dataset)
  • Precomputation status
  1. Node settings:
  • Batch size
  • Shuffle flag
  • num_workers
  • pin_memory
  • return_intercept_shift
  • debug
  • transform
  1. DataLoader overview:
  • Type and length of each loader.

Parameters

None

Returns

None

Notes

Intended for quick inspection and debugging. Uses print statements and does not return structured metadata.

class TramDagModel:
  64class TramDagModel:
  65    """
  66    Probabilistic DAG model built from node-wise TRAMs (transformation models).
  67
  68    This class manages:
  69    - Configuration and per-node model construction.
  70    - Data scaling (min–max).
  71    - Training (sequential or per-node parallel on CPU).
  72    - Diagnostics (loss history, intercepts, linear shifts, latents).
  73    - Sampling from the joint DAG and loading stored samples.
  74    - High-level summaries and plotting utilities.
  75    """
  76    
  77    # ---- defaults used at construction time ----
  78    DEFAULTS_CONFIG = {
  79        "set_initial_weights": False,
  80        "debug":False,
  81        "verbose": False,
  82        "device":'auto',
  83        "initial_data":None,
  84        "overwrite_initial_weights": True,
  85    }
  86
  87    # ---- defaults used at fit() time ----
  88    DEFAULTS_FIT = {
  89        "epochs": 100,
  90        "train_list": None,
  91        "callbacks": None,
  92        "learning_rate": 0.01,
  93        "device": "auto",
  94        "optimizers": None,
  95        "schedulers": None,
  96        "use_scheduler": False,
  97        "save_linear_shifts": True,
  98        "save_simple_intercepts": True,
  99        "debug":False,
 100        "verbose": True,
 101        "train_mode": "sequential",  # or "parallel"
 102        "return_history": False,
 103        "overwrite_inital_weights": True,
 104        "num_workers" : 4,
 105        "persistent_workers" : True,
 106        "prefetch_factor" : 4,
 107        "batch_size":1000,
 108        
 109    }
 110
 111    def __init__(self):
 112        """
 113        Initialize an empty TramDagModel shell.
 114
 115        Notes
 116        -----
 117        This constructor does not build any node models and does not attach a
 118        configuration. Use `TramDagModel.from_config` or `TramDagModel.from_directory`
 119        to obtain a fully configured and ready-to-use instance.
 120        """
 121        
 122        self.debug = False
 123        self.verbose = False
 124        self.device = 'auto'
 125        pass
 126
 127    @staticmethod
 128    def get_device(settings):
 129        """
 130        Resolve the target device string from a settings dictionary.
 131
 132        Parameters
 133        ----------
 134        settings : dict
 135            Dictionary containing at least a key ``"device"`` with one of
 136            {"auto", "cpu", "cuda"}. If missing, "auto" is assumed.
 137
 138        Returns
 139        -------
 140        str
 141            Device string, either "cpu" or "cuda".
 142
 143        Notes
 144        -----
 145        If ``device == "auto"``, CUDA is selected if available, otherwise CPU.
 146        """
 147        device_arg = settings.get("device", "auto")
 148        if device_arg == "auto":
 149            device_str = "cuda" if torch.cuda.is_available() else "cpu"
 150        else:
 151            device_str = device_arg
 152        return device_str
 153
 154    def _validate_kwargs(self, kwargs: dict, defaults_attr: str = "DEFAULTS_FIT", context: str = None):
 155        """
 156        Validate a kwargs dictionary against a class-level defaults dictionary.
 157
 158        Parameters
 159        ----------
 160        kwargs : dict
 161            Keyword arguments to validate.
 162        defaults_attr : str, optional
 163            Name of the attribute on this class that contains the allowed keys,
 164            e.g. ``"DEFAULTS_CONFIG"`` or ``"DEFAULTS_FIT"``. Default is "DEFAULTS_FIT".
 165        context : str or None, optional
 166            Optional label (e.g. caller name) to prepend in error messages.
 167
 168        Raises
 169        ------
 170        AttributeError
 171            If the attribute named by ``defaults_attr`` does not exist.
 172        ValueError
 173            If any key in ``kwargs`` is not present in the corresponding defaults dict.
 174        """
 175        defaults = getattr(self, defaults_attr, None)
 176        if defaults is None:
 177            raise AttributeError(f"{self.__class__.__name__} has no attribute '{defaults_attr}'")
 178
 179        unknown = set(kwargs) - set(defaults)
 180        if unknown:
 181            prefix = f"[{context}] " if context else ""
 182            raise ValueError(f"{prefix}Unknown parameter(s): {', '.join(sorted(unknown))}")
 183            
 184    ## CREATE A TRAMDADMODEL
 185    @classmethod
 186    def from_config(cls, cfg, **kwargs):
 187        """
 188        Construct a TramDagModel from a TramDagConfig object.
 189
 190        This builds one TRAM model per node in the DAG and optionally writes
 191        the initial model parameters to disk.
 192
 193        Parameters
 194        ----------
 195        cfg : TramDagConfig
 196            Configuration wrapper holding the underlying configuration dictionary,
 197            including at least:
 198            - ``conf_dict["nodes"]``: mapping of node names to node configs.
 199            - ``conf_dict["PATHS"]["EXPERIMENT_DIR"]``: experiment directory.
 200        **kwargs
 201            Node-level construction options. Each key must be present in
 202            ``DEFAULTS_CONFIG``. Values can be:
 203            - scalar: applied to all nodes.
 204            - dict: mapping ``{node_name: value}`` for per-node overrides.
 205
 206            Common keys include:
 207            device : {"auto", "cpu", "cuda"}, default "auto"
 208                Device selection (CUDA if available when "auto").
 209            debug : bool, default False
 210                If True, print debug messages.
 211            verbose : bool, default False
 212                If True, print informational messages.
 213            set_initial_weights : bool
 214                Passed to underlying TRAM model constructors.
 215            overwrite_initial_weights : bool, default True
 216                If True, overwrite any existing ``initial_model.pt`` files per node.
 217            initial_data : Any
 218                Optional object passed down to node constructors.
 219
 220        Returns
 221        -------
 222        TramDagModel
 223            Fully initialized instance with:
 224            - ``cfg``
 225            - ``nodes_dict``
 226            - ``models`` (per-node TRAMs)
 227            - ``settings`` (resolved per-node config)
 228
 229        Raises
 230        ------
 231        ValueError
 232            If any dict-valued kwarg does not provide values for exactly the set
 233            of nodes in ``cfg.conf_dict["nodes"]``.
 234        """
 235        
 236        self = cls()
 237        self.cfg = cfg
 238        self.cfg.update()  # ensure latest version from disk
 239        self.cfg._verify_completeness()
 240        
 241        
 242        try:
 243            self.cfg.save()  # persist back to disk
 244            if getattr(self, "debug", False):
 245                print("[DEBUG] Configuration updated and saved.")
 246        except Exception as e:
 247            print(f"[WARNING] Could not save configuration after update: {e}")        
 248            
 249        self.nodes_dict = self.cfg.conf_dict["nodes"] 
 250
 251        self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_CONFIG', context="from_config")
 252
 253        # update defaults with kwargs
 254        settings = dict(cls.DEFAULTS_CONFIG)
 255        settings.update(kwargs)
 256
 257        # resolve device
 258        device_arg = settings.get("device", "auto")
 259        if device_arg == "auto":
 260            device_str = "cuda" if torch.cuda.is_available() else "cpu"
 261        else:
 262            device_str = device_arg
 263        self.device = torch.device(device_str)
 264
 265        # set flags on the instance so they are accessible later
 266        self.debug = settings.get("debug", False)
 267        self.verbose = settings.get("verbose", False)
 268
 269        if  self.debug:
 270            print(f"[DEBUG] TramDagModel using device: {self.device}")
 271            
 272        # initialize settings storage
 273        self.settings = {k: {} for k in settings.keys()}
 274
 275        # validate dict-typed args
 276        for k, v in settings.items():
 277            if isinstance(v, dict):
 278                expected = set(self.nodes_dict.keys())
 279                given = set(v.keys())
 280                if expected != given:
 281                    raise ValueError(
 282                        f"[ERROR] the provided argument '{k}' keys are not same as in cfg.conf_dict['nodes'].keys().\n"
 283                        f"Expected: {expected}, but got: {given}\n"
 284                        f"Please provide values for all variables.")
 285
 286        # build one model per node
 287        self.models = {}
 288        for node in self.nodes_dict.keys():
 289            per_node_kwargs = {}
 290            for k, v in settings.items():
 291                resolved = v[node] if isinstance(v, dict) else v
 292                per_node_kwargs[k] = resolved
 293                self.settings[k][node] = resolved
 294            if self.debug:
 295                print(f"\n[INFO] Building model for node '{node}' with settings: {per_node_kwargs}")
 296            self.models[node] = get_fully_specified_tram_model(
 297                node=node,
 298                configuration_dict=self.cfg.conf_dict,
 299                **per_node_kwargs)
 300            
 301            try:
 302                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 303                NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}")
 304                os.makedirs(NODE_DIR, exist_ok=True)
 305
 306                model_path = os.path.join(NODE_DIR, "initial_model.pt")
 307                overwrite = settings.get("overwrite_initial_weights", True)
 308
 309                if overwrite or not os.path.exists(model_path):
 310                    torch.save(self.models[node].state_dict(), model_path)
 311                    if self.debug:
 312                        print(f"[DEBUG] Saved initial model state for node '{node}' to {model_path} (overwrite={overwrite})")
 313                else:
 314                    if self.debug:
 315                        print(f"[DEBUG] Skipped saving initial model for node '{node}' (already exists at {model_path})")
 316            except Exception as e:
 317                print(f"[ERROR] Could not save initial model state for node '{node}': {e}")
 318            
 319            TEMP_DIR = "temp"
 320            if os.path.isdir(TEMP_DIR) and not os.listdir(TEMP_DIR):
 321                os.rmdir(TEMP_DIR)
 322                            
 323        return self
 324
 325    @classmethod
 326    def from_directory(cls, EXPERIMENT_DIR: str, device: str = "auto", debug: bool = False, verbose: bool = False):
 327        """
 328        Reconstruct a TramDagModel from an experiment directory on disk.
 329
 330        This method:
 331        1. Loads the configuration JSON.
 332        2. Wraps it in a TramDagConfig.
 333        3. Builds all node models via `from_config`.
 334        4. Loads the min–max scaling dictionary.
 335
 336        Parameters
 337        ----------
 338        EXPERIMENT_DIR : str
 339            Path to an experiment directory containing:
 340            - ``configuration.json``
 341            - ``min_max_scaling.json``.
 342        device : {"auto", "cpu", "cuda"}, optional
 343            Device selection. Default is "auto".
 344        debug : bool, optional
 345            If True, enable debug messages. Default is False.
 346        verbose : bool, optional
 347            If True, enable informational messages. Default is False.
 348
 349        Returns
 350        -------
 351        TramDagModel
 352            A TramDagModel instance with models, config, and scaling loaded.
 353
 354        Raises
 355        ------
 356        FileNotFoundError
 357            If configuration or min–max files cannot be found.
 358        RuntimeError
 359            If the min–max file cannot be read or parsed.
 360        """
 361
 362        # --- load config file ---
 363        config_path = os.path.join(EXPERIMENT_DIR, "configuration.json")
 364        if not os.path.exists(config_path):
 365            raise FileNotFoundError(f"[ERROR] Config file not found at {config_path}")
 366
 367        with open(config_path, "r") as f:
 368            cfg_dict = json.load(f)
 369
 370        # Create TramConfig wrapper 
 371        cfg = TramDagConfig(cfg_dict, CONF_DICT_PATH=config_path)
 372
 373        # --- build model from config ---
 374        self = cls.from_config(cfg, device=device, debug=debug, verbose=verbose, overwrite_initial_weights=False)
 375
 376        # --- load minmax scaling ---
 377        minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json")
 378        if not os.path.exists(minmax_path):
 379            raise FileNotFoundError(f"[ERROR] MinMax file not found at {minmax_path}")
 380
 381        with open(minmax_path, "r") as f:
 382            self.minmax_dict = json.load(f)
 383
 384        if self.verbose or self.debug:
 385            print(f"[INFO] Loaded TramDagModel from {EXPERIMENT_DIR}")
 386            print(f"[INFO] Config loaded from {config_path}")
 387            print(f"[INFO] MinMax scaling loaded from {minmax_path}")
 388
 389        return self
 390
 391    def _ensure_dataset(self, data, is_val=False,**kwargs):
 392        """
 393        Ensure that the input data is represented as a TramDagDataset.
 394
 395        Parameters
 396        ----------
 397        data : pandas.DataFrame, TramDagDataset, or None
 398            Input data to be converted or passed through.
 399        is_val : bool, optional
 400            If True, the resulting dataset is treated as validation data
 401            (e.g. no shuffling). Default is False.
 402        **kwargs
 403            Additional keyword arguments passed through to
 404            ``TramDagDataset.from_dataframe``.
 405
 406        Returns
 407        -------
 408        TramDagDataset or None
 409            A TramDagDataset if ``data`` is a DataFrame or TramDagDataset,
 410            otherwise None if ``data`` is None.
 411
 412        Raises
 413        ------
 414        TypeError
 415            If ``data`` is not a DataFrame, TramDagDataset, or None.
 416        """
 417                
 418        if isinstance(data, pd.DataFrame):
 419            return TramDagDataset.from_dataframe(data, self.cfg, shuffle=not is_val,**kwargs)
 420        elif isinstance(data, TramDagDataset):
 421            return data
 422        elif data is None:
 423            return None
 424        else:
 425            raise TypeError(
 426                f"[ERROR] data must be pd.DataFrame, TramDagDataset, or None, got {type(data)}"
 427            )
 428
 429    def load_or_compute_minmax(self, td_train_data=None,use_existing=False, write=True):
 430        """
 431        Load an existing Min–Max scaling dictionary from disk or compute a new one 
 432        from the provided training dataset.
 433
 434        Parameters
 435        ----------
 436        use_existing : bool, optional (default=False)
 437            If True, attempts to load an existing `min_max_scaling.json` file 
 438            from the experiment directory. Raises an error if the file is missing 
 439            or unreadable.
 440
 441        write : bool, optional (default=True)
 442            If True, writes the computed Min–Max scaling dictionary to 
 443            `<EXPERIMENT_DIR>/min_max_scaling.json`.
 444
 445        td_train_data : object, optional
 446            Training dataset used to compute scaling statistics. If not provided,
 447            the method will ensure or construct it via `_ensure_dataset(data=..., is_val=False)`.
 448
 449        Behavior
 450        --------
 451        - If `use_existing=True`, loads the JSON file containing previously saved 
 452          min–max values and stores it in `self.minmax_dict`.
 453        - If `use_existing=False`, computes a new scaling dictionary using 
 454          `td_train_data.compute_scaling()` and stores the result in 
 455          `self.minmax_dict`.
 456        - Optionally writes the computed dictionary to disk.
 457
 458        Side Effects
 459        -------------
 460        - Populates `self.minmax_dict` with scaling values.
 461        - Writes or loads the file `min_max_scaling.json` under 
 462          `<EXPERIMENT_DIR>`.
 463        - Prints diagnostic output if `self.debug` or `self.verbose` is True.
 464
 465        Raises
 466        ------
 467        FileNotFoundError
 468            If `use_existing=True` but the min–max file does not exist.
 469
 470        RuntimeError
 471            If an existing min–max file cannot be read or parsed.
 472
 473        Notes
 474        -----
 475        The computed min–max dictionary is expected to contain scaling statistics 
 476        per feature, typically in the form:
 477            {
 478                "node": {"min": float, "max": float},
 479                ...
 480            }
 481        """
 482        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 483        minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json")
 484
 485        # laod exisitng if possible
 486        if use_existing:
 487            if not os.path.exists(minmax_path):
 488                raise FileNotFoundError(f"MinMax file not found: {minmax_path}")
 489            try:
 490                with open(minmax_path, 'r') as f:
 491                    self.minmax_dict = json.load(f)
 492                if self.debug or self.verbose:
 493                    print(f"[INFO] Loaded existing minmax dict from {minmax_path}")
 494                return
 495            except Exception as e:
 496                raise RuntimeError(f"Could not load existing minmax dict: {e}")
 497
 498        # 
 499        if self.debug or self.verbose:
 500            print("[INFO] Computing new minmax dict from training data...")
 501            
 502        td_train_data=self._ensure_dataset( data=td_train_data, is_val=False)    
 503            
 504        self.minmax_dict = td_train_data.compute_scaling()
 505
 506        if write:
 507            os.makedirs(EXPERIMENT_DIR, exist_ok=True)
 508            with open(minmax_path, 'w') as f:
 509                json.dump(self.minmax_dict, f, indent=4)
 510            if self.debug or self.verbose:
 511                print(f"[INFO] Saved new minmax dict to {minmax_path}")
 512
 513    ## FIT METHODS
 514    @staticmethod
 515    def _fit_single_node(node, self_ref, settings, td_train_data, td_val_data, device_str):
 516        """
 517        Train a single node model (helper for per-node training).
 518
 519        This method is designed to be called either from the main process
 520        (sequential training) or from a joblib worker (parallel CPU training).
 521
 522        Parameters
 523        ----------
 524        node : str
 525            Name of the target node to train.
 526        self_ref : TramDagModel
 527            Reference to the TramDagModel instance containing models and config.
 528        settings : dict
 529            Training settings dictionary, typically derived from ``DEFAULTS_FIT``
 530            plus any user overrides.
 531        td_train_data : TramDagDataset
 532            Training dataset with node-specific DataLoaders in ``.loaders``.
 533        td_val_data : TramDagDataset or None
 534            Validation dataset or None.
 535        device_str : str
 536            Device string, e.g. "cpu" or "cuda".
 537
 538        Returns
 539        -------
 540        tuple
 541            A tuple ``(node, history)`` where:
 542            node : str
 543                Node name.
 544            history : dict or Any
 545                Training history as returned by ``train_val_loop``.
 546        """
 547        torch.set_num_threads(1)  # prevent thread oversubscription
 548
 549        model = self_ref.models[node]
 550
 551        # Resolve per-node settings
 552        def _resolve(key):
 553            val = settings[key]
 554            return val[node] if isinstance(val, dict) else val
 555
 556        node_epochs = _resolve("epochs")
 557        node_lr = _resolve("learning_rate")
 558        node_debug = _resolve("debug")
 559        node_save_linear_shifts = _resolve("save_linear_shifts")
 560        save_simple_intercepts  = _resolve("save_simple_intercepts")
 561        node_verbose = _resolve("verbose")
 562
 563        # Optimizer & scheduler
 564        if settings["optimizers"] and node in settings["optimizers"]:
 565            optimizer = settings["optimizers"][node]
 566        else:
 567            optimizer = Adam(model.parameters(), lr=node_lr)
 568
 569        scheduler = settings["schedulers"].get(node, None) if settings["schedulers"] else None
 570
 571        # Data loaders
 572        train_loader = td_train_data.loaders[node]
 573        val_loader = td_val_data.loaders[node] if td_val_data else None
 574
 575        # Min-max scaling tensors
 576        min_vals = torch.tensor(self_ref.minmax_dict[node][0], dtype=torch.float32)
 577        max_vals = torch.tensor(self_ref.minmax_dict[node][1], dtype=torch.float32)
 578        min_max = torch.stack([min_vals, max_vals], dim=0)
 579
 580        # Node directory
 581        try:
 582            EXPERIMENT_DIR = self_ref.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 583            NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}")
 584        except Exception:
 585            NODE_DIR = os.path.join("models", node)
 586            print("[WARNING] No log directory specified in config, saving to default location.")
 587        os.makedirs(NODE_DIR, exist_ok=True)
 588
 589        if node_verbose:
 590            print(f"\n[INFO] Training node '{node}' for {node_epochs} epochs on {device_str} (pid={os.getpid()})")
 591
 592        # --- train ---
 593        history = train_val_loop(
 594            node=node,
 595            target_nodes=self_ref.nodes_dict,
 596            NODE_DIR=NODE_DIR,
 597            tram_model=model,
 598            train_loader=train_loader,
 599            val_loader=val_loader,
 600            epochs=node_epochs,
 601            optimizer=optimizer,
 602            use_scheduler=(scheduler is not None),
 603            scheduler=scheduler,
 604            save_linear_shifts=node_save_linear_shifts,
 605            save_simple_intercepts=save_simple_intercepts,
 606            verbose=node_verbose,
 607            device=torch.device(device_str),
 608            debug=node_debug,
 609            min_max=min_max)
 610        return node, history
 611
 612    def fit(self, train_data, val_data=None, **kwargs):
 613        """
 614        Train TRAM models for all nodes in the DAG.
 615
 616        Coordinates dataset preparation, min–max scaling, and per-node training,
 617        optionally in parallel on CPU.
 618
 619        Parameters
 620        ----------
 621        train_data : pandas.DataFrame or TramDagDataset
 622            Training data. If a DataFrame is given, it is converted into a
 623            TramDagDataset using `_ensure_dataset`.
 624        val_data : pandas.DataFrame or TramDagDataset or None, optional
 625            Validation data. If a DataFrame is given, it is converted into a
 626            TramDagDataset. If None, no validation loss is computed.
 627        **kwargs
 628            Overrides for ``DEFAULTS_FIT``. All keys must exist in
 629            ``DEFAULTS_FIT``. Common options:
 630
 631            epochs : int, default 100
 632                Number of training epochs per node.
 633            learning_rate : float, default 0.01
 634                Learning rate for the default Adam optimizer.
 635            train_list : list of str or None, optional
 636                List of node names to train. If None, all nodes are trained.
 637            train_mode : {"sequential", "parallel"}, default "sequential"
 638                Training mode. "parallel" uses joblib-based CPU multiprocessing.
 639                GPU forces sequential mode.
 640            device : {"auto", "cpu", "cuda"}, default "auto"
 641                Device selection.
 642            optimizers : dict or None
 643                Optional mapping ``{node_name: optimizer}``. If provided for a
 644                node, that optimizer is used instead of creating a new Adam.
 645            schedulers : dict or None
 646                Optional mapping ``{node_name: scheduler}``.
 647            use_scheduler : bool
 648                If True, enable scheduler usage in the training loop.
 649            num_workers : int
 650                DataLoader workers in sequential mode (ignored in parallel).
 651            persistent_workers : bool
 652                DataLoader persistence in sequential mode (ignored in parallel).
 653            prefetch_factor : int
 654                DataLoader prefetch factor (ignored in parallel).
 655            batch_size : int
 656                Batch size for all node DataLoaders.
 657            debug : bool
 658                Enable debug output.
 659            verbose : bool
 660                Enable informational logging.
 661            return_history : bool
 662                If True, return a history dict.
 663
 664        Returns
 665        -------
 666        dict or None
 667            If ``return_history=True``, a dictionary mapping each node name
 668            to its training history. Otherwise, returns None.
 669
 670        Raises
 671        ------
 672        ValueError
 673            If ``train_mode`` is not "sequential" or "parallel".
 674        """
 675        self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_FIT', context="fit")
 676        
 677        # --- merge defaults ---
 678        settings = dict(self.DEFAULTS_FIT)
 679        settings.update(kwargs)
 680        
 681        
 682        self.debug = settings.get("debug", False)
 683        self.verbose = settings.get("verbose", False)
 684
 685        # --- resolve device ---
 686        device_str=self.get_device(settings)
 687        self.device = torch.device(device_str)
 688
 689        # --- training mode ---
 690        train_mode = settings.get("train_mode", "sequential").lower()
 691        if train_mode not in ("sequential", "parallel"):
 692            raise ValueError("train_mode must be 'sequential' or 'parallel'")
 693
 694        # --- DataLoader safety logic ---
 695        if train_mode == "parallel":
 696            # if user passed loader paralleling params, warn and override
 697            for flag in ("num_workers", "persistent_workers", "prefetch_factor"):
 698                if flag in kwargs:
 699                    print(f"[WARNING] '{flag}' is ignored in parallel mode "
 700                        f"(disabled to prevent nested multiprocessing).")
 701            # disable unsafe loader multiprocessing options
 702            settings["num_workers"] = 0
 703            settings["persistent_workers"] = False
 704            settings["prefetch_factor"] = None
 705        else:
 706            # sequential mode → respect user DataLoader settings
 707            if self.debug:
 708                print("[DEBUG] Sequential mode: using DataLoader kwargs as provided.")
 709
 710        # --- which nodes to train ---
 711        train_list = settings.get("train_list") or list(self.models.keys())
 712
 713
 714        # --- dataset prep (receives adjusted settings) ---
 715        td_train_data = self._ensure_dataset(train_data, is_val=False, **settings)
 716        td_val_data = self._ensure_dataset(val_data, is_val=True, **settings)
 717
 718        # --- normalization ---
 719        self.load_or_compute_minmax(use_existing=False, write=True, td_train_data=td_train_data)
 720
 721        # --- print header ---
 722        if self.verbose or self.debug:
 723            print(f"[INFO] Training {len(train_list)} nodes ({train_mode}) on {device_str}")
 724
 725        # ======================================================================
 726        # Sequential mode  safe for GPU or debugging)
 727        # ======================================================================
 728        if train_mode == "sequential" or "cuda" in device_str:
 729            if "cuda" in device_str and train_mode == "parallel":
 730                print("[WARNING] GPU device detected — forcing sequential mode.")
 731            results = {}
 732            for node in train_list:
 733                node, history = self._fit_single_node(
 734                    node, self, settings, td_train_data, td_val_data, device_str
 735                )
 736                results[node] = history
 737        
 738
 739        # ======================================================================
 740        # parallel mode (CPU only)
 741        # ======================================================================
 742        if train_mode == "parallel":
 743
 744            n_jobs = min(len(train_list), os.cpu_count() // 2 or 1)
 745            if self.verbose or self.debug:
 746                print(f"[INFO] Using {n_jobs} CPU workers for parallel node training")
 747            parallel_outputs = Parallel(
 748                n_jobs=n_jobs,
 749                backend="loky",#loky, multiprocessing
 750                verbose=10,
 751                prefer="processes"
 752            )(delayed(self._fit_single_node)(node, self, settings, td_train_data, td_val_data, device_str) for node in train_list )
 753
 754            results = {node: hist for node, hist in parallel_outputs}
 755        
 756        if settings.get("return_history", False):
 757            return results
 758
 759    ## FIT-DIAGNOSTICS
 760    def loss_history(self):
 761        """
 762        Load training and validation loss history for all nodes.
 763
 764        Looks for per-node JSON files:
 765
 766        - ``EXPERIMENT_DIR/{node}/train_loss_hist.json``
 767        - ``EXPERIMENT_DIR/{node}/val_loss_hist.json``
 768
 769        Returns
 770        -------
 771        dict
 772            A dictionary mapping node names to:
 773
 774            .. code-block:: python
 775
 776                {
 777                    "train": list or None,
 778                    "validation": list or None
 779                }
 780
 781            where each list contains NLL values per epoch, or None if not found.
 782
 783        Raises
 784        ------
 785        ValueError
 786            If the experiment directory cannot be resolved from the configuration.
 787        """
 788        try:
 789            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 790        except KeyError:
 791            raise ValueError(
 792                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
 793                "History retrieval requires experiment logs."
 794            )
 795
 796        all_histories = {}
 797        for node in self.nodes_dict.keys():
 798            node_dir = os.path.join(EXPERIMENT_DIR, node)
 799            train_path = os.path.join(node_dir, "train_loss_hist.json")
 800            val_path = os.path.join(node_dir, "val_loss_hist.json")
 801
 802            node_hist = {}
 803
 804            # --- load train history ---
 805            if os.path.exists(train_path):
 806                try:
 807                    with open(train_path, "r") as f:
 808                        node_hist["train"] = json.load(f)
 809                except Exception as e:
 810                    print(f"[WARNING] Could not load {train_path}: {e}")
 811                    node_hist["train"] = None
 812            else:
 813                node_hist["train"] = None
 814
 815            # --- load val history ---
 816            if os.path.exists(val_path):
 817                try:
 818                    with open(val_path, "r") as f:
 819                        node_hist["validation"] = json.load(f)
 820                except Exception as e:
 821                    print(f"[WARNING] Could not load {val_path}: {e}")
 822                    node_hist["validation"] = None
 823            else:
 824                node_hist["validation"] = None
 825
 826            all_histories[node] = node_hist
 827
 828        if self.verbose or self.debug:
 829            print(f"[INFO] Loaded training/validation histories for {len(all_histories)} nodes.")
 830
 831        return all_histories
 832
 833    def linear_shift_history(self):
 834        """
 835        Load linear shift term histories for all nodes.
 836
 837        Each node history is expected in a JSON file named
 838        ``linear_shifts_all_epochs.json`` under the node directory.
 839
 840        Returns
 841        -------
 842        dict
 843            A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame
 844            contains linear shift weights across epochs.
 845
 846        Raises
 847        ------
 848        ValueError
 849            If the experiment directory cannot be resolved from the configuration.
 850
 851        Notes
 852        -----
 853        If a history file is missing for a node, a warning is printed and the
 854        node is omitted from the returned dictionary.
 855        """
 856        histories = {}
 857        try:
 858            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 859        except KeyError:
 860            raise ValueError(
 861                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
 862                "Cannot load histories without experiment directory."
 863            )
 864
 865        for node in self.nodes_dict.keys():
 866            node_dir = os.path.join(EXPERIMENT_DIR, node)
 867            history_path = os.path.join(node_dir, "linear_shifts_all_epochs.json")
 868            if os.path.exists(history_path):
 869                histories[node] = pd.read_json(history_path)
 870            else:
 871                print(f"[WARNING] No linear shift history found for node '{node}' at {history_path}")
 872        return histories
 873
 874    def simple_intercept_history(self):
 875        """
 876        Load simple intercept histories for all nodes.
 877
 878        Each node history is expected in a JSON file named
 879        ``simple_intercepts_all_epochs.json`` under the node directory.
 880
 881        Returns
 882        -------
 883        dict
 884            A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame
 885            contains intercept weights across epochs.
 886
 887        Raises
 888        ------
 889        ValueError
 890            If the experiment directory cannot be resolved from the configuration.
 891
 892        Notes
 893        -----
 894        If a history file is missing for a node, a warning is printed and the
 895        node is omitted from the returned dictionary.
 896        """
 897        histories = {}
 898        try:
 899            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 900        except KeyError:
 901            raise ValueError(
 902                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
 903                "Cannot load histories without experiment directory."
 904            )
 905
 906        for node in self.nodes_dict.keys():
 907            node_dir = os.path.join(EXPERIMENT_DIR, node)
 908            history_path = os.path.join(node_dir, "simple_intercepts_all_epochs.json")
 909            if os.path.exists(history_path):
 910                histories[node] = pd.read_json(history_path)
 911            else:
 912                print(f"[WARNING] No simple intercept history found for node '{node}' at {history_path}")
 913        return histories
 914
 915    def get_latent(self, df, verbose=False):
 916        """
 917        Compute latent representations for all nodes in the DAG.
 918
 919        Parameters
 920        ----------
 921        df : pandas.DataFrame
 922            Input data frame with columns corresponding to nodes in the DAG.
 923        verbose : bool, optional
 924            If True, print informational messages during latent computation.
 925            Default is False.
 926
 927        Returns
 928        -------
 929        pandas.DataFrame
 930            DataFrame containing the original columns plus latent variables
 931            for each node (e.g. columns named ``f"{node}_U"``).
 932
 933        Raises
 934        ------
 935        ValueError
 936            If the experiment directory is missing from the configuration or
 937            if ``self.minmax_dict`` has not been set.
 938        """
 939        try:
 940            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
 941        except KeyError:
 942            raise ValueError(
 943                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
 944                "Latent extraction requires trained model checkpoints."
 945            )
 946
 947        # ensure minmax_dict is available
 948        if not hasattr(self, "minmax_dict"):
 949            raise ValueError(
 950                "[ERROR] minmax_dict not found in the TramDagModel instance. "
 951                "Either call .load_or_compute_minmax(td_train_data=train_df) or .fit() first."
 952            )
 953
 954        all_latents_df = create_latent_df_for_full_dag(
 955            configuration_dict=self.cfg.conf_dict,
 956            EXPERIMENT_DIR=EXPERIMENT_DIR,
 957            df=df,
 958            verbose=verbose,
 959            min_max_dict=self.minmax_dict,
 960        )
 961
 962        return all_latents_df
 963
 964    
 965    ## PLOTTING FIT-DIAGNOSTICS
 966    
 967    def plot_loss_history(self, variable: str = None):
 968        """
 969        Plot training and validation loss evolution per node.
 970
 971        Parameters
 972        ----------
 973        variable : str or None, optional
 974            If provided, plot loss history for this node only. If None, plot
 975            histories for all nodes that have both train and validation logs.
 976
 977        Returns
 978        -------
 979        None
 980
 981        Notes
 982        -----
 983        Two subplots are produced:
 984        - Full epoch history.
 985        - Last 10% of epochs (or only the last epoch if fewer than 5 epochs).
 986        """
 987
 988        histories = self.loss_history()
 989        if not histories:
 990            print("[WARNING] No loss histories found.")
 991            return
 992
 993        # Select which nodes to plot
 994        if variable is not None:
 995            if variable not in histories:
 996                raise ValueError(f"[ERROR] Node '{variable}' not found in histories.")
 997            nodes_to_plot = [variable]
 998        else:
 999            nodes_to_plot = list(histories.keys())
1000
1001        # Filter out nodes with no valid history
1002        nodes_to_plot = [
1003            n for n in nodes_to_plot
1004            if histories[n].get("train") is not None and len(histories[n]["train"]) > 0
1005            and histories[n].get("validation") is not None and len(histories[n]["validation"]) > 0
1006        ]
1007
1008        if not nodes_to_plot:
1009            print("[WARNING] No valid histories found to plot.")
1010            return
1011
1012        plt.figure(figsize=(14, 12))
1013
1014        # --- Full history (top plot) ---
1015        plt.subplot(2, 1, 1)
1016        for node in nodes_to_plot:
1017            node_hist = histories[node]
1018            train_hist, val_hist = node_hist["train"], node_hist["validation"]
1019
1020            epochs = range(1, len(train_hist) + 1)
1021            plt.plot(epochs, train_hist, label=f"{node} - train", linestyle="--")
1022            plt.plot(epochs, val_hist, label=f"{node} - val")
1023
1024        plt.title("Training and Validation NLL - Full History")
1025        plt.xlabel("Epoch")
1026        plt.ylabel("NLL")
1027        plt.legend()
1028        plt.grid(True)
1029
1030        # --- Last 10% of epochs (bottom plot) ---
1031        plt.subplot(2, 1, 2)
1032        for node in nodes_to_plot:
1033            node_hist = histories[node]
1034            train_hist, val_hist = node_hist["train"], node_hist["validation"]
1035
1036            total_epochs = len(train_hist)
1037            start_idx = total_epochs - 1 if total_epochs < 5 else int(total_epochs * 0.9)
1038
1039            epochs = range(start_idx + 1, total_epochs + 1)
1040            plt.plot(epochs, train_hist[start_idx:], label=f"{node} - train", linestyle="--")
1041            plt.plot(epochs, val_hist[start_idx:], label=f"{node} - val")
1042
1043        plt.title("Training and Validation NLL - Last 10% of Epochs (or Last Epoch if <5)")
1044        plt.xlabel("Epoch")
1045        plt.ylabel("NLL")
1046        plt.legend()
1047        plt.grid(True)
1048
1049        plt.tight_layout()
1050        plt.show()
1051
1052    def plot_linear_shift_history(self, data_dict=None, node=None, ref_lines=None):
1053        """
1054        Plot the evolution of linear shift terms over epochs.
1055
1056        Parameters
1057        ----------
1058        data_dict : dict or None, optional
1059            Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing shift
1060            weights across epochs. If None, `linear_shift_history()` is called.
1061        node : str or None, optional
1062            If provided, plot only this node. Otherwise, plot all nodes
1063            present in ``data_dict``.
1064        ref_lines : dict or None, optional
1065            Optional mapping ``{node_name: list of float}``. For each specified
1066            node, horizontal reference lines are drawn at the given values.
1067
1068        Returns
1069        -------
1070        None
1071
1072        Notes
1073        -----
1074        The function flattens nested list-like entries in the DataFrames to scalars,
1075        converts epoch labels to numeric, and then draws one line per shift term.
1076        """
1077
1078        if data_dict is None:
1079            data_dict = self.linear_shift_history()
1080            if data_dict is None:
1081                raise ValueError("No shift history data provided or stored in the class.")
1082
1083        nodes = [node] if node else list(data_dict.keys())
1084
1085        for n in nodes:
1086            df = data_dict[n].copy()
1087
1088            # Flatten nested lists or list-like cells
1089            def flatten(x):
1090                if isinstance(x, list):
1091                    if len(x) == 0:
1092                        return np.nan
1093                    if all(isinstance(i, (int, float)) for i in x):
1094                        return np.mean(x)  # average simple list
1095                    if all(isinstance(i, list) for i in x):
1096                        # nested list -> flatten inner and average
1097                        flat = [v for sub in x for v in (sub if isinstance(sub, list) else [sub])]
1098                        return np.mean(flat) if flat else np.nan
1099                    return x[0] if len(x) == 1 else np.nan
1100                return x
1101
1102            df = df.applymap(flatten)
1103
1104            # Ensure numeric columns
1105            df = df.apply(pd.to_numeric, errors='coerce')
1106
1107            # Convert epoch labels to numeric
1108            df.columns = [
1109                int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c
1110                for c in df.columns
1111            ]
1112            df = df.reindex(sorted(df.columns), axis=1)
1113
1114            plt.figure(figsize=(10, 6))
1115            for idx in df.index:
1116                plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"shift_{idx}")
1117
1118            if ref_lines and n in ref_lines:
1119                for v in ref_lines[n]:
1120                    plt.axhline(y=v, color="k", linestyle="--", lw=1.0)
1121                    plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8)
1122
1123            plt.xlabel("Epoch")
1124            plt.ylabel("Shift Value")
1125            plt.title(f"Shift Term History — Node: {n}")
1126            plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small")
1127            plt.tight_layout()
1128            plt.show()
1129
1130    def plot_simple_intercepts_history(self, data_dict=None, node=None,ref_lines=None):
1131        """
1132        Plot the evolution of simple intercept weights over epochs.
1133
1134        Parameters
1135        ----------
1136        data_dict : dict or None, optional
1137            Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing intercept
1138            weights across epochs. If None, `simple_intercept_history()` is called.
1139        node : str or None, optional
1140            If provided, plot only this node. Otherwise, plot all nodes present
1141            in ``data_dict``.
1142        ref_lines : dict or None, optional
1143            Optional mapping ``{node_name: list of float}``. For each specified
1144            node, horizontal reference lines are drawn at the given values.
1145
1146        Returns
1147        -------
1148        None
1149
1150        Notes
1151        -----
1152        Nested list-like entries in the DataFrames are reduced to scalars before
1153        plotting. One line is drawn per intercept parameter.
1154        """
1155        if data_dict is None:
1156            data_dict = self.simple_intercept_history()
1157            if data_dict is None:
1158                raise ValueError("No intercept history data provided or stored in the class.")
1159
1160        nodes = [node] if node else list(data_dict.keys())
1161
1162        for n in nodes:
1163            df = data_dict[n].copy()
1164
1165            def extract_scalar(x):
1166                if isinstance(x, list):
1167                    while isinstance(x, list) and len(x) > 0:
1168                        x = x[0]
1169                return float(x) if isinstance(x, (int, float, np.floating)) else np.nan
1170
1171            df = df.applymap(extract_scalar)
1172
1173            # Convert epoch labels → numeric
1174            df.columns = [
1175                int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c
1176                for c in df.columns
1177            ]
1178            df = df.reindex(sorted(df.columns), axis=1)
1179
1180            plt.figure(figsize=(10, 6))
1181            for idx in df.index:
1182                plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"theta_{idx}")
1183            
1184            if ref_lines and n in ref_lines:
1185                for v in ref_lines[n]:
1186                    plt.axhline(y=v, color="k", linestyle="--", lw=1.0)
1187                    plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8)
1188                
1189            plt.xlabel("Epoch")
1190            plt.ylabel("Intercept Weight")
1191            plt.title(f"Simple Intercept Evolution — Node: {n}")
1192            plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small")
1193            plt.tight_layout()
1194            plt.show()
1195
1196    def plot_latents(self, df, variable: str = None, confidence: float = 0.95, simulations: int = 1000):
1197        """
1198        Visualize latent U distributions for one or all nodes.
1199
1200        Parameters
1201        ----------
1202        df : pandas.DataFrame
1203            Input data frame with raw node values.
1204        variable : str or None, optional
1205            If provided, only this node's latents are plotted. If None, all
1206            nodes with latent columns are processed.
1207        confidence : float, optional
1208            Confidence level for QQ-plot bands (0 < confidence < 1).
1209            Default is 0.95.
1210        simulations : int, optional
1211            Number of Monte Carlo simulations for QQ-plot bands. Default is 1000.
1212
1213        Returns
1214        -------
1215        None
1216
1217        Notes
1218        -----
1219        For each node, two plots are produced:
1220        - Histogram of the latent U values.
1221        - QQ-plot with simulation-based confidence bands under a logistic reference.
1222        """
1223        # Compute latent representations
1224        latents_df = self.get_latent(df)
1225
1226        # Select nodes
1227        nodes = [variable] if variable is not None else self.nodes_dict.keys()
1228
1229        for node in nodes:
1230            if f"{node}_U" not in latents_df.columns:
1231                print(f"[WARNING] No latent found for node {node}, skipping.")
1232                continue
1233
1234            sample = latents_df[f"{node}_U"].values
1235
1236            # --- Create plots ---
1237            fig, axs = plt.subplots(1, 2, figsize=(12, 5))
1238
1239            # Histogram
1240            axs[0].hist(sample, bins=50, color="steelblue", alpha=0.7)
1241            axs[0].set_title(f"Latent Histogram ({node})")
1242            axs[0].set_xlabel("U")
1243            axs[0].set_ylabel("Frequency")
1244
1245            # QQ Plot with confidence bands
1246            probplot(sample, dist="logistic", plot=axs[1])
1247            self._add_r_style_confidence_bands(axs[1], sample, dist=logistic,confidence=confidence, simulations=simulations)
1248            axs[1].set_title(f"Latent QQ Plot ({node})")
1249
1250            plt.suptitle(f"Latent Diagnostics for Node: {node}", fontsize=14)
1251            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
1252            plt.show()
1253
1254    def plot_hdag(self,df,variables=None, plot_n_rows=1,**kwargs):
1255        
1256        """
1257        Visualize the transformation function h() for selected DAG nodes.
1258
1259        Parameters
1260        ----------
1261        df : pandas.DataFrame
1262            Input data containing node values or model predictions.
1263        variables : list of str or None, optional
1264            Names of nodes to visualize. If None, all nodes in ``self.models``
1265            are considered.
1266        plot_n_rows : int, optional
1267            Maximum number of rows from ``df`` to visualize. Default is 1.
1268        **kwargs
1269            Additional keyword arguments forwarded to the underlying plotting
1270            helpers (`show_hdag_continous` / `show_hdag_ordinal`).
1271
1272        Returns
1273        -------
1274        None
1275
1276        Notes
1277        -----
1278        - For continuous outcomes, `show_hdag_continous` is called.
1279        - For ordinal outcomes, `show_hdag_ordinal` is called.
1280        - Nodes that are neither continuous nor ordinal are skipped with a warning.
1281        """
1282                
1283
1284        if len(df)> 1:
1285            print("[WARNING] len(df)>1, set: plot_n_rows accordingly")
1286        
1287        variables_list=variables if variables is not None else list(self.models.keys())
1288        for node in variables_list:
1289            if is_outcome_modelled_continous(node, self.nodes_dict):
1290                show_hdag_continous(df,node=node,configuration_dict=self.cfg.conf_dict,minmax_dict=self.minmax_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1291            
1292            elif is_outcome_modelled_ordinal(node, self.nodes_dict):
1293                show_hdag_ordinal(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1294                # plot_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1295                # save_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,**kwargs)
1296            else:
1297                print(f"[WARNING] Node {node} is wheter ordinal nor continous, not implemented yet")
1298    
1299    @staticmethod
1300    def _add_r_style_confidence_bands(ax, sample, dist, confidence=0.95, simulations=1000):
1301        """
1302        Add simulation-based confidence bands to a QQ-plot.
1303
1304        Parameters
1305        ----------
1306        ax : matplotlib.axes.Axes
1307            Axes object on which to draw the QQ-plot and bands.
1308        sample : array-like
1309            Empirical sample used in the QQ-plot.
1310        dist : scipy.stats distribution
1311            Distribution object providing ``ppf`` and ``rvs`` methods (e.g. logistic).
1312        confidence : float, optional
1313            Confidence level (0 < confidence < 1) for the bands. Default is 0.95.
1314        simulations : int, optional
1315            Number of Monte Carlo simulations used to estimate the bands. Default is 1000.
1316
1317        Returns
1318        -------
1319        None
1320
1321        Notes
1322        -----
1323        The axes are cleared, and a new QQ-plot is drawn with:
1324        - Empirical vs. theoretical quantiles.
1325        - 45-degree reference line.
1326        - Shaded confidence band region.
1327        """
1328        
1329        n = len(sample)
1330        if n == 0:
1331            return
1332
1333        quantiles = np.linspace(0, 1, n, endpoint=False) + 0.5 / n
1334        theo_q = dist.ppf(quantiles)
1335
1336        # Simulate order statistics from the theoretical distribution
1337        sim_data = dist.rvs(size=(simulations, n))
1338        sim_order_stats = np.sort(sim_data, axis=1)
1339
1340        # Confidence bands
1341        lower = np.percentile(sim_order_stats, 100 * (1 - confidence) / 2, axis=0)
1342        upper = np.percentile(sim_order_stats, 100 * (1 + confidence) / 2, axis=0)
1343
1344        # Sort empirical sample
1345        sample_sorted = np.sort(sample)
1346
1347        # Re-draw points and CI (overwrite probplot defaults)
1348        ax.clear()
1349        ax.plot(theo_q, sample_sorted, 'o', markersize=3, alpha=0.6, label="Empirical Q-Q")
1350        ax.plot(theo_q, theo_q, 'b--', label="y = x")
1351        ax.fill_between(theo_q, lower, upper, color='gray', alpha=0.3,
1352                        label=f'{int(confidence*100)}% CI')
1353        ax.legend()
1354    
1355    ## SAMPLING METHODS
1356    def sample(
1357        self,
1358        do_interventions: dict = None,
1359        predefined_latent_samples_df: pd.DataFrame = None,
1360        **kwargs,
1361    ):
1362        """
1363        Sample from the joint DAG using the trained TRAM models.
1364
1365        Allows for:
1366        
1367        Oberservational sampling
1368        Interventional sampling via ``do()`` operations
1369        Counterfactial sampling using predefined latent draws and do()
1370        
1371        Parameters
1372        ----------
1373        do_interventions : dict or None, optional
1374            Mapping of node names to intervened (fixed) values. For example:
1375            ``{"x1": 1.0}`` represents ``do(x1 = 1.0)``. Default is None.
1376        predefined_latent_samples_df : pandas.DataFrame or None, optional
1377            DataFrame containing columns ``"{node}_U"`` with predefined latent
1378            draws to be used instead of sampling from the prior. Default is None.
1379        **kwargs
1380            Sampling options overriding internal defaults:
1381
1382            number_of_samples : int, default 10000
1383                Total number of samples to draw.
1384            batch_size : int, default 32
1385                Batch size for internal sampling loops.
1386            delete_all_previously_sampled : bool, default True
1387                If True, delete old sampling files in node-specific sampling
1388                directories before writing new ones.
1389            verbose : bool
1390                If True, print informational messages.
1391            debug : bool
1392                If True, print debug output.
1393            device : {"auto", "cpu", "cuda"}
1394                Device selection for sampling.
1395            use_initial_weights_for_sampling : bool, default False
1396                If True, sample from initial (untrained) model parameters.
1397
1398        Returns
1399        -------
1400        tuple
1401            A tuple ``(sampled_by_node, latents_by_node)``:
1402
1403            sampled_by_node : dict
1404                Mapping ``{node_name: torch.Tensor}`` of sampled node values.
1405            latents_by_node : dict
1406                Mapping ``{node_name: torch.Tensor}`` of latent U values used.
1407
1408        Raises
1409        ------
1410        ValueError
1411            If the experiment directory cannot be resolved or if scaling
1412            information (``self.minmax_dict``) is missing.
1413        RuntimeError
1414            If min–max scaling has not been computed before calling `sample`.
1415        """
1416        try:
1417            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1418        except KeyError:
1419            raise ValueError(
1420                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
1421                "Sampling requires trained model checkpoints."
1422            )
1423
1424        # ---- defaults ----
1425        settings = {
1426            "number_of_samples": 10_000,
1427            "batch_size": 32,
1428            "delete_all_previously_sampled": True,
1429            "verbose": self.verbose if hasattr(self, "verbose") else False,
1430            "debug": self.debug if hasattr(self, "debug") else False,
1431            "device": self.device.type if hasattr(self, "device") else "auto",
1432            "use_initial_weights_for_sampling": False,
1433            
1434        }
1435        
1436        # self._validate_kwargs( kwargs, defaults_attr= "settings", context="sample")
1437        
1438        settings.update(kwargs)
1439
1440        
1441        if not hasattr(self, "minmax_dict"):
1442            raise RuntimeError(
1443                "[ERROR] minmax_dict not found. You must call .fit() or .load_or_compute_minmax() "
1444                "before sampling, so scaling info is available."
1445                )
1446            
1447        # ---- resolve device ----
1448        device_str=self.get_device(settings)
1449        self.device = torch.device(device_str)
1450
1451
1452        if self.debug or settings["debug"]:
1453            print(f"[DEBUG] sample(): device: {self.device}")
1454
1455        # ---- perform sampling ----
1456        sampled_by_node, latents_by_node = sample_full_dag(
1457            configuration_dict=self.cfg.conf_dict,
1458            EXPERIMENT_DIR=EXPERIMENT_DIR,
1459            device=self.device,
1460            do_interventions=do_interventions or {},
1461            predefined_latent_samples_df=predefined_latent_samples_df,
1462            number_of_samples=settings["number_of_samples"],
1463            batch_size=settings["batch_size"],
1464            delete_all_previously_sampled=settings["delete_all_previously_sampled"],
1465            verbose=settings["verbose"],
1466            debug=settings["debug"],
1467            minmax_dict=self.minmax_dict,
1468            use_initial_weights_for_sampling=settings["use_initial_weights_for_sampling"]
1469        )
1470
1471        return sampled_by_node, latents_by_node
1472
1473    def load_sampled_and_latents(self, EXPERIMENT_DIR: str = None, nodes: list = None):
1474        """
1475        Load previously stored sampled values and latents for each node.
1476
1477        Parameters
1478        ----------
1479        EXPERIMENT_DIR : str or None, optional
1480            Experiment directory path. If None, it is taken from
1481            ``self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]``.
1482        nodes : list of str or None, optional
1483            Nodes for which to load samples. If None, use all nodes from
1484            ``self.nodes_dict``.
1485
1486        Returns
1487        -------
1488        tuple
1489            A tuple ``(sampled_by_node, latents_by_node)``:
1490
1491            sampled_by_node : dict
1492                Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU).
1493            latents_by_node : dict
1494                Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU).
1495
1496        Raises
1497        ------
1498        ValueError
1499            If the experiment directory cannot be resolved or if no node list
1500            is available and ``nodes`` is None.
1501
1502        Notes
1503        -----
1504        Nodes without both ``sampled.pt`` and ``latents.pt`` files are skipped
1505        with a warning.
1506        """
1507        # --- resolve paths and node list ---
1508        if EXPERIMENT_DIR is None:
1509            try:
1510                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1511            except (AttributeError, KeyError):
1512                raise ValueError(
1513                    "[ERROR] Could not resolve EXPERIMENT_DIR from cfg.conf_dict['PATHS']. "
1514                    "Please provide EXPERIMENT_DIR explicitly."
1515                )
1516
1517        if nodes is None:
1518            if hasattr(self, "nodes_dict"):
1519                nodes = list(self.nodes_dict.keys())
1520            else:
1521                raise ValueError(
1522                    "[ERROR] No node list found. Please provide `nodes` or initialize model with a config."
1523                )
1524
1525        # --- load tensors ---
1526        sampled_by_node = {}
1527        latents_by_node = {}
1528
1529        for node in nodes:
1530            node_dir = os.path.join(EXPERIMENT_DIR, f"{node}")
1531            sampling_dir = os.path.join(node_dir, "sampling")
1532
1533            sampled_path = os.path.join(sampling_dir, "sampled.pt")
1534            latents_path = os.path.join(sampling_dir, "latents.pt")
1535
1536            if not os.path.exists(sampled_path) or not os.path.exists(latents_path):
1537                print(f"[WARNING] Missing files for node '{node}' — skipping.")
1538                continue
1539
1540            try:
1541                sampled = torch.load(sampled_path, map_location="cpu")
1542                latent_sample = torch.load(latents_path, map_location="cpu")
1543            except Exception as e:
1544                print(f"[ERROR] Could not load sampling files for node '{node}': {e}")
1545                continue
1546
1547            sampled_by_node[node] = sampled.detach().cpu()
1548            latents_by_node[node] = latent_sample.detach().cpu()
1549
1550        if self.verbose or self.debug:
1551            print(f"[INFO] Loaded sampled and latent tensors for {len(sampled_by_node)} nodes from {EXPERIMENT_DIR}")
1552
1553        return sampled_by_node, latents_by_node
1554
1555    def plot_samples_vs_true(
1556        self,
1557        df,
1558        sampled: dict = None,
1559        variable: list = None,
1560        bins: int = 100,
1561        hist_true_color: str = "blue",
1562        hist_est_color: str = "orange",
1563        figsize: tuple = (14, 5),
1564    ):
1565        
1566        
1567        """
1568        Compare sampled vs. observed distributions for selected nodes.
1569
1570        Parameters
1571        ----------
1572        df : pandas.DataFrame
1573            Data frame containing the observed node values.
1574        sampled : dict or None, optional
1575            Optional mapping ``{node_name: array-like or torch.Tensor}`` of sampled
1576            values. If None or if a node is missing, samples are loaded from
1577            ``EXPERIMENT_DIR/{node}/sampling/sampled.pt``.
1578        variable : list of str or None, optional
1579            Subset of nodes to plot. If None, all nodes in the configuration
1580            are considered.
1581        bins : int, optional
1582            Number of histogram bins for continuous variables. Default is 100.
1583        hist_true_color : str, optional
1584            Color name for the histogram of true values. Default is "blue".
1585        hist_est_color : str, optional
1586            Color name for the histogram of sampled values. Default is "orange".
1587        figsize : tuple, optional
1588            Figure size for the matplotlib plots. Default is (14, 5).
1589
1590        Returns
1591        -------
1592        None
1593
1594        Notes
1595        -----
1596        - Continuous outcomes: histogram overlay + QQ-plot.
1597        - Ordinal outcomes: side-by-side bar plot of relative frequencies.
1598        - Other categorical outcomes: side-by-side bar plot with category labels.
1599        - If samples are probabilistic (2D tensor), the argmax across classes is used.
1600        """
1601        
1602        target_nodes = self.cfg.conf_dict["nodes"]
1603        experiment_dir = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1604        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1605
1606        plot_list = variable if variable is not None else target_nodes
1607
1608        for node in plot_list:
1609            # Load sampled data
1610            if sampled is not None and node in sampled:
1611                sdata = sampled[node]
1612                if isinstance(sdata, torch.Tensor):
1613                    sampled_vals = sdata.detach().cpu().numpy()
1614                else:
1615                    sampled_vals = np.asarray(sdata)
1616            else:
1617                sample_path = os.path.join(experiment_dir, f"{node}/sampling/sampled.pt")
1618                if not os.path.isfile(sample_path):
1619                    print(f"[WARNING] skip {node}: {sample_path} not found.")
1620                    continue
1621
1622                try:
1623                    sampled_vals = torch.load(sample_path, map_location=device).cpu().numpy()
1624                except Exception as e:
1625                    print(f"[ERROR] Could not load {sample_path}: {e}")
1626                    continue
1627
1628            # If logits/probabilities per sample, take argmax
1629            if sampled_vals.ndim == 2:
1630                    print(f"[INFO] CAUTION! {node}: samples are probabilistic — each sample follows a probability "
1631                    f"distribution based on the valid latent range. "
1632                    f"Note that this frequency plot reflects only the distribution of the most probable "
1633                    f"class per sample.")
1634                    sampled_vals = np.argmax(sampled_vals, axis=1)
1635
1636            sampled_vals = sampled_vals[np.isfinite(sampled_vals)]
1637
1638            if node not in df.columns:
1639                print(f"[WARNING] skip {node}: column not found in DataFrame.")
1640                continue
1641
1642            true_vals = df[node].dropna().values
1643            true_vals = true_vals[np.isfinite(true_vals)]
1644
1645            if sampled_vals.size == 0 or true_vals.size == 0:
1646                print(f"[WARNING] skip {node}: empty array after NaN/Inf removal.")
1647                continue
1648
1649            fig, axs = plt.subplots(1, 2, figsize=figsize)
1650
1651            if is_outcome_modelled_continous(node, target_nodes):
1652                axs[0].hist(true_vals, bins=bins, density=True, alpha=0.6,
1653                            color=hist_true_color, label=f"True {node}")
1654                axs[0].hist(sampled_vals, bins=bins, density=True, alpha=0.6,
1655                            color=hist_est_color, label="Sampled")
1656                axs[0].set_xlabel("Value")
1657                axs[0].set_ylabel("Density")
1658                axs[0].set_title(f"Histogram overlay for {node}")
1659                axs[0].legend()
1660                axs[0].grid(True, ls="--", alpha=0.4)
1661
1662                qqplot_2samples(true_vals, sampled_vals, line="45", ax=axs[1])
1663                axs[1].set_xlabel("True quantiles")
1664                axs[1].set_ylabel("Sampled quantiles")
1665                axs[1].set_title(f"QQ plot for {node}")
1666                axs[1].grid(True, ls="--", alpha=0.4)
1667
1668            elif is_outcome_modelled_ordinal(node, target_nodes):
1669                unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals))
1670                unique_vals = np.sort(unique_vals)
1671                true_counts = np.array([(true_vals == val).sum() for val in unique_vals])
1672                sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals])
1673
1674                axs[0].bar(unique_vals - 0.2, true_counts / true_counts.sum(),
1675                        width=0.4, color=hist_true_color, alpha=0.7, label="True")
1676                axs[0].bar(unique_vals + 0.2, sampled_counts / sampled_counts.sum(),
1677                        width=0.4, color=hist_est_color, alpha=0.7, label="Sampled")
1678                axs[0].set_xticks(unique_vals)
1679                axs[0].set_xlabel("Ordinal Level")
1680                axs[0].set_ylabel("Relative Frequency")
1681                axs[0].set_title(f"Ordinal bar plot for {node}")
1682                axs[0].legend()
1683                axs[0].grid(True, ls="--", alpha=0.4)
1684                axs[1].axis("off")
1685
1686            else:
1687                unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals))
1688                unique_vals = sorted(unique_vals, key=str)
1689                true_counts = np.array([(true_vals == val).sum() for val in unique_vals])
1690                sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals])
1691
1692                axs[0].bar(np.arange(len(unique_vals)) - 0.2, true_counts / true_counts.sum(),
1693                        width=0.4, color=hist_true_color, alpha=0.7, label="True")
1694                axs[0].bar(np.arange(len(unique_vals)) + 0.2, sampled_counts / sampled_counts.sum(),
1695                        width=0.4, color=hist_est_color, alpha=0.7, label="Sampled")
1696                axs[0].set_xticks(np.arange(len(unique_vals)))
1697                axs[0].set_xticklabels(unique_vals, rotation=45)
1698                axs[0].set_xlabel("Category")
1699                axs[0].set_ylabel("Relative Frequency")
1700                axs[0].set_title(f"Categorical bar plot for {node}")
1701                axs[0].legend()
1702                axs[0].grid(True, ls="--", alpha=0.4)
1703                axs[1].axis("off")
1704
1705            plt.tight_layout()
1706            plt.show()
1707
1708    ## SUMMARY METHODS
1709    def nll(self,data,variables=None):
1710        """
1711        Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes.
1712
1713        This function evaluates trained TRAM models for each specified variable (node) 
1714        on the provided dataset. It performs forward passes only—no training, no weight 
1715        updates—and returns the mean NLL per node.
1716
1717        Parameters
1718        ----------
1719        data : object
1720            Input dataset or data source compatible with `_ensure_dataset`, containing 
1721            both inputs and targets for each node.
1722        variables : list[str], optional
1723            List of variable (node) names to evaluate. If None, all nodes in 
1724            `self.models` are evaluated.
1725
1726        Returns
1727        -------
1728        dict[str, float]
1729            Dictionary mapping each node name to its average NLL value.
1730
1731        Notes
1732        -----
1733        - Each model is evaluated independently on its respective DataLoader.
1734        - The normalization values (`min_max`) for each node are retrieved from 
1735          `self.minmax_dict[node]`.
1736        - The function uses `evaluate_tramdag_model()` for per-node evaluation.
1737        - Expected directory structure:
1738              `<EXPERIMENT_DIR>/<node>/`
1739          where each node directory contains the trained model.
1740        """
1741
1742        td_data = self._ensure_dataset(data, is_val=True)  
1743        variables_list = variables if variables != None else list(self.models.keys())
1744        nll_dict = {}
1745        for node in variables_list:  
1746                min_vals = torch.tensor(self.minmax_dict[node][0], dtype=torch.float32)
1747                max_vals = torch.tensor(self.minmax_dict[node][1], dtype=torch.float32)
1748                min_max = torch.stack([min_vals, max_vals], dim=0)
1749                data_loader = td_data.loaders[node]
1750                model = self.models[node]
1751                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1752                NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}")
1753                nll= evaluate_tramdag_model(node=node,
1754                                            target_nodes=self.nodes_dict,
1755                                            NODE_DIR=NODE_DIR,
1756                                            tram_model=model,
1757                                            data_loader=data_loader,
1758                                            min_max=min_max)
1759                nll_dict[node]=nll
1760        return nll_dict
1761    
1762    def get_train_val_nll(self, node: str, mode: str) -> tuple[float, float]:
1763        """
1764        Retrieve training and validation NLL for a node and a given model state.
1765
1766        Parameters
1767        ----------
1768        node : str
1769            Node name.
1770        mode : {"best", "last", "init"}
1771            State of interest:
1772            - "best": epoch with lowest validation NLL.
1773            - "last": final epoch.
1774            - "init": first epoch (index 0).
1775
1776        Returns
1777        -------
1778        tuple of (float or None, float or None)
1779            A tuple ``(train_nll, val_nll)`` for the requested mode.
1780            Returns ``(None, None)`` if loss files are missing or cannot be read.
1781
1782        Notes
1783        -----
1784        This method expects per-node JSON files:
1785
1786        - ``train_loss_hist.json``
1787        - ``val_loss_hist.json``
1788
1789        in the node directory.
1790        """
1791        NODE_DIR = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node)
1792        train_path = os.path.join(NODE_DIR, "train_loss_hist.json")
1793        val_path = os.path.join(NODE_DIR, "val_loss_hist.json")
1794
1795        if not os.path.exists(train_path) or not os.path.exists(val_path):
1796            if getattr(self, "debug", False):
1797                print(f"[DEBUG] Missing loss files for node '{node}'. Returning None.")
1798            return None, None
1799
1800        try:
1801            with open(train_path, "r") as f:
1802                train_hist = json.load(f)
1803            with open(val_path, "r") as f:
1804                val_hist = json.load(f)
1805
1806            train_nlls = np.array(train_hist)
1807            val_nlls = np.array(val_hist)
1808
1809            if mode == "init":
1810                idx = 0
1811            elif mode == "last":
1812                idx = len(val_nlls) - 1
1813            elif mode == "best":
1814                idx = int(np.argmin(val_nlls))
1815            else:
1816                raise ValueError(f"Invalid mode '{mode}' — must be one of 'best', 'last', 'init'.")
1817
1818            train_nll = float(train_nlls[idx])
1819            val_nll = float(val_nlls[idx])
1820            return train_nll, val_nll
1821
1822        except Exception as e:
1823            print(f"[ERROR] Failed to load NLLs for node '{node}' ({mode}): {e}")
1824            return None, None
1825
1826    def get_thetas(self, node: str, state: str = "best"):
1827        """
1828        Return transformed intercept (theta) parameters for a node and state.
1829
1830        Parameters
1831        ----------
1832        node : str
1833            Node name.
1834        state : {"best", "last", "init"}, optional
1835            Model state for which to return parameters. Default is "best".
1836
1837        Returns
1838        -------
1839        Any or None
1840            Transformed theta parameters for the requested node and state.
1841            The exact structure (scalar, list, or other) depends on the model.
1842
1843        Raises
1844        ------
1845        ValueError
1846            If an invalid state is given (not in {"best", "last", "init"}).
1847
1848        Notes
1849        -----
1850        Intercept dictionaries are cached on the instance under the attribute
1851        ``intercept_dicts``. If missing or incomplete, they are recomputed using
1852        `get_simple_intercepts_dict`.
1853        """
1854
1855        state = state.lower()
1856        if state not in ["best", "last", "init"]:
1857            raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].")
1858
1859        dict_attr = "intercept_dicts"
1860
1861        # If no cached intercepts exist, compute them
1862        if not hasattr(self, dict_attr):
1863            if getattr(self, "debug", False):
1864                print(f"[DEBUG] '{dict_attr}' not found, computing via get_simple_intercepts_dict().")
1865            setattr(self, dict_attr, self.get_simple_intercepts_dict())
1866
1867        all_dicts = getattr(self, dict_attr)
1868
1869        # If the requested state isn’t cached, recompute
1870        if state not in all_dicts:
1871            if getattr(self, "debug", False):
1872                print(f"[DEBUG] State '{state}' not found in cached intercepts, recomputing full dict.")
1873            setattr(self, dict_attr, self.get_simple_intercepts_dict())
1874            all_dicts = getattr(self, dict_attr)
1875
1876        state_dict = all_dicts.get(state, {})
1877
1878        # Return cached node intercept if present
1879        if node in state_dict:
1880            return state_dict[node]
1881
1882        # If not found, recompute full dict as fallback
1883        if getattr(self, "debug", False):
1884            print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.")
1885        setattr(self, dict_attr, self.get_simple_intercepts_dict())
1886        all_dicts = getattr(self, dict_attr)
1887        return all_dicts.get(state, {}).get(node, None)
1888        
1889    def get_linear_shifts(self, node: str, state: str = "best"):
1890        """
1891        Return learned linear shift terms for a node and a given state.
1892
1893        Parameters
1894        ----------
1895        node : str
1896            Node name.
1897        state : {"best", "last", "init"}, optional
1898            Model state for which to return linear shift terms. Default is "best".
1899
1900        Returns
1901        -------
1902        dict or Any or None
1903            Linear shift terms for the given node and state. Usually a dict
1904            mapping term names to weights.
1905
1906        Raises
1907        ------
1908        ValueError
1909            If an invalid state is given (not in {"best", "last", "init"}).
1910
1911        Notes
1912        -----
1913        Linear shift dictionaries are cached on the instance under the attribute
1914        ``linear_shift_dicts``. If missing or incomplete, they are recomputed using
1915        `get_linear_shifts_dict`.
1916        """
1917        state = state.lower()
1918        if state not in ["best", "last", "init"]:
1919            raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].")
1920
1921        dict_attr = "linear_shift_dicts"
1922
1923        # If no global dicts cached, compute once
1924        if not hasattr(self, dict_attr):
1925            if getattr(self, "debug", False):
1926                print(f"[DEBUG] '{dict_attr}' not found, computing via get_linear_shifts_dict().")
1927            setattr(self, dict_attr, self.get_linear_shifts_dict())
1928
1929        all_dicts = getattr(self, dict_attr)
1930
1931        # If the requested state isn't cached, compute all again (covers fresh runs)
1932        if state not in all_dicts:
1933            if getattr(self, "debug", False):
1934                print(f"[DEBUG] State '{state}' not found in cached linear shifts, recomputing full dict.")
1935            setattr(self, dict_attr, self.get_linear_shifts_dict())
1936            all_dicts = getattr(self, dict_attr)
1937
1938        # Now fetch the dictionary for this state
1939        state_dict = all_dicts.get(state, {})
1940
1941        # If the node is available, return its entry
1942        if node in state_dict:
1943            return state_dict[node]
1944
1945        # If missing, try recomputing (fallback)
1946        if getattr(self, "debug", False):
1947            print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.")
1948        setattr(self, dict_attr, self.get_linear_shifts_dict())
1949        all_dicts = getattr(self, dict_attr)
1950        return all_dicts.get(state, {}).get(node, None)
1951
1952    def get_linear_shifts_dict(self):
1953        """
1954        Compute linear shift term dictionaries for all nodes and states.
1955
1956        For each node and each available state ("best", "last", "init"), this
1957        method loads the corresponding model checkpoint, extracts linear shift
1958        weights from the TRAM model, and stores them in a nested dictionary.
1959
1960        Returns
1961        -------
1962        dict
1963            Nested dictionary of the form:
1964
1965            .. code-block:: python
1966
1967                {
1968                    "best": {node: {...}},
1969                    "last": {node: {...}},
1970                    "init": {node: {...}},
1971                }
1972
1973            where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``)
1974            to their weights.
1975
1976        Notes
1977        -----
1978        - If "best" or "last" checkpoints are unavailable for a node, only
1979        the "init" entry is populated.
1980        - Empty outer states (without any nodes) are removed from the result.
1981        """
1982
1983        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1984        nodes_list = list(self.models.keys())
1985        all_states = ["best", "last", "init"]
1986        all_linear_shift_dicts = {state: {} for state in all_states}
1987
1988        for node in nodes_list:
1989            NODE_DIR = os.path.join(EXPERIMENT_DIR, node)
1990            BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR)
1991            INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt")
1992
1993            state_paths = {
1994                "best": BEST_MODEL_PATH,
1995                "last": LAST_MODEL_PATH,
1996                "init": INIT_MODEL_PATH,
1997            }
1998
1999            for state, LOAD_PATH in state_paths.items():
2000                if not os.path.exists(LOAD_PATH):
2001                    if state != "init":
2002                        # skip best/last if unavailable
2003                        continue
2004                    else:
2005                        print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.")
2006                        if not os.path.exists(LOAD_PATH):
2007                            if getattr(self, "debug", False):
2008                                print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.")
2009                            continue
2010
2011                # Load parents and model
2012                _, terms_dict, _ = ordered_parents(node, self.nodes_dict)
2013                state_dict = torch.load(LOAD_PATH, map_location=self.device)
2014                tram_model = self.models[node]
2015                tram_model.load_state_dict(state_dict)
2016
2017                epoch_weights = {}
2018                if hasattr(tram_model, "nn_shift") and tram_model.nn_shift is not None:
2019                    for i, shift_layer in enumerate(tram_model.nn_shift):
2020                        module_name = shift_layer.__class__.__name__
2021                        if (
2022                            hasattr(shift_layer, "fc")
2023                            and hasattr(shift_layer.fc, "weight")
2024                            and module_name == "LinearShift"
2025                        ):
2026                            term_name = list(terms_dict.keys())[i]
2027                            epoch_weights[f"ls({term_name})"] = (
2028                                shift_layer.fc.weight.detach().cpu().squeeze().tolist()
2029                            )
2030                        elif getattr(self, "debug", False):
2031                            term_name = list(terms_dict.keys())[i]
2032                            print(f"[DEBUG] ls({term_name}): missing 'fc' or 'weight' in LinearShift.")
2033                else:
2034                    if getattr(self, "debug", False):
2035                        print(f"[DEBUG] Tram model for node '{node}' has no nn_shift or it is None.")
2036
2037                all_linear_shift_dicts[state][node] = epoch_weights
2038
2039        # Remove empty states (e.g., when best/last not found for all nodes)
2040        all_linear_shift_dicts = {k: v for k, v in all_linear_shift_dicts.items() if v}
2041
2042        return all_linear_shift_dicts
2043
2044    def get_simple_intercepts_dict(self):
2045        """
2046        Compute transformed simple intercept dictionaries for all nodes and states.
2047
2048        For each node and each available state ("best", "last", "init"), this
2049        method loads the corresponding model checkpoint, extracts simple intercept
2050        weights, transforms them into interpretable theta parameters, and stores
2051        them in a nested dictionary.
2052
2053        Returns
2054        -------
2055        dict
2056            Nested dictionary of the form:
2057
2058            .. code-block:: python
2059
2060                {
2061                    "best": {node: [[theta_1], [theta_2], ...]},
2062                    "last": {node: [[theta_1], [theta_2], ...]},
2063                    "init": {node: [[theta_1], [theta_2], ...]},
2064                }
2065
2066        Notes
2067        -----
2068        - For ordinal models (``self.is_ontram == True``), `transform_intercepts_ordinal`
2069        is used.
2070        - For continuous models, `transform_intercepts_continous` is used.
2071        - Empty outer states (without any nodes) are removed from the result.
2072        """
2073
2074        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
2075        nodes_list = list(self.models.keys())
2076        all_states = ["best", "last", "init"]
2077        all_si_intercept_dicts = {state: {} for state in all_states}
2078
2079        debug = getattr(self, "debug", False)
2080        verbose = getattr(self, "verbose", False)
2081        is_ontram = getattr(self, "is_ontram", False)
2082
2083        for node in nodes_list:
2084            NODE_DIR = os.path.join(EXPERIMENT_DIR, node)
2085            BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR)
2086            INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt")
2087
2088            state_paths = {
2089                "best": BEST_MODEL_PATH,
2090                "last": LAST_MODEL_PATH,
2091                "init": INIT_MODEL_PATH,
2092            }
2093
2094            for state, LOAD_PATH in state_paths.items():
2095                if not os.path.exists(LOAD_PATH):
2096                    if state != "init":
2097                        continue
2098                    else:
2099                        print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.")
2100                        if not os.path.exists(LOAD_PATH):
2101                            if debug:
2102                                print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.")
2103                            continue
2104
2105                # Load model state
2106                state_dict = torch.load(LOAD_PATH, map_location=self.device)
2107                tram_model = self.models[node]
2108                tram_model.load_state_dict(state_dict)
2109
2110                # Extract and transform simple intercept weights
2111                si_weights = None
2112                if hasattr(tram_model, "nn_int") and tram_model.nn_int is not None and isinstance(tram_model.nn_int, SimpleIntercept):
2113                    if hasattr(tram_model.nn_int, "fc") and hasattr(tram_model.nn_int.fc, "weight"):
2114                        weights = tram_model.nn_int.fc.weight.detach().cpu().tolist()
2115                        weights_tensor = torch.Tensor(weights)
2116
2117                        if debug:
2118                            print(f"[DEBUG] Node '{node}' ({state}) theta tilde shape: {weights_tensor.shape}")
2119
2120                        if is_ontram:
2121                            si_weights = transform_intercepts_ordinal(weights_tensor.reshape(1, -1))[:, 1:-1].reshape(-1, 1)
2122                        else:
2123                            si_weights = transform_intercepts_continous(weights_tensor.reshape(1, -1)).reshape(-1, 1)
2124
2125                        si_weights = si_weights.tolist()
2126
2127                        if debug:
2128                            print(f"[DEBUG] Node '{node}' ({state}) theta transformed: {si_weights}")
2129                    else:
2130                        if debug:
2131                            print(f"[DEBUG] Node '{node}' ({state}): missing 'fc' or 'weight' in SimpleIntercept.")
2132                else:
2133                    if debug:
2134                        print(f"[DEBUG] Tram model for node '{node}' has no nn_int or it is None.")
2135
2136                all_si_intercept_dicts[state][node] = si_weights
2137
2138        # Clean up empty states
2139        all_si_intercept_dicts = {k: v for k, v in all_si_intercept_dicts.items() if v}
2140        return all_si_intercept_dicts
2141       
2142    def summary(self, verbose=False):
2143        """
2144        Print a multi-part textual summary of the TramDagModel.
2145
2146        The summary includes:
2147        1. Training metrics overview per node (best/last NLL, epochs).
2148        2. Node-specific details (thetas, linear shifts, optional architecture).
2149        3. Basic information about the attached training DataFrame, if present.
2150
2151        Parameters
2152        ----------
2153        verbose : bool, optional
2154            If True, include extended per-node details such as the model
2155            architecture, parameter count, and availability of checkpoints
2156            and sampling results. Default is False.
2157
2158        Returns
2159        -------
2160        None
2161
2162        Notes
2163        -----
2164        This method prints to stdout and does not return structured data.
2165        It is intended for quick, human-readable inspection of the current
2166        training and model state.
2167        """
2168
2169        # ---------- SETUP ----------
2170        try:
2171            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
2172        except KeyError:
2173            EXPERIMENT_DIR = None
2174            print("[WARNING] Missing EXPERIMENT_DIR in cfg.conf_dict['PATHS'].")
2175
2176        print("\n" + "=" * 120)
2177        print(f"{'TRAM DAG MODEL SUMMARY':^120}")
2178        print("=" * 120)
2179
2180        # ---------- METRICS OVERVIEW ----------
2181        summary_data = []
2182        for node in self.models.keys():
2183            node_dir = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node)
2184            train_path = os.path.join(node_dir, "train_loss_hist.json")
2185            val_path = os.path.join(node_dir, "val_loss_hist.json")
2186
2187            if os.path.exists(train_path) and os.path.exists(val_path):
2188                best_train_nll, best_val_nll = self.get_train_val_nll(node, "best")
2189                last_train_nll, last_val_nll = self.get_train_val_nll(node, "last")
2190                n_epochs_total = len(json.load(open(train_path)))
2191            else:
2192                best_train_nll = best_val_nll = last_train_nll = last_val_nll = None
2193                n_epochs_total = 0
2194
2195            summary_data.append({
2196                "Node": node,
2197                "Best Train NLL": best_train_nll,
2198                "Best Val NLL": best_val_nll,
2199                "Last Train NLL": last_train_nll,
2200                "Last Val NLL": last_val_nll,
2201                "Epochs": n_epochs_total,
2202            })
2203
2204        df_summary = pd.DataFrame(summary_data)
2205        df_summary = df_summary.round(4)
2206
2207        print("\n[1] TRAINING METRICS OVERVIEW")
2208        print("-" * 120)
2209        if not df_summary.empty:
2210            print(
2211                df_summary.to_string(
2212                    index=False,
2213                    justify="center",
2214                    col_space=14,
2215                    float_format=lambda x: f"{x:7.4f}",
2216                )
2217            )
2218        else:
2219            print("No training history found for any node.")
2220        print("-" * 120)
2221
2222        # ---------- NODE DETAILS ----------
2223        print("\n[2] NODE-SPECIFIC DETAILS")
2224        print("-" * 120)
2225        for node in self.models.keys():
2226            print(f"\n{f'NODE: {node}':^120}")
2227            print("-" * 120)
2228
2229            # THETAS & SHIFTS
2230            for state in ["init", "last", "best"]:
2231                print(f"\n  [{state.upper()} STATE]")
2232
2233                # ---- Thetas ----
2234                try:
2235                    thetas = getattr(self, "get_thetas", lambda n, s=None: None)(node, state)
2236                    if thetas is not None:
2237                        if isinstance(thetas, (list, np.ndarray, pd.Series)):
2238                            thetas_flat = np.array(thetas).flatten()
2239                            compact = np.round(thetas_flat, 4)
2240                            arr_str = np.array2string(
2241                                compact,
2242                                max_line_width=110,
2243                                threshold=np.inf,
2244                                separator=", "
2245                            )
2246                            lines = arr_str.split("\n")
2247                            if len(lines) > 2:
2248                                arr_str = "\n".join(lines[:2]) + " ..."
2249                            print(f"    Θ ({len(thetas_flat)}): {arr_str}")
2250                        elif isinstance(thetas, dict):
2251                            for k, v in thetas.items():
2252                                print(f"     Θ[{k}]: {v}")
2253                        else:
2254                            print(f"    Θ: {thetas}")
2255                    else:
2256                        print("    Θ: not available")
2257                except Exception as e:
2258                    print(f"    [Error loading thetas] {e}")
2259
2260                # ---- Linear Shifts ----
2261                try:
2262                    linear_shifts = getattr(self, "get_linear_shifts", lambda n, s=None: None)(node, state)
2263                    if linear_shifts is not None:
2264                        if isinstance(linear_shifts, dict):
2265                            for k, v in linear_shifts.items():
2266                                print(f"     {k}: {np.round(v, 4)}")
2267                        elif isinstance(linear_shifts, (list, np.ndarray, pd.Series)):
2268                            arr = np.round(linear_shifts, 4)
2269                            print(f"    Linear shifts ({len(arr)}): {arr}")
2270                        else:
2271                            print(f"    Linear shifts: {linear_shifts}")
2272                    else:
2273                        print("    Linear shifts: not available")
2274                except Exception as e:
2275                    print(f"    [Error loading linear shifts] {e}")
2276
2277            # ---- Verbose info directly below node ----
2278            if verbose:
2279                print("\n  [DETAILS]")
2280                node_dir = os.path.join(EXPERIMENT_DIR, node) if EXPERIMENT_DIR else None
2281                model = self.models[node]
2282
2283                print(f"    Model Architecture:")
2284                arch_str = str(model).split("\n")
2285                for line in arch_str:
2286                    print(f"      {line}")
2287                print(f"    Parameter count: {sum(p.numel() for p in model.parameters()):,}")
2288
2289                if node_dir and os.path.exists(node_dir):
2290                    ckpt_exists = any(f.endswith(('.pt', '.pth')) for f in os.listdir(node_dir))
2291                    print(f"    Checkpoints found: {ckpt_exists}")
2292
2293                    sampling_dir = os.path.join(node_dir, "sampling")
2294                    sampling_exists = os.path.isdir(sampling_dir) and len(os.listdir(sampling_dir)) > 0
2295                    print(f"    Sampling results found: {sampling_exists}")
2296
2297                    for label, filename in [("Train", "train_loss_hist.json"), ("Validation", "val_loss_hist.json")]:
2298                        path = os.path.join(node_dir, filename)
2299                        if os.path.exists(path):
2300                            try:
2301                                with open(path, "r") as f:
2302                                    hist = json.load(f)
2303                                print(f"    {label} history: {len(hist)} epochs")
2304                            except Exception as e:
2305                                print(f"    {label} history: failed to load ({e})")
2306                        else:
2307                            print(f"    {label} history: not found")
2308                else:
2309                    print("    [INFO] No experiment directory defined or missing for this node.")
2310            print("-" * 120)
2311
2312        # ---------- TRAINING DATAFRAME ----------
2313        print("\n[3] TRAINING DATAFRAME")
2314        print("-" * 120)
2315        try:
2316            self.train_df.info()
2317        except AttributeError:
2318            print("No training DataFrame attached to this TramDagModel.")
2319        print("=" * 120 + "\n")

Probabilistic DAG model built from node-wise TRAMs (transformation models).

This class manages:

  • Configuration and per-node model construction.
  • Data scaling (min–max).
  • Training (sequential or per-node parallel on CPU).
  • Diagnostics (loss history, intercepts, linear shifts, latents).
  • Sampling from the joint DAG and loading stored samples.
  • High-level summaries and plotting utilities.
TramDagModel()
111    def __init__(self):
112        """
113        Initialize an empty TramDagModel shell.
114
115        Notes
116        -----
117        This constructor does not build any node models and does not attach a
118        configuration. Use `TramDagModel.from_config` or `TramDagModel.from_directory`
119        to obtain a fully configured and ready-to-use instance.
120        """
121        
122        self.debug = False
123        self.verbose = False
124        self.device = 'auto'
125        pass

Initialize an empty TramDagModel shell.

Notes

This constructor does not build any node models and does not attach a configuration. Use TramDagModel.from_config or TramDagModel.from_directory to obtain a fully configured and ready-to-use instance.

DEFAULTS_CONFIG = {'set_initial_weights': False, 'debug': False, 'verbose': False, 'device': 'auto', 'initial_data': None, 'overwrite_initial_weights': True}
DEFAULTS_FIT = {'epochs': 100, 'train_list': None, 'callbacks': None, 'learning_rate': 0.01, 'device': 'auto', 'optimizers': None, 'schedulers': None, 'use_scheduler': False, 'save_linear_shifts': True, 'save_simple_intercepts': True, 'debug': False, 'verbose': True, 'train_mode': 'sequential', 'return_history': False, 'overwrite_inital_weights': True, 'num_workers': 4, 'persistent_workers': True, 'prefetch_factor': 4, 'batch_size': 1000}
debug
verbose
device
@staticmethod
def get_device(settings):
127    @staticmethod
128    def get_device(settings):
129        """
130        Resolve the target device string from a settings dictionary.
131
132        Parameters
133        ----------
134        settings : dict
135            Dictionary containing at least a key ``"device"`` with one of
136            {"auto", "cpu", "cuda"}. If missing, "auto" is assumed.
137
138        Returns
139        -------
140        str
141            Device string, either "cpu" or "cuda".
142
143        Notes
144        -----
145        If ``device == "auto"``, CUDA is selected if available, otherwise CPU.
146        """
147        device_arg = settings.get("device", "auto")
148        if device_arg == "auto":
149            device_str = "cuda" if torch.cuda.is_available() else "cpu"
150        else:
151            device_str = device_arg
152        return device_str

Resolve the target device string from a settings dictionary.

Parameters

settings : dict Dictionary containing at least a key "device" with one of {"auto", "cpu", "cuda"}. If missing, "auto" is assumed.

Returns

str Device string, either "cpu" or "cuda".

Notes

If device == "auto", CUDA is selected if available, otherwise CPU.

@classmethod
def from_config(cls, cfg, **kwargs):
185    @classmethod
186    def from_config(cls, cfg, **kwargs):
187        """
188        Construct a TramDagModel from a TramDagConfig object.
189
190        This builds one TRAM model per node in the DAG and optionally writes
191        the initial model parameters to disk.
192
193        Parameters
194        ----------
195        cfg : TramDagConfig
196            Configuration wrapper holding the underlying configuration dictionary,
197            including at least:
198            - ``conf_dict["nodes"]``: mapping of node names to node configs.
199            - ``conf_dict["PATHS"]["EXPERIMENT_DIR"]``: experiment directory.
200        **kwargs
201            Node-level construction options. Each key must be present in
202            ``DEFAULTS_CONFIG``. Values can be:
203            - scalar: applied to all nodes.
204            - dict: mapping ``{node_name: value}`` for per-node overrides.
205
206            Common keys include:
207            device : {"auto", "cpu", "cuda"}, default "auto"
208                Device selection (CUDA if available when "auto").
209            debug : bool, default False
210                If True, print debug messages.
211            verbose : bool, default False
212                If True, print informational messages.
213            set_initial_weights : bool
214                Passed to underlying TRAM model constructors.
215            overwrite_initial_weights : bool, default True
216                If True, overwrite any existing ``initial_model.pt`` files per node.
217            initial_data : Any
218                Optional object passed down to node constructors.
219
220        Returns
221        -------
222        TramDagModel
223            Fully initialized instance with:
224            - ``cfg``
225            - ``nodes_dict``
226            - ``models`` (per-node TRAMs)
227            - ``settings`` (resolved per-node config)
228
229        Raises
230        ------
231        ValueError
232            If any dict-valued kwarg does not provide values for exactly the set
233            of nodes in ``cfg.conf_dict["nodes"]``.
234        """
235        
236        self = cls()
237        self.cfg = cfg
238        self.cfg.update()  # ensure latest version from disk
239        self.cfg._verify_completeness()
240        
241        
242        try:
243            self.cfg.save()  # persist back to disk
244            if getattr(self, "debug", False):
245                print("[DEBUG] Configuration updated and saved.")
246        except Exception as e:
247            print(f"[WARNING] Could not save configuration after update: {e}")        
248            
249        self.nodes_dict = self.cfg.conf_dict["nodes"] 
250
251        self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_CONFIG', context="from_config")
252
253        # update defaults with kwargs
254        settings = dict(cls.DEFAULTS_CONFIG)
255        settings.update(kwargs)
256
257        # resolve device
258        device_arg = settings.get("device", "auto")
259        if device_arg == "auto":
260            device_str = "cuda" if torch.cuda.is_available() else "cpu"
261        else:
262            device_str = device_arg
263        self.device = torch.device(device_str)
264
265        # set flags on the instance so they are accessible later
266        self.debug = settings.get("debug", False)
267        self.verbose = settings.get("verbose", False)
268
269        if  self.debug:
270            print(f"[DEBUG] TramDagModel using device: {self.device}")
271            
272        # initialize settings storage
273        self.settings = {k: {} for k in settings.keys()}
274
275        # validate dict-typed args
276        for k, v in settings.items():
277            if isinstance(v, dict):
278                expected = set(self.nodes_dict.keys())
279                given = set(v.keys())
280                if expected != given:
281                    raise ValueError(
282                        f"[ERROR] the provided argument '{k}' keys are not same as in cfg.conf_dict['nodes'].keys().\n"
283                        f"Expected: {expected}, but got: {given}\n"
284                        f"Please provide values for all variables.")
285
286        # build one model per node
287        self.models = {}
288        for node in self.nodes_dict.keys():
289            per_node_kwargs = {}
290            for k, v in settings.items():
291                resolved = v[node] if isinstance(v, dict) else v
292                per_node_kwargs[k] = resolved
293                self.settings[k][node] = resolved
294            if self.debug:
295                print(f"\n[INFO] Building model for node '{node}' with settings: {per_node_kwargs}")
296            self.models[node] = get_fully_specified_tram_model(
297                node=node,
298                configuration_dict=self.cfg.conf_dict,
299                **per_node_kwargs)
300            
301            try:
302                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
303                NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}")
304                os.makedirs(NODE_DIR, exist_ok=True)
305
306                model_path = os.path.join(NODE_DIR, "initial_model.pt")
307                overwrite = settings.get("overwrite_initial_weights", True)
308
309                if overwrite or not os.path.exists(model_path):
310                    torch.save(self.models[node].state_dict(), model_path)
311                    if self.debug:
312                        print(f"[DEBUG] Saved initial model state for node '{node}' to {model_path} (overwrite={overwrite})")
313                else:
314                    if self.debug:
315                        print(f"[DEBUG] Skipped saving initial model for node '{node}' (already exists at {model_path})")
316            except Exception as e:
317                print(f"[ERROR] Could not save initial model state for node '{node}': {e}")
318            
319            TEMP_DIR = "temp"
320            if os.path.isdir(TEMP_DIR) and not os.listdir(TEMP_DIR):
321                os.rmdir(TEMP_DIR)
322                            
323        return self

Construct a TramDagModel from a TramDagConfig object.

This builds one TRAM model per node in the DAG and optionally writes the initial model parameters to disk.

Parameters

cfg : TramDagConfig Configuration wrapper holding the underlying configuration dictionary, including at least: - conf_dict["nodes"]: mapping of node names to node configs. - conf_dict["PATHS"]["EXPERIMENT_DIR"]: experiment directory. **kwargs Node-level construction options. Each key must be present in DEFAULTS_CONFIG. Values can be: - scalar: applied to all nodes. - dict: mapping {node_name: value} for per-node overrides.

Common keys include:
device : {"auto", "cpu", "cuda"}, default "auto"
    Device selection (CUDA if available when "auto").
debug : bool, default False
    If True, print debug messages.
verbose : bool, default False
    If True, print informational messages.
set_initial_weights : bool
    Passed to underlying TRAM model constructors.
overwrite_initial_weights : bool, default True
    If True, overwrite any existing ``initial_model.pt`` files per node.
initial_data : Any
    Optional object passed down to node constructors.

Returns

TramDagModel Fully initialized instance with: - cfg - nodes_dict - models (per-node TRAMs) - settings (resolved per-node config)

Raises

ValueError If any dict-valued kwarg does not provide values for exactly the set of nodes in cfg.conf_dict["nodes"].

@classmethod
def from_directory( cls, EXPERIMENT_DIR: str, device: str = 'auto', debug: bool = False, verbose: bool = False):
325    @classmethod
326    def from_directory(cls, EXPERIMENT_DIR: str, device: str = "auto", debug: bool = False, verbose: bool = False):
327        """
328        Reconstruct a TramDagModel from an experiment directory on disk.
329
330        This method:
331        1. Loads the configuration JSON.
332        2. Wraps it in a TramDagConfig.
333        3. Builds all node models via `from_config`.
334        4. Loads the min–max scaling dictionary.
335
336        Parameters
337        ----------
338        EXPERIMENT_DIR : str
339            Path to an experiment directory containing:
340            - ``configuration.json``
341            - ``min_max_scaling.json``.
342        device : {"auto", "cpu", "cuda"}, optional
343            Device selection. Default is "auto".
344        debug : bool, optional
345            If True, enable debug messages. Default is False.
346        verbose : bool, optional
347            If True, enable informational messages. Default is False.
348
349        Returns
350        -------
351        TramDagModel
352            A TramDagModel instance with models, config, and scaling loaded.
353
354        Raises
355        ------
356        FileNotFoundError
357            If configuration or min–max files cannot be found.
358        RuntimeError
359            If the min–max file cannot be read or parsed.
360        """
361
362        # --- load config file ---
363        config_path = os.path.join(EXPERIMENT_DIR, "configuration.json")
364        if not os.path.exists(config_path):
365            raise FileNotFoundError(f"[ERROR] Config file not found at {config_path}")
366
367        with open(config_path, "r") as f:
368            cfg_dict = json.load(f)
369
370        # Create TramConfig wrapper 
371        cfg = TramDagConfig(cfg_dict, CONF_DICT_PATH=config_path)
372
373        # --- build model from config ---
374        self = cls.from_config(cfg, device=device, debug=debug, verbose=verbose, overwrite_initial_weights=False)
375
376        # --- load minmax scaling ---
377        minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json")
378        if not os.path.exists(minmax_path):
379            raise FileNotFoundError(f"[ERROR] MinMax file not found at {minmax_path}")
380
381        with open(minmax_path, "r") as f:
382            self.minmax_dict = json.load(f)
383
384        if self.verbose or self.debug:
385            print(f"[INFO] Loaded TramDagModel from {EXPERIMENT_DIR}")
386            print(f"[INFO] Config loaded from {config_path}")
387            print(f"[INFO] MinMax scaling loaded from {minmax_path}")
388
389        return self

Reconstruct a TramDagModel from an experiment directory on disk.

This method:

  1. Loads the configuration JSON.
  2. Wraps it in a TramDagConfig.
  3. Builds all node models via from_config.
  4. Loads the min–max scaling dictionary.

Parameters

EXPERIMENT_DIR : str Path to an experiment directory containing: - configuration.json - min_max_scaling.json. device : {"auto", "cpu", "cuda"}, optional Device selection. Default is "auto". debug : bool, optional If True, enable debug messages. Default is False. verbose : bool, optional If True, enable informational messages. Default is False.

Returns

TramDagModel A TramDagModel instance with models, config, and scaling loaded.

Raises

FileNotFoundError If configuration or min–max files cannot be found. RuntimeError If the min–max file cannot be read or parsed.

def load_or_compute_minmax(self, td_train_data=None, use_existing=False, write=True):
429    def load_or_compute_minmax(self, td_train_data=None,use_existing=False, write=True):
430        """
431        Load an existing Min–Max scaling dictionary from disk or compute a new one 
432        from the provided training dataset.
433
434        Parameters
435        ----------
436        use_existing : bool, optional (default=False)
437            If True, attempts to load an existing `min_max_scaling.json` file 
438            from the experiment directory. Raises an error if the file is missing 
439            or unreadable.
440
441        write : bool, optional (default=True)
442            If True, writes the computed Min–Max scaling dictionary to 
443            `<EXPERIMENT_DIR>/min_max_scaling.json`.
444
445        td_train_data : object, optional
446            Training dataset used to compute scaling statistics. If not provided,
447            the method will ensure or construct it via `_ensure_dataset(data=..., is_val=False)`.
448
449        Behavior
450        --------
451        - If `use_existing=True`, loads the JSON file containing previously saved 
452          min–max values and stores it in `self.minmax_dict`.
453        - If `use_existing=False`, computes a new scaling dictionary using 
454          `td_train_data.compute_scaling()` and stores the result in 
455          `self.minmax_dict`.
456        - Optionally writes the computed dictionary to disk.
457
458        Side Effects
459        -------------
460        - Populates `self.minmax_dict` with scaling values.
461        - Writes or loads the file `min_max_scaling.json` under 
462          `<EXPERIMENT_DIR>`.
463        - Prints diagnostic output if `self.debug` or `self.verbose` is True.
464
465        Raises
466        ------
467        FileNotFoundError
468            If `use_existing=True` but the min–max file does not exist.
469
470        RuntimeError
471            If an existing min–max file cannot be read or parsed.
472
473        Notes
474        -----
475        The computed min–max dictionary is expected to contain scaling statistics 
476        per feature, typically in the form:
477            {
478                "node": {"min": float, "max": float},
479                ...
480            }
481        """
482        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
483        minmax_path = os.path.join(EXPERIMENT_DIR, "min_max_scaling.json")
484
485        # laod exisitng if possible
486        if use_existing:
487            if not os.path.exists(minmax_path):
488                raise FileNotFoundError(f"MinMax file not found: {minmax_path}")
489            try:
490                with open(minmax_path, 'r') as f:
491                    self.minmax_dict = json.load(f)
492                if self.debug or self.verbose:
493                    print(f"[INFO] Loaded existing minmax dict from {minmax_path}")
494                return
495            except Exception as e:
496                raise RuntimeError(f"Could not load existing minmax dict: {e}")
497
498        # 
499        if self.debug or self.verbose:
500            print("[INFO] Computing new minmax dict from training data...")
501            
502        td_train_data=self._ensure_dataset( data=td_train_data, is_val=False)    
503            
504        self.minmax_dict = td_train_data.compute_scaling()
505
506        if write:
507            os.makedirs(EXPERIMENT_DIR, exist_ok=True)
508            with open(minmax_path, 'w') as f:
509                json.dump(self.minmax_dict, f, indent=4)
510            if self.debug or self.verbose:
511                print(f"[INFO] Saved new minmax dict to {minmax_path}")

Load an existing Min–Max scaling dictionary from disk or compute a new one from the provided training dataset.

Parameters

use_existing : bool, optional (default=False) If True, attempts to load an existing min_max_scaling.json file from the experiment directory. Raises an error if the file is missing or unreadable.

write : bool, optional (default=True) If True, writes the computed Min–Max scaling dictionary to <EXPERIMENT_DIR>/min_max_scaling.json.

td_train_data : object, optional Training dataset used to compute scaling statistics. If not provided, the method will ensure or construct it via _ensure_dataset(data=..., is_val=False).

Behavior

  • If use_existing=True, loads the JSON file containing previously saved min–max values and stores it in self.minmax_dict.
  • If use_existing=False, computes a new scaling dictionary using td_train_data.compute_scaling() and stores the result in self.minmax_dict.
  • Optionally writes the computed dictionary to disk.

Side Effects

  • Populates self.minmax_dict with scaling values.
  • Writes or loads the file min_max_scaling.json under <EXPERIMENT_DIR>.
  • Prints diagnostic output if self.debug or self.verbose is True.

Raises

FileNotFoundError If use_existing=True but the min–max file does not exist.

RuntimeError If an existing min–max file cannot be read or parsed.

Notes

The computed min–max dictionary is expected to contain scaling statistics per feature, typically in the form: { "node": {"min": float, "max": float}, ... }

def fit(self, train_data, val_data=None, **kwargs):
612    def fit(self, train_data, val_data=None, **kwargs):
613        """
614        Train TRAM models for all nodes in the DAG.
615
616        Coordinates dataset preparation, min–max scaling, and per-node training,
617        optionally in parallel on CPU.
618
619        Parameters
620        ----------
621        train_data : pandas.DataFrame or TramDagDataset
622            Training data. If a DataFrame is given, it is converted into a
623            TramDagDataset using `_ensure_dataset`.
624        val_data : pandas.DataFrame or TramDagDataset or None, optional
625            Validation data. If a DataFrame is given, it is converted into a
626            TramDagDataset. If None, no validation loss is computed.
627        **kwargs
628            Overrides for ``DEFAULTS_FIT``. All keys must exist in
629            ``DEFAULTS_FIT``. Common options:
630
631            epochs : int, default 100
632                Number of training epochs per node.
633            learning_rate : float, default 0.01
634                Learning rate for the default Adam optimizer.
635            train_list : list of str or None, optional
636                List of node names to train. If None, all nodes are trained.
637            train_mode : {"sequential", "parallel"}, default "sequential"
638                Training mode. "parallel" uses joblib-based CPU multiprocessing.
639                GPU forces sequential mode.
640            device : {"auto", "cpu", "cuda"}, default "auto"
641                Device selection.
642            optimizers : dict or None
643                Optional mapping ``{node_name: optimizer}``. If provided for a
644                node, that optimizer is used instead of creating a new Adam.
645            schedulers : dict or None
646                Optional mapping ``{node_name: scheduler}``.
647            use_scheduler : bool
648                If True, enable scheduler usage in the training loop.
649            num_workers : int
650                DataLoader workers in sequential mode (ignored in parallel).
651            persistent_workers : bool
652                DataLoader persistence in sequential mode (ignored in parallel).
653            prefetch_factor : int
654                DataLoader prefetch factor (ignored in parallel).
655            batch_size : int
656                Batch size for all node DataLoaders.
657            debug : bool
658                Enable debug output.
659            verbose : bool
660                Enable informational logging.
661            return_history : bool
662                If True, return a history dict.
663
664        Returns
665        -------
666        dict or None
667            If ``return_history=True``, a dictionary mapping each node name
668            to its training history. Otherwise, returns None.
669
670        Raises
671        ------
672        ValueError
673            If ``train_mode`` is not "sequential" or "parallel".
674        """
675        self._validate_kwargs(kwargs, defaults_attr='DEFAULTS_FIT', context="fit")
676        
677        # --- merge defaults ---
678        settings = dict(self.DEFAULTS_FIT)
679        settings.update(kwargs)
680        
681        
682        self.debug = settings.get("debug", False)
683        self.verbose = settings.get("verbose", False)
684
685        # --- resolve device ---
686        device_str=self.get_device(settings)
687        self.device = torch.device(device_str)
688
689        # --- training mode ---
690        train_mode = settings.get("train_mode", "sequential").lower()
691        if train_mode not in ("sequential", "parallel"):
692            raise ValueError("train_mode must be 'sequential' or 'parallel'")
693
694        # --- DataLoader safety logic ---
695        if train_mode == "parallel":
696            # if user passed loader paralleling params, warn and override
697            for flag in ("num_workers", "persistent_workers", "prefetch_factor"):
698                if flag in kwargs:
699                    print(f"[WARNING] '{flag}' is ignored in parallel mode "
700                        f"(disabled to prevent nested multiprocessing).")
701            # disable unsafe loader multiprocessing options
702            settings["num_workers"] = 0
703            settings["persistent_workers"] = False
704            settings["prefetch_factor"] = None
705        else:
706            # sequential mode → respect user DataLoader settings
707            if self.debug:
708                print("[DEBUG] Sequential mode: using DataLoader kwargs as provided.")
709
710        # --- which nodes to train ---
711        train_list = settings.get("train_list") or list(self.models.keys())
712
713
714        # --- dataset prep (receives adjusted settings) ---
715        td_train_data = self._ensure_dataset(train_data, is_val=False, **settings)
716        td_val_data = self._ensure_dataset(val_data, is_val=True, **settings)
717
718        # --- normalization ---
719        self.load_or_compute_minmax(use_existing=False, write=True, td_train_data=td_train_data)
720
721        # --- print header ---
722        if self.verbose or self.debug:
723            print(f"[INFO] Training {len(train_list)} nodes ({train_mode}) on {device_str}")
724
725        # ======================================================================
726        # Sequential mode  safe for GPU or debugging)
727        # ======================================================================
728        if train_mode == "sequential" or "cuda" in device_str:
729            if "cuda" in device_str and train_mode == "parallel":
730                print("[WARNING] GPU device detected — forcing sequential mode.")
731            results = {}
732            for node in train_list:
733                node, history = self._fit_single_node(
734                    node, self, settings, td_train_data, td_val_data, device_str
735                )
736                results[node] = history
737        
738
739        # ======================================================================
740        # parallel mode (CPU only)
741        # ======================================================================
742        if train_mode == "parallel":
743
744            n_jobs = min(len(train_list), os.cpu_count() // 2 or 1)
745            if self.verbose or self.debug:
746                print(f"[INFO] Using {n_jobs} CPU workers for parallel node training")
747            parallel_outputs = Parallel(
748                n_jobs=n_jobs,
749                backend="loky",#loky, multiprocessing
750                verbose=10,
751                prefer="processes"
752            )(delayed(self._fit_single_node)(node, self, settings, td_train_data, td_val_data, device_str) for node in train_list )
753
754            results = {node: hist for node, hist in parallel_outputs}
755        
756        if settings.get("return_history", False):
757            return results

Train TRAM models for all nodes in the DAG.

Coordinates dataset preparation, min–max scaling, and per-node training, optionally in parallel on CPU.

Parameters

train_data : pandas.DataFrame or TramDagDataset Training data. If a DataFrame is given, it is converted into a TramDagDataset using _ensure_dataset. val_data : pandas.DataFrame or TramDagDataset or None, optional Validation data. If a DataFrame is given, it is converted into a TramDagDataset. If None, no validation loss is computed. **kwargs Overrides for DEFAULTS_FIT. All keys must exist in DEFAULTS_FIT. Common options:

epochs : int, default 100
    Number of training epochs per node.
learning_rate : float, default 0.01
    Learning rate for the default Adam optimizer.
train_list : list of str or None, optional
    List of node names to train. If None, all nodes are trained.
train_mode : {"sequential", "parallel"}, default "sequential"
    Training mode. "parallel" uses joblib-based CPU multiprocessing.
    GPU forces sequential mode.
device : {"auto", "cpu", "cuda"}, default "auto"
    Device selection.
optimizers : dict or None
    Optional mapping ``{node_name: optimizer}``. If provided for a
    node, that optimizer is used instead of creating a new Adam.
schedulers : dict or None
    Optional mapping ``{node_name: scheduler}``.
use_scheduler : bool
    If True, enable scheduler usage in the training loop.
num_workers : int
    DataLoader workers in sequential mode (ignored in parallel).
persistent_workers : bool
    DataLoader persistence in sequential mode (ignored in parallel).
prefetch_factor : int
    DataLoader prefetch factor (ignored in parallel).
batch_size : int
    Batch size for all node DataLoaders.
debug : bool
    Enable debug output.
verbose : bool
    Enable informational logging.
return_history : bool
    If True, return a history dict.

Returns

dict or None If return_history=True, a dictionary mapping each node name to its training history. Otherwise, returns None.

Raises

ValueError If train_mode is not "sequential" or "parallel".

def loss_history(self):
760    def loss_history(self):
761        """
762        Load training and validation loss history for all nodes.
763
764        Looks for per-node JSON files:
765
766        - ``EXPERIMENT_DIR/{node}/train_loss_hist.json``
767        - ``EXPERIMENT_DIR/{node}/val_loss_hist.json``
768
769        Returns
770        -------
771        dict
772            A dictionary mapping node names to:
773
774            .. code-block:: python
775
776                {
777                    "train": list or None,
778                    "validation": list or None
779                }
780
781            where each list contains NLL values per epoch, or None if not found.
782
783        Raises
784        ------
785        ValueError
786            If the experiment directory cannot be resolved from the configuration.
787        """
788        try:
789            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
790        except KeyError:
791            raise ValueError(
792                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
793                "History retrieval requires experiment logs."
794            )
795
796        all_histories = {}
797        for node in self.nodes_dict.keys():
798            node_dir = os.path.join(EXPERIMENT_DIR, node)
799            train_path = os.path.join(node_dir, "train_loss_hist.json")
800            val_path = os.path.join(node_dir, "val_loss_hist.json")
801
802            node_hist = {}
803
804            # --- load train history ---
805            if os.path.exists(train_path):
806                try:
807                    with open(train_path, "r") as f:
808                        node_hist["train"] = json.load(f)
809                except Exception as e:
810                    print(f"[WARNING] Could not load {train_path}: {e}")
811                    node_hist["train"] = None
812            else:
813                node_hist["train"] = None
814
815            # --- load val history ---
816            if os.path.exists(val_path):
817                try:
818                    with open(val_path, "r") as f:
819                        node_hist["validation"] = json.load(f)
820                except Exception as e:
821                    print(f"[WARNING] Could not load {val_path}: {e}")
822                    node_hist["validation"] = None
823            else:
824                node_hist["validation"] = None
825
826            all_histories[node] = node_hist
827
828        if self.verbose or self.debug:
829            print(f"[INFO] Loaded training/validation histories for {len(all_histories)} nodes.")
830
831        return all_histories

Load training and validation loss history for all nodes.

Looks for per-node JSON files:

  • EXPERIMENT_DIR/{node}/train_loss_hist.json
  • EXPERIMENT_DIR/{node}/val_loss_hist.json

Returns

dict A dictionary mapping node names to:

```python

{ "train": list or None, "validation": list or None } ```

where each list contains NLL values per epoch, or None if not found.

Raises

ValueError If the experiment directory cannot be resolved from the configuration.

def linear_shift_history(self):
833    def linear_shift_history(self):
834        """
835        Load linear shift term histories for all nodes.
836
837        Each node history is expected in a JSON file named
838        ``linear_shifts_all_epochs.json`` under the node directory.
839
840        Returns
841        -------
842        dict
843            A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame
844            contains linear shift weights across epochs.
845
846        Raises
847        ------
848        ValueError
849            If the experiment directory cannot be resolved from the configuration.
850
851        Notes
852        -----
853        If a history file is missing for a node, a warning is printed and the
854        node is omitted from the returned dictionary.
855        """
856        histories = {}
857        try:
858            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
859        except KeyError:
860            raise ValueError(
861                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
862                "Cannot load histories without experiment directory."
863            )
864
865        for node in self.nodes_dict.keys():
866            node_dir = os.path.join(EXPERIMENT_DIR, node)
867            history_path = os.path.join(node_dir, "linear_shifts_all_epochs.json")
868            if os.path.exists(history_path):
869                histories[node] = pd.read_json(history_path)
870            else:
871                print(f"[WARNING] No linear shift history found for node '{node}' at {history_path}")
872        return histories

Load linear shift term histories for all nodes.

Each node history is expected in a JSON file named linear_shifts_all_epochs.json under the node directory.

Returns

dict A mapping {node_name: pandas.DataFrame}, where each DataFrame contains linear shift weights across epochs.

Raises

ValueError If the experiment directory cannot be resolved from the configuration.

Notes

If a history file is missing for a node, a warning is printed and the node is omitted from the returned dictionary.

def simple_intercept_history(self):
874    def simple_intercept_history(self):
875        """
876        Load simple intercept histories for all nodes.
877
878        Each node history is expected in a JSON file named
879        ``simple_intercepts_all_epochs.json`` under the node directory.
880
881        Returns
882        -------
883        dict
884            A mapping ``{node_name: pandas.DataFrame}``, where each DataFrame
885            contains intercept weights across epochs.
886
887        Raises
888        ------
889        ValueError
890            If the experiment directory cannot be resolved from the configuration.
891
892        Notes
893        -----
894        If a history file is missing for a node, a warning is printed and the
895        node is omitted from the returned dictionary.
896        """
897        histories = {}
898        try:
899            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
900        except KeyError:
901            raise ValueError(
902                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
903                "Cannot load histories without experiment directory."
904            )
905
906        for node in self.nodes_dict.keys():
907            node_dir = os.path.join(EXPERIMENT_DIR, node)
908            history_path = os.path.join(node_dir, "simple_intercepts_all_epochs.json")
909            if os.path.exists(history_path):
910                histories[node] = pd.read_json(history_path)
911            else:
912                print(f"[WARNING] No simple intercept history found for node '{node}' at {history_path}")
913        return histories

Load simple intercept histories for all nodes.

Each node history is expected in a JSON file named simple_intercepts_all_epochs.json under the node directory.

Returns

dict A mapping {node_name: pandas.DataFrame}, where each DataFrame contains intercept weights across epochs.

Raises

ValueError If the experiment directory cannot be resolved from the configuration.

Notes

If a history file is missing for a node, a warning is printed and the node is omitted from the returned dictionary.

def get_latent(self, df, verbose=False):
915    def get_latent(self, df, verbose=False):
916        """
917        Compute latent representations for all nodes in the DAG.
918
919        Parameters
920        ----------
921        df : pandas.DataFrame
922            Input data frame with columns corresponding to nodes in the DAG.
923        verbose : bool, optional
924            If True, print informational messages during latent computation.
925            Default is False.
926
927        Returns
928        -------
929        pandas.DataFrame
930            DataFrame containing the original columns plus latent variables
931            for each node (e.g. columns named ``f"{node}_U"``).
932
933        Raises
934        ------
935        ValueError
936            If the experiment directory is missing from the configuration or
937            if ``self.minmax_dict`` has not been set.
938        """
939        try:
940            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
941        except KeyError:
942            raise ValueError(
943                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
944                "Latent extraction requires trained model checkpoints."
945            )
946
947        # ensure minmax_dict is available
948        if not hasattr(self, "minmax_dict"):
949            raise ValueError(
950                "[ERROR] minmax_dict not found in the TramDagModel instance. "
951                "Either call .load_or_compute_minmax(td_train_data=train_df) or .fit() first."
952            )
953
954        all_latents_df = create_latent_df_for_full_dag(
955            configuration_dict=self.cfg.conf_dict,
956            EXPERIMENT_DIR=EXPERIMENT_DIR,
957            df=df,
958            verbose=verbose,
959            min_max_dict=self.minmax_dict,
960        )
961
962        return all_latents_df

Compute latent representations for all nodes in the DAG.

Parameters

df : pandas.DataFrame Input data frame with columns corresponding to nodes in the DAG. verbose : bool, optional If True, print informational messages during latent computation. Default is False.

Returns

pandas.DataFrame DataFrame containing the original columns plus latent variables for each node (e.g. columns named f"{node}_U").

Raises

ValueError If the experiment directory is missing from the configuration or if self.minmax_dict has not been set.

def plot_loss_history(self, variable: str = None):
 967    def plot_loss_history(self, variable: str = None):
 968        """
 969        Plot training and validation loss evolution per node.
 970
 971        Parameters
 972        ----------
 973        variable : str or None, optional
 974            If provided, plot loss history for this node only. If None, plot
 975            histories for all nodes that have both train and validation logs.
 976
 977        Returns
 978        -------
 979        None
 980
 981        Notes
 982        -----
 983        Two subplots are produced:
 984        - Full epoch history.
 985        - Last 10% of epochs (or only the last epoch if fewer than 5 epochs).
 986        """
 987
 988        histories = self.loss_history()
 989        if not histories:
 990            print("[WARNING] No loss histories found.")
 991            return
 992
 993        # Select which nodes to plot
 994        if variable is not None:
 995            if variable not in histories:
 996                raise ValueError(f"[ERROR] Node '{variable}' not found in histories.")
 997            nodes_to_plot = [variable]
 998        else:
 999            nodes_to_plot = list(histories.keys())
1000
1001        # Filter out nodes with no valid history
1002        nodes_to_plot = [
1003            n for n in nodes_to_plot
1004            if histories[n].get("train") is not None and len(histories[n]["train"]) > 0
1005            and histories[n].get("validation") is not None and len(histories[n]["validation"]) > 0
1006        ]
1007
1008        if not nodes_to_plot:
1009            print("[WARNING] No valid histories found to plot.")
1010            return
1011
1012        plt.figure(figsize=(14, 12))
1013
1014        # --- Full history (top plot) ---
1015        plt.subplot(2, 1, 1)
1016        for node in nodes_to_plot:
1017            node_hist = histories[node]
1018            train_hist, val_hist = node_hist["train"], node_hist["validation"]
1019
1020            epochs = range(1, len(train_hist) + 1)
1021            plt.plot(epochs, train_hist, label=f"{node} - train", linestyle="--")
1022            plt.plot(epochs, val_hist, label=f"{node} - val")
1023
1024        plt.title("Training and Validation NLL - Full History")
1025        plt.xlabel("Epoch")
1026        plt.ylabel("NLL")
1027        plt.legend()
1028        plt.grid(True)
1029
1030        # --- Last 10% of epochs (bottom plot) ---
1031        plt.subplot(2, 1, 2)
1032        for node in nodes_to_plot:
1033            node_hist = histories[node]
1034            train_hist, val_hist = node_hist["train"], node_hist["validation"]
1035
1036            total_epochs = len(train_hist)
1037            start_idx = total_epochs - 1 if total_epochs < 5 else int(total_epochs * 0.9)
1038
1039            epochs = range(start_idx + 1, total_epochs + 1)
1040            plt.plot(epochs, train_hist[start_idx:], label=f"{node} - train", linestyle="--")
1041            plt.plot(epochs, val_hist[start_idx:], label=f"{node} - val")
1042
1043        plt.title("Training and Validation NLL - Last 10% of Epochs (or Last Epoch if <5)")
1044        plt.xlabel("Epoch")
1045        plt.ylabel("NLL")
1046        plt.legend()
1047        plt.grid(True)
1048
1049        plt.tight_layout()
1050        plt.show()

Plot training and validation loss evolution per node.

Parameters

variable : str or None, optional If provided, plot loss history for this node only. If None, plot histories for all nodes that have both train and validation logs.

Returns

None

Notes

Two subplots are produced:

  • Full epoch history.
  • Last 10% of epochs (or only the last epoch if fewer than 5 epochs).
def plot_linear_shift_history(self, data_dict=None, node=None, ref_lines=None):
1052    def plot_linear_shift_history(self, data_dict=None, node=None, ref_lines=None):
1053        """
1054        Plot the evolution of linear shift terms over epochs.
1055
1056        Parameters
1057        ----------
1058        data_dict : dict or None, optional
1059            Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing shift
1060            weights across epochs. If None, `linear_shift_history()` is called.
1061        node : str or None, optional
1062            If provided, plot only this node. Otherwise, plot all nodes
1063            present in ``data_dict``.
1064        ref_lines : dict or None, optional
1065            Optional mapping ``{node_name: list of float}``. For each specified
1066            node, horizontal reference lines are drawn at the given values.
1067
1068        Returns
1069        -------
1070        None
1071
1072        Notes
1073        -----
1074        The function flattens nested list-like entries in the DataFrames to scalars,
1075        converts epoch labels to numeric, and then draws one line per shift term.
1076        """
1077
1078        if data_dict is None:
1079            data_dict = self.linear_shift_history()
1080            if data_dict is None:
1081                raise ValueError("No shift history data provided or stored in the class.")
1082
1083        nodes = [node] if node else list(data_dict.keys())
1084
1085        for n in nodes:
1086            df = data_dict[n].copy()
1087
1088            # Flatten nested lists or list-like cells
1089            def flatten(x):
1090                if isinstance(x, list):
1091                    if len(x) == 0:
1092                        return np.nan
1093                    if all(isinstance(i, (int, float)) for i in x):
1094                        return np.mean(x)  # average simple list
1095                    if all(isinstance(i, list) for i in x):
1096                        # nested list -> flatten inner and average
1097                        flat = [v for sub in x for v in (sub if isinstance(sub, list) else [sub])]
1098                        return np.mean(flat) if flat else np.nan
1099                    return x[0] if len(x) == 1 else np.nan
1100                return x
1101
1102            df = df.applymap(flatten)
1103
1104            # Ensure numeric columns
1105            df = df.apply(pd.to_numeric, errors='coerce')
1106
1107            # Convert epoch labels to numeric
1108            df.columns = [
1109                int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c
1110                for c in df.columns
1111            ]
1112            df = df.reindex(sorted(df.columns), axis=1)
1113
1114            plt.figure(figsize=(10, 6))
1115            for idx in df.index:
1116                plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"shift_{idx}")
1117
1118            if ref_lines and n in ref_lines:
1119                for v in ref_lines[n]:
1120                    plt.axhline(y=v, color="k", linestyle="--", lw=1.0)
1121                    plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8)
1122
1123            plt.xlabel("Epoch")
1124            plt.ylabel("Shift Value")
1125            plt.title(f"Shift Term History — Node: {n}")
1126            plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small")
1127            plt.tight_layout()
1128            plt.show()

Plot the evolution of linear shift terms over epochs.

Parameters

data_dict : dict or None, optional Pre-loaded mapping {node_name: pandas.DataFrame} containing shift weights across epochs. If None, linear_shift_history() is called. node : str or None, optional If provided, plot only this node. Otherwise, plot all nodes present in data_dict. ref_lines : dict or None, optional Optional mapping {node_name: list of float}. For each specified node, horizontal reference lines are drawn at the given values.

Returns

None

Notes

The function flattens nested list-like entries in the DataFrames to scalars, converts epoch labels to numeric, and then draws one line per shift term.

def plot_simple_intercepts_history(self, data_dict=None, node=None, ref_lines=None):
1130    def plot_simple_intercepts_history(self, data_dict=None, node=None,ref_lines=None):
1131        """
1132        Plot the evolution of simple intercept weights over epochs.
1133
1134        Parameters
1135        ----------
1136        data_dict : dict or None, optional
1137            Pre-loaded mapping ``{node_name: pandas.DataFrame}`` containing intercept
1138            weights across epochs. If None, `simple_intercept_history()` is called.
1139        node : str or None, optional
1140            If provided, plot only this node. Otherwise, plot all nodes present
1141            in ``data_dict``.
1142        ref_lines : dict or None, optional
1143            Optional mapping ``{node_name: list of float}``. For each specified
1144            node, horizontal reference lines are drawn at the given values.
1145
1146        Returns
1147        -------
1148        None
1149
1150        Notes
1151        -----
1152        Nested list-like entries in the DataFrames are reduced to scalars before
1153        plotting. One line is drawn per intercept parameter.
1154        """
1155        if data_dict is None:
1156            data_dict = self.simple_intercept_history()
1157            if data_dict is None:
1158                raise ValueError("No intercept history data provided or stored in the class.")
1159
1160        nodes = [node] if node else list(data_dict.keys())
1161
1162        for n in nodes:
1163            df = data_dict[n].copy()
1164
1165            def extract_scalar(x):
1166                if isinstance(x, list):
1167                    while isinstance(x, list) and len(x) > 0:
1168                        x = x[0]
1169                return float(x) if isinstance(x, (int, float, np.floating)) else np.nan
1170
1171            df = df.applymap(extract_scalar)
1172
1173            # Convert epoch labels → numeric
1174            df.columns = [
1175                int(c.replace("epoch_", "")) if isinstance(c, str) and c.startswith("epoch_") else c
1176                for c in df.columns
1177            ]
1178            df = df.reindex(sorted(df.columns), axis=1)
1179
1180            plt.figure(figsize=(10, 6))
1181            for idx in df.index:
1182                plt.plot(df.columns, df.loc[idx], lw=1.4, label=f"theta_{idx}")
1183            
1184            if ref_lines and n in ref_lines:
1185                for v in ref_lines[n]:
1186                    plt.axhline(y=v, color="k", linestyle="--", lw=1.0)
1187                    plt.text(df.columns[-1], v, f"{n}: {v}", va="bottom", ha="right", fontsize=8)
1188                
1189            plt.xlabel("Epoch")
1190            plt.ylabel("Intercept Weight")
1191            plt.title(f"Simple Intercept Evolution — Node: {n}")
1192            plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize="small")
1193            plt.tight_layout()
1194            plt.show()

Plot the evolution of simple intercept weights over epochs.

Parameters

data_dict : dict or None, optional Pre-loaded mapping {node_name: pandas.DataFrame} containing intercept weights across epochs. If None, simple_intercept_history() is called. node : str or None, optional If provided, plot only this node. Otherwise, plot all nodes present in data_dict. ref_lines : dict or None, optional Optional mapping {node_name: list of float}. For each specified node, horizontal reference lines are drawn at the given values.

Returns

None

Notes

Nested list-like entries in the DataFrames are reduced to scalars before plotting. One line is drawn per intercept parameter.

def plot_latents( self, df, variable: str = None, confidence: float = 0.95, simulations: int = 1000):
1196    def plot_latents(self, df, variable: str = None, confidence: float = 0.95, simulations: int = 1000):
1197        """
1198        Visualize latent U distributions for one or all nodes.
1199
1200        Parameters
1201        ----------
1202        df : pandas.DataFrame
1203            Input data frame with raw node values.
1204        variable : str or None, optional
1205            If provided, only this node's latents are plotted. If None, all
1206            nodes with latent columns are processed.
1207        confidence : float, optional
1208            Confidence level for QQ-plot bands (0 < confidence < 1).
1209            Default is 0.95.
1210        simulations : int, optional
1211            Number of Monte Carlo simulations for QQ-plot bands. Default is 1000.
1212
1213        Returns
1214        -------
1215        None
1216
1217        Notes
1218        -----
1219        For each node, two plots are produced:
1220        - Histogram of the latent U values.
1221        - QQ-plot with simulation-based confidence bands under a logistic reference.
1222        """
1223        # Compute latent representations
1224        latents_df = self.get_latent(df)
1225
1226        # Select nodes
1227        nodes = [variable] if variable is not None else self.nodes_dict.keys()
1228
1229        for node in nodes:
1230            if f"{node}_U" not in latents_df.columns:
1231                print(f"[WARNING] No latent found for node {node}, skipping.")
1232                continue
1233
1234            sample = latents_df[f"{node}_U"].values
1235
1236            # --- Create plots ---
1237            fig, axs = plt.subplots(1, 2, figsize=(12, 5))
1238
1239            # Histogram
1240            axs[0].hist(sample, bins=50, color="steelblue", alpha=0.7)
1241            axs[0].set_title(f"Latent Histogram ({node})")
1242            axs[0].set_xlabel("U")
1243            axs[0].set_ylabel("Frequency")
1244
1245            # QQ Plot with confidence bands
1246            probplot(sample, dist="logistic", plot=axs[1])
1247            self._add_r_style_confidence_bands(axs[1], sample, dist=logistic,confidence=confidence, simulations=simulations)
1248            axs[1].set_title(f"Latent QQ Plot ({node})")
1249
1250            plt.suptitle(f"Latent Diagnostics for Node: {node}", fontsize=14)
1251            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
1252            plt.show()

Visualize latent U distributions for one or all nodes.

Parameters

df : pandas.DataFrame Input data frame with raw node values. variable : str or None, optional If provided, only this node's latents are plotted. If None, all nodes with latent columns are processed. confidence : float, optional Confidence level for QQ-plot bands (0 < confidence < 1). Default is 0.95. simulations : int, optional Number of Monte Carlo simulations for QQ-plot bands. Default is 1000.

Returns

None

Notes

For each node, two plots are produced:

  • Histogram of the latent U values.
  • QQ-plot with simulation-based confidence bands under a logistic reference.
def plot_hdag(self, df, variables=None, plot_n_rows=1, **kwargs):
1254    def plot_hdag(self,df,variables=None, plot_n_rows=1,**kwargs):
1255        
1256        """
1257        Visualize the transformation function h() for selected DAG nodes.
1258
1259        Parameters
1260        ----------
1261        df : pandas.DataFrame
1262            Input data containing node values or model predictions.
1263        variables : list of str or None, optional
1264            Names of nodes to visualize. If None, all nodes in ``self.models``
1265            are considered.
1266        plot_n_rows : int, optional
1267            Maximum number of rows from ``df`` to visualize. Default is 1.
1268        **kwargs
1269            Additional keyword arguments forwarded to the underlying plotting
1270            helpers (`show_hdag_continous` / `show_hdag_ordinal`).
1271
1272        Returns
1273        -------
1274        None
1275
1276        Notes
1277        -----
1278        - For continuous outcomes, `show_hdag_continous` is called.
1279        - For ordinal outcomes, `show_hdag_ordinal` is called.
1280        - Nodes that are neither continuous nor ordinal are skipped with a warning.
1281        """
1282                
1283
1284        if len(df)> 1:
1285            print("[WARNING] len(df)>1, set: plot_n_rows accordingly")
1286        
1287        variables_list=variables if variables is not None else list(self.models.keys())
1288        for node in variables_list:
1289            if is_outcome_modelled_continous(node, self.nodes_dict):
1290                show_hdag_continous(df,node=node,configuration_dict=self.cfg.conf_dict,minmax_dict=self.minmax_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1291            
1292            elif is_outcome_modelled_ordinal(node, self.nodes_dict):
1293                show_hdag_ordinal(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1294                # plot_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,plot_n_rows=plot_n_rows,**kwargs)
1295                # save_cutpoints_with_logistic(df,node=node,configuration_dict=self.cfg.conf_dict,device=self.device,**kwargs)
1296            else:
1297                print(f"[WARNING] Node {node} is wheter ordinal nor continous, not implemented yet")

Visualize the transformation function h() for selected DAG nodes.

Parameters

df : pandas.DataFrame Input data containing node values or model predictions. variables : list of str or None, optional Names of nodes to visualize. If None, all nodes in self.models are considered. plot_n_rows : int, optional Maximum number of rows from df to visualize. Default is 1. **kwargs Additional keyword arguments forwarded to the underlying plotting helpers (show_hdag_continous / show_hdag_ordinal).

Returns

None

Notes

  • For continuous outcomes, show_hdag_continous is called.
  • For ordinal outcomes, show_hdag_ordinal is called.
  • Nodes that are neither continuous nor ordinal are skipped with a warning.
def sample( self, do_interventions: dict = None, predefined_latent_samples_df: pandas.core.frame.DataFrame = None, **kwargs):
1356    def sample(
1357        self,
1358        do_interventions: dict = None,
1359        predefined_latent_samples_df: pd.DataFrame = None,
1360        **kwargs,
1361    ):
1362        """
1363        Sample from the joint DAG using the trained TRAM models.
1364
1365        Allows for:
1366        
1367        Oberservational sampling
1368        Interventional sampling via ``do()`` operations
1369        Counterfactial sampling using predefined latent draws and do()
1370        
1371        Parameters
1372        ----------
1373        do_interventions : dict or None, optional
1374            Mapping of node names to intervened (fixed) values. For example:
1375            ``{"x1": 1.0}`` represents ``do(x1 = 1.0)``. Default is None.
1376        predefined_latent_samples_df : pandas.DataFrame or None, optional
1377            DataFrame containing columns ``"{node}_U"`` with predefined latent
1378            draws to be used instead of sampling from the prior. Default is None.
1379        **kwargs
1380            Sampling options overriding internal defaults:
1381
1382            number_of_samples : int, default 10000
1383                Total number of samples to draw.
1384            batch_size : int, default 32
1385                Batch size for internal sampling loops.
1386            delete_all_previously_sampled : bool, default True
1387                If True, delete old sampling files in node-specific sampling
1388                directories before writing new ones.
1389            verbose : bool
1390                If True, print informational messages.
1391            debug : bool
1392                If True, print debug output.
1393            device : {"auto", "cpu", "cuda"}
1394                Device selection for sampling.
1395            use_initial_weights_for_sampling : bool, default False
1396                If True, sample from initial (untrained) model parameters.
1397
1398        Returns
1399        -------
1400        tuple
1401            A tuple ``(sampled_by_node, latents_by_node)``:
1402
1403            sampled_by_node : dict
1404                Mapping ``{node_name: torch.Tensor}`` of sampled node values.
1405            latents_by_node : dict
1406                Mapping ``{node_name: torch.Tensor}`` of latent U values used.
1407
1408        Raises
1409        ------
1410        ValueError
1411            If the experiment directory cannot be resolved or if scaling
1412            information (``self.minmax_dict``) is missing.
1413        RuntimeError
1414            If min–max scaling has not been computed before calling `sample`.
1415        """
1416        try:
1417            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1418        except KeyError:
1419            raise ValueError(
1420                "[ERROR] Missing 'EXPERIMENT_DIR' in cfg.conf_dict['PATHS']. "
1421                "Sampling requires trained model checkpoints."
1422            )
1423
1424        # ---- defaults ----
1425        settings = {
1426            "number_of_samples": 10_000,
1427            "batch_size": 32,
1428            "delete_all_previously_sampled": True,
1429            "verbose": self.verbose if hasattr(self, "verbose") else False,
1430            "debug": self.debug if hasattr(self, "debug") else False,
1431            "device": self.device.type if hasattr(self, "device") else "auto",
1432            "use_initial_weights_for_sampling": False,
1433            
1434        }
1435        
1436        # self._validate_kwargs( kwargs, defaults_attr= "settings", context="sample")
1437        
1438        settings.update(kwargs)
1439
1440        
1441        if not hasattr(self, "minmax_dict"):
1442            raise RuntimeError(
1443                "[ERROR] minmax_dict not found. You must call .fit() or .load_or_compute_minmax() "
1444                "before sampling, so scaling info is available."
1445                )
1446            
1447        # ---- resolve device ----
1448        device_str=self.get_device(settings)
1449        self.device = torch.device(device_str)
1450
1451
1452        if self.debug or settings["debug"]:
1453            print(f"[DEBUG] sample(): device: {self.device}")
1454
1455        # ---- perform sampling ----
1456        sampled_by_node, latents_by_node = sample_full_dag(
1457            configuration_dict=self.cfg.conf_dict,
1458            EXPERIMENT_DIR=EXPERIMENT_DIR,
1459            device=self.device,
1460            do_interventions=do_interventions or {},
1461            predefined_latent_samples_df=predefined_latent_samples_df,
1462            number_of_samples=settings["number_of_samples"],
1463            batch_size=settings["batch_size"],
1464            delete_all_previously_sampled=settings["delete_all_previously_sampled"],
1465            verbose=settings["verbose"],
1466            debug=settings["debug"],
1467            minmax_dict=self.minmax_dict,
1468            use_initial_weights_for_sampling=settings["use_initial_weights_for_sampling"]
1469        )
1470
1471        return sampled_by_node, latents_by_node

Sample from the joint DAG using the trained TRAM models.

Allows for:

Oberservational sampling Interventional sampling via do() operations Counterfactial sampling using predefined latent draws and do()

Parameters

do_interventions : dict or None, optional Mapping of node names to intervened (fixed) values. For example: {"x1": 1.0} represents do(x1 = 1.0). Default is None. predefined_latent_samples_df : pandas.DataFrame or None, optional DataFrame containing columns "{node}_U" with predefined latent draws to be used instead of sampling from the prior. Default is None. **kwargs Sampling options overriding internal defaults:

number_of_samples : int, default 10000
    Total number of samples to draw.
batch_size : int, default 32
    Batch size for internal sampling loops.
delete_all_previously_sampled : bool, default True
    If True, delete old sampling files in node-specific sampling
    directories before writing new ones.
verbose : bool
    If True, print informational messages.
debug : bool
    If True, print debug output.
device : {"auto", "cpu", "cuda"}
    Device selection for sampling.
use_initial_weights_for_sampling : bool, default False
    If True, sample from initial (untrained) model parameters.

Returns

tuple A tuple (sampled_by_node, latents_by_node):

sampled_by_node : dict
    Mapping ``{node_name: torch.Tensor}`` of sampled node values.
latents_by_node : dict
    Mapping ``{node_name: torch.Tensor}`` of latent U values used.

Raises

ValueError If the experiment directory cannot be resolved or if scaling information (self.minmax_dict) is missing. RuntimeError If min–max scaling has not been computed before calling sample.

def load_sampled_and_latents(self, EXPERIMENT_DIR: str = None, nodes: list = None):
1473    def load_sampled_and_latents(self, EXPERIMENT_DIR: str = None, nodes: list = None):
1474        """
1475        Load previously stored sampled values and latents for each node.
1476
1477        Parameters
1478        ----------
1479        EXPERIMENT_DIR : str or None, optional
1480            Experiment directory path. If None, it is taken from
1481            ``self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]``.
1482        nodes : list of str or None, optional
1483            Nodes for which to load samples. If None, use all nodes from
1484            ``self.nodes_dict``.
1485
1486        Returns
1487        -------
1488        tuple
1489            A tuple ``(sampled_by_node, latents_by_node)``:
1490
1491            sampled_by_node : dict
1492                Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU).
1493            latents_by_node : dict
1494                Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU).
1495
1496        Raises
1497        ------
1498        ValueError
1499            If the experiment directory cannot be resolved or if no node list
1500            is available and ``nodes`` is None.
1501
1502        Notes
1503        -----
1504        Nodes without both ``sampled.pt`` and ``latents.pt`` files are skipped
1505        with a warning.
1506        """
1507        # --- resolve paths and node list ---
1508        if EXPERIMENT_DIR is None:
1509            try:
1510                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1511            except (AttributeError, KeyError):
1512                raise ValueError(
1513                    "[ERROR] Could not resolve EXPERIMENT_DIR from cfg.conf_dict['PATHS']. "
1514                    "Please provide EXPERIMENT_DIR explicitly."
1515                )
1516
1517        if nodes is None:
1518            if hasattr(self, "nodes_dict"):
1519                nodes = list(self.nodes_dict.keys())
1520            else:
1521                raise ValueError(
1522                    "[ERROR] No node list found. Please provide `nodes` or initialize model with a config."
1523                )
1524
1525        # --- load tensors ---
1526        sampled_by_node = {}
1527        latents_by_node = {}
1528
1529        for node in nodes:
1530            node_dir = os.path.join(EXPERIMENT_DIR, f"{node}")
1531            sampling_dir = os.path.join(node_dir, "sampling")
1532
1533            sampled_path = os.path.join(sampling_dir, "sampled.pt")
1534            latents_path = os.path.join(sampling_dir, "latents.pt")
1535
1536            if not os.path.exists(sampled_path) or not os.path.exists(latents_path):
1537                print(f"[WARNING] Missing files for node '{node}' — skipping.")
1538                continue
1539
1540            try:
1541                sampled = torch.load(sampled_path, map_location="cpu")
1542                latent_sample = torch.load(latents_path, map_location="cpu")
1543            except Exception as e:
1544                print(f"[ERROR] Could not load sampling files for node '{node}': {e}")
1545                continue
1546
1547            sampled_by_node[node] = sampled.detach().cpu()
1548            latents_by_node[node] = latent_sample.detach().cpu()
1549
1550        if self.verbose or self.debug:
1551            print(f"[INFO] Loaded sampled and latent tensors for {len(sampled_by_node)} nodes from {EXPERIMENT_DIR}")
1552
1553        return sampled_by_node, latents_by_node

Load previously stored sampled values and latents for each node.

Parameters

EXPERIMENT_DIR : str or None, optional Experiment directory path. If None, it is taken from self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]. nodes : list of str or None, optional Nodes for which to load samples. If None, use all nodes from self.nodes_dict.

Returns

tuple A tuple (sampled_by_node, latents_by_node):

sampled_by_node : dict
    Mapping ``{node_name: torch.Tensor}`` of sampled values (on CPU).
latents_by_node : dict
    Mapping ``{node_name: torch.Tensor}`` of latent values (on CPU).

Raises

ValueError If the experiment directory cannot be resolved or if no node list is available and nodes is None.

Notes

Nodes without both sampled.pt and latents.pt files are skipped with a warning.

def plot_samples_vs_true( self, df, sampled: dict = None, variable: list = None, bins: int = 100, hist_true_color: str = 'blue', hist_est_color: str = 'orange', figsize: tuple = (14, 5)):
1555    def plot_samples_vs_true(
1556        self,
1557        df,
1558        sampled: dict = None,
1559        variable: list = None,
1560        bins: int = 100,
1561        hist_true_color: str = "blue",
1562        hist_est_color: str = "orange",
1563        figsize: tuple = (14, 5),
1564    ):
1565        
1566        
1567        """
1568        Compare sampled vs. observed distributions for selected nodes.
1569
1570        Parameters
1571        ----------
1572        df : pandas.DataFrame
1573            Data frame containing the observed node values.
1574        sampled : dict or None, optional
1575            Optional mapping ``{node_name: array-like or torch.Tensor}`` of sampled
1576            values. If None or if a node is missing, samples are loaded from
1577            ``EXPERIMENT_DIR/{node}/sampling/sampled.pt``.
1578        variable : list of str or None, optional
1579            Subset of nodes to plot. If None, all nodes in the configuration
1580            are considered.
1581        bins : int, optional
1582            Number of histogram bins for continuous variables. Default is 100.
1583        hist_true_color : str, optional
1584            Color name for the histogram of true values. Default is "blue".
1585        hist_est_color : str, optional
1586            Color name for the histogram of sampled values. Default is "orange".
1587        figsize : tuple, optional
1588            Figure size for the matplotlib plots. Default is (14, 5).
1589
1590        Returns
1591        -------
1592        None
1593
1594        Notes
1595        -----
1596        - Continuous outcomes: histogram overlay + QQ-plot.
1597        - Ordinal outcomes: side-by-side bar plot of relative frequencies.
1598        - Other categorical outcomes: side-by-side bar plot with category labels.
1599        - If samples are probabilistic (2D tensor), the argmax across classes is used.
1600        """
1601        
1602        target_nodes = self.cfg.conf_dict["nodes"]
1603        experiment_dir = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1604        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1605
1606        plot_list = variable if variable is not None else target_nodes
1607
1608        for node in plot_list:
1609            # Load sampled data
1610            if sampled is not None and node in sampled:
1611                sdata = sampled[node]
1612                if isinstance(sdata, torch.Tensor):
1613                    sampled_vals = sdata.detach().cpu().numpy()
1614                else:
1615                    sampled_vals = np.asarray(sdata)
1616            else:
1617                sample_path = os.path.join(experiment_dir, f"{node}/sampling/sampled.pt")
1618                if not os.path.isfile(sample_path):
1619                    print(f"[WARNING] skip {node}: {sample_path} not found.")
1620                    continue
1621
1622                try:
1623                    sampled_vals = torch.load(sample_path, map_location=device).cpu().numpy()
1624                except Exception as e:
1625                    print(f"[ERROR] Could not load {sample_path}: {e}")
1626                    continue
1627
1628            # If logits/probabilities per sample, take argmax
1629            if sampled_vals.ndim == 2:
1630                    print(f"[INFO] CAUTION! {node}: samples are probabilistic — each sample follows a probability "
1631                    f"distribution based on the valid latent range. "
1632                    f"Note that this frequency plot reflects only the distribution of the most probable "
1633                    f"class per sample.")
1634                    sampled_vals = np.argmax(sampled_vals, axis=1)
1635
1636            sampled_vals = sampled_vals[np.isfinite(sampled_vals)]
1637
1638            if node not in df.columns:
1639                print(f"[WARNING] skip {node}: column not found in DataFrame.")
1640                continue
1641
1642            true_vals = df[node].dropna().values
1643            true_vals = true_vals[np.isfinite(true_vals)]
1644
1645            if sampled_vals.size == 0 or true_vals.size == 0:
1646                print(f"[WARNING] skip {node}: empty array after NaN/Inf removal.")
1647                continue
1648
1649            fig, axs = plt.subplots(1, 2, figsize=figsize)
1650
1651            if is_outcome_modelled_continous(node, target_nodes):
1652                axs[0].hist(true_vals, bins=bins, density=True, alpha=0.6,
1653                            color=hist_true_color, label=f"True {node}")
1654                axs[0].hist(sampled_vals, bins=bins, density=True, alpha=0.6,
1655                            color=hist_est_color, label="Sampled")
1656                axs[0].set_xlabel("Value")
1657                axs[0].set_ylabel("Density")
1658                axs[0].set_title(f"Histogram overlay for {node}")
1659                axs[0].legend()
1660                axs[0].grid(True, ls="--", alpha=0.4)
1661
1662                qqplot_2samples(true_vals, sampled_vals, line="45", ax=axs[1])
1663                axs[1].set_xlabel("True quantiles")
1664                axs[1].set_ylabel("Sampled quantiles")
1665                axs[1].set_title(f"QQ plot for {node}")
1666                axs[1].grid(True, ls="--", alpha=0.4)
1667
1668            elif is_outcome_modelled_ordinal(node, target_nodes):
1669                unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals))
1670                unique_vals = np.sort(unique_vals)
1671                true_counts = np.array([(true_vals == val).sum() for val in unique_vals])
1672                sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals])
1673
1674                axs[0].bar(unique_vals - 0.2, true_counts / true_counts.sum(),
1675                        width=0.4, color=hist_true_color, alpha=0.7, label="True")
1676                axs[0].bar(unique_vals + 0.2, sampled_counts / sampled_counts.sum(),
1677                        width=0.4, color=hist_est_color, alpha=0.7, label="Sampled")
1678                axs[0].set_xticks(unique_vals)
1679                axs[0].set_xlabel("Ordinal Level")
1680                axs[0].set_ylabel("Relative Frequency")
1681                axs[0].set_title(f"Ordinal bar plot for {node}")
1682                axs[0].legend()
1683                axs[0].grid(True, ls="--", alpha=0.4)
1684                axs[1].axis("off")
1685
1686            else:
1687                unique_vals = np.union1d(np.unique(true_vals), np.unique(sampled_vals))
1688                unique_vals = sorted(unique_vals, key=str)
1689                true_counts = np.array([(true_vals == val).sum() for val in unique_vals])
1690                sampled_counts = np.array([(sampled_vals == val).sum() for val in unique_vals])
1691
1692                axs[0].bar(np.arange(len(unique_vals)) - 0.2, true_counts / true_counts.sum(),
1693                        width=0.4, color=hist_true_color, alpha=0.7, label="True")
1694                axs[0].bar(np.arange(len(unique_vals)) + 0.2, sampled_counts / sampled_counts.sum(),
1695                        width=0.4, color=hist_est_color, alpha=0.7, label="Sampled")
1696                axs[0].set_xticks(np.arange(len(unique_vals)))
1697                axs[0].set_xticklabels(unique_vals, rotation=45)
1698                axs[0].set_xlabel("Category")
1699                axs[0].set_ylabel("Relative Frequency")
1700                axs[0].set_title(f"Categorical bar plot for {node}")
1701                axs[0].legend()
1702                axs[0].grid(True, ls="--", alpha=0.4)
1703                axs[1].axis("off")
1704
1705            plt.tight_layout()
1706            plt.show()

Compare sampled vs. observed distributions for selected nodes.

Parameters

df : pandas.DataFrame Data frame containing the observed node values. sampled : dict or None, optional Optional mapping {node_name: array-like or torch.Tensor} of sampled values. If None or if a node is missing, samples are loaded from EXPERIMENT_DIR/{node}/sampling/sampled.pt. variable : list of str or None, optional Subset of nodes to plot. If None, all nodes in the configuration are considered. bins : int, optional Number of histogram bins for continuous variables. Default is 100. hist_true_color : str, optional Color name for the histogram of true values. Default is "blue". hist_est_color : str, optional Color name for the histogram of sampled values. Default is "orange". figsize : tuple, optional Figure size for the matplotlib plots. Default is (14, 5).

Returns

None

Notes

  • Continuous outcomes: histogram overlay + QQ-plot.
  • Ordinal outcomes: side-by-side bar plot of relative frequencies.
  • Other categorical outcomes: side-by-side bar plot with category labels.
  • If samples are probabilistic (2D tensor), the argmax across classes is used.
def nll(self, data, variables=None):
1709    def nll(self,data,variables=None):
1710        """
1711        Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes.
1712
1713        This function evaluates trained TRAM models for each specified variable (node) 
1714        on the provided dataset. It performs forward passes only—no training, no weight 
1715        updates—and returns the mean NLL per node.
1716
1717        Parameters
1718        ----------
1719        data : object
1720            Input dataset or data source compatible with `_ensure_dataset`, containing 
1721            both inputs and targets for each node.
1722        variables : list[str], optional
1723            List of variable (node) names to evaluate. If None, all nodes in 
1724            `self.models` are evaluated.
1725
1726        Returns
1727        -------
1728        dict[str, float]
1729            Dictionary mapping each node name to its average NLL value.
1730
1731        Notes
1732        -----
1733        - Each model is evaluated independently on its respective DataLoader.
1734        - The normalization values (`min_max`) for each node are retrieved from 
1735          `self.minmax_dict[node]`.
1736        - The function uses `evaluate_tramdag_model()` for per-node evaluation.
1737        - Expected directory structure:
1738              `<EXPERIMENT_DIR>/<node>/`
1739          where each node directory contains the trained model.
1740        """
1741
1742        td_data = self._ensure_dataset(data, is_val=True)  
1743        variables_list = variables if variables != None else list(self.models.keys())
1744        nll_dict = {}
1745        for node in variables_list:  
1746                min_vals = torch.tensor(self.minmax_dict[node][0], dtype=torch.float32)
1747                max_vals = torch.tensor(self.minmax_dict[node][1], dtype=torch.float32)
1748                min_max = torch.stack([min_vals, max_vals], dim=0)
1749                data_loader = td_data.loaders[node]
1750                model = self.models[node]
1751                EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1752                NODE_DIR = os.path.join(EXPERIMENT_DIR, f"{node}")
1753                nll= evaluate_tramdag_model(node=node,
1754                                            target_nodes=self.nodes_dict,
1755                                            NODE_DIR=NODE_DIR,
1756                                            tram_model=model,
1757                                            data_loader=data_loader,
1758                                            min_max=min_max)
1759                nll_dict[node]=nll
1760        return nll_dict

Compute the Negative Log-Likelihood (NLL) for all or selected TRAM nodes.

This function evaluates trained TRAM models for each specified variable (node) on the provided dataset. It performs forward passes only—no training, no weight updates—and returns the mean NLL per node.

Parameters

data : object Input dataset or data source compatible with _ensure_dataset, containing both inputs and targets for each node. variables : list[str], optional List of variable (node) names to evaluate. If None, all nodes in self.models are evaluated.

Returns

dict[str, float] Dictionary mapping each node name to its average NLL value.

Notes

  • Each model is evaluated independently on its respective DataLoader.
  • The normalization values (min_max) for each node are retrieved from self.minmax_dict[node].
  • The function uses evaluate_tramdag_model() for per-node evaluation.
  • Expected directory structure: <EXPERIMENT_DIR>/<node>/ where each node directory contains the trained model.
def get_train_val_nll(self, node: str, mode: str) -> tuple[float, float]:
1762    def get_train_val_nll(self, node: str, mode: str) -> tuple[float, float]:
1763        """
1764        Retrieve training and validation NLL for a node and a given model state.
1765
1766        Parameters
1767        ----------
1768        node : str
1769            Node name.
1770        mode : {"best", "last", "init"}
1771            State of interest:
1772            - "best": epoch with lowest validation NLL.
1773            - "last": final epoch.
1774            - "init": first epoch (index 0).
1775
1776        Returns
1777        -------
1778        tuple of (float or None, float or None)
1779            A tuple ``(train_nll, val_nll)`` for the requested mode.
1780            Returns ``(None, None)`` if loss files are missing or cannot be read.
1781
1782        Notes
1783        -----
1784        This method expects per-node JSON files:
1785
1786        - ``train_loss_hist.json``
1787        - ``val_loss_hist.json``
1788
1789        in the node directory.
1790        """
1791        NODE_DIR = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node)
1792        train_path = os.path.join(NODE_DIR, "train_loss_hist.json")
1793        val_path = os.path.join(NODE_DIR, "val_loss_hist.json")
1794
1795        if not os.path.exists(train_path) or not os.path.exists(val_path):
1796            if getattr(self, "debug", False):
1797                print(f"[DEBUG] Missing loss files for node '{node}'. Returning None.")
1798            return None, None
1799
1800        try:
1801            with open(train_path, "r") as f:
1802                train_hist = json.load(f)
1803            with open(val_path, "r") as f:
1804                val_hist = json.load(f)
1805
1806            train_nlls = np.array(train_hist)
1807            val_nlls = np.array(val_hist)
1808
1809            if mode == "init":
1810                idx = 0
1811            elif mode == "last":
1812                idx = len(val_nlls) - 1
1813            elif mode == "best":
1814                idx = int(np.argmin(val_nlls))
1815            else:
1816                raise ValueError(f"Invalid mode '{mode}' — must be one of 'best', 'last', 'init'.")
1817
1818            train_nll = float(train_nlls[idx])
1819            val_nll = float(val_nlls[idx])
1820            return train_nll, val_nll
1821
1822        except Exception as e:
1823            print(f"[ERROR] Failed to load NLLs for node '{node}' ({mode}): {e}")
1824            return None, None

Retrieve training and validation NLL for a node and a given model state.

Parameters

node : str Node name. mode : {"best", "last", "init"} State of interest: - "best": epoch with lowest validation NLL. - "last": final epoch. - "init": first epoch (index 0).

Returns

tuple of (float or None, float or None) A tuple (train_nll, val_nll) for the requested mode. Returns (None, None) if loss files are missing or cannot be read.

Notes

This method expects per-node JSON files:

  • train_loss_hist.json
  • val_loss_hist.json

in the node directory.

def get_thetas(self, node: str, state: str = 'best'):
1826    def get_thetas(self, node: str, state: str = "best"):
1827        """
1828        Return transformed intercept (theta) parameters for a node and state.
1829
1830        Parameters
1831        ----------
1832        node : str
1833            Node name.
1834        state : {"best", "last", "init"}, optional
1835            Model state for which to return parameters. Default is "best".
1836
1837        Returns
1838        -------
1839        Any or None
1840            Transformed theta parameters for the requested node and state.
1841            The exact structure (scalar, list, or other) depends on the model.
1842
1843        Raises
1844        ------
1845        ValueError
1846            If an invalid state is given (not in {"best", "last", "init"}).
1847
1848        Notes
1849        -----
1850        Intercept dictionaries are cached on the instance under the attribute
1851        ``intercept_dicts``. If missing or incomplete, they are recomputed using
1852        `get_simple_intercepts_dict`.
1853        """
1854
1855        state = state.lower()
1856        if state not in ["best", "last", "init"]:
1857            raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].")
1858
1859        dict_attr = "intercept_dicts"
1860
1861        # If no cached intercepts exist, compute them
1862        if not hasattr(self, dict_attr):
1863            if getattr(self, "debug", False):
1864                print(f"[DEBUG] '{dict_attr}' not found, computing via get_simple_intercepts_dict().")
1865            setattr(self, dict_attr, self.get_simple_intercepts_dict())
1866
1867        all_dicts = getattr(self, dict_attr)
1868
1869        # If the requested state isn’t cached, recompute
1870        if state not in all_dicts:
1871            if getattr(self, "debug", False):
1872                print(f"[DEBUG] State '{state}' not found in cached intercepts, recomputing full dict.")
1873            setattr(self, dict_attr, self.get_simple_intercepts_dict())
1874            all_dicts = getattr(self, dict_attr)
1875
1876        state_dict = all_dicts.get(state, {})
1877
1878        # Return cached node intercept if present
1879        if node in state_dict:
1880            return state_dict[node]
1881
1882        # If not found, recompute full dict as fallback
1883        if getattr(self, "debug", False):
1884            print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.")
1885        setattr(self, dict_attr, self.get_simple_intercepts_dict())
1886        all_dicts = getattr(self, dict_attr)
1887        return all_dicts.get(state, {}).get(node, None)

Return transformed intercept (theta) parameters for a node and state.

Parameters

node : str Node name. state : {"best", "last", "init"}, optional Model state for which to return parameters. Default is "best".

Returns

Any or None Transformed theta parameters for the requested node and state. The exact structure (scalar, list, or other) depends on the model.

Raises

ValueError If an invalid state is given (not in {"best", "last", "init"}).

Notes

Intercept dictionaries are cached on the instance under the attribute intercept_dicts. If missing or incomplete, they are recomputed using get_simple_intercepts_dict.

def get_linear_shifts(self, node: str, state: str = 'best'):
1889    def get_linear_shifts(self, node: str, state: str = "best"):
1890        """
1891        Return learned linear shift terms for a node and a given state.
1892
1893        Parameters
1894        ----------
1895        node : str
1896            Node name.
1897        state : {"best", "last", "init"}, optional
1898            Model state for which to return linear shift terms. Default is "best".
1899
1900        Returns
1901        -------
1902        dict or Any or None
1903            Linear shift terms for the given node and state. Usually a dict
1904            mapping term names to weights.
1905
1906        Raises
1907        ------
1908        ValueError
1909            If an invalid state is given (not in {"best", "last", "init"}).
1910
1911        Notes
1912        -----
1913        Linear shift dictionaries are cached on the instance under the attribute
1914        ``linear_shift_dicts``. If missing or incomplete, they are recomputed using
1915        `get_linear_shifts_dict`.
1916        """
1917        state = state.lower()
1918        if state not in ["best", "last", "init"]:
1919            raise ValueError(f"[ERROR] Invalid state '{state}'. Must be one of ['best', 'last', 'init'].")
1920
1921        dict_attr = "linear_shift_dicts"
1922
1923        # If no global dicts cached, compute once
1924        if not hasattr(self, dict_attr):
1925            if getattr(self, "debug", False):
1926                print(f"[DEBUG] '{dict_attr}' not found, computing via get_linear_shifts_dict().")
1927            setattr(self, dict_attr, self.get_linear_shifts_dict())
1928
1929        all_dicts = getattr(self, dict_attr)
1930
1931        # If the requested state isn't cached, compute all again (covers fresh runs)
1932        if state not in all_dicts:
1933            if getattr(self, "debug", False):
1934                print(f"[DEBUG] State '{state}' not found in cached linear shifts, recomputing full dict.")
1935            setattr(self, dict_attr, self.get_linear_shifts_dict())
1936            all_dicts = getattr(self, dict_attr)
1937
1938        # Now fetch the dictionary for this state
1939        state_dict = all_dicts.get(state, {})
1940
1941        # If the node is available, return its entry
1942        if node in state_dict:
1943            return state_dict[node]
1944
1945        # If missing, try recomputing (fallback)
1946        if getattr(self, "debug", False):
1947            print(f"[DEBUG] Node '{node}' not found in state '{state}', recomputing full dict.")
1948        setattr(self, dict_attr, self.get_linear_shifts_dict())
1949        all_dicts = getattr(self, dict_attr)
1950        return all_dicts.get(state, {}).get(node, None)

Return learned linear shift terms for a node and a given state.

Parameters

node : str Node name. state : {"best", "last", "init"}, optional Model state for which to return linear shift terms. Default is "best".

Returns

dict or Any or None Linear shift terms for the given node and state. Usually a dict mapping term names to weights.

Raises

ValueError If an invalid state is given (not in {"best", "last", "init"}).

Notes

Linear shift dictionaries are cached on the instance under the attribute linear_shift_dicts. If missing or incomplete, they are recomputed using get_linear_shifts_dict.

def get_linear_shifts_dict(self):
1952    def get_linear_shifts_dict(self):
1953        """
1954        Compute linear shift term dictionaries for all nodes and states.
1955
1956        For each node and each available state ("best", "last", "init"), this
1957        method loads the corresponding model checkpoint, extracts linear shift
1958        weights from the TRAM model, and stores them in a nested dictionary.
1959
1960        Returns
1961        -------
1962        dict
1963            Nested dictionary of the form:
1964
1965            .. code-block:: python
1966
1967                {
1968                    "best": {node: {...}},
1969                    "last": {node: {...}},
1970                    "init": {node: {...}},
1971                }
1972
1973            where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``)
1974            to their weights.
1975
1976        Notes
1977        -----
1978        - If "best" or "last" checkpoints are unavailable for a node, only
1979        the "init" entry is populated.
1980        - Empty outer states (without any nodes) are removed from the result.
1981        """
1982
1983        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
1984        nodes_list = list(self.models.keys())
1985        all_states = ["best", "last", "init"]
1986        all_linear_shift_dicts = {state: {} for state in all_states}
1987
1988        for node in nodes_list:
1989            NODE_DIR = os.path.join(EXPERIMENT_DIR, node)
1990            BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR)
1991            INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt")
1992
1993            state_paths = {
1994                "best": BEST_MODEL_PATH,
1995                "last": LAST_MODEL_PATH,
1996                "init": INIT_MODEL_PATH,
1997            }
1998
1999            for state, LOAD_PATH in state_paths.items():
2000                if not os.path.exists(LOAD_PATH):
2001                    if state != "init":
2002                        # skip best/last if unavailable
2003                        continue
2004                    else:
2005                        print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.")
2006                        if not os.path.exists(LOAD_PATH):
2007                            if getattr(self, "debug", False):
2008                                print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.")
2009                            continue
2010
2011                # Load parents and model
2012                _, terms_dict, _ = ordered_parents(node, self.nodes_dict)
2013                state_dict = torch.load(LOAD_PATH, map_location=self.device)
2014                tram_model = self.models[node]
2015                tram_model.load_state_dict(state_dict)
2016
2017                epoch_weights = {}
2018                if hasattr(tram_model, "nn_shift") and tram_model.nn_shift is not None:
2019                    for i, shift_layer in enumerate(tram_model.nn_shift):
2020                        module_name = shift_layer.__class__.__name__
2021                        if (
2022                            hasattr(shift_layer, "fc")
2023                            and hasattr(shift_layer.fc, "weight")
2024                            and module_name == "LinearShift"
2025                        ):
2026                            term_name = list(terms_dict.keys())[i]
2027                            epoch_weights[f"ls({term_name})"] = (
2028                                shift_layer.fc.weight.detach().cpu().squeeze().tolist()
2029                            )
2030                        elif getattr(self, "debug", False):
2031                            term_name = list(terms_dict.keys())[i]
2032                            print(f"[DEBUG] ls({term_name}): missing 'fc' or 'weight' in LinearShift.")
2033                else:
2034                    if getattr(self, "debug", False):
2035                        print(f"[DEBUG] Tram model for node '{node}' has no nn_shift or it is None.")
2036
2037                all_linear_shift_dicts[state][node] = epoch_weights
2038
2039        # Remove empty states (e.g., when best/last not found for all nodes)
2040        all_linear_shift_dicts = {k: v for k, v in all_linear_shift_dicts.items() if v}
2041
2042        return all_linear_shift_dicts

Compute linear shift term dictionaries for all nodes and states.

For each node and each available state ("best", "last", "init"), this method loads the corresponding model checkpoint, extracts linear shift weights from the TRAM model, and stores them in a nested dictionary.

Returns

dict Nested dictionary of the form:

```python

{ "best": {node: {...}}, "last": {node: {...}}, "init": {node: {...}}, } ```

where the innermost dict maps term labels (e.g. ``"ls(parent_name)"``)
to their weights.

Notes

  • If "best" or "last" checkpoints are unavailable for a node, only the "init" entry is populated.
  • Empty outer states (without any nodes) are removed from the result.
def get_simple_intercepts_dict(self):
2044    def get_simple_intercepts_dict(self):
2045        """
2046        Compute transformed simple intercept dictionaries for all nodes and states.
2047
2048        For each node and each available state ("best", "last", "init"), this
2049        method loads the corresponding model checkpoint, extracts simple intercept
2050        weights, transforms them into interpretable theta parameters, and stores
2051        them in a nested dictionary.
2052
2053        Returns
2054        -------
2055        dict
2056            Nested dictionary of the form:
2057
2058            .. code-block:: python
2059
2060                {
2061                    "best": {node: [[theta_1], [theta_2], ...]},
2062                    "last": {node: [[theta_1], [theta_2], ...]},
2063                    "init": {node: [[theta_1], [theta_2], ...]},
2064                }
2065
2066        Notes
2067        -----
2068        - For ordinal models (``self.is_ontram == True``), `transform_intercepts_ordinal`
2069        is used.
2070        - For continuous models, `transform_intercepts_continous` is used.
2071        - Empty outer states (without any nodes) are removed from the result.
2072        """
2073
2074        EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
2075        nodes_list = list(self.models.keys())
2076        all_states = ["best", "last", "init"]
2077        all_si_intercept_dicts = {state: {} for state in all_states}
2078
2079        debug = getattr(self, "debug", False)
2080        verbose = getattr(self, "verbose", False)
2081        is_ontram = getattr(self, "is_ontram", False)
2082
2083        for node in nodes_list:
2084            NODE_DIR = os.path.join(EXPERIMENT_DIR, node)
2085            BEST_MODEL_PATH, LAST_MODEL_PATH, _, _ = model_train_val_paths(NODE_DIR)
2086            INIT_MODEL_PATH = os.path.join(NODE_DIR, "initial_model.pt")
2087
2088            state_paths = {
2089                "best": BEST_MODEL_PATH,
2090                "last": LAST_MODEL_PATH,
2091                "init": INIT_MODEL_PATH,
2092            }
2093
2094            for state, LOAD_PATH in state_paths.items():
2095                if not os.path.exists(LOAD_PATH):
2096                    if state != "init":
2097                        continue
2098                    else:
2099                        print(f"[WARNING] No models found for node '{node}'. Only initial model will be used.")
2100                        if not os.path.exists(LOAD_PATH):
2101                            if debug:
2102                                print(f"[DEBUG] Initial model also missing for node '{node}'. Skipping.")
2103                            continue
2104
2105                # Load model state
2106                state_dict = torch.load(LOAD_PATH, map_location=self.device)
2107                tram_model = self.models[node]
2108                tram_model.load_state_dict(state_dict)
2109
2110                # Extract and transform simple intercept weights
2111                si_weights = None
2112                if hasattr(tram_model, "nn_int") and tram_model.nn_int is not None and isinstance(tram_model.nn_int, SimpleIntercept):
2113                    if hasattr(tram_model.nn_int, "fc") and hasattr(tram_model.nn_int.fc, "weight"):
2114                        weights = tram_model.nn_int.fc.weight.detach().cpu().tolist()
2115                        weights_tensor = torch.Tensor(weights)
2116
2117                        if debug:
2118                            print(f"[DEBUG] Node '{node}' ({state}) theta tilde shape: {weights_tensor.shape}")
2119
2120                        if is_ontram:
2121                            si_weights = transform_intercepts_ordinal(weights_tensor.reshape(1, -1))[:, 1:-1].reshape(-1, 1)
2122                        else:
2123                            si_weights = transform_intercepts_continous(weights_tensor.reshape(1, -1)).reshape(-1, 1)
2124
2125                        si_weights = si_weights.tolist()
2126
2127                        if debug:
2128                            print(f"[DEBUG] Node '{node}' ({state}) theta transformed: {si_weights}")
2129                    else:
2130                        if debug:
2131                            print(f"[DEBUG] Node '{node}' ({state}): missing 'fc' or 'weight' in SimpleIntercept.")
2132                else:
2133                    if debug:
2134                        print(f"[DEBUG] Tram model for node '{node}' has no nn_int or it is None.")
2135
2136                all_si_intercept_dicts[state][node] = si_weights
2137
2138        # Clean up empty states
2139        all_si_intercept_dicts = {k: v for k, v in all_si_intercept_dicts.items() if v}
2140        return all_si_intercept_dicts

Compute transformed simple intercept dictionaries for all nodes and states.

For each node and each available state ("best", "last", "init"), this method loads the corresponding model checkpoint, extracts simple intercept weights, transforms them into interpretable theta parameters, and stores them in a nested dictionary.

Returns

dict Nested dictionary of the form:

```python

{ "best": {node: [[theta_1], [theta_2], ...]}, "last": {node: [[theta_1], [theta_2], ...]}, "init": {node: [[theta_1], [theta_2], ...]}, } ```

Notes

  • For ordinal models (self.is_ontram == True), transform_intercepts_ordinal is used.
  • For continuous models, transform_intercepts_continous is used.
  • Empty outer states (without any nodes) are removed from the result.
def summary(self, verbose=False):
2142    def summary(self, verbose=False):
2143        """
2144        Print a multi-part textual summary of the TramDagModel.
2145
2146        The summary includes:
2147        1. Training metrics overview per node (best/last NLL, epochs).
2148        2. Node-specific details (thetas, linear shifts, optional architecture).
2149        3. Basic information about the attached training DataFrame, if present.
2150
2151        Parameters
2152        ----------
2153        verbose : bool, optional
2154            If True, include extended per-node details such as the model
2155            architecture, parameter count, and availability of checkpoints
2156            and sampling results. Default is False.
2157
2158        Returns
2159        -------
2160        None
2161
2162        Notes
2163        -----
2164        This method prints to stdout and does not return structured data.
2165        It is intended for quick, human-readable inspection of the current
2166        training and model state.
2167        """
2168
2169        # ---------- SETUP ----------
2170        try:
2171            EXPERIMENT_DIR = self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"]
2172        except KeyError:
2173            EXPERIMENT_DIR = None
2174            print("[WARNING] Missing EXPERIMENT_DIR in cfg.conf_dict['PATHS'].")
2175
2176        print("\n" + "=" * 120)
2177        print(f"{'TRAM DAG MODEL SUMMARY':^120}")
2178        print("=" * 120)
2179
2180        # ---------- METRICS OVERVIEW ----------
2181        summary_data = []
2182        for node in self.models.keys():
2183            node_dir = os.path.join(self.cfg.conf_dict["PATHS"]["EXPERIMENT_DIR"], node)
2184            train_path = os.path.join(node_dir, "train_loss_hist.json")
2185            val_path = os.path.join(node_dir, "val_loss_hist.json")
2186
2187            if os.path.exists(train_path) and os.path.exists(val_path):
2188                best_train_nll, best_val_nll = self.get_train_val_nll(node, "best")
2189                last_train_nll, last_val_nll = self.get_train_val_nll(node, "last")
2190                n_epochs_total = len(json.load(open(train_path)))
2191            else:
2192                best_train_nll = best_val_nll = last_train_nll = last_val_nll = None
2193                n_epochs_total = 0
2194
2195            summary_data.append({
2196                "Node": node,
2197                "Best Train NLL": best_train_nll,
2198                "Best Val NLL": best_val_nll,
2199                "Last Train NLL": last_train_nll,
2200                "Last Val NLL": last_val_nll,
2201                "Epochs": n_epochs_total,
2202            })
2203
2204        df_summary = pd.DataFrame(summary_data)
2205        df_summary = df_summary.round(4)
2206
2207        print("\n[1] TRAINING METRICS OVERVIEW")
2208        print("-" * 120)
2209        if not df_summary.empty:
2210            print(
2211                df_summary.to_string(
2212                    index=False,
2213                    justify="center",
2214                    col_space=14,
2215                    float_format=lambda x: f"{x:7.4f}",
2216                )
2217            )
2218        else:
2219            print("No training history found for any node.")
2220        print("-" * 120)
2221
2222        # ---------- NODE DETAILS ----------
2223        print("\n[2] NODE-SPECIFIC DETAILS")
2224        print("-" * 120)
2225        for node in self.models.keys():
2226            print(f"\n{f'NODE: {node}':^120}")
2227            print("-" * 120)
2228
2229            # THETAS & SHIFTS
2230            for state in ["init", "last", "best"]:
2231                print(f"\n  [{state.upper()} STATE]")
2232
2233                # ---- Thetas ----
2234                try:
2235                    thetas = getattr(self, "get_thetas", lambda n, s=None: None)(node, state)
2236                    if thetas is not None:
2237                        if isinstance(thetas, (list, np.ndarray, pd.Series)):
2238                            thetas_flat = np.array(thetas).flatten()
2239                            compact = np.round(thetas_flat, 4)
2240                            arr_str = np.array2string(
2241                                compact,
2242                                max_line_width=110,
2243                                threshold=np.inf,
2244                                separator=", "
2245                            )
2246                            lines = arr_str.split("\n")
2247                            if len(lines) > 2:
2248                                arr_str = "\n".join(lines[:2]) + " ..."
2249                            print(f"    Θ ({len(thetas_flat)}): {arr_str}")
2250                        elif isinstance(thetas, dict):
2251                            for k, v in thetas.items():
2252                                print(f"     Θ[{k}]: {v}")
2253                        else:
2254                            print(f"    Θ: {thetas}")
2255                    else:
2256                        print("    Θ: not available")
2257                except Exception as e:
2258                    print(f"    [Error loading thetas] {e}")
2259
2260                # ---- Linear Shifts ----
2261                try:
2262                    linear_shifts = getattr(self, "get_linear_shifts", lambda n, s=None: None)(node, state)
2263                    if linear_shifts is not None:
2264                        if isinstance(linear_shifts, dict):
2265                            for k, v in linear_shifts.items():
2266                                print(f"     {k}: {np.round(v, 4)}")
2267                        elif isinstance(linear_shifts, (list, np.ndarray, pd.Series)):
2268                            arr = np.round(linear_shifts, 4)
2269                            print(f"    Linear shifts ({len(arr)}): {arr}")
2270                        else:
2271                            print(f"    Linear shifts: {linear_shifts}")
2272                    else:
2273                        print("    Linear shifts: not available")
2274                except Exception as e:
2275                    print(f"    [Error loading linear shifts] {e}")
2276
2277            # ---- Verbose info directly below node ----
2278            if verbose:
2279                print("\n  [DETAILS]")
2280                node_dir = os.path.join(EXPERIMENT_DIR, node) if EXPERIMENT_DIR else None
2281                model = self.models[node]
2282
2283                print(f"    Model Architecture:")
2284                arch_str = str(model).split("\n")
2285                for line in arch_str:
2286                    print(f"      {line}")
2287                print(f"    Parameter count: {sum(p.numel() for p in model.parameters()):,}")
2288
2289                if node_dir and os.path.exists(node_dir):
2290                    ckpt_exists = any(f.endswith(('.pt', '.pth')) for f in os.listdir(node_dir))
2291                    print(f"    Checkpoints found: {ckpt_exists}")
2292
2293                    sampling_dir = os.path.join(node_dir, "sampling")
2294                    sampling_exists = os.path.isdir(sampling_dir) and len(os.listdir(sampling_dir)) > 0
2295                    print(f"    Sampling results found: {sampling_exists}")
2296
2297                    for label, filename in [("Train", "train_loss_hist.json"), ("Validation", "val_loss_hist.json")]:
2298                        path = os.path.join(node_dir, filename)
2299                        if os.path.exists(path):
2300                            try:
2301                                with open(path, "r") as f:
2302                                    hist = json.load(f)
2303                                print(f"    {label} history: {len(hist)} epochs")
2304                            except Exception as e:
2305                                print(f"    {label} history: failed to load ({e})")
2306                        else:
2307                            print(f"    {label} history: not found")
2308                else:
2309                    print("    [INFO] No experiment directory defined or missing for this node.")
2310            print("-" * 120)
2311
2312        # ---------- TRAINING DATAFRAME ----------
2313        print("\n[3] TRAINING DATAFRAME")
2314        print("-" * 120)
2315        try:
2316            self.train_df.info()
2317        except AttributeError:
2318            print("No training DataFrame attached to this TramDagModel.")
2319        print("=" * 120 + "\n")

Print a multi-part textual summary of the TramDagModel.

The summary includes:

  1. Training metrics overview per node (best/last NLL, epochs).
  2. Node-specific details (thetas, linear shifts, optional architecture).
  3. Basic information about the attached training DataFrame, if present.

Parameters

verbose : bool, optional If True, include extended per-node details such as the model architecture, parameter count, and availability of checkpoints and sampling results. Default is False.

Returns

None

Notes

This method prints to stdout and does not return structured data. It is intended for quick, human-readable inspection of the current training and model state.