tramdag.TramDagConfig

Copyright 2025 Zurich University of Applied Sciences (ZHAW) Pascal Buehler, Beate Sick, Oliver Duerr

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

  1"""
  2Copyright 2025 Zurich University of Applied Sciences (ZHAW)
  3Pascal Buehler, Beate Sick, Oliver Duerr
  4
  5Licensed under the Apache License, Version 2.0 (the "License");
  6you may not use this file except in compliance with the License.
  7You may obtain a copy of the License at
  8
  9    http://www.apache.org/licenses/LICENSE-2.0
 10
 11Unless required by applicable law or agreed to in writing, software
 12distributed under the License is distributed on an "AS IS" BASIS,
 13WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14See the License for the specific language governing permissions and
 15limitations under the License.
 16"""
 17
 18import matplotlib.pyplot as plt
 19import numpy as np
 20import networkx as nx
 21from matplotlib.patches import Patch
 22import pandas as pd
 23import os
 24
 25from .utils.configuration import *
 26
 27
 28# renamme set _meta_adj amtrix
 29class TramDagConfig:
 30    """
 31    Configuration manager for TRAM-DAG experiments.
 32
 33    This class encapsulates:
 34
 35    - The experiment configuration dictionary (`conf_dict`).
 36    - Its backing file path (`CONF_DICT_PATH`).
 37    - Utilities to load, validate, modify, and persist configuration.
 38    - DAG visualization and interactive editing helpers.
 39
 40    Typical usage
 41    -------------
 42    - Load existing configuration from disk via `TramDagConfig.load_json`.
 43    - Or create/reuse experiment setup via `TramDagConfig().setup_configuration`.
 44    - Update sections such as `data_type`, adjacency matrix, and neural network
 45    model names using the provided methods.
 46    """
 47        
 48    def __init__(self, conf_dict: dict = None, CONF_DICT_PATH: str = None,  _verify: bool = False,**kwargs):
 49        """
 50        Initialize a TramDagConfig instance.
 51
 52        Parameters
 53        ----------
 54        conf_dict : dict or None, optional
 55            Configuration dictionary. If None, an empty dict is used and can
 56            be populated later. Default is None.
 57        CONF_DICT_PATH : str or None, optional
 58            Path to the configuration file on disk. Default is None.
 59        _verify : bool, optional
 60            If True, run `_verify_completeness()` after initialization.
 61            Default is False.
 62        **kwargs
 63            Additional attributes to be set on the instance. Keys "conf_dict"
 64            and "CONF_DICT_PATH" are forbidden and raise a ValueError.
 65
 66        Raises
 67        ------
 68        ValueError
 69            If any key in `kwargs` is "conf_dict" or "CONF_DICT_PATH".
 70
 71        Notes
 72        -----
 73        By default, `debug` and `verbose` are set to False. They can be
 74        overridden via `kwargs`.
 75        """
 76
 77        self.debug = False
 78        self.verbose = False
 79        
 80        for key, value in kwargs.items():
 81            if key in ['conf_dict', 'CONF_DICT_PATH']:
 82                raise ValueError(f"Cannot override '{key}' via kwargs.")
 83            setattr(self, key, value)
 84        
 85        self.conf_dict = conf_dict or {}
 86        self.CONF_DICT_PATH = CONF_DICT_PATH
 87        
 88        # verification 
 89        if _verify:
 90            self._verify_completeness()
 91
 92    @classmethod
 93    def load_json(cls, CONF_DICT_PATH: str,debug: bool = False):
 94        """
 95        Load a configuration from a JSON file and construct a TramDagConfig.
 96
 97        Parameters
 98        ----------
 99        CONF_DICT_PATH : str
100            Path to the configuration JSON file.
101        debug : bool, optional
102            If True, initialize the instance with `debug=True`. Default is False.
103
104        Returns
105        -------
106        TramDagConfig
107            Newly created configuration instance with `conf_dict` loaded from
108            `CONF_DICT_PATH` and `_verify_completeness()` executed.
109
110        Raises
111        ------
112        FileNotFoundError
113            If the configuration file cannot be found (propagated by
114            `load_configuration_dict`).
115        """
116
117        conf = load_configuration_dict(CONF_DICT_PATH)
118        return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=True)
119
120    def update(self):
121        """
122        Reload the latest configuration from disk into this instance.
123
124        Parameters
125        ----------
126        None
127
128        Returns
129        -------
130        None
131
132        Raises
133        ------
134        ValueError
135            If `CONF_DICT_PATH` is not set on the instance.
136
137        Notes
138        -----
139        The current in-memory `conf_dict` is overwritten by the contents
140        loaded from `CONF_DICT_PATH`.
141        """
142
143        if not hasattr(self, "CONF_DICT_PATH") or self.CONF_DICT_PATH is None:
144            raise ValueError("CONF_DICT_PATH not set — cannot update configuration.")
145        
146        
147        self.conf_dict = load_configuration_dict(self.CONF_DICT_PATH)
148
149
150    def save(self, CONF_DICT_PATH: str = None):
151        """
152        Persist the current configuration dictionary to disk.
153
154        Parameters
155        ----------
156        CONF_DICT_PATH : str or None, optional
157            Target path for the configuration file. If None, uses
158            `self.CONF_DICT_PATH`. Default is None.
159
160        Returns
161        -------
162        None
163
164        Raises
165        ------
166        ValueError
167            If neither the function argument nor `self.CONF_DICT_PATH`
168            provides a valid path.
169
170        Notes
171        -----
172        The resulting file is written via `write_configuration_dict`.
173        """
174        path = CONF_DICT_PATH or self.CONF_DICT_PATH
175        if path is None:
176            raise ValueError("No CONF_DICT_PATH provided to save config.")
177        write_configuration_dict(self.conf_dict, path)
178
179    def _verify_completeness(self):
180        """
181        Check that the configuration is structurally complete and consistent.
182
183        The following checks are performed:
184
185        1. Top-level mandatory keys:
186        - "experiment_name"
187        - "PATHS"
188        - "nodes"
189        - "data_type"
190        - "adj_matrix"
191        - "model_names"
192
193        2. Per-node mandatory keys:
194        - "data_type"
195        - "node_type"
196        - "parents"
197        - "parents_datatype"
198        - "transformation_terms_in_h()"
199        - "transformation_term_nn_models_in_h()"
200
201        3. Ordinal / categorical levels:
202        - All ordinal variables must have a corresponding "levels" entry
203            under `conf_dict["nodes"][var]`.
204
205        4. Experiment name:
206        - Must be non-empty.
207
208        5. Adjacency matrix:
209        - Must be valid under `validate_adj_matrix`.
210
211        Parameters
212        ----------
213        None
214
215        Returns
216        -------
217        None
218
219        Notes
220        -----
221        - Missing or invalid components are reported via printed warnings.
222        - Detailed debug messages are printed when `self.debug=True`.
223        """
224        mandatory_keys = ["experiment_name","PATHS", "nodes", "data_type", "adj_matrix","nodes","model_names"]
225        optional_keys = ["date_of_creation", "seed"]
226
227        # ---- 1. Check mandatory keys exist
228        missing = [k for k in mandatory_keys if k not in self.conf_dict]
229        if missing:
230            print(f"[WARNING] Missing mandatory keys in configuration: {missing}"
231                "\n Please add them to the configuration dict and reload.")
232            
233        # --- 2. Check  if mandatory keys in nodesdict are present
234        mandatory_keys_nodes = ['data_type', 'node_type','parents','parents_datatype','transformation_terms_in_h()','transformation_term_nn_models_in_h()']
235        optional_keys_nodes = ["levels"]
236        for node, node_dict in self.conf_dict.get("nodes", {}).items():
237            # check missing mandatory keys
238            missing_node_keys = [k for k in mandatory_keys_nodes if k not in node_dict]
239            if missing_node_keys:
240                print(f"[WARNING] Node '{node}' is missing mandatory keys: {missing_node_keys}")
241                
242
243        
244        if self._verify_levels_dict():
245            if self.debug:
246                print("[DEBUG] levels are present for all ordinal variables in configuration dict.")
247            pass
248        else:
249            print("[WARNING]  levels are missing for some ordinal variables in configuration dict. THIS will FAIL in model training later!\n"
250                " Please provide levels manually to config and reload or compute levels from data using the method compute_levels().\n"
251                " e.g. cfg.compute_levels(train_df) # computes levels from training data and writes to cfg")
252
253        if self._verify_experiment_name():
254            if self.debug:
255                print("[DEBUG] experiment_name is valid in configuration dict.")
256            pass
257        
258        if self._verify_adj_matrix():
259            if self.debug:
260                print("[DEBUG] adj_matrix is valid in configuration dict.")
261            pass
262
263    def _verify_levels_dict(self):
264        """
265        Verify that all ordinal variables have levels specified in the config.
266
267        Parameters
268        ----------
269        None
270
271        Returns
272        -------
273        bool
274            True if all variables declared as ordinal in ``conf_dict["data_type"]``
275            have a "levels" entry in ``conf_dict["nodes"][var]``.
276            False otherwise.
277
278        Notes
279        -----
280        This method does not modify the configuration; it only checks presence
281        of level information.
282        """
283        data_type = self.conf_dict.get('data_type', {})
284        nodes = self.conf_dict.get('nodes', {})
285        for var, dtype in data_type.items():
286            if 'ordinal' in dtype:
287                if var not in nodes or 'levels' not in nodes[var]:
288                    return False
289        return True
290
291    def _verify_experiment_name(self):
292        """
293        Check whether the experiment name in the configuration is valid.
294
295        Parameters
296        ----------
297        None
298
299        Returns
300        -------
301        bool
302            True if ``conf_dict["experiment_name"]`` exists and is non-empty
303            after stripping whitespace. False otherwise.
304        """
305        experiment_name = self.conf_dict.get("experiment_name")
306        if experiment_name is None or str(experiment_name).strip() == "":
307            return False
308        return True
309        
310    def _verify_adj_matrix(self):
311        """
312        Validate the adjacency matrix stored in the configuration.
313
314        Parameters
315        ----------
316        None
317
318        Returns
319        -------
320        bool
321            True if the adjacency matrix passes `validate_adj_matrix`, False otherwise.
322
323        Notes
324        -----
325        If the adjacency matrix is stored as a list, it is converted to a
326        NumPy array before validation.
327        """
328        adj_matrix = self.conf_dict['adj_matrix']
329        if isinstance(adj_matrix, list):
330            adj_matrix = np.array(self.conf_dict['adj_matrix'])
331        if validate_adj_matrix(adj_matrix):
332            return True
333        else:
334            return False
335
336    def compute_levels(self, df: pd.DataFrame, write: bool = True):
337        """
338        Infer and update ordinal/categorical levels from data.
339
340        For each variable in the configuration's `data_type` section, this
341        method uses the provided DataFrame to construct a levels dictionary
342        and injects the corresponding "levels" entry into `conf_dict["nodes"]`.
343
344        Parameters
345        ----------
346        df : pandas.DataFrame
347            DataFrame used to infer levels for configured variables.
348        write : bool, optional
349            If True and `CONF_DICT_PATH` is set, the updated configuration is
350            written back to disk. Default is True.
351
352        Returns
353        -------
354        None
355
356        Raises
357        ------
358        Exception
359            If saving the configuration fails when `write=True`.
360
361        Notes
362        -----
363        - Variables present in `levels_dict` but not in `conf_dict["nodes"]`
364        trigger a warning and are skipped.
365        - If `self.verbose` or `self.debug` is True, a success message is printed
366        when the configuration is saved.
367        """
368        self.update()
369        levels_dict = create_levels_dict(df, self.conf_dict['data_type'])
370        
371        # update nodes dict with levels
372        for var, levels in levels_dict.items():
373            if var in self.conf_dict['nodes']:
374                self.conf_dict['nodes'][var]['levels'] = levels
375            else:
376                print(f"[WARNING] Variable '{var}' not found in nodes dict. Cannot add levels.")
377        
378        if write and self.CONF_DICT_PATH is not None:
379            try:
380                self.save(self.CONF_DICT_PATH)
381                if self.verbose or self.debug:
382                    print(f'[INFO] Configuration with updated levels saved to {self.CONF_DICT_PATH}')
383            except Exception as e:
384                print(f'[ERROR] Failed to save configuration: {e}')
385
386    def plot_dag(self, seed: int = 42, causal_order: bool = False):
387        """
388        Visualize the DAG defined by the configuration.
389
390        Nodes are categorized and colored as:
391        - Source nodes (no incoming edges): green.
392        - Sink nodes (no outgoing edges): red.
393        - Intermediate nodes: light blue.
394
395        Parameters
396        ----------
397        seed : int, optional
398            Random seed for layout stability in the spring layout fallback.
399            Default is 42.
400        causal_order : bool, optional
401            If True, attempt to use Graphviz 'dot' layout via
402            `networkx.nx_agraph.graphviz_layout` to preserve causal ordering.
403            If False or if Graphviz is unavailable, use `spring_layout`.
404            Default is False.
405
406        Returns
407        -------
408        None
409
410        Raises
411        ------
412        ValueError
413            If `adj_matrix` or `data_type` is missing or inconsistent with each other,
414            or if the adjacency matrix fails validation.
415
416        Notes
417        -----
418        Edge labels are colored by prefix:
419        - "ci": blue
420        - "ls": red
421        - "cs": green
422        - other: black
423        """
424        adj_matrix = self.conf_dict.get("adj_matrix")
425        data_type  = self.conf_dict.get("data_type")
426
427        if adj_matrix is None or data_type is None:
428            raise ValueError("Configuration must include 'adj_matrix' and 'data_type'.")
429
430        if isinstance(adj_matrix, list):
431            adj_matrix = np.array(adj_matrix)
432
433        if not validate_adj_matrix(adj_matrix):
434            raise ValueError("Invalid adjacency matrix.")
435        if len(data_type) != adj_matrix.shape[0]:
436            raise ValueError("data_type must match adjacency matrix size.")
437
438        node_labels = list(data_type.keys())
439        G, edge_labels = create_nx_graph(adj_matrix, node_labels)
440
441        sources       = {n for n in G.nodes if G.in_degree(n) == 0}
442        sinks         = {n for n in G.nodes if G.out_degree(n) == 0}
443        intermediates = set(G.nodes) - sources - sinks
444
445        node_colors = [
446            "green" if n in sources
447            else "red" if n in sinks
448            else "lightblue"
449            for n in G.nodes
450        ]
451
452        if causal_order:
453            try:
454                pos = nx.nx_agraph.graphviz_layout(G, prog="dot")
455            except (ImportError, nx.NetworkXException):
456                pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100)
457        else:
458            pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100)
459
460        plt.figure(figsize=(8, 6))
461        nx.draw(
462            G, pos,
463            with_labels=True,
464            node_color=node_colors,
465            edge_color="gray",
466            node_size=2500,
467            arrowsize=20
468        )
469
470        for (u, v), lbl in edge_labels.items():
471            color = (
472                "blue"  if lbl.startswith("ci")
473                else "red"   if lbl.startswith("ls")
474                else "green" if lbl.startswith("cs")
475                else "black"
476            )
477            nx.draw_networkx_edge_labels(
478                G, pos,
479                edge_labels={(u, v): lbl},
480                font_color=color,
481                font_size=12
482            )
483
484        legend_items = [
485            Patch(facecolor="green",     edgecolor="black", label="Source"),
486            Patch(facecolor="red",       edgecolor="black", label="Sink"),
487            Patch(facecolor="lightblue", edgecolor="black", label="Intermediate")
488        ]
489        plt.legend(handles=legend_items, loc="upper right", frameon=True)
490
491        plt.title(f"TRAM DAG")
492        plt.axis("off")
493        plt.tight_layout()
494        plt.show()
495
496  
497    def setup_configuration(self, experiment_name=None, EXPERIMENT_DIR=None, debug=False, _verify=False):
498        """
499        Create or reuse a configuration for an experiment.
500
501        This method behaves differently depending on how it is called:
502
503        1. Class call (e.g. `TramDagConfig.setup_configuration(...)`):
504        - Creates or loads a configuration at the resolved path.
505        - Returns a new `TramDagConfig` instance.
506
507        2. Instance call (e.g. `cfg.setup_configuration(...)`):
508        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` in place.
509        - Optionally verifies completeness.
510        - Returns None.
511
512        Parameters
513        ----------
514        experiment_name : str or None, optional
515            Name of the experiment. If None, defaults to "experiment_1".
516        EXPERIMENT_DIR : str or None, optional
517            Directory for the experiment. If None, defaults to
518            `<cwd>/<experiment_name>`.
519        debug : bool, optional
520            If True, initialize / update with `debug=True`. Default is False.
521        _verify : bool, optional
522            If True, call `_verify_completeness()` after loading. Default is False.
523
524        Returns
525        -------
526        TramDagConfig or None
527            - A new instance when called on the class.
528            - None when called on an existing instance.
529
530        Notes
531        -----
532        - A configuration file named "configuration.json" is created if it does
533        not exist yet.
534        - Underlying creation uses `create_and_write_new_configuration_dict`
535        and `load_configuration_dict`.
536        """
537        is_class_call = isinstance(self, type)
538        cls = self if is_class_call else self.__class__
539
540        if experiment_name is None:
541            experiment_name = "experiment_1"
542        if EXPERIMENT_DIR is None:
543            EXPERIMENT_DIR = os.path.join(os.getcwd(), experiment_name)
544
545        CONF_DICT_PATH = os.path.join(EXPERIMENT_DIR, "configuration.json")
546        DATA_PATH = EXPERIMENT_DIR
547
548        os.makedirs(EXPERIMENT_DIR, exist_ok=True)
549
550        if os.path.exists(CONF_DICT_PATH):
551            print(f"Configuration already exists: {CONF_DICT_PATH}")
552        else:
553            _ = create_and_write_new_configuration_dict(
554                experiment_name, CONF_DICT_PATH, EXPERIMENT_DIR, DATA_PATH, None
555            )
556            print(f"Created new configuration file at {CONF_DICT_PATH}")
557
558        conf = load_configuration_dict(CONF_DICT_PATH)
559
560        if is_class_call:
561            return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=_verify)
562        else:
563            self.conf_dict = conf
564            self.CONF_DICT_PATH = CONF_DICT_PATH
565            if _verify:
566                self._verify_completeness()
567                
568    def set_data_type(self, data_type: dict, CONF_DICT_PATH: str = None) -> None:
569        """
570        Update or write the `data_type` section of a configuration file.
571
572        Supports both class-level and instance-level usage:
573
574        - Class call:
575        - Requires `CONF_DICT_PATH` argument.
576        - Reads the file if it exists, or starts from an empty dict.
577        - Writes updated configuration to `CONF_DICT_PATH`.
578
579        - Instance call:
580        - Uses `self.CONF_DICT_PATH` if available, otherwise defaults to
581            `<cwd>/configuration.json` if no path is provided.
582        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` after writing.
583
584        Parameters
585        ----------
586        data_type : dict
587            Mapping `{variable_name: type_spec}`, where `type_spec` encodes
588            modeling types (e.g. continuous, ordinal, etc.).
589        CONF_DICT_PATH : str or None, optional
590            Path to the configuration file. Must be provided for class calls.
591            For instance calls, defaults as described above.
592
593        Returns
594        -------
595        None
596
597        Raises
598        ------
599        ValueError
600            If `CONF_DICT_PATH` is missing when called on the class, or if
601            validation of data types fails.
602
603        Notes
604        -----
605        - Variable names are validated via `validate_variable_names`.
606        - Data type values are validated via `validate_data_types`.
607        - A textual summary of modeling settings is printed via
608        `print_data_type_modeling_setting`, if possible.
609        """
610        is_class_call = isinstance(self, type)
611        cls = self if is_class_call else self.__class__
612
613        # resolve path
614        if CONF_DICT_PATH is None:
615            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
616                CONF_DICT_PATH = self.CONF_DICT_PATH
617            elif not is_class_call:
618                CONF_DICT_PATH = os.path.join(os.getcwd(), "configuration.json")
619            else:
620                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
621
622        try:
623            # load existing or create empty configuration
624            configuration_dict = (
625                load_configuration_dict(CONF_DICT_PATH)
626                if os.path.exists(CONF_DICT_PATH)
627                else {}
628            )
629
630            validate_variable_names(data_type.keys())
631            if not validate_data_types(data_type):
632                raise ValueError("Invalid data types in the provided dictionary.")
633
634            configuration_dict["data_type"] = data_type
635            write_configuration_dict(configuration_dict, CONF_DICT_PATH)
636
637            # safe printing
638            try:
639                print_data_type_modeling_setting(data_type or {})
640            except Exception as e:
641                print(f"[WARNING] Could not print data type modeling settings: {e}")
642
643            if not is_class_call:
644                self.conf_dict = configuration_dict
645                self.CONF_DICT_PATH = CONF_DICT_PATH
646
647
648        except Exception as e:
649            print(f"Failed to update configuration: {e}")
650        else:
651            print(f"Configuration updated successfully at {CONF_DICT_PATH}.")
652            
653    def set_meta_adj_matrix(self, CONF_DICT_PATH: str = None, seed: int = 5):
654        """
655        Launch the interactive editor to set or modify the adjacency matrix.
656
657        This method:
658
659        1. Resolves the configuration path either from the argument or, for
660        instances, from `self.CONF_DICT_PATH`.
661        2. Invokes `interactive_adj_matrix` to edit the adjacency matrix.
662        3. For instances, reloads the updated configuration into `self.conf_dict`.
663
664        Parameters
665        ----------
666        CONF_DICT_PATH : str or None, optional
667            Path to the configuration file. Must be provided when called
668            on the class. For instance calls, defaults to `self.CONF_DICT_PATH`.
669        seed : int, optional
670            Random seed for any layout or stochastic behavior in the interactive
671            editor. Default is 5.
672
673        Returns
674        -------
675        None
676
677        Raises
678        ------
679        ValueError
680            If `CONF_DICT_PATH` is not provided and cannot be inferred
681            (e.g. in a class call without path).
682
683        Notes
684        -----
685        `self.update()` is called at the start to ensure the in-memory config
686        is in sync with the file before launching the editor.
687        """
688        self.update()
689        is_class_call = isinstance(self, type)
690        # resolve path
691        if CONF_DICT_PATH is None:
692            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
693                CONF_DICT_PATH = self.CONF_DICT_PATH
694            else:
695                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
696
697        # launch interactive editor
698        
699        interactive_adj_matrix(CONF_DICT_PATH, seed=seed)
700
701        # reload config if instance
702        if not is_class_call:
703            self.conf_dict = load_configuration_dict(CONF_DICT_PATH)
704            self.CONF_DICT_PATH = CONF_DICT_PATH
705
706
707    def set_tramdag_nn_models(self, CONF_DICT_PATH: str = None):
708        """
709        Launch the interactive editor to set TRAM-DAG neural network model names.
710
711        Depending on call context:
712
713        - Class call:
714        - Requires `CONF_DICT_PATH` argument.
715        - Returns nothing and does not modify a specific instance.
716
717        - Instance call:
718        - Resolves `CONF_DICT_PATH` from the argument or `self.CONF_DICT_PATH`.
719        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` if the editor
720            returns an updated configuration.
721
722        Parameters
723        ----------
724        CONF_DICT_PATH : str or None, optional
725            Path to the configuration file. Must be provided when called on
726            the class. For instance calls, defaults to `self.CONF_DICT_PATH`.
727
728        Returns
729        -------
730        None
731
732        Raises
733        ------
734        ValueError
735            If `CONF_DICT_PATH` is not provided and cannot be inferred
736            (e.g. in a class call without path).
737
738        Notes
739        -----
740        The interactive editor is invoked via `interactive_nn_names_matrix`.
741        If it returns `None`, the instance configuration is left unchanged.
742        """
743        is_class_call = isinstance(self, type)
744        if CONF_DICT_PATH is None:
745            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
746                CONF_DICT_PATH = self.CONF_DICT_PATH
747            else:
748                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
749
750        updated_conf = interactive_nn_names_matrix(CONF_DICT_PATH)
751        if updated_conf is not None and not is_class_call:
752            self.conf_dict = updated_conf
753            self.CONF_DICT_PATH = CONF_DICT_PATH
class TramDagConfig:
 30class TramDagConfig:
 31    """
 32    Configuration manager for TRAM-DAG experiments.
 33
 34    This class encapsulates:
 35
 36    - The experiment configuration dictionary (`conf_dict`).
 37    - Its backing file path (`CONF_DICT_PATH`).
 38    - Utilities to load, validate, modify, and persist configuration.
 39    - DAG visualization and interactive editing helpers.
 40
 41    Typical usage
 42    -------------
 43    - Load existing configuration from disk via `TramDagConfig.load_json`.
 44    - Or create/reuse experiment setup via `TramDagConfig().setup_configuration`.
 45    - Update sections such as `data_type`, adjacency matrix, and neural network
 46    model names using the provided methods.
 47    """
 48        
 49    def __init__(self, conf_dict: dict = None, CONF_DICT_PATH: str = None,  _verify: bool = False,**kwargs):
 50        """
 51        Initialize a TramDagConfig instance.
 52
 53        Parameters
 54        ----------
 55        conf_dict : dict or None, optional
 56            Configuration dictionary. If None, an empty dict is used and can
 57            be populated later. Default is None.
 58        CONF_DICT_PATH : str or None, optional
 59            Path to the configuration file on disk. Default is None.
 60        _verify : bool, optional
 61            If True, run `_verify_completeness()` after initialization.
 62            Default is False.
 63        **kwargs
 64            Additional attributes to be set on the instance. Keys "conf_dict"
 65            and "CONF_DICT_PATH" are forbidden and raise a ValueError.
 66
 67        Raises
 68        ------
 69        ValueError
 70            If any key in `kwargs` is "conf_dict" or "CONF_DICT_PATH".
 71
 72        Notes
 73        -----
 74        By default, `debug` and `verbose` are set to False. They can be
 75        overridden via `kwargs`.
 76        """
 77
 78        self.debug = False
 79        self.verbose = False
 80        
 81        for key, value in kwargs.items():
 82            if key in ['conf_dict', 'CONF_DICT_PATH']:
 83                raise ValueError(f"Cannot override '{key}' via kwargs.")
 84            setattr(self, key, value)
 85        
 86        self.conf_dict = conf_dict or {}
 87        self.CONF_DICT_PATH = CONF_DICT_PATH
 88        
 89        # verification 
 90        if _verify:
 91            self._verify_completeness()
 92
 93    @classmethod
 94    def load_json(cls, CONF_DICT_PATH: str,debug: bool = False):
 95        """
 96        Load a configuration from a JSON file and construct a TramDagConfig.
 97
 98        Parameters
 99        ----------
100        CONF_DICT_PATH : str
101            Path to the configuration JSON file.
102        debug : bool, optional
103            If True, initialize the instance with `debug=True`. Default is False.
104
105        Returns
106        -------
107        TramDagConfig
108            Newly created configuration instance with `conf_dict` loaded from
109            `CONF_DICT_PATH` and `_verify_completeness()` executed.
110
111        Raises
112        ------
113        FileNotFoundError
114            If the configuration file cannot be found (propagated by
115            `load_configuration_dict`).
116        """
117
118        conf = load_configuration_dict(CONF_DICT_PATH)
119        return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=True)
120
121    def update(self):
122        """
123        Reload the latest configuration from disk into this instance.
124
125        Parameters
126        ----------
127        None
128
129        Returns
130        -------
131        None
132
133        Raises
134        ------
135        ValueError
136            If `CONF_DICT_PATH` is not set on the instance.
137
138        Notes
139        -----
140        The current in-memory `conf_dict` is overwritten by the contents
141        loaded from `CONF_DICT_PATH`.
142        """
143
144        if not hasattr(self, "CONF_DICT_PATH") or self.CONF_DICT_PATH is None:
145            raise ValueError("CONF_DICT_PATH not set — cannot update configuration.")
146        
147        
148        self.conf_dict = load_configuration_dict(self.CONF_DICT_PATH)
149
150
151    def save(self, CONF_DICT_PATH: str = None):
152        """
153        Persist the current configuration dictionary to disk.
154
155        Parameters
156        ----------
157        CONF_DICT_PATH : str or None, optional
158            Target path for the configuration file. If None, uses
159            `self.CONF_DICT_PATH`. Default is None.
160
161        Returns
162        -------
163        None
164
165        Raises
166        ------
167        ValueError
168            If neither the function argument nor `self.CONF_DICT_PATH`
169            provides a valid path.
170
171        Notes
172        -----
173        The resulting file is written via `write_configuration_dict`.
174        """
175        path = CONF_DICT_PATH or self.CONF_DICT_PATH
176        if path is None:
177            raise ValueError("No CONF_DICT_PATH provided to save config.")
178        write_configuration_dict(self.conf_dict, path)
179
180    def _verify_completeness(self):
181        """
182        Check that the configuration is structurally complete and consistent.
183
184        The following checks are performed:
185
186        1. Top-level mandatory keys:
187        - "experiment_name"
188        - "PATHS"
189        - "nodes"
190        - "data_type"
191        - "adj_matrix"
192        - "model_names"
193
194        2. Per-node mandatory keys:
195        - "data_type"
196        - "node_type"
197        - "parents"
198        - "parents_datatype"
199        - "transformation_terms_in_h()"
200        - "transformation_term_nn_models_in_h()"
201
202        3. Ordinal / categorical levels:
203        - All ordinal variables must have a corresponding "levels" entry
204            under `conf_dict["nodes"][var]`.
205
206        4. Experiment name:
207        - Must be non-empty.
208
209        5. Adjacency matrix:
210        - Must be valid under `validate_adj_matrix`.
211
212        Parameters
213        ----------
214        None
215
216        Returns
217        -------
218        None
219
220        Notes
221        -----
222        - Missing or invalid components are reported via printed warnings.
223        - Detailed debug messages are printed when `self.debug=True`.
224        """
225        mandatory_keys = ["experiment_name","PATHS", "nodes", "data_type", "adj_matrix","nodes","model_names"]
226        optional_keys = ["date_of_creation", "seed"]
227
228        # ---- 1. Check mandatory keys exist
229        missing = [k for k in mandatory_keys if k not in self.conf_dict]
230        if missing:
231            print(f"[WARNING] Missing mandatory keys in configuration: {missing}"
232                "\n Please add them to the configuration dict and reload.")
233            
234        # --- 2. Check  if mandatory keys in nodesdict are present
235        mandatory_keys_nodes = ['data_type', 'node_type','parents','parents_datatype','transformation_terms_in_h()','transformation_term_nn_models_in_h()']
236        optional_keys_nodes = ["levels"]
237        for node, node_dict in self.conf_dict.get("nodes", {}).items():
238            # check missing mandatory keys
239            missing_node_keys = [k for k in mandatory_keys_nodes if k not in node_dict]
240            if missing_node_keys:
241                print(f"[WARNING] Node '{node}' is missing mandatory keys: {missing_node_keys}")
242                
243
244        
245        if self._verify_levels_dict():
246            if self.debug:
247                print("[DEBUG] levels are present for all ordinal variables in configuration dict.")
248            pass
249        else:
250            print("[WARNING]  levels are missing for some ordinal variables in configuration dict. THIS will FAIL in model training later!\n"
251                " Please provide levels manually to config and reload or compute levels from data using the method compute_levels().\n"
252                " e.g. cfg.compute_levels(train_df) # computes levels from training data and writes to cfg")
253
254        if self._verify_experiment_name():
255            if self.debug:
256                print("[DEBUG] experiment_name is valid in configuration dict.")
257            pass
258        
259        if self._verify_adj_matrix():
260            if self.debug:
261                print("[DEBUG] adj_matrix is valid in configuration dict.")
262            pass
263
264    def _verify_levels_dict(self):
265        """
266        Verify that all ordinal variables have levels specified in the config.
267
268        Parameters
269        ----------
270        None
271
272        Returns
273        -------
274        bool
275            True if all variables declared as ordinal in ``conf_dict["data_type"]``
276            have a "levels" entry in ``conf_dict["nodes"][var]``.
277            False otherwise.
278
279        Notes
280        -----
281        This method does not modify the configuration; it only checks presence
282        of level information.
283        """
284        data_type = self.conf_dict.get('data_type', {})
285        nodes = self.conf_dict.get('nodes', {})
286        for var, dtype in data_type.items():
287            if 'ordinal' in dtype:
288                if var not in nodes or 'levels' not in nodes[var]:
289                    return False
290        return True
291
292    def _verify_experiment_name(self):
293        """
294        Check whether the experiment name in the configuration is valid.
295
296        Parameters
297        ----------
298        None
299
300        Returns
301        -------
302        bool
303            True if ``conf_dict["experiment_name"]`` exists and is non-empty
304            after stripping whitespace. False otherwise.
305        """
306        experiment_name = self.conf_dict.get("experiment_name")
307        if experiment_name is None or str(experiment_name).strip() == "":
308            return False
309        return True
310        
311    def _verify_adj_matrix(self):
312        """
313        Validate the adjacency matrix stored in the configuration.
314
315        Parameters
316        ----------
317        None
318
319        Returns
320        -------
321        bool
322            True if the adjacency matrix passes `validate_adj_matrix`, False otherwise.
323
324        Notes
325        -----
326        If the adjacency matrix is stored as a list, it is converted to a
327        NumPy array before validation.
328        """
329        adj_matrix = self.conf_dict['adj_matrix']
330        if isinstance(adj_matrix, list):
331            adj_matrix = np.array(self.conf_dict['adj_matrix'])
332        if validate_adj_matrix(adj_matrix):
333            return True
334        else:
335            return False
336
337    def compute_levels(self, df: pd.DataFrame, write: bool = True):
338        """
339        Infer and update ordinal/categorical levels from data.
340
341        For each variable in the configuration's `data_type` section, this
342        method uses the provided DataFrame to construct a levels dictionary
343        and injects the corresponding "levels" entry into `conf_dict["nodes"]`.
344
345        Parameters
346        ----------
347        df : pandas.DataFrame
348            DataFrame used to infer levels for configured variables.
349        write : bool, optional
350            If True and `CONF_DICT_PATH` is set, the updated configuration is
351            written back to disk. Default is True.
352
353        Returns
354        -------
355        None
356
357        Raises
358        ------
359        Exception
360            If saving the configuration fails when `write=True`.
361
362        Notes
363        -----
364        - Variables present in `levels_dict` but not in `conf_dict["nodes"]`
365        trigger a warning and are skipped.
366        - If `self.verbose` or `self.debug` is True, a success message is printed
367        when the configuration is saved.
368        """
369        self.update()
370        levels_dict = create_levels_dict(df, self.conf_dict['data_type'])
371        
372        # update nodes dict with levels
373        for var, levels in levels_dict.items():
374            if var in self.conf_dict['nodes']:
375                self.conf_dict['nodes'][var]['levels'] = levels
376            else:
377                print(f"[WARNING] Variable '{var}' not found in nodes dict. Cannot add levels.")
378        
379        if write and self.CONF_DICT_PATH is not None:
380            try:
381                self.save(self.CONF_DICT_PATH)
382                if self.verbose or self.debug:
383                    print(f'[INFO] Configuration with updated levels saved to {self.CONF_DICT_PATH}')
384            except Exception as e:
385                print(f'[ERROR] Failed to save configuration: {e}')
386
387    def plot_dag(self, seed: int = 42, causal_order: bool = False):
388        """
389        Visualize the DAG defined by the configuration.
390
391        Nodes are categorized and colored as:
392        - Source nodes (no incoming edges): green.
393        - Sink nodes (no outgoing edges): red.
394        - Intermediate nodes: light blue.
395
396        Parameters
397        ----------
398        seed : int, optional
399            Random seed for layout stability in the spring layout fallback.
400            Default is 42.
401        causal_order : bool, optional
402            If True, attempt to use Graphviz 'dot' layout via
403            `networkx.nx_agraph.graphviz_layout` to preserve causal ordering.
404            If False or if Graphviz is unavailable, use `spring_layout`.
405            Default is False.
406
407        Returns
408        -------
409        None
410
411        Raises
412        ------
413        ValueError
414            If `adj_matrix` or `data_type` is missing or inconsistent with each other,
415            or if the adjacency matrix fails validation.
416
417        Notes
418        -----
419        Edge labels are colored by prefix:
420        - "ci": blue
421        - "ls": red
422        - "cs": green
423        - other: black
424        """
425        adj_matrix = self.conf_dict.get("adj_matrix")
426        data_type  = self.conf_dict.get("data_type")
427
428        if adj_matrix is None or data_type is None:
429            raise ValueError("Configuration must include 'adj_matrix' and 'data_type'.")
430
431        if isinstance(adj_matrix, list):
432            adj_matrix = np.array(adj_matrix)
433
434        if not validate_adj_matrix(adj_matrix):
435            raise ValueError("Invalid adjacency matrix.")
436        if len(data_type) != adj_matrix.shape[0]:
437            raise ValueError("data_type must match adjacency matrix size.")
438
439        node_labels = list(data_type.keys())
440        G, edge_labels = create_nx_graph(adj_matrix, node_labels)
441
442        sources       = {n for n in G.nodes if G.in_degree(n) == 0}
443        sinks         = {n for n in G.nodes if G.out_degree(n) == 0}
444        intermediates = set(G.nodes) - sources - sinks
445
446        node_colors = [
447            "green" if n in sources
448            else "red" if n in sinks
449            else "lightblue"
450            for n in G.nodes
451        ]
452
453        if causal_order:
454            try:
455                pos = nx.nx_agraph.graphviz_layout(G, prog="dot")
456            except (ImportError, nx.NetworkXException):
457                pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100)
458        else:
459            pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100)
460
461        plt.figure(figsize=(8, 6))
462        nx.draw(
463            G, pos,
464            with_labels=True,
465            node_color=node_colors,
466            edge_color="gray",
467            node_size=2500,
468            arrowsize=20
469        )
470
471        for (u, v), lbl in edge_labels.items():
472            color = (
473                "blue"  if lbl.startswith("ci")
474                else "red"   if lbl.startswith("ls")
475                else "green" if lbl.startswith("cs")
476                else "black"
477            )
478            nx.draw_networkx_edge_labels(
479                G, pos,
480                edge_labels={(u, v): lbl},
481                font_color=color,
482                font_size=12
483            )
484
485        legend_items = [
486            Patch(facecolor="green",     edgecolor="black", label="Source"),
487            Patch(facecolor="red",       edgecolor="black", label="Sink"),
488            Patch(facecolor="lightblue", edgecolor="black", label="Intermediate")
489        ]
490        plt.legend(handles=legend_items, loc="upper right", frameon=True)
491
492        plt.title(f"TRAM DAG")
493        plt.axis("off")
494        plt.tight_layout()
495        plt.show()
496
497  
498    def setup_configuration(self, experiment_name=None, EXPERIMENT_DIR=None, debug=False, _verify=False):
499        """
500        Create or reuse a configuration for an experiment.
501
502        This method behaves differently depending on how it is called:
503
504        1. Class call (e.g. `TramDagConfig.setup_configuration(...)`):
505        - Creates or loads a configuration at the resolved path.
506        - Returns a new `TramDagConfig` instance.
507
508        2. Instance call (e.g. `cfg.setup_configuration(...)`):
509        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` in place.
510        - Optionally verifies completeness.
511        - Returns None.
512
513        Parameters
514        ----------
515        experiment_name : str or None, optional
516            Name of the experiment. If None, defaults to "experiment_1".
517        EXPERIMENT_DIR : str or None, optional
518            Directory for the experiment. If None, defaults to
519            `<cwd>/<experiment_name>`.
520        debug : bool, optional
521            If True, initialize / update with `debug=True`. Default is False.
522        _verify : bool, optional
523            If True, call `_verify_completeness()` after loading. Default is False.
524
525        Returns
526        -------
527        TramDagConfig or None
528            - A new instance when called on the class.
529            - None when called on an existing instance.
530
531        Notes
532        -----
533        - A configuration file named "configuration.json" is created if it does
534        not exist yet.
535        - Underlying creation uses `create_and_write_new_configuration_dict`
536        and `load_configuration_dict`.
537        """
538        is_class_call = isinstance(self, type)
539        cls = self if is_class_call else self.__class__
540
541        if experiment_name is None:
542            experiment_name = "experiment_1"
543        if EXPERIMENT_DIR is None:
544            EXPERIMENT_DIR = os.path.join(os.getcwd(), experiment_name)
545
546        CONF_DICT_PATH = os.path.join(EXPERIMENT_DIR, "configuration.json")
547        DATA_PATH = EXPERIMENT_DIR
548
549        os.makedirs(EXPERIMENT_DIR, exist_ok=True)
550
551        if os.path.exists(CONF_DICT_PATH):
552            print(f"Configuration already exists: {CONF_DICT_PATH}")
553        else:
554            _ = create_and_write_new_configuration_dict(
555                experiment_name, CONF_DICT_PATH, EXPERIMENT_DIR, DATA_PATH, None
556            )
557            print(f"Created new configuration file at {CONF_DICT_PATH}")
558
559        conf = load_configuration_dict(CONF_DICT_PATH)
560
561        if is_class_call:
562            return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=_verify)
563        else:
564            self.conf_dict = conf
565            self.CONF_DICT_PATH = CONF_DICT_PATH
566            if _verify:
567                self._verify_completeness()
568                
569    def set_data_type(self, data_type: dict, CONF_DICT_PATH: str = None) -> None:
570        """
571        Update or write the `data_type` section of a configuration file.
572
573        Supports both class-level and instance-level usage:
574
575        - Class call:
576        - Requires `CONF_DICT_PATH` argument.
577        - Reads the file if it exists, or starts from an empty dict.
578        - Writes updated configuration to `CONF_DICT_PATH`.
579
580        - Instance call:
581        - Uses `self.CONF_DICT_PATH` if available, otherwise defaults to
582            `<cwd>/configuration.json` if no path is provided.
583        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` after writing.
584
585        Parameters
586        ----------
587        data_type : dict
588            Mapping `{variable_name: type_spec}`, where `type_spec` encodes
589            modeling types (e.g. continuous, ordinal, etc.).
590        CONF_DICT_PATH : str or None, optional
591            Path to the configuration file. Must be provided for class calls.
592            For instance calls, defaults as described above.
593
594        Returns
595        -------
596        None
597
598        Raises
599        ------
600        ValueError
601            If `CONF_DICT_PATH` is missing when called on the class, or if
602            validation of data types fails.
603
604        Notes
605        -----
606        - Variable names are validated via `validate_variable_names`.
607        - Data type values are validated via `validate_data_types`.
608        - A textual summary of modeling settings is printed via
609        `print_data_type_modeling_setting`, if possible.
610        """
611        is_class_call = isinstance(self, type)
612        cls = self if is_class_call else self.__class__
613
614        # resolve path
615        if CONF_DICT_PATH is None:
616            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
617                CONF_DICT_PATH = self.CONF_DICT_PATH
618            elif not is_class_call:
619                CONF_DICT_PATH = os.path.join(os.getcwd(), "configuration.json")
620            else:
621                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
622
623        try:
624            # load existing or create empty configuration
625            configuration_dict = (
626                load_configuration_dict(CONF_DICT_PATH)
627                if os.path.exists(CONF_DICT_PATH)
628                else {}
629            )
630
631            validate_variable_names(data_type.keys())
632            if not validate_data_types(data_type):
633                raise ValueError("Invalid data types in the provided dictionary.")
634
635            configuration_dict["data_type"] = data_type
636            write_configuration_dict(configuration_dict, CONF_DICT_PATH)
637
638            # safe printing
639            try:
640                print_data_type_modeling_setting(data_type or {})
641            except Exception as e:
642                print(f"[WARNING] Could not print data type modeling settings: {e}")
643
644            if not is_class_call:
645                self.conf_dict = configuration_dict
646                self.CONF_DICT_PATH = CONF_DICT_PATH
647
648
649        except Exception as e:
650            print(f"Failed to update configuration: {e}")
651        else:
652            print(f"Configuration updated successfully at {CONF_DICT_PATH}.")
653            
654    def set_meta_adj_matrix(self, CONF_DICT_PATH: str = None, seed: int = 5):
655        """
656        Launch the interactive editor to set or modify the adjacency matrix.
657
658        This method:
659
660        1. Resolves the configuration path either from the argument or, for
661        instances, from `self.CONF_DICT_PATH`.
662        2. Invokes `interactive_adj_matrix` to edit the adjacency matrix.
663        3. For instances, reloads the updated configuration into `self.conf_dict`.
664
665        Parameters
666        ----------
667        CONF_DICT_PATH : str or None, optional
668            Path to the configuration file. Must be provided when called
669            on the class. For instance calls, defaults to `self.CONF_DICT_PATH`.
670        seed : int, optional
671            Random seed for any layout or stochastic behavior in the interactive
672            editor. Default is 5.
673
674        Returns
675        -------
676        None
677
678        Raises
679        ------
680        ValueError
681            If `CONF_DICT_PATH` is not provided and cannot be inferred
682            (e.g. in a class call without path).
683
684        Notes
685        -----
686        `self.update()` is called at the start to ensure the in-memory config
687        is in sync with the file before launching the editor.
688        """
689        self.update()
690        is_class_call = isinstance(self, type)
691        # resolve path
692        if CONF_DICT_PATH is None:
693            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
694                CONF_DICT_PATH = self.CONF_DICT_PATH
695            else:
696                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
697
698        # launch interactive editor
699        
700        interactive_adj_matrix(CONF_DICT_PATH, seed=seed)
701
702        # reload config if instance
703        if not is_class_call:
704            self.conf_dict = load_configuration_dict(CONF_DICT_PATH)
705            self.CONF_DICT_PATH = CONF_DICT_PATH
706
707
708    def set_tramdag_nn_models(self, CONF_DICT_PATH: str = None):
709        """
710        Launch the interactive editor to set TRAM-DAG neural network model names.
711
712        Depending on call context:
713
714        - Class call:
715        - Requires `CONF_DICT_PATH` argument.
716        - Returns nothing and does not modify a specific instance.
717
718        - Instance call:
719        - Resolves `CONF_DICT_PATH` from the argument or `self.CONF_DICT_PATH`.
720        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` if the editor
721            returns an updated configuration.
722
723        Parameters
724        ----------
725        CONF_DICT_PATH : str or None, optional
726            Path to the configuration file. Must be provided when called on
727            the class. For instance calls, defaults to `self.CONF_DICT_PATH`.
728
729        Returns
730        -------
731        None
732
733        Raises
734        ------
735        ValueError
736            If `CONF_DICT_PATH` is not provided and cannot be inferred
737            (e.g. in a class call without path).
738
739        Notes
740        -----
741        The interactive editor is invoked via `interactive_nn_names_matrix`.
742        If it returns `None`, the instance configuration is left unchanged.
743        """
744        is_class_call = isinstance(self, type)
745        if CONF_DICT_PATH is None:
746            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
747                CONF_DICT_PATH = self.CONF_DICT_PATH
748            else:
749                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
750
751        updated_conf = interactive_nn_names_matrix(CONF_DICT_PATH)
752        if updated_conf is not None and not is_class_call:
753            self.conf_dict = updated_conf
754            self.CONF_DICT_PATH = CONF_DICT_PATH

