tramdag.TramDagDataset
Copyright 2025 Zurich University of Applied Sciences (ZHAW) Pascal Buehler, Beate Sick, Oliver Duerr
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
1""" 2Copyright 2025 Zurich University of Applied Sciences (ZHAW) 3Pascal Buehler, Beate Sick, Oliver Duerr 4 5Licensed under the Apache License, Version 2.0 (the "License"); 6you may not use this file except in compliance with the License. 7You may obtain a copy of the License at 8 9 http://www.apache.org/licenses/LICENSE-2.0 10 11Unless required by applicable law or agreed to in writing, software 12distributed under the License is distributed on an "AS IS" BASIS, 13WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14See the License for the specific language governing permissions and 15limitations under the License. 16""" 17 18 19import inspect 20import pandas as pd 21import os 22from torch.utils.data import Dataset, DataLoader 23 24from .utils.data import GenericDataset, GenericDatasetPrecomputed 25 26 27 28class TramDagDataset(Dataset): 29 30 """ 31 TramDagDataset 32 ============== 33 34 The `TramDagDataset` class handles structured data preparation for TRAM-DAG 35 models. It wraps a pandas DataFrame together with its configuration and provides 36 utilities for scaling, transformation, and efficient DataLoader construction 37 for each node in a DAG-based configuration. 38 39 --------------------------------------------------------------------- 40 Core Responsibilities 41 --------------------------------------------------------------------- 42 - Validate and store configuration metadata (`TramDagConfig`). 43 - Manage per-node settings for DataLoader creation (batch size, shuffling, workers). 44 - Compute scaling information (quantile-based min/max). 45 - Optionally precompute and cache dataset representations. 46 - Expose PyTorch Dataset and DataLoader interfaces for model training. 47 48 --------------------------------------------------------------------- 49 Key Attributes 50 --------------------------------------------------------------------- 51 - **df** : pandas.DataFrame 52 The dataset content used for building loaders and computing scaling. 53 54 - **cfg** : TramDagConfig 55 Configuration object defining nodes and variable metadata. 56 57 - **nodes_dict** : dict 58 Mapping of variable names to node specifications from the configuration. 59 60 - **loaders** : dict 61 Mapping of node names to `torch.utils.data.DataLoader` instances or `GenericDataset` objects. 62 63 - **DEFAULTS** : dict 64 Default DataLoader and dataset-related settings (e.g., batch_size, shuffle, num_workers, etc.). 65 66 --------------------------------------------------------------------- 67 Main Methods 68 --------------------------------------------------------------------- 69 - **from_dataframe(df, cfg, **kwargs)** 70 Construct the dataset directly from a pandas DataFrame. 71 72 - **compute_scaling(df=None, write=True)** 73 Compute per-variable min/max scaling values from data. 74 75 - **summary()** 76 Print dataset overview including shape, dtypes, statistics, and node settings. 77 78 --------------------------------------------------------------------- 79 Notes 80 --------------------------------------------------------------------- 81 - Intended for training data; `compute_scaling()` should use only training subsets. 82 - Compatible with both CPU and GPU DataLoader options. 83 - Strict validation of keyword arguments against `DEFAULTS` prevents silent misconfiguration. 84 85 --------------------------------------------------------------------- 86 Example 87 --------------------------------------------------------------------- 88 >>> cfg = TramDagConfig.from_json("config.json") 89 >>> dataset = TramDagDataset.from_dataframe(train_df, cfg, batch_size=1024, debug=True) 90 >>> dataset.summary() 91 >>> minmax = dataset.compute_scaling(train_df) 92 >>> loader = dataset.loaders["variable_x"] 93 >>> next(iter(loader)) 94 """ 95 96 DEFAULTS = { 97 "batch_size": 32_000, 98 "shuffle": True, 99 "num_workers": 4, 100 "pin_memory": True, 101 "return_intercept_shift": True, 102 "debug": False, 103 "transform": None, 104 "use_dataloader": True, 105 "use_precomputed": False, 106 # DataLoader extras 107 "sampler": None, 108 "batch_sampler": None, 109 "collate_fn": None, 110 "drop_last": False, 111 "timeout": 0, 112 "worker_init_fn": None, 113 "multiprocessing_context": None, 114 "generator": None, 115 "prefetch_factor": 2, 116 "persistent_workers": True, 117 "pin_memory_device": "", 118 } 119 120 def __init__(self): 121 """ 122 Initialize an empty TramDagDataset shell. 123 124 Notes 125 ----- 126 This constructor does not attach data or configuration. Use 127 `TramDagDataset.from_dataframe` to obtain a ready-to-use instance. 128 """ 129 pass 130 131 @classmethod 132 def from_dataframe(cls, df, cfg, **kwargs): 133 """ 134 Create a TramDagDataset instance directly from a pandas DataFrame. 135 136 This classmethod: 137 138 1. Validates keyword arguments against `DEFAULTS`. 139 2. Merges user overrides with defaults into a resolved settings dict. 140 3. Stores the configuration and verifies its completeness. 141 4. Applies settings to the instance. 142 5. Builds per-node datasets and DataLoaders. 143 144 Parameters 145 ---------- 146 df : pandas.DataFrame 147 Input DataFrame containing the dataset. 148 cfg : TramDagConfig 149 Configuration object defining nodes and variable metadata. 150 **kwargs 151 Optional overrides for `DEFAULTS`. All keys must exist in 152 `TramDagDataset.DEFAULTS`. Common keys include: 153 154 batch_size : int 155 Batch size for DataLoaders. 156 shuffle : bool 157 Whether to shuffle samples per epoch. 158 num_workers : int 159 Number of DataLoader workers. 160 pin_memory : bool 161 Whether to pin memory for faster host-to-device transfers. 162 return_intercept_shift : bool 163 Whether datasets should return intercept/shift information. 164 debug : bool 165 Enable debug printing. 166 transform : callable or dict or None 167 Optional transform(s) applied to samples. 168 use_dataloader : bool 169 If True, construct DataLoaders; else store raw Dataset objects. 170 use_precomputed : bool 171 If True, precompute dataset representation to disk and reload it. 172 173 Returns 174 ------- 175 TramDagDataset 176 Initialized dataset instance. 177 178 Raises 179 ------ 180 TypeError 181 If `df` is not a pandas DataFrame. 182 ValueError 183 If unknown keyword arguments are provided (when validation is enabled). 184 185 Notes 186 ----- 187 If `shuffle=True` and the inferred variable name of `df` suggests 188 validation/test data (e.g. "val", "test"), a warning is printed. 189 """ 190 self = cls() 191 if not isinstance(df, pd.DataFrame): 192 raise TypeError(f"[ERROR] df must be a pandas DataFrame, but got {type(df)}") 193 194 # validate kwargs 195 #self._validate_kwargs(kwargs, context="from_dataframe") 196 197 # merge defaults with overrides 198 settings = dict(cls.DEFAULTS) 199 settings.update(kwargs) 200 201 # store config and verify 202 self.cfg = cfg 203 self.cfg._verify_completeness() 204 205 # ouptu all setttings if debug 206 if settings.get("debug", False): 207 print("[DEBUG] TramDagDataset.from_dataframe() settings (after defaults + overrides):") 208 for k, v in settings.items(): 209 print(f" {k}: {v}") 210 211 # infer variable name automatically 212 callers_locals = inspect.currentframe().f_back.f_locals 213 inferred = None 214 for var_name, var_val in callers_locals.items(): 215 if var_val is df: 216 inferred = var_name 217 break 218 df_name = inferred or "dataframe" 219 220 if settings["shuffle"]: 221 if any(x in df_name.lower() for x in ["val", "validation", "test"]): 222 print(f"[WARNING] DataFrame '{df_name}' looks like a validation/test set → shuffle=True. Are you sure?") 223 224 # call again to ensure Warning messages if ordinal vars have missing levels 225 self.df = df.copy() 226 self._apply_settings(settings) 227 self._build_dataloaders() 228 return self 229 230 def compute_scaling(self, df: pd.DataFrame = None, write: bool = True): 231 """ 232 Compute variable-wise scaling parameters from data. 233 234 Per variable, this method computes approximate minimum and maximum 235 values using the 5th and 95th percentiles. This is typically used 236 to derive robust normalization/clipping ranges from training data. 237 238 Parameters 239 ---------- 240 df : pandas.DataFrame or None, optional 241 DataFrame used to compute scaling. If None, `self.df` is used. 242 write : bool, optional 243 Unused placeholder for interface compatibility with other components. 244 Kept for potential future extensions. Default is True. 245 246 Returns 247 ------- 248 dict 249 Mapping `{column_name: [min_value, max_value]}`, where values 250 are derived from the 0.05 and 0.95 quantiles. 251 252 Notes 253 ----- 254 If `self.debug` is True, the method emits debug messages about the 255 data source. Only training data should be used to avoid leakage. 256 """ 257 if self.debug: 258 print("[DEBUG] Make sure to provide only training data to compute_scaling!") 259 if df is None: 260 df = self.df 261 if self.debug: 262 print("[DEBUG] No DataFrame provided, using internal df.") 263 quantiles = df.quantile([0.05, 0.95]) 264 min_vals = quantiles.loc[0.05] 265 max_vals = quantiles.loc[0.95] 266 minmax_dict = pd.concat([min_vals, max_vals], axis=1).T.to_dict('list') 267 return minmax_dict 268 269 def summary(self): 270 """ 271 Print a structured overview of the dataset and configuration. 272 273 The summary includes: 274 275 1. DataFrame information: 276 - Shape 277 - Columns 278 - Head (first rows) 279 - Dtypes 280 - Descriptive statistics 281 282 2. Configuration overview: 283 - Number of nodes 284 - Loader mode (DataLoader vs. raw Dataset) 285 - Precomputation status 286 287 3. Node settings: 288 - Batch size 289 - Shuffle flag 290 - num_workers 291 - pin_memory 292 - return_intercept_shift 293 - debug 294 - transform 295 296 4. DataLoader overview: 297 - Type and length of each loader. 298 299 Parameters 300 ---------- 301 None 302 303 Returns 304 ------- 305 None 306 307 Notes 308 ----- 309 Intended for quick inspection and debugging. Uses `print` statements 310 and does not return structured metadata. 311 """ 312 313 print("\n[TramDagDataset Summary]") 314 print("=" * 60) 315 316 print("\n[DataFrame]") 317 print("Shape:", self.df.shape) 318 print("Columns:", list(self.df.columns)) 319 print("\nHead:") 320 print(self.df.head()) 321 322 print("\nDtypes:") 323 print(self.df.dtypes) 324 325 print("\nDescribe:") 326 print(self.df.describe(include="all")) 327 328 print("\n[Configuration]") 329 print(f"Nodes: {len(self.nodes_dict)}") 330 print(f"Loader mode: {'DataLoader' if self.use_dataloader else 'Direct dataset'}") 331 print(f"Precomputed: {getattr(self, 'use_precomputed', False)}") 332 333 print("\n[Node Settings]") 334 for node in self.nodes_dict.keys(): 335 batch_size = self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size 336 shuffle_flag = self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle 337 num_workers = self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers 338 pin_memory = self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory 339 rshift = self.return_intercept_shift[node] if isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift 340 debug_flag = self.debug[node] if isinstance(self.debug, dict) else self.debug 341 transform = self.transform[node] if isinstance(self.transform, dict) else self.transform 342 343 print( 344 f" Node '{node}': " 345 f"batch_size={batch_size}, " 346 f"shuffle={shuffle_flag}, " 347 f"num_workers={num_workers}, " 348 f"pin_memory={pin_memory}, " 349 f"return_intercept_shift={rshift}, " 350 f"debug={debug_flag}, " 351 f"transform={transform}" 352 ) 353 354 if hasattr(self, "loaders"): 355 print("\n[DataLoaders]") 356 for node, loader in self.loaders.items(): 357 try: 358 length = len(loader) 359 except Exception: 360 length = "?" 361 print(f" {node}: {type(loader).__name__}, len={length}") 362 363 print("=" * 60 + "\n") 364 365 def _validate_kwargs(self, kwargs: dict, defaults_attr: str = "DEFAULTS", context: str = None): 366 """ 367 Validate keyword arguments against a defaults dictionary. 368 369 Parameters 370 ---------- 371 kwargs : dict 372 Keyword arguments to validate. 373 defaults_attr : str, optional 374 Name of the attribute on this class containing the default keys 375 (e.g. "DEFAULTS"). Default is "DEFAULTS". 376 context : str or None, optional 377 Optional context label (e.g. caller name) included in error messages. 378 379 Raises 380 ------ 381 AttributeError 382 If the attribute named by `defaults_attr` does not exist. 383 ValueError 384 If any keys in `kwargs` are not present in the defaults dictionary. 385 """ 386 defaults = getattr(self, defaults_attr, None) 387 if defaults is None: 388 raise AttributeError(f"{self.__class__.__name__} has no attribute '{defaults_attr}'") 389 390 unknown = set(kwargs) - set(defaults) 391 if unknown: 392 prefix = f"[{context}] " if context else "" 393 raise ValueError(f"{prefix}Unknown parameter(s): {', '.join(sorted(unknown))}") 394 395 def _apply_settings(self, settings: dict): 396 """ 397 Apply resolved settings to the dataset instance. 398 399 This method: 400 401 1. Stores all key–value pairs from `settings` as attributes on `self`. 402 2. Extracts `nodes_dict` from the configuration. 403 3. Validates that dict-valued core attributes (batch_size, shuffle, etc.) 404 have keys matching the node set. 405 406 Parameters 407 ---------- 408 settings : dict 409 Resolved settings dictionary, usually built from `DEFAULTS` plus 410 user overrides. 411 412 Returns 413 ------- 414 None 415 416 Raises 417 ------ 418 ValueError 419 If any dict-valued core attribute has keys that do not match 420 `cfg.conf_dict["nodes"].keys()`. 421 """ 422 for k, v in settings.items(): 423 setattr(self, k, v) 424 425 self.nodes_dict = self.cfg.conf_dict["nodes"] 426 427 # validate only the most important ones 428 for name in ["batch_size", "shuffle", "num_workers", "pin_memory", 429 "return_intercept_shift", "debug", "transform"]: 430 self._check_keys(name, getattr(self, name)) 431 432 def _build_dataloaders(self): 433 """Build node-specific dataloaders or raw datasets depending on settings.""" 434 self.loaders = {} 435 for node in self.nodes_dict: 436 ds = GenericDataset( 437 self.df, 438 target_col=node, 439 target_nodes=self.nodes_dict, 440 transform=self.transform if not isinstance(self.transform, dict) else self.transform[node], 441 return_intercept_shift=self.return_intercept_shift if not isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift[node], 442 debug=self.debug if not isinstance(self.debug, dict) else self.debug[node], 443 ) 444 445 ########## QUICK PATCH 446 if hasattr(self, "use_precomputed") and self.use_precomputed: 447 os.makedirs("temp", exist_ok=True) 448 pth = os.path.join("temp", "precomputed.pt") 449 450 if hasattr(ds, "save_precomputed") and callable(getattr(ds, "save_precomputed")): 451 ds.save_precomputed(pth) 452 ds = GenericDatasetPrecomputed(pth) 453 else: 454 print("[WARNING] Dataset has no 'save_precomputed()' method — skipping precomputation.") 455 456 457 if self.use_dataloader: 458 # resolve per-node overrides 459 kwargs = { 460 "batch_size": self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size, 461 "shuffle": self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle, 462 "num_workers": self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers, 463 "pin_memory": self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory, 464 "sampler": self.sampler, 465 "batch_sampler": self.batch_sampler, 466 "collate_fn": self.collate_fn, 467 "drop_last": self.drop_last, 468 "timeout": self.timeout, 469 "worker_init_fn": self.worker_init_fn, 470 "multiprocessing_context": self.multiprocessing_context, 471 "generator": self.generator, 472 "prefetch_factor": self.prefetch_factor, 473 "persistent_workers": self.persistent_workers, 474 "pin_memory_device": self.pin_memory_device, 475 } 476 self.loaders[node] = DataLoader(ds, **kwargs) 477 else: 478 self.loaders[node] = ds 479 480 if hasattr(self, "use_precomputed") and self.use_precomputed: 481 if os.path.exists(pth): 482 try: 483 os.remove(pth) 484 if self.debug: 485 print(f"[INFO] Removed existing precomputed file: {pth}") 486 except Exception as e: 487 print(f"[WARNING] Could not remove {pth}: {e}") 488 489 def _check_keys(self, attr_name, attr_value): 490 """ 491 Check that dict-valued attributes use node names as keys. 492 493 Parameters 494 ---------- 495 attr_name : str 496 Name of the attribute being checked (for error messages). 497 attr_value : Any 498 Attribute value. If it is a dict, its keys are validated. 499 500 Returns 501 ------- 502 None 503 504 Raises 505 ------ 506 ValueError 507 If `attr_value` is a dict and its keys do not exactly match 508 `cfg.conf_dict["nodes"].keys()`. 509 510 Notes 511 ----- 512 This check prevents partial or mismatched per-node settings such as 513 batch sizes or shuffle flags. 514 """ 515 if isinstance(attr_value, dict): 516 expected_keys = set(self.nodes_dict.keys()) 517 given_keys = set(attr_value.keys()) 518 if expected_keys != given_keys: 519 raise ValueError( 520 f"[ERROR] the provided attribute '{attr_name}' keys are not same as in cfg.conf_dict['nodes'].keys().\n" 521 f"Expected: {expected_keys}, but got: {given_keys}\n" 522 f"Please provide values for all variables." 523 ) 524 525 def __getitem__(self, idx): 526 return self.df.iloc[idx].to_dict() 527 528 def __len__(self): 529 return len(self.df)
29class TramDagDataset(Dataset): 30 31 """ 32 TramDagDataset 33 ============== 34 35 The `TramDagDataset` class handles structured data preparation for TRAM-DAG 36 models. It wraps a pandas DataFrame together with its configuration and provides 37 utilities for scaling, transformation, and efficient DataLoader construction 38 for each node in a DAG-based configuration. 39 40 --------------------------------------------------------------------- 41 Core Responsibilities 42 --------------------------------------------------------------------- 43 - Validate and store configuration metadata (`TramDagConfig`). 44 - Manage per-node settings for DataLoader creation (batch size, shuffling, workers). 45 - Compute scaling information (quantile-based min/max). 46 - Optionally precompute and cache dataset representations. 47 - Expose PyTorch Dataset and DataLoader interfaces for model training. 48 49 --------------------------------------------------------------------- 50 Key Attributes 51 --------------------------------------------------------------------- 52 - **df** : pandas.DataFrame 53 The dataset content used for building loaders and computing scaling. 54 55 - **cfg** : TramDagConfig 56 Configuration object defining nodes and variable metadata. 57 58 - **nodes_dict** : dict 59 Mapping of variable names to node specifications from the configuration. 60 61 - **loaders** : dict 62 Mapping of node names to `torch.utils.data.DataLoader` instances or `GenericDataset` objects. 63 64 - **DEFAULTS** : dict 65 Default DataLoader and dataset-related settings (e.g., batch_size, shuffle, num_workers, etc.). 66 67 --------------------------------------------------------------------- 68 Main Methods 69 --------------------------------------------------------------------- 70 - **from_dataframe(df, cfg, **kwargs)** 71 Construct the dataset directly from a pandas DataFrame. 72 73 - **compute_scaling(df=None, write=True)** 74 Compute per-variable min/max scaling values from data. 75 76 - **summary()** 77 Print dataset overview including shape, dtypes, statistics, and node settings. 78 79 --------------------------------------------------------------------- 80 Notes 81 --------------------------------------------------------------------- 82 - Intended for training data; `compute_scaling()` should use only training subsets. 83 - Compatible with both CPU and GPU DataLoader options. 84 - Strict validation of keyword arguments against `DEFAULTS` prevents silent misconfiguration. 85 86 --------------------------------------------------------------------- 87 Example 88 --------------------------------------------------------------------- 89 >>> cfg = TramDagConfig.from_json("config.json") 90 >>> dataset = TramDagDataset.from_dataframe(train_df, cfg, batch_size=1024, debug=True) 91 >>> dataset.summary() 92 >>> minmax = dataset.compute_scaling(train_df) 93 >>> loader = dataset.loaders["variable_x"] 94 >>> next(iter(loader)) 95 """ 96 97 DEFAULTS = { 98 "batch_size": 32_000, 99 "shuffle": True, 100 "num_workers": 4, 101 "pin_memory": True, 102 "return_intercept_shift": True, 103 "debug": False, 104 "transform": None, 105 "use_dataloader": True, 106 "use_precomputed": False, 107 # DataLoader extras 108 "sampler": None, 109 "batch_sampler": None, 110 "collate_fn": None, 111 "drop_last": False, 112 "timeout": 0, 113 "worker_init_fn": None, 114 "multiprocessing_context": None, 115 "generator": None, 116 "prefetch_factor": 2, 117 "persistent_workers": True, 118 "pin_memory_device": "", 119 } 120 121 def __init__(self): 122 """ 123 Initialize an empty TramDagDataset shell. 124 125 Notes 126 ----- 127 This constructor does not attach data or configuration. Use 128 `TramDagDataset.from_dataframe` to obtain a ready-to-use instance. 129 """ 130 pass 131 132 @classmethod 133 def from_dataframe(cls, df, cfg, **kwargs): 134 """ 135 Create a TramDagDataset instance directly from a pandas DataFrame. 136 137 This classmethod: 138 139 1. Validates keyword arguments against `DEFAULTS`. 140 2. Merges user overrides with defaults into a resolved settings dict. 141 3. Stores the configuration and verifies its completeness. 142 4. Applies settings to the instance. 143 5. Builds per-node datasets and DataLoaders. 144 145 Parameters 146 ---------- 147 df : pandas.DataFrame 148 Input DataFrame containing the dataset. 149 cfg : TramDagConfig 150 Configuration object defining nodes and variable metadata. 151 **kwargs 152 Optional overrides for `DEFAULTS`. All keys must exist in 153 `TramDagDataset.DEFAULTS`. Common keys include: 154 155 batch_size : int 156 Batch size for DataLoaders. 157 shuffle : bool 158 Whether to shuffle samples per epoch. 159 num_workers : int 160 Number of DataLoader workers. 161 pin_memory : bool 162 Whether to pin memory for faster host-to-device transfers. 163 return_intercept_shift : bool 164 Whether datasets should return intercept/shift information. 165 debug : bool 166 Enable debug printing. 167 transform : callable or dict or None 168 Optional transform(s) applied to samples. 169 use_dataloader : bool 170 If True, construct DataLoaders; else store raw Dataset objects. 171 use_precomputed : bool 172 If True, precompute dataset representation to disk and reload it. 173 174 Returns 175 ------- 176 TramDagDataset 177 Initialized dataset instance. 178 179 Raises 180 ------ 181 TypeError 182 If `df` is not a pandas DataFrame. 183 ValueError 184 If unknown keyword arguments are provided (when validation is enabled). 185 186 Notes 187 ----- 188 If `shuffle=True` and the inferred variable name of `df` suggests 189 validation/test data (e.g. "val", "test"), a warning is printed. 190 """ 191 self = cls() 192 if not isinstance(df, pd.DataFrame): 193 raise TypeError(f"[ERROR] df must be a pandas DataFrame, but got {type(df)}") 194 195 # validate kwargs 196 #self._validate_kwargs(kwargs, context="from_dataframe") 197 198 # merge defaults with overrides 199 settings = dict(cls.DEFAULTS) 200 settings.update(kwargs) 201 202 # store config and verify 203 self.cfg = cfg 204 self.cfg._verify_completeness() 205 206 # ouptu all setttings if debug 207 if settings.get("debug", False): 208 print("[DEBUG] TramDagDataset.from_dataframe() settings (after defaults + overrides):") 209 for k, v in settings.items(): 210 print(f" {k}: {v}") 211 212 # infer variable name automatically 213 callers_locals = inspect.currentframe().f_back.f_locals 214 inferred = None 215 for var_name, var_val in callers_locals.items(): 216 if var_val is df: 217 inferred = var_name 218 break 219 df_name = inferred or "dataframe" 220 221 if settings["shuffle"]: 222 if any(x in df_name.lower() for x in ["val", "validation", "test"]): 223 print(f"[WARNING] DataFrame '{df_name}' looks like a validation/test set → shuffle=True. Are you sure?") 224 225 # call again to ensure Warning messages if ordinal vars have missing levels 226 self.df = df.copy() 227 self._apply_settings(settings) 228 self._build_dataloaders() 229 return self 230 231 def compute_scaling(self, df: pd.DataFrame = None, write: bool = True): 232 """ 233 Compute variable-wise scaling parameters from data. 234 235 Per variable, this method computes approximate minimum and maximum 236 values using the 5th and 95th percentiles. This is typically used 237 to derive robust normalization/clipping ranges from training data. 238 239 Parameters 240 ---------- 241 df : pandas.DataFrame or None, optional 242 DataFrame used to compute scaling. If None, `self.df` is used. 243 write : bool, optional 244 Unused placeholder for interface compatibility with other components. 245 Kept for potential future extensions. Default is True. 246 247 Returns 248 ------- 249 dict 250 Mapping `{column_name: [min_value, max_value]}`, where values 251 are derived from the 0.05 and 0.95 quantiles. 252 253 Notes 254 ----- 255 If `self.debug` is True, the method emits debug messages about the 256 data source. Only training data should be used to avoid leakage. 257 """ 258 if self.debug: 259 print("[DEBUG] Make sure to provide only training data to compute_scaling!") 260 if df is None: 261 df = self.df 262 if self.debug: 263 print("[DEBUG] No DataFrame provided, using internal df.") 264 quantiles = df.quantile([0.05, 0.95]) 265 min_vals = quantiles.loc[0.05] 266 max_vals = quantiles.loc[0.95] 267 minmax_dict = pd.concat([min_vals, max_vals], axis=1).T.to_dict('list') 268 return minmax_dict 269 270 def summary(self): 271 """ 272 Print a structured overview of the dataset and configuration. 273 274 The summary includes: 275 276 1. DataFrame information: 277 - Shape 278 - Columns 279 - Head (first rows) 280 - Dtypes 281 - Descriptive statistics 282 283 2. Configuration overview: 284 - Number of nodes 285 - Loader mode (DataLoader vs. raw Dataset) 286 - Precomputation status 287 288 3. Node settings: 289 - Batch size 290 - Shuffle flag 291 - num_workers 292 - pin_memory 293 - return_intercept_shift 294 - debug 295 - transform 296 297 4. DataLoader overview: 298 - Type and length of each loader. 299 300 Parameters 301 ---------- 302 None 303 304 Returns 305 ------- 306 None 307 308 Notes 309 ----- 310 Intended for quick inspection and debugging. Uses `print` statements 311 and does not return structured metadata. 312 """ 313 314 print("\n[TramDagDataset Summary]") 315 print("=" * 60) 316 317 print("\n[DataFrame]") 318 print("Shape:", self.df.shape) 319 print("Columns:", list(self.df.columns)) 320 print("\nHead:") 321 print(self.df.head()) 322 323 print("\nDtypes:") 324 print(self.df.dtypes) 325 326 print("\nDescribe:") 327 print(self.df.describe(include="all")) 328 329 print("\n[Configuration]") 330 print(f"Nodes: {len(self.nodes_dict)}") 331 print(f"Loader mode: {'DataLoader' if self.use_dataloader else 'Direct dataset'}") 332 print(f"Precomputed: {getattr(self, 'use_precomputed', False)}") 333 334 print("\n[Node Settings]") 335 for node in self.nodes_dict.keys(): 336 batch_size = self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size 337 shuffle_flag = self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle 338 num_workers = self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers 339 pin_memory = self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory 340 rshift = self.return_intercept_shift[node] if isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift 341 debug_flag = self.debug[node] if isinstance(self.debug, dict) else self.debug 342 transform = self.transform[node] if isinstance(self.transform, dict) else self.transform 343 344 print( 345 f" Node '{node}': " 346 f"batch_size={batch_size}, " 347 f"shuffle={shuffle_flag}, " 348 f"num_workers={num_workers}, " 349 f"pin_memory={pin_memory}, " 350 f"return_intercept_shift={rshift}, " 351 f"debug={debug_flag}, " 352 f"transform={transform}" 353 ) 354 355 if hasattr(self, "loaders"): 356 print("\n[DataLoaders]") 357 for node, loader in self.loaders.items(): 358 try: 359 length = len(loader) 360 except Exception: 361 length = "?" 362 print(f" {node}: {type(loader).__name__}, len={length}") 363 364 print("=" * 60 + "\n") 365 366 def _validate_kwargs(self, kwargs: dict, defaults_attr: str = "DEFAULTS", context: str = None): 367 """ 368 Validate keyword arguments against a defaults dictionary. 369 370 Parameters 371 ---------- 372 kwargs : dict 373 Keyword arguments to validate. 374 defaults_attr : str, optional 375 Name of the attribute on this class containing the default keys 376 (e.g. "DEFAULTS"). Default is "DEFAULTS". 377 context : str or None, optional 378 Optional context label (e.g. caller name) included in error messages. 379 380 Raises 381 ------ 382 AttributeError 383 If the attribute named by `defaults_attr` does not exist. 384 ValueError 385 If any keys in `kwargs` are not present in the defaults dictionary. 386 """ 387 defaults = getattr(self, defaults_attr, None) 388 if defaults is None: 389 raise AttributeError(f"{self.__class__.__name__} has no attribute '{defaults_attr}'") 390 391 unknown = set(kwargs) - set(defaults) 392 if unknown: 393 prefix = f"[{context}] " if context else "" 394 raise ValueError(f"{prefix}Unknown parameter(s): {', '.join(sorted(unknown))}") 395 396 def _apply_settings(self, settings: dict): 397 """ 398 Apply resolved settings to the dataset instance. 399 400 This method: 401 402 1. Stores all key–value pairs from `settings` as attributes on `self`. 403 2. Extracts `nodes_dict` from the configuration. 404 3. Validates that dict-valued core attributes (batch_size, shuffle, etc.) 405 have keys matching the node set. 406 407 Parameters 408 ---------- 409 settings : dict 410 Resolved settings dictionary, usually built from `DEFAULTS` plus 411 user overrides. 412 413 Returns 414 ------- 415 None 416 417 Raises 418 ------ 419 ValueError 420 If any dict-valued core attribute has keys that do not match 421 `cfg.conf_dict["nodes"].keys()`. 422 """ 423 for k, v in settings.items(): 424 setattr(self, k, v) 425 426 self.nodes_dict = self.cfg.conf_dict["nodes"] 427 428 # validate only the most important ones 429 for name in ["batch_size", "shuffle", "num_workers", "pin_memory", 430 "return_intercept_shift", "debug", "transform"]: 431 self._check_keys(name, getattr(self, name)) 432 433 def _build_dataloaders(self): 434 """Build node-specific dataloaders or raw datasets depending on settings.""" 435 self.loaders = {} 436 for node in self.nodes_dict: 437 ds = GenericDataset( 438 self.df, 439 target_col=node, 440 target_nodes=self.nodes_dict, 441 transform=self.transform if not isinstance(self.transform, dict) else self.transform[node], 442 return_intercept_shift=self.return_intercept_shift if not isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift[node], 443 debug=self.debug if not isinstance(self.debug, dict) else self.debug[node], 444 ) 445 446 ########## QUICK PATCH 447 if hasattr(self, "use_precomputed") and self.use_precomputed: 448 os.makedirs("temp", exist_ok=True) 449 pth = os.path.join("temp", "precomputed.pt") 450 451 if hasattr(ds, "save_precomputed") and callable(getattr(ds, "save_precomputed")): 452 ds.save_precomputed(pth) 453 ds = GenericDatasetPrecomputed(pth) 454 else: 455 print("[WARNING] Dataset has no 'save_precomputed()' method — skipping precomputation.") 456 457 458 if self.use_dataloader: 459 # resolve per-node overrides 460 kwargs = { 461 "batch_size": self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size, 462 "shuffle": self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle, 463 "num_workers": self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers, 464 "pin_memory": self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory, 465 "sampler": self.sampler, 466 "batch_sampler": self.batch_sampler, 467 "collate_fn": self.collate_fn, 468 "drop_last": self.drop_last, 469 "timeout": self.timeout, 470 "worker_init_fn": self.worker_init_fn, 471 "multiprocessing_context": self.multiprocessing_context, 472 "generator": self.generator, 473 "prefetch_factor": self.prefetch_factor, 474 "persistent_workers": self.persistent_workers, 475 "pin_memory_device": self.pin_memory_device, 476 } 477 self.loaders[node] = DataLoader(ds, **kwargs) 478 else: 479 self.loaders[node] = ds 480 481 if hasattr(self, "use_precomputed") and self.use_precomputed: 482 if os.path.exists(pth): 483 try: 484 os.remove(pth) 485 if self.debug: 486 print(f"[INFO] Removed existing precomputed file: {pth}") 487 except Exception as e: 488 print(f"[WARNING] Could not remove {pth}: {e}") 489 490 def _check_keys(self, attr_name, attr_value): 491 """ 492 Check that dict-valued attributes use node names as keys. 493 494 Parameters 495 ---------- 496 attr_name : str 497 Name of the attribute being checked (for error messages). 498 attr_value : Any 499 Attribute value. If it is a dict, its keys are validated. 500 501 Returns 502 ------- 503 None 504 505 Raises 506 ------ 507 ValueError 508 If `attr_value` is a dict and its keys do not exactly match 509 `cfg.conf_dict["nodes"].keys()`. 510 511 Notes 512 ----- 513 This check prevents partial or mismatched per-node settings such as 514 batch sizes or shuffle flags. 515 """ 516 if isinstance(attr_value, dict): 517 expected_keys = set(self.nodes_dict.keys()) 518 given_keys = set(attr_value.keys()) 519 if expected_keys != given_keys: 520 raise ValueError( 521 f"[ERROR] the provided attribute '{attr_name}' keys are not same as in cfg.conf_dict['nodes'].keys().\n" 522 f"Expected: {expected_keys}, but got: {given_keys}\n" 523 f"Please provide values for all variables." 524 ) 525 526 def __getitem__(self, idx): 527 return self.df.iloc[idx].to_dict() 528 529 def __len__(self): 530 return len(self.df)
TramDagDataset
The TramDagDataset class handles structured data preparation for TRAM-DAG
models. It wraps a pandas DataFrame together with its configuration and provides
utilities for scaling, transformation, and efficient DataLoader construction
for each node in a DAG-based configuration.
Core Responsibilities
- Validate and store configuration metadata (
TramDagConfig). - Manage per-node settings for DataLoader creation (batch size, shuffling, workers).
- Compute scaling information (quantile-based min/max).
- Optionally precompute and cache dataset representations.
- Expose PyTorch Dataset and DataLoader interfaces for model training.
Key Attributes
df : pandas.DataFrame
The dataset content used for building loaders and computing scaling.cfg : TramDagConfig
Configuration object defining nodes and variable metadata.nodes_dict : dict
Mapping of variable names to node specifications from the configuration.loaders : dict
Mapping of node names totorch.utils.data.DataLoaderinstances orGenericDatasetobjects.DEFAULTS : dict
Default DataLoader and dataset-related settings (e.g., batch_size, shuffle, num_workers, etc.).
Main Methods
from_dataframe(df, cfg, **kwargs)
Construct the dataset directly from a pandas DataFrame.compute_scaling(df=None, write=True)
Compute per-variable min/max scaling values from data.summary()
Print dataset overview including shape, dtypes, statistics, and node settings.
Notes
- Intended for training data;
compute_scaling()should use only training subsets. - Compatible with both CPU and GPU DataLoader options.
- Strict validation of keyword arguments against
DEFAULTSprevents silent misconfiguration.
Example
>>> cfg = TramDagConfig.from_json("config.json")
>>> dataset = TramDagDataset.from_dataframe(train_df, cfg, batch_size=1024, debug=True)
>>> dataset.summary()
>>> minmax = dataset.compute_scaling(train_df)
>>> loader = dataset.loaders["variable_x"]
>>> next(iter(loader))
121 def __init__(self): 122 """ 123 Initialize an empty TramDagDataset shell. 124 125 Notes 126 ----- 127 This constructor does not attach data or configuration. Use 128 `TramDagDataset.from_dataframe` to obtain a ready-to-use instance. 129 """ 130 pass
Initialize an empty TramDagDataset shell.
Notes
This constructor does not attach data or configuration. Use
TramDagDataset.from_dataframe to obtain a ready-to-use instance.
132 @classmethod 133 def from_dataframe(cls, df, cfg, **kwargs): 134 """ 135 Create a TramDagDataset instance directly from a pandas DataFrame. 136 137 This classmethod: 138 139 1. Validates keyword arguments against `DEFAULTS`. 140 2. Merges user overrides with defaults into a resolved settings dict. 141 3. Stores the configuration and verifies its completeness. 142 4. Applies settings to the instance. 143 5. Builds per-node datasets and DataLoaders. 144 145 Parameters 146 ---------- 147 df : pandas.DataFrame 148 Input DataFrame containing the dataset. 149 cfg : TramDagConfig 150 Configuration object defining nodes and variable metadata. 151 **kwargs 152 Optional overrides for `DEFAULTS`. All keys must exist in 153 `TramDagDataset.DEFAULTS`. Common keys include: 154 155 batch_size : int 156 Batch size for DataLoaders. 157 shuffle : bool 158 Whether to shuffle samples per epoch. 159 num_workers : int 160 Number of DataLoader workers. 161 pin_memory : bool 162 Whether to pin memory for faster host-to-device transfers. 163 return_intercept_shift : bool 164 Whether datasets should return intercept/shift information. 165 debug : bool 166 Enable debug printing. 167 transform : callable or dict or None 168 Optional transform(s) applied to samples. 169 use_dataloader : bool 170 If True, construct DataLoaders; else store raw Dataset objects. 171 use_precomputed : bool 172 If True, precompute dataset representation to disk and reload it. 173 174 Returns 175 ------- 176 TramDagDataset 177 Initialized dataset instance. 178 179 Raises 180 ------ 181 TypeError 182 If `df` is not a pandas DataFrame. 183 ValueError 184 If unknown keyword arguments are provided (when validation is enabled). 185 186 Notes 187 ----- 188 If `shuffle=True` and the inferred variable name of `df` suggests 189 validation/test data (e.g. "val", "test"), a warning is printed. 190 """ 191 self = cls() 192 if not isinstance(df, pd.DataFrame): 193 raise TypeError(f"[ERROR] df must be a pandas DataFrame, but got {type(df)}") 194 195 # validate kwargs 196 #self._validate_kwargs(kwargs, context="from_dataframe") 197 198 # merge defaults with overrides 199 settings = dict(cls.DEFAULTS) 200 settings.update(kwargs) 201 202 # store config and verify 203 self.cfg = cfg 204 self.cfg._verify_completeness() 205 206 # ouptu all setttings if debug 207 if settings.get("debug", False): 208 print("[DEBUG] TramDagDataset.from_dataframe() settings (after defaults + overrides):") 209 for k, v in settings.items(): 210 print(f" {k}: {v}") 211 212 # infer variable name automatically 213 callers_locals = inspect.currentframe().f_back.f_locals 214 inferred = None 215 for var_name, var_val in callers_locals.items(): 216 if var_val is df: 217 inferred = var_name 218 break 219 df_name = inferred or "dataframe" 220 221 if settings["shuffle"]: 222 if any(x in df_name.lower() for x in ["val", "validation", "test"]): 223 print(f"[WARNING] DataFrame '{df_name}' looks like a validation/test set → shuffle=True. Are you sure?") 224 225 # call again to ensure Warning messages if ordinal vars have missing levels 226 self.df = df.copy() 227 self._apply_settings(settings) 228 self._build_dataloaders() 229 return self
Create a TramDagDataset instance directly from a pandas DataFrame.
This classmethod:
- Validates keyword arguments against
DEFAULTS. - Merges user overrides with defaults into a resolved settings dict.
- Stores the configuration and verifies its completeness.
- Applies settings to the instance.
- Builds per-node datasets and DataLoaders.
Parameters
df : pandas.DataFrame
Input DataFrame containing the dataset.
cfg : TramDagConfig
Configuration object defining nodes and variable metadata.
**kwargs
Optional overrides for DEFAULTS. All keys must exist in
TramDagDataset.DEFAULTS. Common keys include:
batch_size : int
Batch size for DataLoaders.
shuffle : bool
Whether to shuffle samples per epoch.
num_workers : int
Number of DataLoader workers.
pin_memory : bool
Whether to pin memory for faster host-to-device transfers.
return_intercept_shift : bool
Whether datasets should return intercept/shift information.
debug : bool
Enable debug printing.
transform : callable or dict or None
Optional transform(s) applied to samples.
use_dataloader : bool
If True, construct DataLoaders; else store raw Dataset objects.
use_precomputed : bool
If True, precompute dataset representation to disk and reload it.
Returns
TramDagDataset Initialized dataset instance.
Raises
TypeError
If df is not a pandas DataFrame.
ValueError
If unknown keyword arguments are provided (when validation is enabled).
Notes
If shuffle=True and the inferred variable name of df suggests
validation/test data (e.g. "val", "test"), a warning is printed.
231 def compute_scaling(self, df: pd.DataFrame = None, write: bool = True): 232 """ 233 Compute variable-wise scaling parameters from data. 234 235 Per variable, this method computes approximate minimum and maximum 236 values using the 5th and 95th percentiles. This is typically used 237 to derive robust normalization/clipping ranges from training data. 238 239 Parameters 240 ---------- 241 df : pandas.DataFrame or None, optional 242 DataFrame used to compute scaling. If None, `self.df` is used. 243 write : bool, optional 244 Unused placeholder for interface compatibility with other components. 245 Kept for potential future extensions. Default is True. 246 247 Returns 248 ------- 249 dict 250 Mapping `{column_name: [min_value, max_value]}`, where values 251 are derived from the 0.05 and 0.95 quantiles. 252 253 Notes 254 ----- 255 If `self.debug` is True, the method emits debug messages about the 256 data source. Only training data should be used to avoid leakage. 257 """ 258 if self.debug: 259 print("[DEBUG] Make sure to provide only training data to compute_scaling!") 260 if df is None: 261 df = self.df 262 if self.debug: 263 print("[DEBUG] No DataFrame provided, using internal df.") 264 quantiles = df.quantile([0.05, 0.95]) 265 min_vals = quantiles.loc[0.05] 266 max_vals = quantiles.loc[0.95] 267 minmax_dict = pd.concat([min_vals, max_vals], axis=1).T.to_dict('list') 268 return minmax_dict
Compute variable-wise scaling parameters from data.
Per variable, this method computes approximate minimum and maximum values using the 5th and 95th percentiles. This is typically used to derive robust normalization/clipping ranges from training data.
Parameters
df : pandas.DataFrame or None, optional
DataFrame used to compute scaling. If None, self.df is used.
write : bool, optional
Unused placeholder for interface compatibility with other components.
Kept for potential future extensions. Default is True.
Returns
dict
Mapping {column_name: [min_value, max_value]}, where values
are derived from the 0.05 and 0.95 quantiles.
Notes
If self.debug is True, the method emits debug messages about the
data source. Only training data should be used to avoid leakage.
270 def summary(self): 271 """ 272 Print a structured overview of the dataset and configuration. 273 274 The summary includes: 275 276 1. DataFrame information: 277 - Shape 278 - Columns 279 - Head (first rows) 280 - Dtypes 281 - Descriptive statistics 282 283 2. Configuration overview: 284 - Number of nodes 285 - Loader mode (DataLoader vs. raw Dataset) 286 - Precomputation status 287 288 3. Node settings: 289 - Batch size 290 - Shuffle flag 291 - num_workers 292 - pin_memory 293 - return_intercept_shift 294 - debug 295 - transform 296 297 4. DataLoader overview: 298 - Type and length of each loader. 299 300 Parameters 301 ---------- 302 None 303 304 Returns 305 ------- 306 None 307 308 Notes 309 ----- 310 Intended for quick inspection and debugging. Uses `print` statements 311 and does not return structured metadata. 312 """ 313 314 print("\n[TramDagDataset Summary]") 315 print("=" * 60) 316 317 print("\n[DataFrame]") 318 print("Shape:", self.df.shape) 319 print("Columns:", list(self.df.columns)) 320 print("\nHead:") 321 print(self.df.head()) 322 323 print("\nDtypes:") 324 print(self.df.dtypes) 325 326 print("\nDescribe:") 327 print(self.df.describe(include="all")) 328 329 print("\n[Configuration]") 330 print(f"Nodes: {len(self.nodes_dict)}") 331 print(f"Loader mode: {'DataLoader' if self.use_dataloader else 'Direct dataset'}") 332 print(f"Precomputed: {getattr(self, 'use_precomputed', False)}") 333 334 print("\n[Node Settings]") 335 for node in self.nodes_dict.keys(): 336 batch_size = self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size 337 shuffle_flag = self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle 338 num_workers = self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers 339 pin_memory = self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory 340 rshift = self.return_intercept_shift[node] if isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift 341 debug_flag = self.debug[node] if isinstance(self.debug, dict) else self.debug 342 transform = self.transform[node] if isinstance(self.transform, dict) else self.transform 343 344 print( 345 f" Node '{node}': " 346 f"batch_size={batch_size}, " 347 f"shuffle={shuffle_flag}, " 348 f"num_workers={num_workers}, " 349 f"pin_memory={pin_memory}, " 350 f"return_intercept_shift={rshift}, " 351 f"debug={debug_flag}, " 352 f"transform={transform}" 353 ) 354 355 if hasattr(self, "loaders"): 356 print("\n[DataLoaders]") 357 for node, loader in self.loaders.items(): 358 try: 359 length = len(loader) 360 except Exception: 361 length = "?" 362 print(f" {node}: {type(loader).__name__}, len={length}") 363 364 print("=" * 60 + "\n")
Print a structured overview of the dataset and configuration.
The summary includes:
- DataFrame information:
- Shape
- Columns
- Head (first rows)
- Dtypes
- Descriptive statistics
- Configuration overview:
- Number of nodes
- Loader mode (DataLoader vs. raw Dataset)
- Precomputation status
- Node settings:
- Batch size
- Shuffle flag
- num_workers
- pin_memory
- return_intercept_shift
- debug
- transform
- DataLoader overview:
- Type and length of each loader.
Parameters
None
Returns
None
Notes
Intended for quick inspection and debugging. Uses print statements
and does not return structured metadata.