tramdag.TramDagConfig
Copyright 2025 Zurich University of Applied Sciences (ZHAW) Pascal Buehler, Beate Sick, Oliver Duerr
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
1""" 2Copyright 2025 Zurich University of Applied Sciences (ZHAW) 3Pascal Buehler, Beate Sick, Oliver Duerr 4 5Licensed under the Apache License, Version 2.0 (the "License"); 6you may not use this file except in compliance with the License. 7You may obtain a copy of the License at 8 9 http://www.apache.org/licenses/LICENSE-2.0 10 11Unless required by applicable law or agreed to in writing, software 12distributed under the License is distributed on an "AS IS" BASIS, 13WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14See the License for the specific language governing permissions and 15limitations under the License. 16""" 17 18import matplotlib.pyplot as plt 19import numpy as np 20import networkx as nx 21from matplotlib.patches import Patch 22import pandas as pd 23import os 24 25from .utils.configuration import * 26 27 28# renamme set _meta_adj amtrix 29class TramDagConfig: 30 """ 31 Configuration manager for TRAM-DAG experiments. 32 33 This class encapsulates: 34 35 - The experiment configuration dictionary (`conf_dict`). 36 - Its backing file path (`CONF_DICT_PATH`). 37 - Utilities to load, validate, modify, and persist configuration. 38 - DAG visualization and interactive editing helpers. 39 40 Typical usage 41 ------------- 42 - Load existing configuration from disk via `TramDagConfig.load_json`. 43 - Or create/reuse experiment setup via `TramDagConfig().setup_configuration`. 44 - Update sections such as `data_type`, adjacency matrix, and neural network 45 model names using the provided methods. 46 """ 47 48 def __init__(self, conf_dict: dict = None, CONF_DICT_PATH: str = None, _verify: bool = False,**kwargs): 49 """ 50 Initialize a TramDagConfig instance. 51 52 Parameters 53 ---------- 54 conf_dict : dict or None, optional 55 Configuration dictionary. If None, an empty dict is used and can 56 be populated later. Default is None. 57 CONF_DICT_PATH : str or None, optional 58 Path to the configuration file on disk. Default is None. 59 _verify : bool, optional 60 If True, run `_verify_completeness()` after initialization. 61 Default is False. 62 **kwargs 63 Additional attributes to be set on the instance. Keys "conf_dict" 64 and "CONF_DICT_PATH" are forbidden and raise a ValueError. 65 66 Raises 67 ------ 68 ValueError 69 If any key in `kwargs` is "conf_dict" or "CONF_DICT_PATH". 70 71 Notes 72 ----- 73 By default, `debug` and `verbose` are set to False. They can be 74 overridden via `kwargs`. 75 """ 76 77 self.debug = False 78 self.verbose = False 79 80 for key, value in kwargs.items(): 81 if key in ['conf_dict', 'CONF_DICT_PATH']: 82 raise ValueError(f"Cannot override '{key}' via kwargs.") 83 setattr(self, key, value) 84 85 self.conf_dict = conf_dict or {} 86 self.CONF_DICT_PATH = CONF_DICT_PATH 87 88 # verification 89 if _verify: 90 self._verify_completeness() 91 92 @classmethod 93 def load_json(cls, CONF_DICT_PATH: str,debug: bool = False): 94 """ 95 Load a configuration from a JSON file and construct a TramDagConfig. 96 97 Parameters 98 ---------- 99 CONF_DICT_PATH : str 100 Path to the configuration JSON file. 101 debug : bool, optional 102 If True, initialize the instance with `debug=True`. Default is False. 103 104 Returns 105 ------- 106 TramDagConfig 107 Newly created configuration instance with `conf_dict` loaded from 108 `CONF_DICT_PATH` and `_verify_completeness()` executed. 109 110 Raises 111 ------ 112 FileNotFoundError 113 If the configuration file cannot be found (propagated by 114 `load_configuration_dict`). 115 """ 116 117 conf = load_configuration_dict(CONF_DICT_PATH) 118 return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=True) 119 120 def update(self): 121 """ 122 Reload the latest configuration from disk into this instance. 123 124 Parameters 125 ---------- 126 None 127 128 Returns 129 ------- 130 None 131 132 Raises 133 ------ 134 ValueError 135 If `CONF_DICT_PATH` is not set on the instance. 136 137 Notes 138 ----- 139 The current in-memory `conf_dict` is overwritten by the contents 140 loaded from `CONF_DICT_PATH`. 141 """ 142 143 if not hasattr(self, "CONF_DICT_PATH") or self.CONF_DICT_PATH is None: 144 raise ValueError("CONF_DICT_PATH not set — cannot update configuration.") 145 146 147 self.conf_dict = load_configuration_dict(self.CONF_DICT_PATH) 148 149 150 def save(self, CONF_DICT_PATH: str = None): 151 """ 152 Persist the current configuration dictionary to disk. 153 154 Parameters 155 ---------- 156 CONF_DICT_PATH : str or None, optional 157 Target path for the configuration file. If None, uses 158 `self.CONF_DICT_PATH`. Default is None. 159 160 Returns 161 ------- 162 None 163 164 Raises 165 ------ 166 ValueError 167 If neither the function argument nor `self.CONF_DICT_PATH` 168 provides a valid path. 169 170 Notes 171 ----- 172 The resulting file is written via `write_configuration_dict`. 173 """ 174 path = CONF_DICT_PATH or self.CONF_DICT_PATH 175 if path is None: 176 raise ValueError("No CONF_DICT_PATH provided to save config.") 177 write_configuration_dict(self.conf_dict, path) 178 179 def _verify_completeness(self): 180 """ 181 Check that the configuration is structurally complete and consistent. 182 183 The following checks are performed: 184 185 1. Top-level mandatory keys: 186 - "experiment_name" 187 - "PATHS" 188 - "nodes" 189 - "data_type" 190 - "adj_matrix" 191 - "model_names" 192 193 2. Per-node mandatory keys: 194 - "data_type" 195 - "node_type" 196 - "parents" 197 - "parents_datatype" 198 - "transformation_terms_in_h()" 199 - "transformation_term_nn_models_in_h()" 200 201 3. Ordinal / categorical levels: 202 - All ordinal variables must have a corresponding "levels" entry 203 under `conf_dict["nodes"][var]`. 204 205 4. Experiment name: 206 - Must be non-empty. 207 208 5. Adjacency matrix: 209 - Must be valid under `validate_adj_matrix`. 210 211 Parameters 212 ---------- 213 None 214 215 Returns 216 ------- 217 None 218 219 Notes 220 ----- 221 - Missing or invalid components are reported via printed warnings. 222 - Detailed debug messages are printed when `self.debug=True`. 223 """ 224 mandatory_keys = ["experiment_name","PATHS", "nodes", "data_type", "adj_matrix","nodes","model_names"] 225 optional_keys = ["date_of_creation", "seed"] 226 227 # ---- 1. Check mandatory keys exist 228 missing = [k for k in mandatory_keys if k not in self.conf_dict] 229 if missing: 230 print(f"[WARNING] Missing mandatory keys in configuration: {missing}" 231 "\n Please add them to the configuration dict and reload.") 232 233 # --- 2. Check if mandatory keys in nodesdict are present 234 mandatory_keys_nodes = ['data_type', 'node_type','parents','parents_datatype','transformation_terms_in_h()','transformation_term_nn_models_in_h()'] 235 optional_keys_nodes = ["levels"] 236 for node, node_dict in self.conf_dict.get("nodes", {}).items(): 237 # check missing mandatory keys 238 missing_node_keys = [k for k in mandatory_keys_nodes if k not in node_dict] 239 if missing_node_keys: 240 print(f"[WARNING] Node '{node}' is missing mandatory keys: {missing_node_keys}") 241 242 243 244 if self._verify_levels_dict(): 245 if self.debug: 246 print("[DEBUG] levels are present for all ordinal variables in configuration dict.") 247 pass 248 else: 249 print("[WARNING] levels are missing for some ordinal variables in configuration dict. THIS will FAIL in model training later!\n" 250 " Please provide levels manually to config and reload or compute levels from data using the method compute_levels().\n" 251 " e.g. cfg.compute_levels(train_df) # computes levels from training data and writes to cfg") 252 253 if self._verify_experiment_name(): 254 if self.debug: 255 print("[DEBUG] experiment_name is valid in configuration dict.") 256 pass 257 258 if self._verify_adj_matrix(): 259 if self.debug: 260 print("[DEBUG] adj_matrix is valid in configuration dict.") 261 pass 262 263 def _verify_levels_dict(self): 264 """ 265 Verify that all ordinal variables have levels specified in the config. 266 267 Parameters 268 ---------- 269 None 270 271 Returns 272 ------- 273 bool 274 True if all variables declared as ordinal in ``conf_dict["data_type"]`` 275 have a "levels" entry in ``conf_dict["nodes"][var]``. 276 False otherwise. 277 278 Notes 279 ----- 280 This method does not modify the configuration; it only checks presence 281 of level information. 282 """ 283 data_type = self.conf_dict.get('data_type', {}) 284 nodes = self.conf_dict.get('nodes', {}) 285 for var, dtype in data_type.items(): 286 if 'ordinal' in dtype: 287 if var not in nodes or 'levels' not in nodes[var]: 288 return False 289 return True 290 291 def _verify_experiment_name(self): 292 """ 293 Check whether the experiment name in the configuration is valid. 294 295 Parameters 296 ---------- 297 None 298 299 Returns 300 ------- 301 bool 302 True if ``conf_dict["experiment_name"]`` exists and is non-empty 303 after stripping whitespace. False otherwise. 304 """ 305 experiment_name = self.conf_dict.get("experiment_name") 306 if experiment_name is None or str(experiment_name).strip() == "": 307 return False 308 return True 309 310 def _verify_adj_matrix(self): 311 """ 312 Validate the adjacency matrix stored in the configuration. 313 314 Parameters 315 ---------- 316 None 317 318 Returns 319 ------- 320 bool 321 True if the adjacency matrix passes `validate_adj_matrix`, False otherwise. 322 323 Notes 324 ----- 325 If the adjacency matrix is stored as a list, it is converted to a 326 NumPy array before validation. 327 """ 328 adj_matrix = self.conf_dict['adj_matrix'] 329 if isinstance(adj_matrix, list): 330 adj_matrix = np.array(self.conf_dict['adj_matrix']) 331 if validate_adj_matrix(adj_matrix): 332 return True 333 else: 334 return False 335 336 def compute_levels(self, df: pd.DataFrame, write: bool = True): 337 """ 338 Infer and update ordinal/categorical levels from data. 339 340 For each variable in the configuration's `data_type` section, this 341 method uses the provided DataFrame to construct a levels dictionary 342 and injects the corresponding "levels" entry into `conf_dict["nodes"]`. 343 344 Parameters 345 ---------- 346 df : pandas.DataFrame 347 DataFrame used to infer levels for configured variables. 348 write : bool, optional 349 If True and `CONF_DICT_PATH` is set, the updated configuration is 350 written back to disk. Default is True. 351 352 Returns 353 ------- 354 None 355 356 Raises 357 ------ 358 Exception 359 If saving the configuration fails when `write=True`. 360 361 Notes 362 ----- 363 - Variables present in `levels_dict` but not in `conf_dict["nodes"]` 364 trigger a warning and are skipped. 365 - If `self.verbose` or `self.debug` is True, a success message is printed 366 when the configuration is saved. 367 """ 368 self.update() 369 levels_dict = create_levels_dict(df, self.conf_dict['data_type']) 370 371 # update nodes dict with levels 372 for var, levels in levels_dict.items(): 373 if var in self.conf_dict['nodes']: 374 self.conf_dict['nodes'][var]['levels'] = levels 375 else: 376 print(f"[WARNING] Variable '{var}' not found in nodes dict. Cannot add levels.") 377 378 if write and self.CONF_DICT_PATH is not None: 379 try: 380 self.save(self.CONF_DICT_PATH) 381 if self.verbose or self.debug: 382 print(f'[INFO] Configuration with updated levels saved to {self.CONF_DICT_PATH}') 383 except Exception as e: 384 print(f'[ERROR] Failed to save configuration: {e}') 385 386 def plot_dag(self, seed: int = 42, causal_order: bool = False): 387 """ 388 Visualize the DAG defined by the configuration. 389 390 Nodes are categorized and colored as: 391 - Source nodes (no incoming edges): green. 392 - Sink nodes (no outgoing edges): red. 393 - Intermediate nodes: light blue. 394 395 Parameters 396 ---------- 397 seed : int, optional 398 Random seed for layout stability in the spring layout fallback. 399 Default is 42. 400 causal_order : bool, optional 401 If True, attempt to use Graphviz 'dot' layout via 402 `networkx.nx_agraph.graphviz_layout` to preserve causal ordering. 403 If False or if Graphviz is unavailable, use `spring_layout`. 404 Default is False. 405 406 Returns 407 ------- 408 None 409 410 Raises 411 ------ 412 ValueError 413 If `adj_matrix` or `data_type` is missing or inconsistent with each other, 414 or if the adjacency matrix fails validation. 415 416 Notes 417 ----- 418 Edge labels are colored by prefix: 419 - "ci": blue 420 - "ls": red 421 - "cs": green 422 - other: black 423 """ 424 adj_matrix = self.conf_dict.get("adj_matrix") 425 data_type = self.conf_dict.get("data_type") 426 427 if adj_matrix is None or data_type is None: 428 raise ValueError("Configuration must include 'adj_matrix' and 'data_type'.") 429 430 if isinstance(adj_matrix, list): 431 adj_matrix = np.array(adj_matrix) 432 433 if not validate_adj_matrix(adj_matrix): 434 raise ValueError("Invalid adjacency matrix.") 435 if len(data_type) != adj_matrix.shape[0]: 436 raise ValueError("data_type must match adjacency matrix size.") 437 438 node_labels = list(data_type.keys()) 439 G, edge_labels = create_nx_graph(adj_matrix, node_labels) 440 441 sources = {n for n in G.nodes if G.in_degree(n) == 0} 442 sinks = {n for n in G.nodes if G.out_degree(n) == 0} 443 intermediates = set(G.nodes) - sources - sinks 444 445 node_colors = [ 446 "green" if n in sources 447 else "red" if n in sinks 448 else "lightblue" 449 for n in G.nodes 450 ] 451 452 if causal_order: 453 try: 454 pos = nx.nx_agraph.graphviz_layout(G, prog="dot") 455 except (ImportError, nx.NetworkXException): 456 pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100) 457 else: 458 pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100) 459 460 plt.figure(figsize=(8, 6)) 461 nx.draw( 462 G, pos, 463 with_labels=True, 464 node_color=node_colors, 465 edge_color="gray", 466 node_size=2500, 467 arrowsize=20 468 ) 469 470 for (u, v), lbl in edge_labels.items(): 471 color = ( 472 "blue" if lbl.startswith("ci") 473 else "red" if lbl.startswith("ls") 474 else "green" if lbl.startswith("cs") 475 else "black" 476 ) 477 nx.draw_networkx_edge_labels( 478 G, pos, 479 edge_labels={(u, v): lbl}, 480 font_color=color, 481 font_size=12 482 ) 483 484 legend_items = [ 485 Patch(facecolor="green", edgecolor="black", label="Source"), 486 Patch(facecolor="red", edgecolor="black", label="Sink"), 487 Patch(facecolor="lightblue", edgecolor="black", label="Intermediate") 488 ] 489 plt.legend(handles=legend_items, loc="upper right", frameon=True) 490 491 plt.title(f"TRAM DAG") 492 plt.axis("off") 493 plt.tight_layout() 494 plt.show() 495 496 497 def setup_configuration(self, experiment_name=None, EXPERIMENT_DIR=None, debug=False, _verify=False): 498 """ 499 Create or reuse a configuration for an experiment. 500 501 This method behaves differently depending on how it is called: 502 503 1. Class call (e.g. `TramDagConfig.setup_configuration(...)`): 504 - Creates or loads a configuration at the resolved path. 505 - Returns a new `TramDagConfig` instance. 506 507 2. Instance call (e.g. `cfg.setup_configuration(...)`): 508 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` in place. 509 - Optionally verifies completeness. 510 - Returns None. 511 512 Parameters 513 ---------- 514 experiment_name : str or None, optional 515 Name of the experiment. If None, defaults to "experiment_1". 516 EXPERIMENT_DIR : str or None, optional 517 Directory for the experiment. If None, defaults to 518 `<cwd>/<experiment_name>`. 519 debug : bool, optional 520 If True, initialize / update with `debug=True`. Default is False. 521 _verify : bool, optional 522 If True, call `_verify_completeness()` after loading. Default is False. 523 524 Returns 525 ------- 526 TramDagConfig or None 527 - A new instance when called on the class. 528 - None when called on an existing instance. 529 530 Notes 531 ----- 532 - A configuration file named "configuration.json" is created if it does 533 not exist yet. 534 - Underlying creation uses `create_and_write_new_configuration_dict` 535 and `load_configuration_dict`. 536 """ 537 is_class_call = isinstance(self, type) 538 cls = self if is_class_call else self.__class__ 539 540 if experiment_name is None: 541 experiment_name = "experiment_1" 542 if EXPERIMENT_DIR is None: 543 EXPERIMENT_DIR = os.path.join(os.getcwd(), experiment_name) 544 545 CONF_DICT_PATH = os.path.join(EXPERIMENT_DIR, "configuration.json") 546 DATA_PATH = EXPERIMENT_DIR 547 548 os.makedirs(EXPERIMENT_DIR, exist_ok=True) 549 550 if os.path.exists(CONF_DICT_PATH): 551 print(f"Configuration already exists: {CONF_DICT_PATH}") 552 else: 553 _ = create_and_write_new_configuration_dict( 554 experiment_name, CONF_DICT_PATH, EXPERIMENT_DIR, DATA_PATH, None 555 ) 556 print(f"Created new configuration file at {CONF_DICT_PATH}") 557 558 conf = load_configuration_dict(CONF_DICT_PATH) 559 560 if is_class_call: 561 return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=_verify) 562 else: 563 self.conf_dict = conf 564 self.CONF_DICT_PATH = CONF_DICT_PATH 565 if _verify: 566 self._verify_completeness() 567 568 def set_data_type(self, data_type: dict, CONF_DICT_PATH: str = None) -> None: 569 """ 570 Update or write the `data_type` section of a configuration file. 571 572 Supports both class-level and instance-level usage: 573 574 - Class call: 575 - Requires `CONF_DICT_PATH` argument. 576 - Reads the file if it exists, or starts from an empty dict. 577 - Writes updated configuration to `CONF_DICT_PATH`. 578 579 - Instance call: 580 - Uses `self.CONF_DICT_PATH` if available, otherwise defaults to 581 `<cwd>/configuration.json` if no path is provided. 582 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` after writing. 583 584 Parameters 585 ---------- 586 data_type : dict 587 Mapping `{variable_name: type_spec}`, where `type_spec` encodes 588 modeling types (e.g. continuous, ordinal, etc.). 589 CONF_DICT_PATH : str or None, optional 590 Path to the configuration file. Must be provided for class calls. 591 For instance calls, defaults as described above. 592 593 Returns 594 ------- 595 None 596 597 Raises 598 ------ 599 ValueError 600 If `CONF_DICT_PATH` is missing when called on the class, or if 601 validation of data types fails. 602 603 Notes 604 ----- 605 - Variable names are validated via `validate_variable_names`. 606 - Data type values are validated via `validate_data_types`. 607 - A textual summary of modeling settings is printed via 608 `print_data_type_modeling_setting`, if possible. 609 """ 610 is_class_call = isinstance(self, type) 611 cls = self if is_class_call else self.__class__ 612 613 # resolve path 614 if CONF_DICT_PATH is None: 615 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 616 CONF_DICT_PATH = self.CONF_DICT_PATH 617 elif not is_class_call: 618 CONF_DICT_PATH = os.path.join(os.getcwd(), "configuration.json") 619 else: 620 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 621 622 try: 623 # load existing or create empty configuration 624 configuration_dict = ( 625 load_configuration_dict(CONF_DICT_PATH) 626 if os.path.exists(CONF_DICT_PATH) 627 else {} 628 ) 629 630 validate_variable_names(data_type.keys()) 631 if not validate_data_types(data_type): 632 raise ValueError("Invalid data types in the provided dictionary.") 633 634 configuration_dict["data_type"] = data_type 635 write_configuration_dict(configuration_dict, CONF_DICT_PATH) 636 637 # safe printing 638 try: 639 print_data_type_modeling_setting(data_type or {}) 640 except Exception as e: 641 print(f"[WARNING] Could not print data type modeling settings: {e}") 642 643 if not is_class_call: 644 self.conf_dict = configuration_dict 645 self.CONF_DICT_PATH = CONF_DICT_PATH 646 647 648 except Exception as e: 649 print(f"Failed to update configuration: {e}") 650 else: 651 print(f"Configuration updated successfully at {CONF_DICT_PATH}.") 652 653 def set_meta_adj_matrix(self, CONF_DICT_PATH: str = None, seed: int = 5): 654 """ 655 Launch the interactive editor to set or modify the adjacency matrix. 656 657 This method: 658 659 1. Resolves the configuration path either from the argument or, for 660 instances, from `self.CONF_DICT_PATH`. 661 2. Invokes `interactive_adj_matrix` to edit the adjacency matrix. 662 3. For instances, reloads the updated configuration into `self.conf_dict`. 663 664 Parameters 665 ---------- 666 CONF_DICT_PATH : str or None, optional 667 Path to the configuration file. Must be provided when called 668 on the class. For instance calls, defaults to `self.CONF_DICT_PATH`. 669 seed : int, optional 670 Random seed for any layout or stochastic behavior in the interactive 671 editor. Default is 5. 672 673 Returns 674 ------- 675 None 676 677 Raises 678 ------ 679 ValueError 680 If `CONF_DICT_PATH` is not provided and cannot be inferred 681 (e.g. in a class call without path). 682 683 Notes 684 ----- 685 `self.update()` is called at the start to ensure the in-memory config 686 is in sync with the file before launching the editor. 687 """ 688 self.update() 689 is_class_call = isinstance(self, type) 690 # resolve path 691 if CONF_DICT_PATH is None: 692 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 693 CONF_DICT_PATH = self.CONF_DICT_PATH 694 else: 695 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 696 697 # launch interactive editor 698 699 interactive_adj_matrix(CONF_DICT_PATH, seed=seed) 700 701 # reload config if instance 702 if not is_class_call: 703 self.conf_dict = load_configuration_dict(CONF_DICT_PATH) 704 self.CONF_DICT_PATH = CONF_DICT_PATH 705 706 707 def set_tramdag_nn_models(self, CONF_DICT_PATH: str = None): 708 """ 709 Launch the interactive editor to set TRAM-DAG neural network model names. 710 711 Depending on call context: 712 713 - Class call: 714 - Requires `CONF_DICT_PATH` argument. 715 - Returns nothing and does not modify a specific instance. 716 717 - Instance call: 718 - Resolves `CONF_DICT_PATH` from the argument or `self.CONF_DICT_PATH`. 719 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` if the editor 720 returns an updated configuration. 721 722 Parameters 723 ---------- 724 CONF_DICT_PATH : str or None, optional 725 Path to the configuration file. Must be provided when called on 726 the class. For instance calls, defaults to `self.CONF_DICT_PATH`. 727 728 Returns 729 ------- 730 None 731 732 Raises 733 ------ 734 ValueError 735 If `CONF_DICT_PATH` is not provided and cannot be inferred 736 (e.g. in a class call without path). 737 738 Notes 739 ----- 740 The interactive editor is invoked via `interactive_nn_names_matrix`. 741 If it returns `None`, the instance configuration is left unchanged. 742 """ 743 is_class_call = isinstance(self, type) 744 if CONF_DICT_PATH is None: 745 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 746 CONF_DICT_PATH = self.CONF_DICT_PATH 747 else: 748 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 749 750 updated_conf = interactive_nn_names_matrix(CONF_DICT_PATH) 751 if updated_conf is not None and not is_class_call: 752 self.conf_dict = updated_conf 753 self.CONF_DICT_PATH = CONF_DICT_PATH
30class TramDagConfig: 31 """ 32 Configuration manager for TRAM-DAG experiments. 33 34 This class encapsulates: 35 36 - The experiment configuration dictionary (`conf_dict`). 37 - Its backing file path (`CONF_DICT_PATH`). 38 - Utilities to load, validate, modify, and persist configuration. 39 - DAG visualization and interactive editing helpers. 40 41 Typical usage 42 ------------- 43 - Load existing configuration from disk via `TramDagConfig.load_json`. 44 - Or create/reuse experiment setup via `TramDagConfig().setup_configuration`. 45 - Update sections such as `data_type`, adjacency matrix, and neural network 46 model names using the provided methods. 47 """ 48 49 def __init__(self, conf_dict: dict = None, CONF_DICT_PATH: str = None, _verify: bool = False,**kwargs): 50 """ 51 Initialize a TramDagConfig instance. 52 53 Parameters 54 ---------- 55 conf_dict : dict or None, optional 56 Configuration dictionary. If None, an empty dict is used and can 57 be populated later. Default is None. 58 CONF_DICT_PATH : str or None, optional 59 Path to the configuration file on disk. Default is None. 60 _verify : bool, optional 61 If True, run `_verify_completeness()` after initialization. 62 Default is False. 63 **kwargs 64 Additional attributes to be set on the instance. Keys "conf_dict" 65 and "CONF_DICT_PATH" are forbidden and raise a ValueError. 66 67 Raises 68 ------ 69 ValueError 70 If any key in `kwargs` is "conf_dict" or "CONF_DICT_PATH". 71 72 Notes 73 ----- 74 By default, `debug` and `verbose` are set to False. They can be 75 overridden via `kwargs`. 76 """ 77 78 self.debug = False 79 self.verbose = False 80 81 for key, value in kwargs.items(): 82 if key in ['conf_dict', 'CONF_DICT_PATH']: 83 raise ValueError(f"Cannot override '{key}' via kwargs.") 84 setattr(self, key, value) 85 86 self.conf_dict = conf_dict or {} 87 self.CONF_DICT_PATH = CONF_DICT_PATH 88 89 # verification 90 if _verify: 91 self._verify_completeness() 92 93 @classmethod 94 def load_json(cls, CONF_DICT_PATH: str,debug: bool = False): 95 """ 96 Load a configuration from a JSON file and construct a TramDagConfig. 97 98 Parameters 99 ---------- 100 CONF_DICT_PATH : str 101 Path to the configuration JSON file. 102 debug : bool, optional 103 If True, initialize the instance with `debug=True`. Default is False. 104 105 Returns 106 ------- 107 TramDagConfig 108 Newly created configuration instance with `conf_dict` loaded from 109 `CONF_DICT_PATH` and `_verify_completeness()` executed. 110 111 Raises 112 ------ 113 FileNotFoundError 114 If the configuration file cannot be found (propagated by 115 `load_configuration_dict`). 116 """ 117 118 conf = load_configuration_dict(CONF_DICT_PATH) 119 return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=True) 120 121 def update(self): 122 """ 123 Reload the latest configuration from disk into this instance. 124 125 Parameters 126 ---------- 127 None 128 129 Returns 130 ------- 131 None 132 133 Raises 134 ------ 135 ValueError 136 If `CONF_DICT_PATH` is not set on the instance. 137 138 Notes 139 ----- 140 The current in-memory `conf_dict` is overwritten by the contents 141 loaded from `CONF_DICT_PATH`. 142 """ 143 144 if not hasattr(self, "CONF_DICT_PATH") or self.CONF_DICT_PATH is None: 145 raise ValueError("CONF_DICT_PATH not set — cannot update configuration.") 146 147 148 self.conf_dict = load_configuration_dict(self.CONF_DICT_PATH) 149 150 151 def save(self, CONF_DICT_PATH: str = None): 152 """ 153 Persist the current configuration dictionary to disk. 154 155 Parameters 156 ---------- 157 CONF_DICT_PATH : str or None, optional 158 Target path for the configuration file. If None, uses 159 `self.CONF_DICT_PATH`. Default is None. 160 161 Returns 162 ------- 163 None 164 165 Raises 166 ------ 167 ValueError 168 If neither the function argument nor `self.CONF_DICT_PATH` 169 provides a valid path. 170 171 Notes 172 ----- 173 The resulting file is written via `write_configuration_dict`. 174 """ 175 path = CONF_DICT_PATH or self.CONF_DICT_PATH 176 if path is None: 177 raise ValueError("No CONF_DICT_PATH provided to save config.") 178 write_configuration_dict(self.conf_dict, path) 179 180 def _verify_completeness(self): 181 """ 182 Check that the configuration is structurally complete and consistent. 183 184 The following checks are performed: 185 186 1. Top-level mandatory keys: 187 - "experiment_name" 188 - "PATHS" 189 - "nodes" 190 - "data_type" 191 - "adj_matrix" 192 - "model_names" 193 194 2. Per-node mandatory keys: 195 - "data_type" 196 - "node_type" 197 - "parents" 198 - "parents_datatype" 199 - "transformation_terms_in_h()" 200 - "transformation_term_nn_models_in_h()" 201 202 3. Ordinal / categorical levels: 203 - All ordinal variables must have a corresponding "levels" entry 204 under `conf_dict["nodes"][var]`. 205 206 4. Experiment name: 207 - Must be non-empty. 208 209 5. Adjacency matrix: 210 - Must be valid under `validate_adj_matrix`. 211 212 Parameters 213 ---------- 214 None 215 216 Returns 217 ------- 218 None 219 220 Notes 221 ----- 222 - Missing or invalid components are reported via printed warnings. 223 - Detailed debug messages are printed when `self.debug=True`. 224 """ 225 mandatory_keys = ["experiment_name","PATHS", "nodes", "data_type", "adj_matrix","nodes","model_names"] 226 optional_keys = ["date_of_creation", "seed"] 227 228 # ---- 1. Check mandatory keys exist 229 missing = [k for k in mandatory_keys if k not in self.conf_dict] 230 if missing: 231 print(f"[WARNING] Missing mandatory keys in configuration: {missing}" 232 "\n Please add them to the configuration dict and reload.") 233 234 # --- 2. Check if mandatory keys in nodesdict are present 235 mandatory_keys_nodes = ['data_type', 'node_type','parents','parents_datatype','transformation_terms_in_h()','transformation_term_nn_models_in_h()'] 236 optional_keys_nodes = ["levels"] 237 for node, node_dict in self.conf_dict.get("nodes", {}).items(): 238 # check missing mandatory keys 239 missing_node_keys = [k for k in mandatory_keys_nodes if k not in node_dict] 240 if missing_node_keys: 241 print(f"[WARNING] Node '{node}' is missing mandatory keys: {missing_node_keys}") 242 243 244 245 if self._verify_levels_dict(): 246 if self.debug: 247 print("[DEBUG] levels are present for all ordinal variables in configuration dict.") 248 pass 249 else: 250 print("[WARNING] levels are missing for some ordinal variables in configuration dict. THIS will FAIL in model training later!\n" 251 " Please provide levels manually to config and reload or compute levels from data using the method compute_levels().\n" 252 " e.g. cfg.compute_levels(train_df) # computes levels from training data and writes to cfg") 253 254 if self._verify_experiment_name(): 255 if self.debug: 256 print("[DEBUG] experiment_name is valid in configuration dict.") 257 pass 258 259 if self._verify_adj_matrix(): 260 if self.debug: 261 print("[DEBUG] adj_matrix is valid in configuration dict.") 262 pass 263 264 def _verify_levels_dict(self): 265 """ 266 Verify that all ordinal variables have levels specified in the config. 267 268 Parameters 269 ---------- 270 None 271 272 Returns 273 ------- 274 bool 275 True if all variables declared as ordinal in ``conf_dict["data_type"]`` 276 have a "levels" entry in ``conf_dict["nodes"][var]``. 277 False otherwise. 278 279 Notes 280 ----- 281 This method does not modify the configuration; it only checks presence 282 of level information. 283 """ 284 data_type = self.conf_dict.get('data_type', {}) 285 nodes = self.conf_dict.get('nodes', {}) 286 for var, dtype in data_type.items(): 287 if 'ordinal' in dtype: 288 if var not in nodes or 'levels' not in nodes[var]: 289 return False 290 return True 291 292 def _verify_experiment_name(self): 293 """ 294 Check whether the experiment name in the configuration is valid. 295 296 Parameters 297 ---------- 298 None 299 300 Returns 301 ------- 302 bool 303 True if ``conf_dict["experiment_name"]`` exists and is non-empty 304 after stripping whitespace. False otherwise. 305 """ 306 experiment_name = self.conf_dict.get("experiment_name") 307 if experiment_name is None or str(experiment_name).strip() == "": 308 return False 309 return True 310 311 def _verify_adj_matrix(self): 312 """ 313 Validate the adjacency matrix stored in the configuration. 314 315 Parameters 316 ---------- 317 None 318 319 Returns 320 ------- 321 bool 322 True if the adjacency matrix passes `validate_adj_matrix`, False otherwise. 323 324 Notes 325 ----- 326 If the adjacency matrix is stored as a list, it is converted to a 327 NumPy array before validation. 328 """ 329 adj_matrix = self.conf_dict['adj_matrix'] 330 if isinstance(adj_matrix, list): 331 adj_matrix = np.array(self.conf_dict['adj_matrix']) 332 if validate_adj_matrix(adj_matrix): 333 return True 334 else: 335 return False 336 337 def compute_levels(self, df: pd.DataFrame, write: bool = True): 338 """ 339 Infer and update ordinal/categorical levels from data. 340 341 For each variable in the configuration's `data_type` section, this 342 method uses the provided DataFrame to construct a levels dictionary 343 and injects the corresponding "levels" entry into `conf_dict["nodes"]`. 344 345 Parameters 346 ---------- 347 df : pandas.DataFrame 348 DataFrame used to infer levels for configured variables. 349 write : bool, optional 350 If True and `CONF_DICT_PATH` is set, the updated configuration is 351 written back to disk. Default is True. 352 353 Returns 354 ------- 355 None 356 357 Raises 358 ------ 359 Exception 360 If saving the configuration fails when `write=True`. 361 362 Notes 363 ----- 364 - Variables present in `levels_dict` but not in `conf_dict["nodes"]` 365 trigger a warning and are skipped. 366 - If `self.verbose` or `self.debug` is True, a success message is printed 367 when the configuration is saved. 368 """ 369 self.update() 370 levels_dict = create_levels_dict(df, self.conf_dict['data_type']) 371 372 # update nodes dict with levels 373 for var, levels in levels_dict.items(): 374 if var in self.conf_dict['nodes']: 375 self.conf_dict['nodes'][var]['levels'] = levels 376 else: 377 print(f"[WARNING] Variable '{var}' not found in nodes dict. Cannot add levels.") 378 379 if write and self.CONF_DICT_PATH is not None: 380 try: 381 self.save(self.CONF_DICT_PATH) 382 if self.verbose or self.debug: 383 print(f'[INFO] Configuration with updated levels saved to {self.CONF_DICT_PATH}') 384 except Exception as e: 385 print(f'[ERROR] Failed to save configuration: {e}') 386 387 def plot_dag(self, seed: int = 42, causal_order: bool = False): 388 """ 389 Visualize the DAG defined by the configuration. 390 391 Nodes are categorized and colored as: 392 - Source nodes (no incoming edges): green. 393 - Sink nodes (no outgoing edges): red. 394 - Intermediate nodes: light blue. 395 396 Parameters 397 ---------- 398 seed : int, optional 399 Random seed for layout stability in the spring layout fallback. 400 Default is 42. 401 causal_order : bool, optional 402 If True, attempt to use Graphviz 'dot' layout via 403 `networkx.nx_agraph.graphviz_layout` to preserve causal ordering. 404 If False or if Graphviz is unavailable, use `spring_layout`. 405 Default is False. 406 407 Returns 408 ------- 409 None 410 411 Raises 412 ------ 413 ValueError 414 If `adj_matrix` or `data_type` is missing or inconsistent with each other, 415 or if the adjacency matrix fails validation. 416 417 Notes 418 ----- 419 Edge labels are colored by prefix: 420 - "ci": blue 421 - "ls": red 422 - "cs": green 423 - other: black 424 """ 425 adj_matrix = self.conf_dict.get("adj_matrix") 426 data_type = self.conf_dict.get("data_type") 427 428 if adj_matrix is None or data_type is None: 429 raise ValueError("Configuration must include 'adj_matrix' and 'data_type'.") 430 431 if isinstance(adj_matrix, list): 432 adj_matrix = np.array(adj_matrix) 433 434 if not validate_adj_matrix(adj_matrix): 435 raise ValueError("Invalid adjacency matrix.") 436 if len(data_type) != adj_matrix.shape[0]: 437 raise ValueError("data_type must match adjacency matrix size.") 438 439 node_labels = list(data_type.keys()) 440 G, edge_labels = create_nx_graph(adj_matrix, node_labels) 441 442 sources = {n for n in G.nodes if G.in_degree(n) == 0} 443 sinks = {n for n in G.nodes if G.out_degree(n) == 0} 444 intermediates = set(G.nodes) - sources - sinks 445 446 node_colors = [ 447 "green" if n in sources 448 else "red" if n in sinks 449 else "lightblue" 450 for n in G.nodes 451 ] 452 453 if causal_order: 454 try: 455 pos = nx.nx_agraph.graphviz_layout(G, prog="dot") 456 except (ImportError, nx.NetworkXException): 457 pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100) 458 else: 459 pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100) 460 461 plt.figure(figsize=(8, 6)) 462 nx.draw( 463 G, pos, 464 with_labels=True, 465 node_color=node_colors, 466 edge_color="gray", 467 node_size=2500, 468 arrowsize=20 469 ) 470 471 for (u, v), lbl in edge_labels.items(): 472 color = ( 473 "blue" if lbl.startswith("ci") 474 else "red" if lbl.startswith("ls") 475 else "green" if lbl.startswith("cs") 476 else "black" 477 ) 478 nx.draw_networkx_edge_labels( 479 G, pos, 480 edge_labels={(u, v): lbl}, 481 font_color=color, 482 font_size=12 483 ) 484 485 legend_items = [ 486 Patch(facecolor="green", edgecolor="black", label="Source"), 487 Patch(facecolor="red", edgecolor="black", label="Sink"), 488 Patch(facecolor="lightblue", edgecolor="black", label="Intermediate") 489 ] 490 plt.legend(handles=legend_items, loc="upper right", frameon=True) 491 492 plt.title(f"TRAM DAG") 493 plt.axis("off") 494 plt.tight_layout() 495 plt.show() 496 497 498 def setup_configuration(self, experiment_name=None, EXPERIMENT_DIR=None, debug=False, _verify=False): 499 """ 500 Create or reuse a configuration for an experiment. 501 502 This method behaves differently depending on how it is called: 503 504 1. Class call (e.g. `TramDagConfig.setup_configuration(...)`): 505 - Creates or loads a configuration at the resolved path. 506 - Returns a new `TramDagConfig` instance. 507 508 2. Instance call (e.g. `cfg.setup_configuration(...)`): 509 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` in place. 510 - Optionally verifies completeness. 511 - Returns None. 512 513 Parameters 514 ---------- 515 experiment_name : str or None, optional 516 Name of the experiment. If None, defaults to "experiment_1". 517 EXPERIMENT_DIR : str or None, optional 518 Directory for the experiment. If None, defaults to 519 `<cwd>/<experiment_name>`. 520 debug : bool, optional 521 If True, initialize / update with `debug=True`. Default is False. 522 _verify : bool, optional 523 If True, call `_verify_completeness()` after loading. Default is False. 524 525 Returns 526 ------- 527 TramDagConfig or None 528 - A new instance when called on the class. 529 - None when called on an existing instance. 530 531 Notes 532 ----- 533 - A configuration file named "configuration.json" is created if it does 534 not exist yet. 535 - Underlying creation uses `create_and_write_new_configuration_dict` 536 and `load_configuration_dict`. 537 """ 538 is_class_call = isinstance(self, type) 539 cls = self if is_class_call else self.__class__ 540 541 if experiment_name is None: 542 experiment_name = "experiment_1" 543 if EXPERIMENT_DIR is None: 544 EXPERIMENT_DIR = os.path.join(os.getcwd(), experiment_name) 545 546 CONF_DICT_PATH = os.path.join(EXPERIMENT_DIR, "configuration.json") 547 DATA_PATH = EXPERIMENT_DIR 548 549 os.makedirs(EXPERIMENT_DIR, exist_ok=True) 550 551 if os.path.exists(CONF_DICT_PATH): 552 print(f"Configuration already exists: {CONF_DICT_PATH}") 553 else: 554 _ = create_and_write_new_configuration_dict( 555 experiment_name, CONF_DICT_PATH, EXPERIMENT_DIR, DATA_PATH, None 556 ) 557 print(f"Created new configuration file at {CONF_DICT_PATH}") 558 559 conf = load_configuration_dict(CONF_DICT_PATH) 560 561 if is_class_call: 562 return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=_verify) 563 else: 564 self.conf_dict = conf 565 self.CONF_DICT_PATH = CONF_DICT_PATH 566 if _verify: 567 self._verify_completeness() 568 569 def set_data_type(self, data_type: dict, CONF_DICT_PATH: str = None) -> None: 570 """ 571 Update or write the `data_type` section of a configuration file. 572 573 Supports both class-level and instance-level usage: 574 575 - Class call: 576 - Requires `CONF_DICT_PATH` argument. 577 - Reads the file if it exists, or starts from an empty dict. 578 - Writes updated configuration to `CONF_DICT_PATH`. 579 580 - Instance call: 581 - Uses `self.CONF_DICT_PATH` if available, otherwise defaults to 582 `<cwd>/configuration.json` if no path is provided. 583 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` after writing. 584 585 Parameters 586 ---------- 587 data_type : dict 588 Mapping `{variable_name: type_spec}`, where `type_spec` encodes 589 modeling types (e.g. continuous, ordinal, etc.). 590 CONF_DICT_PATH : str or None, optional 591 Path to the configuration file. Must be provided for class calls. 592 For instance calls, defaults as described above. 593 594 Returns 595 ------- 596 None 597 598 Raises 599 ------ 600 ValueError 601 If `CONF_DICT_PATH` is missing when called on the class, or if 602 validation of data types fails. 603 604 Notes 605 ----- 606 - Variable names are validated via `validate_variable_names`. 607 - Data type values are validated via `validate_data_types`. 608 - A textual summary of modeling settings is printed via 609 `print_data_type_modeling_setting`, if possible. 610 """ 611 is_class_call = isinstance(self, type) 612 cls = self if is_class_call else self.__class__ 613 614 # resolve path 615 if CONF_DICT_PATH is None: 616 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 617 CONF_DICT_PATH = self.CONF_DICT_PATH 618 elif not is_class_call: 619 CONF_DICT_PATH = os.path.join(os.getcwd(), "configuration.json") 620 else: 621 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 622 623 try: 624 # load existing or create empty configuration 625 configuration_dict = ( 626 load_configuration_dict(CONF_DICT_PATH) 627 if os.path.exists(CONF_DICT_PATH) 628 else {} 629 ) 630 631 validate_variable_names(data_type.keys()) 632 if not validate_data_types(data_type): 633 raise ValueError("Invalid data types in the provided dictionary.") 634 635 configuration_dict["data_type"] = data_type 636 write_configuration_dict(configuration_dict, CONF_DICT_PATH) 637 638 # safe printing 639 try: 640 print_data_type_modeling_setting(data_type or {}) 641 except Exception as e: 642 print(f"[WARNING] Could not print data type modeling settings: {e}") 643 644 if not is_class_call: 645 self.conf_dict = configuration_dict 646 self.CONF_DICT_PATH = CONF_DICT_PATH 647 648 649 except Exception as e: 650 print(f"Failed to update configuration: {e}") 651 else: 652 print(f"Configuration updated successfully at {CONF_DICT_PATH}.") 653 654 def set_meta_adj_matrix(self, CONF_DICT_PATH: str = None, seed: int = 5): 655 """ 656 Launch the interactive editor to set or modify the adjacency matrix. 657 658 This method: 659 660 1. Resolves the configuration path either from the argument or, for 661 instances, from `self.CONF_DICT_PATH`. 662 2. Invokes `interactive_adj_matrix` to edit the adjacency matrix. 663 3. For instances, reloads the updated configuration into `self.conf_dict`. 664 665 Parameters 666 ---------- 667 CONF_DICT_PATH : str or None, optional 668 Path to the configuration file. Must be provided when called 669 on the class. For instance calls, defaults to `self.CONF_DICT_PATH`. 670 seed : int, optional 671 Random seed for any layout or stochastic behavior in the interactive 672 editor. Default is 5. 673 674 Returns 675 ------- 676 None 677 678 Raises 679 ------ 680 ValueError 681 If `CONF_DICT_PATH` is not provided and cannot be inferred 682 (e.g. in a class call without path). 683 684 Notes 685 ----- 686 `self.update()` is called at the start to ensure the in-memory config 687 is in sync with the file before launching the editor. 688 """ 689 self.update() 690 is_class_call = isinstance(self, type) 691 # resolve path 692 if CONF_DICT_PATH is None: 693 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 694 CONF_DICT_PATH = self.CONF_DICT_PATH 695 else: 696 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 697 698 # launch interactive editor 699 700 interactive_adj_matrix(CONF_DICT_PATH, seed=seed) 701 702 # reload config if instance 703 if not is_class_call: 704 self.conf_dict = load_configuration_dict(CONF_DICT_PATH) 705 self.CONF_DICT_PATH = CONF_DICT_PATH 706 707 708 def set_tramdag_nn_models(self, CONF_DICT_PATH: str = None): 709 """ 710 Launch the interactive editor to set TRAM-DAG neural network model names. 711 712 Depending on call context: 713 714 - Class call: 715 - Requires `CONF_DICT_PATH` argument. 716 - Returns nothing and does not modify a specific instance. 717 718 - Instance call: 719 - Resolves `CONF_DICT_PATH` from the argument or `self.CONF_DICT_PATH`. 720 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` if the editor 721 returns an updated configuration. 722 723 Parameters 724 ---------- 725 CONF_DICT_PATH : str or None, optional 726 Path to the configuration file. Must be provided when called on 727 the class. For instance calls, defaults to `self.CONF_DICT_PATH`. 728 729 Returns 730 ------- 731 None 732 733 Raises 734 ------ 735 ValueError 736 If `CONF_DICT_PATH` is not provided and cannot be inferred 737 (e.g. in a class call without path). 738 739 Notes 740 ----- 741 The interactive editor is invoked via `interactive_nn_names_matrix`. 742 If it returns `None`, the instance configuration is left unchanged. 743 """ 744 is_class_call = isinstance(self, type) 745 if CONF_DICT_PATH is None: 746 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 747 CONF_DICT_PATH = self.CONF_DICT_PATH 748 else: 749 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 750 751 updated_conf = interactive_nn_names_matrix(CONF_DICT_PATH) 752 if updated_conf is not None and not is_class_call: 753 self.conf_dict = updated_conf 754 self.CONF_DICT_PATH = CONF_DICT_PATH
Configuration manager for TRAM-DAG experiments.
This class encapsulates:
- The experiment configuration dictionary (
conf_dict). - Its backing file path (
CONF_DICT_PATH). - Utilities to load, validate, modify, and persist configuration.
- DAG visualization and interactive editing helpers.
Typical usage
- Load existing configuration from disk via
TramDagConfig.load_json. - Or create/reuse experiment setup via
TramDagConfig().setup_configuration. - Update sections such as
data_type, adjacency matrix, and neural network model names using the provided methods.
49 def __init__(self, conf_dict: dict = None, CONF_DICT_PATH: str = None, _verify: bool = False,**kwargs): 50 """ 51 Initialize a TramDagConfig instance. 52 53 Parameters 54 ---------- 55 conf_dict : dict or None, optional 56 Configuration dictionary. If None, an empty dict is used and can 57 be populated later. Default is None. 58 CONF_DICT_PATH : str or None, optional 59 Path to the configuration file on disk. Default is None. 60 _verify : bool, optional 61 If True, run `_verify_completeness()` after initialization. 62 Default is False. 63 **kwargs 64 Additional attributes to be set on the instance. Keys "conf_dict" 65 and "CONF_DICT_PATH" are forbidden and raise a ValueError. 66 67 Raises 68 ------ 69 ValueError 70 If any key in `kwargs` is "conf_dict" or "CONF_DICT_PATH". 71 72 Notes 73 ----- 74 By default, `debug` and `verbose` are set to False. They can be 75 overridden via `kwargs`. 76 """ 77 78 self.debug = False 79 self.verbose = False 80 81 for key, value in kwargs.items(): 82 if key in ['conf_dict', 'CONF_DICT_PATH']: 83 raise ValueError(f"Cannot override '{key}' via kwargs.") 84 setattr(self, key, value) 85 86 self.conf_dict = conf_dict or {} 87 self.CONF_DICT_PATH = CONF_DICT_PATH 88 89 # verification 90 if _verify: 91 self._verify_completeness()
Initialize a TramDagConfig instance.
Parameters
conf_dict : dict or None, optional
Configuration dictionary. If None, an empty dict is used and can
be populated later. Default is None.
CONF_DICT_PATH : str or None, optional
Path to the configuration file on disk. Default is None.
_verify : bool, optional
If True, run _verify_completeness() after initialization.
Default is False.
**kwargs
Additional attributes to be set on the instance. Keys "conf_dict"
and "CONF_DICT_PATH" are forbidden and raise a ValueError.
Raises
ValueError
If any key in kwargs is "conf_dict" or "CONF_DICT_PATH".
Notes
By default, debug and verbose are set to False. They can be
overridden via kwargs.
93 @classmethod 94 def load_json(cls, CONF_DICT_PATH: str,debug: bool = False): 95 """ 96 Load a configuration from a JSON file and construct a TramDagConfig. 97 98 Parameters 99 ---------- 100 CONF_DICT_PATH : str 101 Path to the configuration JSON file. 102 debug : bool, optional 103 If True, initialize the instance with `debug=True`. Default is False. 104 105 Returns 106 ------- 107 TramDagConfig 108 Newly created configuration instance with `conf_dict` loaded from 109 `CONF_DICT_PATH` and `_verify_completeness()` executed. 110 111 Raises 112 ------ 113 FileNotFoundError 114 If the configuration file cannot be found (propagated by 115 `load_configuration_dict`). 116 """ 117 118 conf = load_configuration_dict(CONF_DICT_PATH) 119 return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=True)
Load a configuration from a JSON file and construct a TramDagConfig.
Parameters
CONF_DICT_PATH : str
Path to the configuration JSON file.
debug : bool, optional
If True, initialize the instance with debug=True. Default is False.
Returns
TramDagConfig
Newly created configuration instance with conf_dict loaded from
CONF_DICT_PATH and _verify_completeness() executed.
Raises
FileNotFoundError
If the configuration file cannot be found (propagated by
load_configuration_dict).
121 def update(self): 122 """ 123 Reload the latest configuration from disk into this instance. 124 125 Parameters 126 ---------- 127 None 128 129 Returns 130 ------- 131 None 132 133 Raises 134 ------ 135 ValueError 136 If `CONF_DICT_PATH` is not set on the instance. 137 138 Notes 139 ----- 140 The current in-memory `conf_dict` is overwritten by the contents 141 loaded from `CONF_DICT_PATH`. 142 """ 143 144 if not hasattr(self, "CONF_DICT_PATH") or self.CONF_DICT_PATH is None: 145 raise ValueError("CONF_DICT_PATH not set — cannot update configuration.") 146 147 148 self.conf_dict = load_configuration_dict(self.CONF_DICT_PATH)
Reload the latest configuration from disk into this instance.
Parameters
None
Returns
None
Raises
ValueError
If CONF_DICT_PATH is not set on the instance.
Notes
The current in-memory conf_dict is overwritten by the contents
loaded from CONF_DICT_PATH.
151 def save(self, CONF_DICT_PATH: str = None): 152 """ 153 Persist the current configuration dictionary to disk. 154 155 Parameters 156 ---------- 157 CONF_DICT_PATH : str or None, optional 158 Target path for the configuration file. If None, uses 159 `self.CONF_DICT_PATH`. Default is None. 160 161 Returns 162 ------- 163 None 164 165 Raises 166 ------ 167 ValueError 168 If neither the function argument nor `self.CONF_DICT_PATH` 169 provides a valid path. 170 171 Notes 172 ----- 173 The resulting file is written via `write_configuration_dict`. 174 """ 175 path = CONF_DICT_PATH or self.CONF_DICT_PATH 176 if path is None: 177 raise ValueError("No CONF_DICT_PATH provided to save config.") 178 write_configuration_dict(self.conf_dict, path)
Persist the current configuration dictionary to disk.
Parameters
CONF_DICT_PATH : str or None, optional
Target path for the configuration file. If None, uses
self.CONF_DICT_PATH. Default is None.
Returns
None
Raises
ValueError
If neither the function argument nor self.CONF_DICT_PATH
provides a valid path.
Notes
The resulting file is written via write_configuration_dict.
337 def compute_levels(self, df: pd.DataFrame, write: bool = True): 338 """ 339 Infer and update ordinal/categorical levels from data. 340 341 For each variable in the configuration's `data_type` section, this 342 method uses the provided DataFrame to construct a levels dictionary 343 and injects the corresponding "levels" entry into `conf_dict["nodes"]`. 344 345 Parameters 346 ---------- 347 df : pandas.DataFrame 348 DataFrame used to infer levels for configured variables. 349 write : bool, optional 350 If True and `CONF_DICT_PATH` is set, the updated configuration is 351 written back to disk. Default is True. 352 353 Returns 354 ------- 355 None 356 357 Raises 358 ------ 359 Exception 360 If saving the configuration fails when `write=True`. 361 362 Notes 363 ----- 364 - Variables present in `levels_dict` but not in `conf_dict["nodes"]` 365 trigger a warning and are skipped. 366 - If `self.verbose` or `self.debug` is True, a success message is printed 367 when the configuration is saved. 368 """ 369 self.update() 370 levels_dict = create_levels_dict(df, self.conf_dict['data_type']) 371 372 # update nodes dict with levels 373 for var, levels in levels_dict.items(): 374 if var in self.conf_dict['nodes']: 375 self.conf_dict['nodes'][var]['levels'] = levels 376 else: 377 print(f"[WARNING] Variable '{var}' not found in nodes dict. Cannot add levels.") 378 379 if write and self.CONF_DICT_PATH is not None: 380 try: 381 self.save(self.CONF_DICT_PATH) 382 if self.verbose or self.debug: 383 print(f'[INFO] Configuration with updated levels saved to {self.CONF_DICT_PATH}') 384 except Exception as e: 385 print(f'[ERROR] Failed to save configuration: {e}')
Infer and update ordinal/categorical levels from data.
For each variable in the configuration's data_type section, this
method uses the provided DataFrame to construct a levels dictionary
and injects the corresponding "levels" entry into conf_dict["nodes"].
Parameters
df : pandas.DataFrame
DataFrame used to infer levels for configured variables.
write : bool, optional
If True and CONF_DICT_PATH is set, the updated configuration is
written back to disk. Default is True.
Returns
None
Raises
Exception
If saving the configuration fails when write=True.
Notes
- Variables present in
levels_dictbut not inconf_dict["nodes"]trigger a warning and are skipped. - If
self.verboseorself.debugis True, a success message is printed when the configuration is saved.
387 def plot_dag(self, seed: int = 42, causal_order: bool = False): 388 """ 389 Visualize the DAG defined by the configuration. 390 391 Nodes are categorized and colored as: 392 - Source nodes (no incoming edges): green. 393 - Sink nodes (no outgoing edges): red. 394 - Intermediate nodes: light blue. 395 396 Parameters 397 ---------- 398 seed : int, optional 399 Random seed for layout stability in the spring layout fallback. 400 Default is 42. 401 causal_order : bool, optional 402 If True, attempt to use Graphviz 'dot' layout via 403 `networkx.nx_agraph.graphviz_layout` to preserve causal ordering. 404 If False or if Graphviz is unavailable, use `spring_layout`. 405 Default is False. 406 407 Returns 408 ------- 409 None 410 411 Raises 412 ------ 413 ValueError 414 If `adj_matrix` or `data_type` is missing or inconsistent with each other, 415 or if the adjacency matrix fails validation. 416 417 Notes 418 ----- 419 Edge labels are colored by prefix: 420 - "ci": blue 421 - "ls": red 422 - "cs": green 423 - other: black 424 """ 425 adj_matrix = self.conf_dict.get("adj_matrix") 426 data_type = self.conf_dict.get("data_type") 427 428 if adj_matrix is None or data_type is None: 429 raise ValueError("Configuration must include 'adj_matrix' and 'data_type'.") 430 431 if isinstance(adj_matrix, list): 432 adj_matrix = np.array(adj_matrix) 433 434 if not validate_adj_matrix(adj_matrix): 435 raise ValueError("Invalid adjacency matrix.") 436 if len(data_type) != adj_matrix.shape[0]: 437 raise ValueError("data_type must match adjacency matrix size.") 438 439 node_labels = list(data_type.keys()) 440 G, edge_labels = create_nx_graph(adj_matrix, node_labels) 441 442 sources = {n for n in G.nodes if G.in_degree(n) == 0} 443 sinks = {n for n in G.nodes if G.out_degree(n) == 0} 444 intermediates = set(G.nodes) - sources - sinks 445 446 node_colors = [ 447 "green" if n in sources 448 else "red" if n in sinks 449 else "lightblue" 450 for n in G.nodes 451 ] 452 453 if causal_order: 454 try: 455 pos = nx.nx_agraph.graphviz_layout(G, prog="dot") 456 except (ImportError, nx.NetworkXException): 457 pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100) 458 else: 459 pos = nx.spring_layout(G, seed=seed, k=1.5, iterations=100) 460 461 plt.figure(figsize=(8, 6)) 462 nx.draw( 463 G, pos, 464 with_labels=True, 465 node_color=node_colors, 466 edge_color="gray", 467 node_size=2500, 468 arrowsize=20 469 ) 470 471 for (u, v), lbl in edge_labels.items(): 472 color = ( 473 "blue" if lbl.startswith("ci") 474 else "red" if lbl.startswith("ls") 475 else "green" if lbl.startswith("cs") 476 else "black" 477 ) 478 nx.draw_networkx_edge_labels( 479 G, pos, 480 edge_labels={(u, v): lbl}, 481 font_color=color, 482 font_size=12 483 ) 484 485 legend_items = [ 486 Patch(facecolor="green", edgecolor="black", label="Source"), 487 Patch(facecolor="red", edgecolor="black", label="Sink"), 488 Patch(facecolor="lightblue", edgecolor="black", label="Intermediate") 489 ] 490 plt.legend(handles=legend_items, loc="upper right", frameon=True) 491 492 plt.title(f"TRAM DAG") 493 plt.axis("off") 494 plt.tight_layout() 495 plt.show()
Visualize the DAG defined by the configuration.
Nodes are categorized and colored as:
- Source nodes (no incoming edges): green.
- Sink nodes (no outgoing edges): red.
- Intermediate nodes: light blue.
Parameters
seed : int, optional
Random seed for layout stability in the spring layout fallback.
Default is 42.
causal_order : bool, optional
If True, attempt to use Graphviz 'dot' layout via
networkx.nx_agraph.graphviz_layout to preserve causal ordering.
If False or if Graphviz is unavailable, use spring_layout.
Default is False.
Returns
None
Raises
ValueError
If adj_matrix or data_type is missing or inconsistent with each other,
or if the adjacency matrix fails validation.
Notes
Edge labels are colored by prefix:
- "ci": blue
- "ls": red
- "cs": green
- other: black
498 def setup_configuration(self, experiment_name=None, EXPERIMENT_DIR=None, debug=False, _verify=False): 499 """ 500 Create or reuse a configuration for an experiment. 501 502 This method behaves differently depending on how it is called: 503 504 1. Class call (e.g. `TramDagConfig.setup_configuration(...)`): 505 - Creates or loads a configuration at the resolved path. 506 - Returns a new `TramDagConfig` instance. 507 508 2. Instance call (e.g. `cfg.setup_configuration(...)`): 509 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` in place. 510 - Optionally verifies completeness. 511 - Returns None. 512 513 Parameters 514 ---------- 515 experiment_name : str or None, optional 516 Name of the experiment. If None, defaults to "experiment_1". 517 EXPERIMENT_DIR : str or None, optional 518 Directory for the experiment. If None, defaults to 519 `<cwd>/<experiment_name>`. 520 debug : bool, optional 521 If True, initialize / update with `debug=True`. Default is False. 522 _verify : bool, optional 523 If True, call `_verify_completeness()` after loading. Default is False. 524 525 Returns 526 ------- 527 TramDagConfig or None 528 - A new instance when called on the class. 529 - None when called on an existing instance. 530 531 Notes 532 ----- 533 - A configuration file named "configuration.json" is created if it does 534 not exist yet. 535 - Underlying creation uses `create_and_write_new_configuration_dict` 536 and `load_configuration_dict`. 537 """ 538 is_class_call = isinstance(self, type) 539 cls = self if is_class_call else self.__class__ 540 541 if experiment_name is None: 542 experiment_name = "experiment_1" 543 if EXPERIMENT_DIR is None: 544 EXPERIMENT_DIR = os.path.join(os.getcwd(), experiment_name) 545 546 CONF_DICT_PATH = os.path.join(EXPERIMENT_DIR, "configuration.json") 547 DATA_PATH = EXPERIMENT_DIR 548 549 os.makedirs(EXPERIMENT_DIR, exist_ok=True) 550 551 if os.path.exists(CONF_DICT_PATH): 552 print(f"Configuration already exists: {CONF_DICT_PATH}") 553 else: 554 _ = create_and_write_new_configuration_dict( 555 experiment_name, CONF_DICT_PATH, EXPERIMENT_DIR, DATA_PATH, None 556 ) 557 print(f"Created new configuration file at {CONF_DICT_PATH}") 558 559 conf = load_configuration_dict(CONF_DICT_PATH) 560 561 if is_class_call: 562 return cls(conf, CONF_DICT_PATH=CONF_DICT_PATH, debug=debug, _verify=_verify) 563 else: 564 self.conf_dict = conf 565 self.CONF_DICT_PATH = CONF_DICT_PATH 566 if _verify: 567 self._verify_completeness()
Create or reuse a configuration for an experiment.
This method behaves differently depending on how it is called:
- Class call (e.g.
TramDagConfig.setup_configuration(...)):
- Creates or loads a configuration at the resolved path.
- Returns a new
TramDagConfiginstance.
- Instance call (e.g.
cfg.setup_configuration(...)):
- Updates
self.conf_dictandself.CONF_DICT_PATHin place. - Optionally verifies completeness.
- Returns None.
Parameters
experiment_name : str or None, optional
Name of the experiment. If None, defaults to "experiment_1".
EXPERIMENT_DIR : str or None, optional
Directory for the experiment. If None, defaults to
<cwd>/<experiment_name>.
debug : bool, optional
If True, initialize / update with debug=True. Default is False.
_verify : bool, optional
If True, call _verify_completeness() after loading. Default is False.
Returns
TramDagConfig or None - A new instance when called on the class. - None when called on an existing instance.
Notes
- A configuration file named "configuration.json" is created if it does not exist yet.
- Underlying creation uses
create_and_write_new_configuration_dictandload_configuration_dict.
569 def set_data_type(self, data_type: dict, CONF_DICT_PATH: str = None) -> None: 570 """ 571 Update or write the `data_type` section of a configuration file. 572 573 Supports both class-level and instance-level usage: 574 575 - Class call: 576 - Requires `CONF_DICT_PATH` argument. 577 - Reads the file if it exists, or starts from an empty dict. 578 - Writes updated configuration to `CONF_DICT_PATH`. 579 580 - Instance call: 581 - Uses `self.CONF_DICT_PATH` if available, otherwise defaults to 582 `<cwd>/configuration.json` if no path is provided. 583 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` after writing. 584 585 Parameters 586 ---------- 587 data_type : dict 588 Mapping `{variable_name: type_spec}`, where `type_spec` encodes 589 modeling types (e.g. continuous, ordinal, etc.). 590 CONF_DICT_PATH : str or None, optional 591 Path to the configuration file. Must be provided for class calls. 592 For instance calls, defaults as described above. 593 594 Returns 595 ------- 596 None 597 598 Raises 599 ------ 600 ValueError 601 If `CONF_DICT_PATH` is missing when called on the class, or if 602 validation of data types fails. 603 604 Notes 605 ----- 606 - Variable names are validated via `validate_variable_names`. 607 - Data type values are validated via `validate_data_types`. 608 - A textual summary of modeling settings is printed via 609 `print_data_type_modeling_setting`, if possible. 610 """ 611 is_class_call = isinstance(self, type) 612 cls = self if is_class_call else self.__class__ 613 614 # resolve path 615 if CONF_DICT_PATH is None: 616 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 617 CONF_DICT_PATH = self.CONF_DICT_PATH 618 elif not is_class_call: 619 CONF_DICT_PATH = os.path.join(os.getcwd(), "configuration.json") 620 else: 621 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 622 623 try: 624 # load existing or create empty configuration 625 configuration_dict = ( 626 load_configuration_dict(CONF_DICT_PATH) 627 if os.path.exists(CONF_DICT_PATH) 628 else {} 629 ) 630 631 validate_variable_names(data_type.keys()) 632 if not validate_data_types(data_type): 633 raise ValueError("Invalid data types in the provided dictionary.") 634 635 configuration_dict["data_type"] = data_type 636 write_configuration_dict(configuration_dict, CONF_DICT_PATH) 637 638 # safe printing 639 try: 640 print_data_type_modeling_setting(data_type or {}) 641 except Exception as e: 642 print(f"[WARNING] Could not print data type modeling settings: {e}") 643 644 if not is_class_call: 645 self.conf_dict = configuration_dict 646 self.CONF_DICT_PATH = CONF_DICT_PATH 647 648 649 except Exception as e: 650 print(f"Failed to update configuration: {e}") 651 else: 652 print(f"Configuration updated successfully at {CONF_DICT_PATH}.")
Update or write the data_type section of a configuration file.
Supports both class-level and instance-level usage:
- Class call:
- Requires
CONF_DICT_PATHargument. - Reads the file if it exists, or starts from an empty dict.
Writes updated configuration to
CONF_DICT_PATH.Instance call:
- Uses
self.CONF_DICT_PATHif available, otherwise defaults to<cwd>/configuration.jsonif no path is provided. - Updates
self.conf_dictandself.CONF_DICT_PATHafter writing.
Parameters
data_type : dict
Mapping {variable_name: type_spec}, where type_spec encodes
modeling types (e.g. continuous, ordinal, etc.).
CONF_DICT_PATH : str or None, optional
Path to the configuration file. Must be provided for class calls.
For instance calls, defaults as described above.
Returns
None
Raises
ValueError
If CONF_DICT_PATH is missing when called on the class, or if
validation of data types fails.
Notes
- Variable names are validated via
validate_variable_names. - Data type values are validated via
validate_data_types. - A textual summary of modeling settings is printed via
print_data_type_modeling_setting, if possible.
654 def set_meta_adj_matrix(self, CONF_DICT_PATH: str = None, seed: int = 5): 655 """ 656 Launch the interactive editor to set or modify the adjacency matrix. 657 658 This method: 659 660 1. Resolves the configuration path either from the argument or, for 661 instances, from `self.CONF_DICT_PATH`. 662 2. Invokes `interactive_adj_matrix` to edit the adjacency matrix. 663 3. For instances, reloads the updated configuration into `self.conf_dict`. 664 665 Parameters 666 ---------- 667 CONF_DICT_PATH : str or None, optional 668 Path to the configuration file. Must be provided when called 669 on the class. For instance calls, defaults to `self.CONF_DICT_PATH`. 670 seed : int, optional 671 Random seed for any layout or stochastic behavior in the interactive 672 editor. Default is 5. 673 674 Returns 675 ------- 676 None 677 678 Raises 679 ------ 680 ValueError 681 If `CONF_DICT_PATH` is not provided and cannot be inferred 682 (e.g. in a class call without path). 683 684 Notes 685 ----- 686 `self.update()` is called at the start to ensure the in-memory config 687 is in sync with the file before launching the editor. 688 """ 689 self.update() 690 is_class_call = isinstance(self, type) 691 # resolve path 692 if CONF_DICT_PATH is None: 693 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 694 CONF_DICT_PATH = self.CONF_DICT_PATH 695 else: 696 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 697 698 # launch interactive editor 699 700 interactive_adj_matrix(CONF_DICT_PATH, seed=seed) 701 702 # reload config if instance 703 if not is_class_call: 704 self.conf_dict = load_configuration_dict(CONF_DICT_PATH) 705 self.CONF_DICT_PATH = CONF_DICT_PATH
Launch the interactive editor to set or modify the adjacency matrix.
This method:
- Resolves the configuration path either from the argument or, for
instances, from
self.CONF_DICT_PATH. - Invokes
interactive_adj_matrixto edit the adjacency matrix. - For instances, reloads the updated configuration into
self.conf_dict.
Parameters
CONF_DICT_PATH : str or None, optional
Path to the configuration file. Must be provided when called
on the class. For instance calls, defaults to self.CONF_DICT_PATH.
seed : int, optional
Random seed for any layout or stochastic behavior in the interactive
editor. Default is 5.
Returns
None
Raises
ValueError
If CONF_DICT_PATH is not provided and cannot be inferred
(e.g. in a class call without path).
Notes
self.update() is called at the start to ensure the in-memory config
is in sync with the file before launching the editor.
708 def set_tramdag_nn_models(self, CONF_DICT_PATH: str = None): 709 """ 710 Launch the interactive editor to set TRAM-DAG neural network model names. 711 712 Depending on call context: 713 714 - Class call: 715 - Requires `CONF_DICT_PATH` argument. 716 - Returns nothing and does not modify a specific instance. 717 718 - Instance call: 719 - Resolves `CONF_DICT_PATH` from the argument or `self.CONF_DICT_PATH`. 720 - Updates `self.conf_dict` and `self.CONF_DICT_PATH` if the editor 721 returns an updated configuration. 722 723 Parameters 724 ---------- 725 CONF_DICT_PATH : str or None, optional 726 Path to the configuration file. Must be provided when called on 727 the class. For instance calls, defaults to `self.CONF_DICT_PATH`. 728 729 Returns 730 ------- 731 None 732 733 Raises 734 ------ 735 ValueError 736 If `CONF_DICT_PATH` is not provided and cannot be inferred 737 (e.g. in a class call without path). 738 739 Notes 740 ----- 741 The interactive editor is invoked via `interactive_nn_names_matrix`. 742 If it returns `None`, the instance configuration is left unchanged. 743 """ 744 is_class_call = isinstance(self, type) 745 if CONF_DICT_PATH is None: 746 if not is_class_call and getattr(self, "CONF_DICT_PATH", None): 747 CONF_DICT_PATH = self.CONF_DICT_PATH 748 else: 749 raise ValueError("CONF_DICT_PATH must be provided when called on the class.") 750 751 updated_conf = interactive_nn_names_matrix(CONF_DICT_PATH) 752 if updated_conf is not None and not is_class_call: 753 self.conf_dict = updated_conf 754 self.CONF_DICT_PATH = CONF_DICT_PATH
Launch the interactive editor to set TRAM-DAG neural network model names.
Depending on call context:
- Class call:
- Requires
CONF_DICT_PATHargument. Returns nothing and does not modify a specific instance.
Instance call:
- Resolves
CONF_DICT_PATHfrom the argument orself.CONF_DICT_PATH. - Updates
self.conf_dictandself.CONF_DICT_PATHif the editor returns an updated configuration.
Parameters
CONF_DICT_PATH : str or None, optional
Path to the configuration file. Must be provided when called on
the class. For instance calls, defaults to self.CONF_DICT_PATH.
Returns
None
Raises
ValueError
If CONF_DICT_PATH is not provided and cannot be inferred
(e.g. in a class call without path).
Notes
The interactive editor is invoked via interactive_nn_names_matrix.
If it returns None, the instance configuration is left unchanged.