Configuration manager for TRAM-DAG experiments.

This class encapsulates:

  • The experiment configuration dictionary (conf_dict).
  • Its backing file path (CONF_DICT_PATH).
  • Utilities to load, validate, modify, and persist configuration.
  • DAG visualization and interactive editing helpers.

Typical usage

  • Load existing configuration from disk via TramDagConfig.load_json.
  • Or create/reuse experiment setup via TramDagConfig().setup_configuration.
  • Update sections such as data_type, adjacency matrix, and neural network model names using the provided methods.
TramDagConfig( conf_dict: dict = None, CONF_DICT_PATH: str = None, _verify: bool = False, **kwargs)
49    def __init__(self, conf_dict: dict = None, CONF_DICT_PATH: str = None,  _verify: bool = False,**kwargs):
50        """
51        Initialize a TramDagConfig instance.
52
53        Parameters
54        ----------
55        conf_dict : dict or None, optional
56            Configuration dictionary. If None, an empty dict is used and can
57            be populated later. Default is None.
58        CONF_DICT_PATH : str or None, optional
59            Path to the configuration file on disk. Default is None.
60        _verify : bool, optional
61            If True, run `_verify_completeness()` after initialization.
62            Default is False.
63        **kwargs
64            Additional attributes to be set on the instance. Keys "conf_dict"
65            and "CONF_DICT_PATH" are forbidden and raise a ValueError.
66
67        Raises
68        ------
69        ValueError
70            If any key in `kwargs` is "conf_dict" or "CONF_DICT_PATH".
71
72        Notes
73        -----
74        By default, `debug` and `verbose` are set to False. They can be
75        overridden via `kwargs`.
76        """
77
78        self.debug = False
79        self.verbose = False
80        
81        for key, value in kwargs.items():
82            if key in ['conf_dict', 'CONF_DICT_PATH']:
83                raise ValueError(f"Cannot override '{key}' via kwargs.")
84            setattr(self, key, value)
85        
86        self.conf_dict = conf_dict or {}
87        self.CONF_DICT_PATH = CONF_DICT_PATH
88        
89        # verification 
90        if _verify:
91            self._verify_completeness()

Initialize a TramDagConfig instance.

Parameters

conf_dict : dict or None, optional Configuration dictionary. If None, an empty dict is used and can be populated later. Default is None. CONF_DICT_PATH : str or None, optional Path to the configuration file on disk. Default is None. _verify : bool, optional If True, run _verify_completeness() after initialization. Default is False. **kwargs Additional attributes to be set on the instance. Keys "conf_dict" and "CONF_DICT_PATH" are forbidden and raise a ValueError.

Raises

ValueError If any key in kwargs is "conf_dict" or "CONF_DICT_PATH".

Notes

By default, debug and verbose are set to False. They can be overridden via kwargs.

debug
verbose
conf_dict
CONF_DICT_PATH
@classmethod
def load_json(cls, CONF_DICT_PATH: str, debug: bool = False):
 93    @classmethod
 94    def load_json(cls, CONF_DICT_PATH: str,debug: bool = False):
 95        """
 96        Load a configuration from a JSON file and construct a TramDagConfig.
 97
 98        Parameters
 99        ----------
100        CONF_DICT_PATH : str
101            Path to the configuration JSON file.
102        debug : bool, optional
103            If True, initialize the instance with `debug=True`. Default is False.
104
105        Returns
106        -------
107        TramDagConfig
108            Newly created configuration instance with `conf_dict` loaded from
109            `CONF_DICT_PATH` and `_verify_completeness()` executed.
110
111        Raises
112        ------
113        FileNotFoundError
114            If the configuration file cannot be found (propagated by
115            `load_configuration_dict`).
116        """
117
118        conf = load_configuration_dict(CONF_DICT_PATH)
119        return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=True)

Load a configuration from a JSON file and construct a TramDagConfig.

Parameters

CONF_DICT_PATH : str Path to the configuration JSON file. debug : bool, optional If True, initialize the instance with debug=True. Default is False.

Returns

TramDagConfig Newly created configuration instance with conf_dict loaded from CONF_DICT_PATH and _verify_completeness() executed.

Raises

FileNotFoundError If the configuration file cannot be found (propagated by load_configuration_dict).

def update(self):
121    def update(self):
122        """
123        Reload the latest configuration from disk into this instance.
124
125        Parameters
126        ----------
127        None
128
129        Returns
130        -------
131        None
132
133        Raises
134        ------
135        ValueError
136            If `CONF_DICT_PATH` is not set on the instance.
137
138        Notes
139        -----
140        The current in-memory `conf_dict` is overwritten by the contents
141        loaded from `CONF_DICT_PATH`.
142        """
143
144        if not hasattr(self, "CONF_DICT_PATH") or self.CONF_DICT_PATH is None:
145            raise ValueError("CONF_DICT_PATH not set — cannot update configuration.")
146        
147        
148        self.conf_dict = load_configuration_dict(self.CONF_DICT_PATH)

Reload the latest configuration from disk into this instance.

Parameters

None

Returns

None

Raises

ValueError If CONF_DICT_PATH is not set on the instance.

Notes

The current in-memory conf_dict is overwritten by the contents loaded from CONF_DICT_PATH.

def save(self, CONF_DICT_PATH: str = None):
151    def save(self, CONF_DICT_PATH: str = None):
152        """
153        Persist the current configuration dictionary to disk.
154
155        Parameters
156        ----------
157        CONF_DICT_PATH : str or None, optional
158            Target path for the configuration file. If None, uses
159            `self.CONF_DICT_PATH`. Default is None.
160
161        Returns
162        -------
163        None
164
165        Raises
166        ------
167        ValueError
168            If neither the function argument nor `self.CONF_DICT_PATH`
169            provides a valid path.
170
171        Notes
172        -----
173        The resulting file is written via `write_configuration_dict`.
174        """
175        path = CONF_DICT_PATH or self.CONF_DICT_PATH
176        if path is None:
177            raise ValueError("No CONF_DICT_PATH provided to save config.")
178        write_configuration_dict(self.conf_dict, path)

Persist the current configuration dictionary to disk.

Parameters

CONF_DICT_PATH : str or None, optional Target path for the configuration file. If None, uses self.CONF_DICT_PATH. Default is None.

Returns

None

Raises

ValueError If neither the function argument nor self.CONF_DICT_PATH provides a valid path.

Notes

The resulting file is written via write_configuration_dict.

def compute_levels(self, df: pandas.core.frame.DataFrame, write: bool = True):
337    def compute_levels(self, df: pd.DataFrame, write: bool = True):
338        """
339        Infer and update ordinal/categorical levels from data.
340
341        For each variable in the configuration's `data_type` section, this
342        method uses the provided DataFrame to construct a levels dictionary
343        and injects the corresponding "levels" entry into `conf_dict["nodes"]`.
344
345        Parameters
346        ----------
347        df : pandas.DataFrame
348            DataFrame used to infer levels for configured variables.
349        write : bool, optional
350            If True and `CONF_DICT_PATH` is set, the updated configuration is
351            written back to disk. Default is True.
352
353        Returns
354        -------
355        None
356
357        Raises
358        ------
359        Exception
360            If saving the configuration fails when `write=True`.
361
362        Notes
363        -----
364        - Variables present in `levels_dict` but not in `conf_dict["nodes"]`
365        trigger a warning and are skipped.
366        - If `self.verbose` or `self.debug` is True, a success message is printed
367        when the configuration is saved.
368        """
369        self.update()
370        levels_dict = create_levels_dict(df, self.conf_dict['data_type'])
371        
372        # update nodes dict with levels
373        for var, levels in levels_dict.items():
374            if var in self.conf_dict['nodes']:
375                self.conf_dict['nodes'][var]['levels'] = levels
376            else:
377                print(f"[WARNING] Variable '{var}' not found in nodes dict. Cannot add levels.")
378        
379        if write and self.CONF_DICT_PATH is not None:
380            try:
381                self.save(self.CONF_DICT_PATH)
382                if self.verbose or self.debug:
383                    print(f'[INFO] Configuration with updated levels saved to {self.CONF_DICT_PATH}')
384            except Exception as e:
385                print(f'[ERROR] Failed to save configuration: {e}')

Infer and update ordinal/categorical levels from data.

For each variable in the configuration's data_type section, this method uses the provided DataFrame to construct a levels dictionary and injects the corresponding "levels" entry into conf_dict["nodes"].

Parameters

df : pandas.DataFrame DataFrame used to infer levels for configured variables. write : bool, optional If True and CONF_DICT_PATH is set, the updated configuration is written back to disk. Default is True.

Returns

None

Raises

Exception If saving the configuration fails when write=True.

Notes

  • Variables present in levels_dict but not in conf_dict["nodes"] trigger a warning and are skipped.
  • If self.verbose or self.debug is True, a success message is printed when the configuration is saved.
def plot_dag(self, seed: int = 42, causal_order: bool = False):
387    def plot_dag(self, seed: int = 42, causal_order: bool = False):
388        """
389        Visualize the DAG defined by the configuration.
390
391        Nodes are categorized and colored as:
392        - Source nodes (no incoming edges): green.
393        - Sink nodes (no outgoing edges): red.
394        - Intermediate nodes: light blue.
395
396        Parameters
397        ----------
398        seed : int, optional
399            Random seed for layout stability in the spring layout fallback.
400            Default is 42.
401        causal_order : bool, optional
402            If True, attempt to use Graphviz 'dot' layout via
403            `networkx.nx_agraph.graphviz_layout` to preserve causal ordering.
404            If False or if Graphviz is unavailable, use `spring_layout`.
405            Default is False.
406
407        Returns
408        -------
409        None
410
411        Raises
412        ------
413        ValueError
414            If `adj_matrix` or `data_type` is missing or inconsistent with each other,
415            or if the adjacency matrix fails validation.
416
417        Notes
418        -----
419        Edge labels are colored by prefix:
420        - "ci": blue
421        - "ls": red
422        - "cs": green
423        - other: black
424        """
425        adj_matrix = self.conf_dict.get("adj_matrix")
426        data_type  = self.conf_dict.get("data_type")
427
428        if adj_matrix is None or data_type is None:
429            raise ValueError("Configuration must include 'adj_matrix' and 'data_type'.")
430
431        if isinstance(adj_matrix, list):
432            adj_matrix = np.array(adj_matrix)
433
434        if not validate_adj_matrix(adj_matrix):
435            raise ValueError("Invalid adjacency matrix.")
436        if len(data_type) != adj_matrix.shape[0]:
437            raise ValueError("data_type must match adjacency matrix size.")
438
439        node_labels = list(data_type.keys())
440        G, edge_labels = create_nx_graph(adj_matrix, node_labels)
441
442        sources       = {n for n in G.nodes if G.in_degree(n) == 0}
443        sinks         = {n for n in G.nodes if G.out_degree(n) == 0}
444        intermediates = set(G.nodes) - sources - sinks
445
446        node_colors = [
447            "green" if n in sources
448            else "red" if n in sinks
449            else "lightblue"
450            for n in G.nodes
451        ]
452
453        if causal_order:
454            try:
455                pos = nx.nx_agraph.graphviz_layout(G, prog="dot")
456            except (ImportError, nx.NetworkXException):
457                pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100)
458        else:
459            pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100)
460
461        plt.figure(figsize=(8, 6))
462        nx.draw(
463            G, pos,
464            with_labels=True,
465            node_color=node_colors,
466            edge_color="gray",
467            node_size=2500,
468            arrowsize=20
469        )
470
471        for (u, v), lbl in edge_labels.items():
472            color = (
473                "blue"  if lbl.startswith("ci")
474                else "red"   if lbl.startswith("ls")
475                else "green" if lbl.startswith("cs")
476                else "black"
477            )
478            nx.draw_networkx_edge_labels(
479                G, pos,
480                edge_labels={(u, v): lbl},
481                font_color=color,
482                font_size=12
483            )
484
485        legend_items = [
486            Patch(facecolor="green",     edgecolor="black", label="Source"),
487            Patch(facecolor="red",       edgecolor="black", label="Sink"),
488            Patch(facecolor="lightblue", edgecolor="black", label="Intermediate")
489        ]
490        plt.legend(handles=legend_items, loc="upper right", frameon=True)
491
492        plt.title(f"TRAM DAG")
493        plt.axis("off")
494        plt.tight_layout()
495        plt.show()

Visualize the DAG defined by the configuration.

Nodes are categorized and colored as:

  • Source nodes (no incoming edges): green.
  • Sink nodes (no outgoing edges): red.
  • Intermediate nodes: light blue.

Parameters

seed : int, optional Random seed for layout stability in the spring layout fallback. Default is 42. causal_order : bool, optional If True, attempt to use Graphviz 'dot' layout via networkx.nx_agraph.graphviz_layout to preserve causal ordering. If False or if Graphviz is unavailable, use spring_layout. Default is False.

Returns

None

Raises

ValueError If adj_matrix or data_type is missing or inconsistent with each other, or if the adjacency matrix fails validation.

Notes

Edge labels are colored by prefix:

  • "ci": blue
  • "ls": red
  • "cs": green
  • other: black
def setup_configuration( self, experiment_name=None, EXPERIMENT_DIR=None, debug=False, _verify=False):
498    def setup_configuration(self, experiment_name=None, EXPERIMENT_DIR=None, debug=False, _verify=False):
499        """
500        Create or reuse a configuration for an experiment.
501
502        This method behaves differently depending on how it is called:
503
504        1. Class call (e.g. `TramDagConfig.setup_configuration(...)`):
505        - Creates or loads a configuration at the resolved path.
506        - Returns a new `TramDagConfig` instance.
507
508        2. Instance call (e.g. `cfg.setup_configuration(...)`):
509        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` in place.
510        - Optionally verifies completeness.
511        - Returns None.
512
513        Parameters
514        ----------
515        experiment_name : str or None, optional
516            Name of the experiment. If None, defaults to "experiment_1".
517        EXPERIMENT_DIR : str or None, optional
518            Directory for the experiment. If None, defaults to
519            `<cwd>/<experiment_name>`.
520        debug : bool, optional
521            If True, initialize / update with `debug=True`. Default is False.
522        _verify : bool, optional
523            If True, call `_verify_completeness()` after loading. Default is False.
524
525        Returns
526        -------
527        TramDagConfig or None
528            - A new instance when called on the class.
529            - None when called on an existing instance.
530
531        Notes
532        -----
533        - A configuration file named "configuration.json" is created if it does
534        not exist yet.
535        - Underlying creation uses `create_and_write_new_configuration_dict`
536        and `load_configuration_dict`.
537        """
538        is_class_call = isinstance(self, type)
539        cls = self if is_class_call else self.__class__
540
541        if experiment_name is None:
542            experiment_name = "experiment_1"
543        if EXPERIMENT_DIR is None:
544            EXPERIMENT_DIR = os.path.join(os.getcwd(), experiment_name)
545
546        CONF_DICT_PATH = os.path.join(EXPERIMENT_DIR, "configuration.json")
547        DATA_PATH = EXPERIMENT_DIR
548
549        os.makedirs(EXPERIMENT_DIR, exist_ok=True)
550
551        if os.path.exists(CONF_DICT_PATH):
552            print(f"Configuration already exists: {CONF_DICT_PATH}")
553        else:
554            _ = create_and_write_new_configuration_dict(
555                experiment_name, CONF_DICT_PATH, EXPERIMENT_DIR, DATA_PATH, None
556            )
557            print(f"Created new configuration file at {CONF_DICT_PATH}")
558
559        conf = load_configuration_dict(CONF_DICT_PATH)
560
561        if is_class_call:
562            return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=_verify)
563        else:
564            self.conf_dict = conf
565            self.CONF_DICT_PATH = CONF_DICT_PATH
566            if _verify:
567                self._verify_completeness()

Create or reuse a configuration for an experiment.

This method behaves differently depending on how it is called:

  1. Class call (e.g. TramDagConfig.setup_configuration(...)):
  • Creates or loads a configuration at the resolved path.
  • Returns a new TramDagConfig instance.
  1. Instance call (e.g. cfg.setup_configuration(...)):
  • Updates self.conf_dict and self.CONF_DICT_PATH in place.
  • Optionally verifies completeness.
  • Returns None.

Parameters

experiment_name : str or None, optional Name of the experiment. If None, defaults to "experiment_1". EXPERIMENT_DIR : str or None, optional Directory for the experiment. If None, defaults to <cwd>/<experiment_name>. debug : bool, optional If True, initialize / update with debug=True. Default is False. _verify : bool, optional If True, call _verify_completeness() after loading. Default is False.

Returns

TramDagConfig or None - A new instance when called on the class. - None when called on an existing instance.

Notes

  • A configuration file named "configuration.json" is created if it does not exist yet.
  • Underlying creation uses create_and_write_new_configuration_dict and load_configuration_dict.
def set_data_type(self, data_type: dict, CONF_DICT_PATH: str = None) -> None:
569    def set_data_type(self, data_type: dict, CONF_DICT_PATH: str = None) -> None:
570        """
571        Update or write the `data_type` section of a configuration file.
572
573        Supports both class-level and instance-level usage:
574
575        - Class call:
576        - Requires `CONF_DICT_PATH` argument.
577        - Reads the file if it exists, or starts from an empty dict.
578        - Writes updated configuration to `CONF_DICT_PATH`.
579
580        - Instance call:
581        - Uses `self.CONF_DICT_PATH` if available, otherwise defaults to
582            `<cwd>/configuration.json` if no path is provided.
583        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` after writing.
584
585        Parameters
586        ----------
587        data_type : dict
588            Mapping `{variable_name: type_spec}`, where `type_spec` encodes
589            modeling types (e.g. continuous, ordinal, etc.).
590        CONF_DICT_PATH : str or None, optional
591            Path to the configuration file. Must be provided for class calls.
592            For instance calls, defaults as described above.
593
594        Returns
595        -------
596        None
597
598        Raises
599        ------
600        ValueError
601            If `CONF_DICT_PATH` is missing when called on the class, or if
602            validation of data types fails.
603
604        Notes
605        -----
606        - Variable names are validated via `validate_variable_names`.
607        - Data type values are validated via `validate_data_types`.
608        - A textual summary of modeling settings is printed via
609        `print_data_type_modeling_setting`, if possible.
610        """
611        is_class_call = isinstance(self, type)
612        cls = self if is_class_call else self.__class__
613
614        # resolve path
615        if CONF_DICT_PATH is None:
616            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
617                CONF_DICT_PATH = self.CONF_DICT_PATH
618            elif not is_class_call:
619                CONF_DICT_PATH = os.path.join(os.getcwd(), "configuration.json")
620            else:
621                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
622
623        try:
624            # load existing or create empty configuration
625            configuration_dict = (
626                load_configuration_dict(CONF_DICT_PATH)
627                if os.path.exists(CONF_DICT_PATH)
628                else {}
629            )
630
631            validate_variable_names(data_type.keys())
632            if not validate_data_types(data_type):
633                raise ValueError("Invalid data types in the provided dictionary.")
634
635            configuration_dict["data_type"] = data_type
636            write_configuration_dict(configuration_dict, CONF_DICT_PATH)
637
638            # safe printing
639            try:
640                print_data_type_modeling_setting(data_type or {})
641            except Exception as e:
642                print(f"[WARNING] Could not print data type modeling settings: {e}")
643
644            if not is_class_call:
645                self.conf_dict = configuration_dict
646                self.CONF_DICT_PATH = CONF_DICT_PATH
647
648
649        except Exception as e:
650            print(f"Failed to update configuration: {e}")
651        else:
652            print(f"Configuration updated successfully at {CONF_DICT_PATH}.")

Update or write the data_type section of a configuration file.

Supports both class-level and instance-level usage:

  • Class call:
  • Requires CONF_DICT_PATH argument.
  • Reads the file if it exists, or starts from an empty dict.
  • Writes updated configuration to CONF_DICT_PATH.

  • Instance call:

  • Uses self.CONF_DICT_PATH if available, otherwise defaults to <cwd>/configuration.json if no path is provided.
  • Updates self.conf_dict and self.CONF_DICT_PATH after writing.

Parameters

data_type : dict Mapping {variable_name: type_spec}, where type_spec encodes modeling types (e.g. continuous, ordinal, etc.). CONF_DICT_PATH : str or None, optional Path to the configuration file. Must be provided for class calls. For instance calls, defaults as described above.

Returns

None

Raises

ValueError If CONF_DICT_PATH is missing when called on the class, or if validation of data types fails.

Notes

  • Variable names are validated via validate_variable_names.
  • Data type values are validated via validate_data_types.
  • A textual summary of modeling settings is printed via print_data_type_modeling_setting, if possible.
def set_meta_adj_matrix(self, CONF_DICT_PATH: str = None, seed: int = 5):
654    def set_meta_adj_matrix(self, CONF_DICT_PATH: str = None, seed: int = 5):
655        """
656        Launch the interactive editor to set or modify the adjacency matrix.
657
658        This method:
659
660        1. Resolves the configuration path either from the argument or, for
661        instances, from `self.CONF_DICT_PATH`.
662        2. Invokes `interactive_adj_matrix` to edit the adjacency matrix.
663        3. For instances, reloads the updated configuration into `self.conf_dict`.
664
665        Parameters
666        ----------
667        CONF_DICT_PATH : str or None, optional
668            Path to the configuration file. Must be provided when called
669            on the class. For instance calls, defaults to `self.CONF_DICT_PATH`.
670        seed : int, optional
671            Random seed for any layout or stochastic behavior in the interactive
672            editor. Default is 5.
673
674        Returns
675        -------
676        None
677
678        Raises
679        ------
680        ValueError
681            If `CONF_DICT_PATH` is not provided and cannot be inferred
682            (e.g. in a class call without path).
683
684        Notes
685        -----
686        `self.update()` is called at the start to ensure the in-memory config
687        is in sync with the file before launching the editor.
688        """
689        self.update()
690        is_class_call = isinstance(self, type)
691        # resolve path
692        if CONF_DICT_PATH is None:
693            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
694                CONF_DICT_PATH = self.CONF_DICT_PATH
695            else:
696                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
697
698        # launch interactive editor
699        
700        interactive_adj_matrix(CONF_DICT_PATH, seed=seed)
701
702        # reload config if instance
703        if not is_class_call:
704            self.conf_dict = load_configuration_dict(CONF_DICT_PATH)
705            self.CONF_DICT_PATH = CONF_DICT_PATH

Launch the interactive editor to set or modify the adjacency matrix.

This method:

  1. Resolves the configuration path either from the argument or, for instances, from self.CONF_DICT_PATH.
  2. Invokes interactive_adj_matrix to edit the adjacency matrix.
  3. For instances, reloads the updated configuration into self.conf_dict.

Parameters

CONF_DICT_PATH : str or None, optional Path to the configuration file. Must be provided when called on the class. For instance calls, defaults to self.CONF_DICT_PATH. seed : int, optional Random seed for any layout or stochastic behavior in the interactive editor. Default is 5.

Returns

None

Raises

ValueError If CONF_DICT_PATH is not provided and cannot be inferred (e.g. in a class call without path).

Notes

self.update() is called at the start to ensure the in-memory config is in sync with the file before launching the editor.

def set_tramdag_nn_models(self, CONF_DICT_PATH: str = None):
708    def set_tramdag_nn_models(self, CONF_DICT_PATH: str = None):
709        """
710        Launch the interactive editor to set TRAM-DAG neural network model names.
711
712        Depending on call context:
713
714        - Class call:
715        - Requires `CONF_DICT_PATH` argument.
716        - Returns nothing and does not modify a specific instance.
717
718        - Instance call:
719        - Resolves `CONF_DICT_PATH` from the argument or `self.CONF_DICT_PATH`.
720        - Updates `self.conf_dict` and `self.CONF_DICT_PATH` if the editor
721            returns an updated configuration.
722
723        Parameters
724        ----------
725        CONF_DICT_PATH : str or None, optional
726            Path to the configuration file. Must be provided when called on
727            the class. For instance calls, defaults to `self.CONF_DICT_PATH`.
728
729        Returns
730        -------
731        None
732
733        Raises
734        ------
735        ValueError
736            If `CONF_DICT_PATH` is not provided and cannot be inferred
737            (e.g. in a class call without path).
738
739        Notes
740        -----
741        The interactive editor is invoked via `interactive_nn_names_matrix`.
742        If it returns `None`, the instance configuration is left unchanged.
743        """
744        is_class_call = isinstance(self, type)
745        if CONF_DICT_PATH is None:
746            if not is_class_call and getattr(self, "CONF_DICT_PATH", None):
747                CONF_DICT_PATH = self.CONF_DICT_PATH
748            else:
749                raise ValueError("CONF_DICT_PATH must be provided when called on the class.")
750
751        updated_conf = interactive_nn_names_matrix(CONF_DICT_PATH)
752        if updated_conf is not None and not is_class_call:
753            self.conf_dict = updated_conf
754            self.CONF_DICT_PATH = CONF_DICT_PATH

Launch the interactive editor to set TRAM-DAG neural network model names.

Depending on call context:

  • Class call:
  • Requires CONF_DICT_PATH argument.
  • Returns nothing and does not modify a specific instance.

  • Instance call:

  • Resolves CONF_DICT_PATH from the argument or self.CONF_DICT_PATH.
  • Updates self.conf_dict and self.CONF_DICT_PATH if the editor returns an updated configuration.

Parameters

CONF_DICT_PATH : str or None, optional Path to the configuration file. Must be provided when called on the class. For instance calls, defaults to self.CONF_DICT_PATH.

Returns

None

Raises

ValueError If CONF_DICT_PATH is not provided and cannot be inferred (e.g. in a class call without path).

Notes

The interactive editor is invoked via interactive_nn_names_matrix. If it returns None, the instance configuration is left unchanged.