tramdag.TramDagDataset

Copyright 2025 Zurich University of Applied Sciences (ZHAW) Pascal Buehler, Beate Sick, Oliver Duerr

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

  1"""
  2Copyright 2025 Zurich University of Applied Sciences (ZHAW)
  3Pascal Buehler, Beate Sick, Oliver Duerr
  4
  5Licensed under the Apache License, Version 2.0 (the "License");
  6you may not use this file except in compliance with the License.
  7You may obtain a copy of the License at
  8
  9    http://www.apache.org/licenses/LICENSE-2.0
 10
 11Unless required by applicable law or agreed to in writing, software
 12distributed under the License is distributed on an "AS IS" BASIS,
 13WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14See the License for the specific language governing permissions and
 15limitations under the License.
 16"""
 17
 18
 19import inspect
 20import pandas as pd
 21import os
 22from torch.utils.data import Dataset, DataLoader
 23
 24from .utils.data import GenericDataset, GenericDatasetPrecomputed
 25
 26
 27
 28class TramDagDataset(Dataset):
 29    
 30    """
 31    TramDagDataset
 32    ==============
 33
 34    The `TramDagDataset` class handles structured data preparation for TRAM-DAG
 35    models. It wraps a pandas DataFrame together with its configuration and provides
 36    utilities for scaling, transformation, and efficient DataLoader construction
 37    for each node in a DAG-based configuration.
 38
 39    ---------------------------------------------------------------------
 40    Core Responsibilities
 41    ---------------------------------------------------------------------
 42    - Validate and store configuration metadata (`TramDagConfig`).
 43    - Manage per-node settings for DataLoader creation (batch size, shuffling, workers).
 44    - Compute scaling information (quantile-based min/max).
 45    - Optionally precompute and cache dataset representations.
 46    - Expose PyTorch Dataset and DataLoader interfaces for model training.
 47
 48    ---------------------------------------------------------------------
 49    Key Attributes
 50    ---------------------------------------------------------------------
 51    - **df** : pandas.DataFrame  
 52      The dataset content used for building loaders and computing scaling.
 53
 54    - **cfg** : TramDagConfig  
 55      Configuration object defining nodes and variable metadata.
 56
 57    - **nodes_dict** : dict  
 58      Mapping of variable names to node specifications from the configuration.
 59
 60    - **loaders** : dict  
 61      Mapping of node names to `torch.utils.data.DataLoader` instances or `GenericDataset` objects.
 62
 63    - **DEFAULTS** : dict  
 64      Default DataLoader and dataset-related settings (e.g., batch_size, shuffle, num_workers, etc.).
 65
 66    ---------------------------------------------------------------------
 67    Main Methods
 68    ---------------------------------------------------------------------
 69    - **from_dataframe(df, cfg, **kwargs)**  
 70      Construct the dataset directly from a pandas DataFrame.
 71
 72    - **compute_scaling(df=None, write=True)**  
 73      Compute per-variable min/max scaling values from data.
 74
 75    - **summary()**  
 76      Print dataset overview including shape, dtypes, statistics, and node settings.
 77
 78    ---------------------------------------------------------------------
 79    Notes
 80    ---------------------------------------------------------------------
 81    - Intended for training data; `compute_scaling()` should use only training subsets.
 82    - Compatible with both CPU and GPU DataLoader options.
 83    - Strict validation of keyword arguments against `DEFAULTS` prevents silent misconfiguration.
 84
 85    ---------------------------------------------------------------------
 86    Example
 87    ---------------------------------------------------------------------
 88    >>> cfg = TramDagConfig.from_json("config.json")
 89    >>> dataset = TramDagDataset.from_dataframe(train_df, cfg, batch_size=1024, debug=True)
 90    >>> dataset.summary()
 91    >>> minmax = dataset.compute_scaling(train_df)
 92    >>> loader = dataset.loaders["variable_x"]
 93    >>> next(iter(loader))
 94    """
 95
 96    DEFAULTS = {
 97        "batch_size": 32_000,
 98        "shuffle": True,
 99        "num_workers": 4,
100        "pin_memory": True,
101        "return_intercept_shift": True,
102        "debug": False,
103        "transform": None,
104        "use_dataloader": True,
105        "use_precomputed": False, 
106        # DataLoader extras
107        "sampler": None,
108        "batch_sampler": None,
109        "collate_fn": None,
110        "drop_last": False,
111        "timeout": 0,
112        "worker_init_fn": None,
113        "multiprocessing_context": None,
114        "generator": None,
115        "prefetch_factor": 2,
116        "persistent_workers": True,
117        "pin_memory_device": "",
118    }
119
120    def __init__(self):
121        """
122        Initialize an empty TramDagDataset shell.
123
124        Notes
125        -----
126        This constructor does not attach data or configuration. Use
127        `TramDagDataset.from_dataframe` to obtain a ready-to-use instance.
128        """
129        pass
130
131    @classmethod
132    def from_dataframe(cls, df, cfg, **kwargs):
133        """
134        Create a TramDagDataset instance directly from a pandas DataFrame.
135
136        This classmethod:
137
138        1. Validates keyword arguments against `DEFAULTS`.
139        2. Merges user overrides with defaults into a resolved settings dict.
140        3. Stores the configuration and verifies its completeness.
141        4. Applies settings to the instance.
142        5. Builds per-node datasets and DataLoaders.
143
144        Parameters
145        ----------
146        df : pandas.DataFrame
147            Input DataFrame containing the dataset.
148        cfg : TramDagConfig
149            Configuration object defining nodes and variable metadata.
150        **kwargs
151            Optional overrides for `DEFAULTS`. All keys must exist in
152            `TramDagDataset.DEFAULTS`. Common keys include:
153
154            batch_size : int
155                Batch size for DataLoaders.
156            shuffle : bool
157                Whether to shuffle samples per epoch.
158            num_workers : int
159                Number of DataLoader workers.
160            pin_memory : bool
161                Whether to pin memory for faster host-to-device transfers.
162            return_intercept_shift : bool
163                Whether datasets should return intercept/shift information.
164            debug : bool
165                Enable debug printing.
166            transform : callable or dict or None
167                Optional transform(s) applied to samples.
168            use_dataloader : bool
169                If True, construct DataLoaders; else store raw Dataset objects.
170            use_precomputed : bool
171                If True, precompute dataset representation to disk and reload it.
172
173        Returns
174        -------
175        TramDagDataset
176            Initialized dataset instance.
177
178        Raises
179        ------
180        TypeError
181            If `df` is not a pandas DataFrame.
182        ValueError
183            If unknown keyword arguments are provided (when validation is enabled).
184
185        Notes
186        -----
187        If `shuffle=True` and the inferred variable name of `df` suggests
188        validation/test data (e.g. "val", "test"), a warning is printed.
189        """
190        self = cls()
191        if not isinstance(df, pd.DataFrame):
192            raise TypeError(f"[ERROR] df must be a pandas DataFrame, but got {type(df)}")
193
194        # validate kwargs
195        #self._validate_kwargs(kwargs, context="from_dataframe")
196        
197        # merge defaults with overrides
198        settings = dict(cls.DEFAULTS)
199        settings.update(kwargs)
200
201        # store config and verify
202        self.cfg = cfg
203        self.cfg._verify_completeness()
204
205        # ouptu all setttings if debug
206        if settings.get("debug", False):
207            print("[DEBUG] TramDagDataset.from_dataframe() settings (after defaults + overrides):")
208            for k, v in settings.items():
209                print(f"    {k}: {v}")
210
211        # infer variable name automatically
212        callers_locals = inspect.currentframe().f_back.f_locals
213        inferred = None
214        for var_name, var_val in callers_locals.items():
215            if var_val is df:
216                inferred = var_name
217                break
218        df_name = inferred or "dataframe"
219
220        if settings["shuffle"]:
221            if any(x in df_name.lower() for x in ["val", "validation", "test"]):
222                print(f"[WARNING] DataFrame '{df_name}' looks like a validation/test set → shuffle=True. Are you sure?")
223
224        # call again to ensure Warning messages if ordinal vars have missing levels
225        self.df = df.copy()
226        self._apply_settings(settings)
227        self._build_dataloaders()
228        return self
229
230    def compute_scaling(self, df: pd.DataFrame = None, write: bool = True):
231        """
232        Compute variable-wise scaling parameters from data.
233
234        Per variable, this method computes approximate minimum and maximum
235        values using the 5th and 95th percentiles. This is typically used
236        to derive robust normalization/clipping ranges from training data.
237
238        Parameters
239        ----------
240        df : pandas.DataFrame or None, optional
241            DataFrame used to compute scaling. If None, `self.df` is used.
242        write : bool, optional
243            Unused placeholder for interface compatibility with other components.
244            Kept for potential future extensions. Default is True.
245
246        Returns
247        -------
248        dict
249            Mapping `{column_name: [min_value, max_value]}`, where values
250            are derived from the 0.05 and 0.95 quantiles.
251
252        Notes
253        -----
254        If `self.debug` is True, the method emits debug messages about the
255        data source. Only training data should be used to avoid leakage.
256        """
257        if self.debug:
258            print("[DEBUG] Make sure to provide only training data to compute_scaling!")     
259        if df is None:
260            df = self.df
261            if self.debug:
262                print("[DEBUG] No DataFrame provided, using internal df.")
263        quantiles = df.quantile([0.05, 0.95])
264        min_vals = quantiles.loc[0.05]
265        max_vals = quantiles.loc[0.95]
266        minmax_dict = pd.concat([min_vals, max_vals], axis=1).T.to_dict('list')
267        return minmax_dict
268
269    def summary(self):
270        """
271        Print a structured overview of the dataset and configuration.
272
273        The summary includes:
274
275        1. DataFrame information:
276        - Shape
277        - Columns
278        - Head (first rows)
279        - Dtypes
280        - Descriptive statistics
281
282        2. Configuration overview:
283        - Number of nodes
284        - Loader mode (DataLoader vs. raw Dataset)
285        - Precomputation status
286
287        3. Node settings:
288        - Batch size
289        - Shuffle flag
290        - num_workers
291        - pin_memory
292        - return_intercept_shift
293        - debug
294        - transform
295
296        4. DataLoader overview:
297        - Type and length of each loader.
298
299        Parameters
300        ----------
301        None
302
303        Returns
304        -------
305        None
306
307        Notes
308        -----
309        Intended for quick inspection and debugging. Uses `print` statements
310        and does not return structured metadata.
311        """
312        
313        print("\n[TramDagDataset Summary]")
314        print("=" * 60)
315
316        print("\n[DataFrame]")
317        print("Shape:", self.df.shape)
318        print("Columns:", list(self.df.columns))
319        print("\nHead:")
320        print(self.df.head())
321
322        print("\nDtypes:")
323        print(self.df.dtypes)
324
325        print("\nDescribe:")
326        print(self.df.describe(include="all"))
327
328        print("\n[Configuration]")
329        print(f"Nodes: {len(self.nodes_dict)}")
330        print(f"Loader mode: {'DataLoader' if self.use_dataloader else 'Direct dataset'}")
331        print(f"Precomputed: {getattr(self, 'use_precomputed', False)}")
332
333        print("\n[Node Settings]")
334        for node in self.nodes_dict.keys():
335            batch_size = self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size
336            shuffle_flag = self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle
337            num_workers = self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers
338            pin_memory = self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory
339            rshift = self.return_intercept_shift[node] if isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift
340            debug_flag = self.debug[node] if isinstance(self.debug, dict) else self.debug
341            transform = self.transform[node] if isinstance(self.transform, dict) else self.transform
342
343            print(
344                f" Node '{node}': "
345                f"batch_size={batch_size}, "
346                f"shuffle={shuffle_flag}, "
347                f"num_workers={num_workers}, "
348                f"pin_memory={pin_memory}, "
349                f"return_intercept_shift={rshift}, "
350                f"debug={debug_flag}, "
351                f"transform={transform}"
352            )
353
354        if hasattr(self, "loaders"):
355            print("\n[DataLoaders]")
356            for node, loader in self.loaders.items():
357                try:
358                    length = len(loader)
359                except Exception:
360                    length = "?"
361                print(f"  {node}: {type(loader).__name__}, len={length}")
362
363        print("=" * 60 + "\n")
364
365    def _validate_kwargs(self, kwargs: dict, defaults_attr: str = "DEFAULTS", context: str = None):
366        """
367        Validate keyword arguments against a defaults dictionary.
368
369        Parameters
370        ----------
371        kwargs : dict
372            Keyword arguments to validate.
373        defaults_attr : str, optional
374            Name of the attribute on this class containing the default keys
375            (e.g. "DEFAULTS"). Default is "DEFAULTS".
376        context : str or None, optional
377            Optional context label (e.g. caller name) included in error messages.
378
379        Raises
380        ------
381        AttributeError
382            If the attribute named by `defaults_attr` does not exist.
383        ValueError
384            If any keys in `kwargs` are not present in the defaults dictionary.
385        """
386        defaults = getattr(self, defaults_attr, None)
387        if defaults is None:
388            raise AttributeError(f"{self.__class__.__name__} has no attribute '{defaults_attr}'")
389
390        unknown = set(kwargs) - set(defaults)
391        if unknown:
392            prefix = f"[{context}] " if context else ""
393            raise ValueError(f"{prefix}Unknown parameter(s): {', '.join(sorted(unknown))}")
394
395    def _apply_settings(self, settings: dict):
396        """
397        Apply resolved settings to the dataset instance.
398
399        This method:
400
401        1. Stores all key–value pairs from `settings` as attributes on `self`.
402        2. Extracts `nodes_dict` from the configuration.
403        3. Validates that dict-valued core attributes (batch_size, shuffle, etc.)
404        have keys matching the node set.
405
406        Parameters
407        ----------
408        settings : dict
409            Resolved settings dictionary, usually built from `DEFAULTS` plus
410            user overrides.
411
412        Returns
413        -------
414        None
415
416        Raises
417        ------
418        ValueError
419            If any dict-valued core attribute has keys that do not match
420            `cfg.conf_dict["nodes"].keys()`.
421        """
422        for k, v in settings.items():
423            setattr(self, k, v)
424
425        self.nodes_dict = self.cfg.conf_dict["nodes"]
426
427        # validate only the most important ones
428        for name in ["batch_size", "shuffle", "num_workers", "pin_memory",
429                     "return_intercept_shift", "debug", "transform"]:
430            self._check_keys(name, getattr(self, name))
431
432    def _build_dataloaders(self):
433        """Build node-specific dataloaders or raw datasets depending on settings."""
434        self.loaders = {}
435        for node in self.nodes_dict:
436            ds = GenericDataset(
437                self.df,
438                target_col=node,
439                target_nodes=self.nodes_dict,
440                transform=self.transform if not isinstance(self.transform, dict) else self.transform[node],
441                return_intercept_shift=self.return_intercept_shift if not isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift[node],
442                debug=self.debug if not isinstance(self.debug, dict) else self.debug[node],
443            )
444
445        ########## QUICK PATCH 
446            if hasattr(self, "use_precomputed") and self.use_precomputed:
447                os.makedirs("temp", exist_ok=True) 
448                pth = os.path.join("temp", "precomputed.pt")
449
450                if hasattr(ds, "save_precomputed") and callable(getattr(ds, "save_precomputed")):
451                    ds.save_precomputed(pth)
452                    ds = GenericDatasetPrecomputed(pth)
453                else:
454                    print("[WARNING] Dataset has no 'save_precomputed()' method — skipping precomputation.")
455
456
457            if self.use_dataloader:
458                # resolve per-node overrides
459                kwargs = {
460                    "batch_size": self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size,
461                    "shuffle": self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle,
462                    "num_workers": self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers,
463                    "pin_memory": self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory,
464                    "sampler": self.sampler,
465                    "batch_sampler": self.batch_sampler,
466                    "collate_fn": self.collate_fn,
467                    "drop_last": self.drop_last,
468                    "timeout": self.timeout,
469                    "worker_init_fn": self.worker_init_fn,
470                    "multiprocessing_context": self.multiprocessing_context,
471                    "generator": self.generator,
472                    "prefetch_factor": self.prefetch_factor,
473                    "persistent_workers": self.persistent_workers,
474                    "pin_memory_device": self.pin_memory_device,
475                }
476                self.loaders[node] = DataLoader(ds, **kwargs)
477            else:
478                self.loaders[node] = ds
479                
480        if hasattr(self, "use_precomputed") and self.use_precomputed:
481            if os.path.exists(pth):
482                try:
483                    os.remove(pth)
484                    if self.debug:
485                        print(f"[INFO] Removed existing precomputed file: {pth}")
486                except Exception as e:
487                    print(f"[WARNING] Could not remove {pth}: {e}")
488
489    def _check_keys(self, attr_name, attr_value):
490        """
491        Check that dict-valued attributes use node names as keys.
492
493        Parameters
494        ----------
495        attr_name : str
496            Name of the attribute being checked (for error messages).
497        attr_value : Any
498            Attribute value. If it is a dict, its keys are validated.
499
500        Returns
501        -------
502        None
503
504        Raises
505        ------
506        ValueError
507            If `attr_value` is a dict and its keys do not exactly match
508            `cfg.conf_dict["nodes"].keys()`.
509
510        Notes
511        -----
512        This check prevents partial or mismatched per-node settings such as
513        batch sizes or shuffle flags.
514        """
515        if isinstance(attr_value, dict):
516            expected_keys = set(self.nodes_dict.keys())
517            given_keys = set(attr_value.keys())
518            if expected_keys != given_keys:
519                raise ValueError(
520                    f"[ERROR] the provided attribute '{attr_name}' keys are not same as in cfg.conf_dict['nodes'].keys().\n"
521                    f"Expected: {expected_keys}, but got: {given_keys}\n"
522                    f"Please provide values for all variables."
523                )
524
525    def __getitem__(self, idx):
526        return self.df.iloc[idx].to_dict()
527
528    def __len__(self):
529        return len(self.df)
class TramDagDataset(typing.Generic[+_T_co]):
 29class TramDagDataset(Dataset):
 30    
 31    """
 32    TramDagDataset
 33    ==============
 34
 35    The `TramDagDataset` class handles structured data preparation for TRAM-DAG
 36    models. It wraps a pandas DataFrame together with its configuration and provides
 37    utilities for scaling, transformation, and efficient DataLoader construction
 38    for each node in a DAG-based configuration.
 39
 40    ---------------------------------------------------------------------
 41    Core Responsibilities
 42    ---------------------------------------------------------------------
 43    - Validate and store configuration metadata (`TramDagConfig`).
 44    - Manage per-node settings for DataLoader creation (batch size, shuffling, workers).
 45    - Compute scaling information (quantile-based min/max).
 46    - Optionally precompute and cache dataset representations.
 47    - Expose PyTorch Dataset and DataLoader interfaces for model training.
 48
 49    ---------------------------------------------------------------------
 50    Key Attributes
 51    ---------------------------------------------------------------------
 52    - **df** : pandas.DataFrame  
 53      The dataset content used for building loaders and computing scaling.
 54
 55    - **cfg** : TramDagConfig  
 56      Configuration object defining nodes and variable metadata.
 57
 58    - **nodes_dict** : dict  
 59      Mapping of variable names to node specifications from the configuration.
 60
 61    - **loaders** : dict  
 62      Mapping of node names to `torch.utils.data.DataLoader` instances or `GenericDataset` objects.
 63
 64    - **DEFAULTS** : dict  
 65      Default DataLoader and dataset-related settings (e.g., batch_size, shuffle, num_workers, etc.).
 66
 67    ---------------------------------------------------------------------
 68    Main Methods
 69    ---------------------------------------------------------------------
 70    - **from_dataframe(df, cfg, **kwargs)**  
 71      Construct the dataset directly from a pandas DataFrame.
 72
 73    - **compute_scaling(df=None, write=True)**  
 74      Compute per-variable min/max scaling values from data.
 75
 76    - **summary()**  
 77      Print dataset overview including shape, dtypes, statistics, and node settings.
 78
 79    ---------------------------------------------------------------------
 80    Notes
 81    ---------------------------------------------------------------------
 82    - Intended for training data; `compute_scaling()` should use only training subsets.
 83    - Compatible with both CPU and GPU DataLoader options.
 84    - Strict validation of keyword arguments against `DEFAULTS` prevents silent misconfiguration.
 85
 86    ---------------------------------------------------------------------
 87    Example
 88    ---------------------------------------------------------------------
 89    >>> cfg = TramDagConfig.from_json("config.json")
 90    >>> dataset = TramDagDataset.from_dataframe(train_df, cfg, batch_size=1024, debug=True)
 91    >>> dataset.summary()
 92    >>> minmax = dataset.compute_scaling(train_df)
 93    >>> loader = dataset.loaders["variable_x"]
 94    >>> next(iter(loader))
 95    """
 96
 97    DEFAULTS = {
 98        "batch_size": 32_000,
 99        "shuffle": True,
100        "num_workers": 4,
101        "pin_memory": True,
102        "return_intercept_shift": True,
103        "debug": False,
104        "transform": None,
105        "use_dataloader": True,
106        "use_precomputed": False, 
107        # DataLoader extras
108        "sampler": None,
109        "batch_sampler": None,
110        "collate_fn": None,
111        "drop_last": False,
112        "timeout": 0,
113        "worker_init_fn": None,
114        "multiprocessing_context": None,
115        "generator": None,
116        "prefetch_factor": 2,
117        "persistent_workers": True,
118        "pin_memory_device": "",
119    }
120
121    def __init__(self):
122        """
123        Initialize an empty TramDagDataset shell.
124
125        Notes
126        -----
127        This constructor does not attach data or configuration. Use
128        `TramDagDataset.from_dataframe` to obtain a ready-to-use instance.
129        """
130        pass
131
132    @classmethod
133    def from_dataframe(cls, df, cfg, **kwargs):
134        """
135        Create a TramDagDataset instance directly from a pandas DataFrame.
136
137        This classmethod:
138
139        1. Validates keyword arguments against `DEFAULTS`.
140        2. Merges user overrides with defaults into a resolved settings dict.
141        3. Stores the configuration and verifies its completeness.
142        4. Applies settings to the instance.
143        5. Builds per-node datasets and DataLoaders.
144
145        Parameters
146        ----------
147        df : pandas.DataFrame
148            Input DataFrame containing the dataset.
149        cfg : TramDagConfig
150            Configuration object defining nodes and variable metadata.
151        **kwargs
152            Optional overrides for `DEFAULTS`. All keys must exist in
153            `TramDagDataset.DEFAULTS`. Common keys include:
154
155            batch_size : int
156                Batch size for DataLoaders.
157            shuffle : bool
158                Whether to shuffle samples per epoch.
159            num_workers : int
160                Number of DataLoader workers.
161            pin_memory : bool
162                Whether to pin memory for faster host-to-device transfers.
163            return_intercept_shift : bool
164                Whether datasets should return intercept/shift information.
165            debug : bool
166                Enable debug printing.
167            transform : callable or dict or None
168                Optional transform(s) applied to samples.
169            use_dataloader : bool
170                If True, construct DataLoaders; else store raw Dataset objects.
171            use_precomputed : bool
172                If True, precompute dataset representation to disk and reload it.
173
174        Returns
175        -------
176        TramDagDataset
177            Initialized dataset instance.
178
179        Raises
180        ------
181        TypeError
182            If `df` is not a pandas DataFrame.
183        ValueError
184            If unknown keyword arguments are provided (when validation is enabled).
185
186        Notes
187        -----
188        If `shuffle=True` and the inferred variable name of `df` suggests
189        validation/test data (e.g. "val", "test"), a warning is printed.
190        """
191        self = cls()
192        if not isinstance(df, pd.DataFrame):
193            raise TypeError(f"[ERROR] df must be a pandas DataFrame, but got {type(df)}")
194
195        # validate kwargs
196        #self._validate_kwargs(kwargs, context="from_dataframe")
197        
198        # merge defaults with overrides
199        settings = dict(cls.DEFAULTS)
200        settings.update(kwargs)
201
202        # store config and verify
203        self.cfg = cfg
204        self.cfg._verify_completeness()
205
206        # ouptu all setttings if debug
207        if settings.get("debug", False):
208            print("[DEBUG] TramDagDataset.from_dataframe() settings (after defaults + overrides):")
209            for k, v in settings.items():
210                print(f"    {k}: {v}")
211
212        # infer variable name automatically
213        callers_locals = inspect.currentframe().f_back.f_locals
214        inferred = None
215        for var_name, var_val in callers_locals.items():
216            if var_val is df:
217                inferred = var_name
218                break
219        df_name = inferred or "dataframe"
220
221        if settings["shuffle"]:
222            if any(x in df_name.lower() for x in ["val", "validation", "test"]):
223                print(f"[WARNING] DataFrame '{df_name}' looks like a validation/test set → shuffle=True. Are you sure?")
224
225        # call again to ensure Warning messages if ordinal vars have missing levels
226        self.df = df.copy()
227        self._apply_settings(settings)
228        self._build_dataloaders()
229        return self
230
231    def compute_scaling(self, df: pd.DataFrame = None, write: bool = True):
232        """
233        Compute variable-wise scaling parameters from data.
234
235        Per variable, this method computes approximate minimum and maximum
236        values using the 5th and 95th percentiles. This is typically used
237        to derive robust normalization/clipping ranges from training data.
238
239        Parameters
240        ----------
241        df : pandas.DataFrame or None, optional
242            DataFrame used to compute scaling. If None, `self.df` is used.
243        write : bool, optional
244            Unused placeholder for interface compatibility with other components.
245            Kept for potential future extensions. Default is True.
246
247        Returns
248        -------
249        dict
250            Mapping `{column_name: [min_value, max_value]}`, where values
251            are derived from the 0.05 and 0.95 quantiles.
252
253        Notes
254        -----
255        If `self.debug` is True, the method emits debug messages about the
256        data source. Only training data should be used to avoid leakage.
257        """
258        if self.debug:
259            print("[DEBUG] Make sure to provide only training data to compute_scaling!")     
260        if df is None:
261            df = self.df
262            if self.debug:
263                print("[DEBUG] No DataFrame provided, using internal df.")
264        quantiles = df.quantile([0.05, 0.95])
265        min_vals = quantiles.loc[0.05]
266        max_vals = quantiles.loc[0.95]
267        minmax_dict = pd.concat([min_vals, max_vals], axis=1).T.to_dict('list')
268        return minmax_dict
269
270    def summary(self):
271        """
272        Print a structured overview of the dataset and configuration.
273
274        The summary includes:
275
276        1. DataFrame information:
277        - Shape
278        - Columns
279        - Head (first rows)
280        - Dtypes
281        - Descriptive statistics
282
283        2. Configuration overview:
284        - Number of nodes
285        - Loader mode (DataLoader vs. raw Dataset)
286        - Precomputation status
287
288        3. Node settings:
289        - Batch size
290        - Shuffle flag
291        - num_workers
292        - pin_memory
293        - return_intercept_shift
294        - debug
295        - transform
296
297        4. DataLoader overview:
298        - Type and length of each loader.
299
300        Parameters
301        ----------
302        None
303
304        Returns
305        -------
306        None
307
308        Notes
309        -----
310        Intended for quick inspection and debugging. Uses `print` statements
311        and does not return structured metadata.
312        """
313        
314        print("\n[TramDagDataset Summary]")
315        print("=" * 60)
316
317        print("\n[DataFrame]")
318        print("Shape:", self.df.shape)
319        print("Columns:", list(self.df.columns))
320        print("\nHead:")
321        print(self.df.head())
322
323        print("\nDtypes:")
324        print(self.df.dtypes)
325
326        print("\nDescribe:")
327        print(self.df.describe(include="all"))
328
329        print("\n[Configuration]")
330        print(f"Nodes: {len(self.nodes_dict)}")
331        print(f"Loader mode: {'DataLoader' if self.use_dataloader else 'Direct dataset'}")
332        print(f"Precomputed: {getattr(self, 'use_precomputed', False)}")
333
334        print("\n[Node Settings]")
335        for node in self.nodes_dict.keys():
336            batch_size = self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size
337            shuffle_flag = self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle
338            num_workers = self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers
339            pin_memory = self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory
340            rshift = self.return_intercept_shift[node] if isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift
341            debug_flag = self.debug[node] if isinstance(self.debug, dict) else self.debug
342            transform = self.transform[node] if isinstance(self.transform, dict) else self.transform
343
344            print(
345                f" Node '{node}': "
346                f"batch_size={batch_size}, "
347                f"shuffle={shuffle_flag}, "
348                f"num_workers={num_workers}, "
349                f"pin_memory={pin_memory}, "
350                f"return_intercept_shift={rshift}, "
351                f"debug={debug_flag}, "
352                f"transform={transform}"
353            )
354
355        if hasattr(self, "loaders"):
356            print("\n[DataLoaders]")
357            for node, loader in self.loaders.items():
358                try:
359                    length = len(loader)
360                except Exception:
361                    length = "?"
362                print(f"  {node}: {type(loader).__name__}, len={length}")
363
364        print("=" * 60 + "\n")
365
366    def _validate_kwargs(self, kwargs: dict, defaults_attr: str = "DEFAULTS", context: str = None):
367        """
368        Validate keyword arguments against a defaults dictionary.
369
370        Parameters
371        ----------
372        kwargs : dict
373            Keyword arguments to validate.
374        defaults_attr : str, optional
375            Name of the attribute on this class containing the default keys
376            (e.g. "DEFAULTS"). Default is "DEFAULTS".
377        context : str or None, optional
378            Optional context label (e.g. caller name) included in error messages.
379
380        Raises
381        ------
382        AttributeError
383            If the attribute named by `defaults_attr` does not exist.
384        ValueError
385            If any keys in `kwargs` are not present in the defaults dictionary.
386        """
387        defaults = getattr(self, defaults_attr, None)
388        if defaults is None:
389            raise AttributeError(f"{self.__class__.__name__} has no attribute '{defaults_attr}'")
390
391        unknown = set(kwargs) - set(defaults)
392        if unknown:
393            prefix = f"[{context}] " if context else ""
394            raise ValueError(f"{prefix}Unknown parameter(s): {', '.join(sorted(unknown))}")
395
396    def _apply_settings(self, settings: dict):
397        """
398        Apply resolved settings to the dataset instance.
399
400        This method:
401
402        1. Stores all key–value pairs from `settings` as attributes on `self`.
403        2. Extracts `nodes_dict` from the configuration.
404        3. Validates that dict-valued core attributes (batch_size, shuffle, etc.)
405        have keys matching the node set.
406
407        Parameters
408        ----------
409        settings : dict
410            Resolved settings dictionary, usually built from `DEFAULTS` plus
411            user overrides.
412
413        Returns
414        -------
415        None
416
417        Raises
418        ------
419        ValueError
420            If any dict-valued core attribute has keys that do not match
421            `cfg.conf_dict["nodes"].keys()`.
422        """
423        for k, v in settings.items():
424            setattr(self, k, v)
425
426        self.nodes_dict = self.cfg.conf_dict["nodes"]
427
428        # validate only the most important ones
429        for name in ["batch_size", "shuffle", "num_workers", "pin_memory",
430                     "return_intercept_shift", "debug", "transform"]:
431            self._check_keys(name, getattr(self, name))
432
433    def _build_dataloaders(self):
434        """Build node-specific dataloaders or raw datasets depending on settings."""
435        self.loaders = {}
436        for node in self.nodes_dict:
437            ds = GenericDataset(
438                self.df,
439                target_col=node,
440                target_nodes=self.nodes_dict,
441                transform=self.transform if not isinstance(self.transform, dict) else self.transform[node],
442                return_intercept_shift=self.return_intercept_shift if not isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift[node],
443                debug=self.debug if not isinstance(self.debug, dict) else self.debug[node],
444            )
445
446        ########## QUICK PATCH 
447            if hasattr(self, "use_precomputed") and self.use_precomputed:
448                os.makedirs("temp", exist_ok=True) 
449                pth = os.path.join("temp", "precomputed.pt")
450
451                if hasattr(ds, "save_precomputed") and callable(getattr(ds, "save_precomputed")):
452                    ds.save_precomputed(pth)
453                    ds = GenericDatasetPrecomputed(pth)
454                else:
455                    print("[WARNING] Dataset has no 'save_precomputed()' method — skipping precomputation.")
456
457
458            if self.use_dataloader:
459                # resolve per-node overrides
460                kwargs = {
461                    "batch_size": self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size,
462                    "shuffle": self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle,
463                    "num_workers": self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers,
464                    "pin_memory": self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory,
465                    "sampler": self.sampler,
466                    "batch_sampler": self.batch_sampler,
467                    "collate_fn": self.collate_fn,
468                    "drop_last": self.drop_last,
469                    "timeout": self.timeout,
470                    "worker_init_fn": self.worker_init_fn,
471                    "multiprocessing_context": self.multiprocessing_context,
472                    "generator": self.generator,
473                    "prefetch_factor": self.prefetch_factor,
474                    "persistent_workers": self.persistent_workers,
475                    "pin_memory_device": self.pin_memory_device,
476                }
477                self.loaders[node] = DataLoader(ds, **kwargs)
478            else:
479                self.loaders[node] = ds
480                
481        if hasattr(self, "use_precomputed") and self.use_precomputed:
482            if os.path.exists(pth):
483                try:
484                    os.remove(pth)
485                    if self.debug:
486                        print(f"[INFO] Removed existing precomputed file: {pth}")
487                except Exception as e:
488                    print(f"[WARNING] Could not remove {pth}: {e}")
489
490    def _check_keys(self, attr_name, attr_value):
491        """
492        Check that dict-valued attributes use node names as keys.
493
494        Parameters
495        ----------
496        attr_name : str
497            Name of the attribute being checked (for error messages).
498        attr_value : Any
499            Attribute value. If it is a dict, its keys are validated.
500
501        Returns
502        -------
503        None
504
505        Raises
506        ------
507        ValueError
508            If `attr_value` is a dict and its keys do not exactly match
509            `cfg.conf_dict["nodes"].keys()`.
510
511        Notes
512        -----
513        This check prevents partial or mismatched per-node settings such as
514        batch sizes or shuffle flags.
515        """
516        if isinstance(attr_value, dict):
517            expected_keys = set(self.nodes_dict.keys())
518            given_keys = set(attr_value.keys())
519            if expected_keys != given_keys:
520                raise ValueError(
521                    f"[ERROR] the provided attribute '{attr_name}' keys are not same as in cfg.conf_dict['nodes'].keys().\n"
522                    f"Expected: {expected_keys}, but got: {given_keys}\n"
523                    f"Please provide values for all variables."
524                )
525
526    def __getitem__(self, idx):
527        return self.df.iloc[idx].to_dict()
528
529    def __len__(self):
530        return len(self.df)

TramDagDataset

The TramDagDataset class handles structured data preparation for TRAM-DAG models. It wraps a pandas DataFrame together with its configuration and provides utilities for scaling, transformation, and efficient DataLoader construction for each node in a DAG-based configuration.


Core Responsibilities

  • Validate and store configuration metadata (TramDagConfig).
  • Manage per-node settings for DataLoader creation (batch size, shuffling, workers).
  • Compute scaling information (quantile-based min/max).
  • Optionally precompute and cache dataset representations.
  • Expose PyTorch Dataset and DataLoader interfaces for model training.

Key Attributes

  • df : pandas.DataFrame
    The dataset content used for building loaders and computing scaling.

  • cfg : TramDagConfig
    Configuration object defining nodes and variable metadata.

  • nodes_dict : dict
    Mapping of variable names to node specifications from the configuration.

  • loaders : dict
    Mapping of node names to torch.utils.data.DataLoader instances or GenericDataset objects.

  • DEFAULTS : dict
    Default DataLoader and dataset-related settings (e.g., batch_size, shuffle, num_workers, etc.).


Main Methods

  • from_dataframe(df, cfg, **kwargs)
    Construct the dataset directly from a pandas DataFrame.

  • compute_scaling(df=None, write=True)
    Compute per-variable min/max scaling values from data.

  • summary()
    Print dataset overview including shape, dtypes, statistics, and node settings.


Notes

  • Intended for training data; compute_scaling() should use only training subsets.
  • Compatible with both CPU and GPU DataLoader options.
  • Strict validation of keyword arguments against DEFAULTS prevents silent misconfiguration.

Example

>>> cfg = TramDagConfig.from_json("config.json")
>>> dataset = TramDagDataset.from_dataframe(train_df, cfg, batch_size=1024, debug=True)
>>> dataset.summary()
>>> minmax = dataset.compute_scaling(train_df)
>>> loader = dataset.loaders["variable_x"]
>>> next(iter(loader))
TramDagDataset()
121    def __init__(self):
122        """
123        Initialize an empty TramDagDataset shell.
124
125        Notes
126        -----
127        This constructor does not attach data or configuration. Use
128        `TramDagDataset.from_dataframe` to obtain a ready-to-use instance.
129        """
130        pass

Initialize an empty TramDagDataset shell.

Notes

This constructor does not attach data or configuration. Use TramDagDataset.from_dataframe to obtain a ready-to-use instance.

DEFAULTS = {'batch_size': 32000, 'shuffle': True, 'num_workers': 4, 'pin_memory': True, 'return_intercept_shift': True, 'debug': False, 'transform': None, 'use_dataloader': True, 'use_precomputed': False, 'sampler': None, 'batch_sampler': None, 'collate_fn': None, 'drop_last': False, 'timeout': 0, 'worker_init_fn': None, 'multiprocessing_context': None, 'generator': None, 'prefetch_factor': 2, 'persistent_workers': True, 'pin_memory_device': ''}
@classmethod
def from_dataframe(cls, df, cfg, **kwargs):
132    @classmethod
133    def from_dataframe(cls, df, cfg, **kwargs):
134        """
135        Create a TramDagDataset instance directly from a pandas DataFrame.
136
137        This classmethod:
138
139        1. Validates keyword arguments against `DEFAULTS`.
140        2. Merges user overrides with defaults into a resolved settings dict.
141        3. Stores the configuration and verifies its completeness.
142        4. Applies settings to the instance.
143        5. Builds per-node datasets and DataLoaders.
144
145        Parameters
146        ----------
147        df : pandas.DataFrame
148            Input DataFrame containing the dataset.
149        cfg : TramDagConfig
150            Configuration object defining nodes and variable metadata.
151        **kwargs
152            Optional overrides for `DEFAULTS`. All keys must exist in
153            `TramDagDataset.DEFAULTS`. Common keys include:
154
155            batch_size : int
156                Batch size for DataLoaders.
157            shuffle : bool
158                Whether to shuffle samples per epoch.
159            num_workers : int
160                Number of DataLoader workers.
161            pin_memory : bool
162                Whether to pin memory for faster host-to-device transfers.
163            return_intercept_shift : bool
164                Whether datasets should return intercept/shift information.
165            debug : bool
166                Enable debug printing.
167            transform : callable or dict or None
168                Optional transform(s) applied to samples.
169            use_dataloader : bool
170                If True, construct DataLoaders; else store raw Dataset objects.
171            use_precomputed : bool
172                If True, precompute dataset representation to disk and reload it.
173
174        Returns
175        -------
176        TramDagDataset
177            Initialized dataset instance.
178
179        Raises
180        ------
181        TypeError
182            If `df` is not a pandas DataFrame.
183        ValueError
184            If unknown keyword arguments are provided (when validation is enabled).
185
186        Notes
187        -----
188        If `shuffle=True` and the inferred variable name of `df` suggests
189        validation/test data (e.g. "val", "test"), a warning is printed.
190        """
191        self = cls()
192        if not isinstance(df, pd.DataFrame):
193            raise TypeError(f"[ERROR] df must be a pandas DataFrame, but got {type(df)}")
194
195        # validate kwargs
196        #self._validate_kwargs(kwargs, context="from_dataframe")
197        
198        # merge defaults with overrides
199        settings = dict(cls.DEFAULTS)
200        settings.update(kwargs)
201
202        # store config and verify
203        self.cfg = cfg
204        self.cfg._verify_completeness()
205
206        # ouptu all setttings if debug
207        if settings.get("debug", False):
208            print("[DEBUG] TramDagDataset.from_dataframe() settings (after defaults + overrides):")
209            for k, v in settings.items():
210                print(f"    {k}: {v}")
211
212        # infer variable name automatically
213        callers_locals = inspect.currentframe().f_back.f_locals
214        inferred = None
215        for var_name, var_val in callers_locals.items():
216            if var_val is df:
217                inferred = var_name
218                break
219        df_name = inferred or "dataframe"
220
221        if settings["shuffle"]:
222            if any(x in df_name.lower() for x in ["val", "validation", "test"]):
223                print(f"[WARNING] DataFrame '{df_name}' looks like a validation/test set → shuffle=True. Are you sure?")
224
225        # call again to ensure Warning messages if ordinal vars have missing levels
226        self.df = df.copy()
227        self._apply_settings(settings)
228        self._build_dataloaders()
229        return self

Create a TramDagDataset instance directly from a pandas DataFrame.

This classmethod:

  1. Validates keyword arguments against DEFAULTS.
  2. Merges user overrides with defaults into a resolved settings dict.
  3. Stores the configuration and verifies its completeness.
  4. Applies settings to the instance.
  5. Builds per-node datasets and DataLoaders.

Parameters

df : pandas.DataFrame Input DataFrame containing the dataset. cfg : TramDagConfig Configuration object defining nodes and variable metadata. **kwargs Optional overrides for DEFAULTS. All keys must exist in TramDagDataset.DEFAULTS. Common keys include:

batch_size : int
    Batch size for DataLoaders.
shuffle : bool
    Whether to shuffle samples per epoch.
num_workers : int
    Number of DataLoader workers.
pin_memory : bool
    Whether to pin memory for faster host-to-device transfers.
return_intercept_shift : bool
    Whether datasets should return intercept/shift information.
debug : bool
    Enable debug printing.
transform : callable or dict or None
    Optional transform(s) applied to samples.
use_dataloader : bool
    If True, construct DataLoaders; else store raw Dataset objects.
use_precomputed : bool
    If True, precompute dataset representation to disk and reload it.

Returns

TramDagDataset Initialized dataset instance.

Raises

TypeError If df is not a pandas DataFrame. ValueError If unknown keyword arguments are provided (when validation is enabled).

Notes

If shuffle=True and the inferred variable name of df suggests validation/test data (e.g. "val", "test"), a warning is printed.

def compute_scaling(self, df: pandas.core.frame.DataFrame = None, write: bool = True):
231    def compute_scaling(self, df: pd.DataFrame = None, write: bool = True):
232        """
233        Compute variable-wise scaling parameters from data.
234
235        Per variable, this method computes approximate minimum and maximum
236        values using the 5th and 95th percentiles. This is typically used
237        to derive robust normalization/clipping ranges from training data.
238
239        Parameters
240        ----------
241        df : pandas.DataFrame or None, optional
242            DataFrame used to compute scaling. If None, `self.df` is used.
243        write : bool, optional
244            Unused placeholder for interface compatibility with other components.
245            Kept for potential future extensions. Default is True.
246
247        Returns
248        -------
249        dict
250            Mapping `{column_name: [min_value, max_value]}`, where values
251            are derived from the 0.05 and 0.95 quantiles.
252
253        Notes
254        -----
255        If `self.debug` is True, the method emits debug messages about the
256        data source. Only training data should be used to avoid leakage.
257        """
258        if self.debug:
259            print("[DEBUG] Make sure to provide only training data to compute_scaling!")     
260        if df is None:
261            df = self.df
262            if self.debug:
263                print("[DEBUG] No DataFrame provided, using internal df.")
264        quantiles = df.quantile([0.05, 0.95])
265        min_vals = quantiles.loc[0.05]
266        max_vals = quantiles.loc[0.95]
267        minmax_dict = pd.concat([min_vals, max_vals], axis=1).T.to_dict('list')
268        return minmax_dict

Compute variable-wise scaling parameters from data.

Per variable, this method computes approximate minimum and maximum values using the 5th and 95th percentiles. This is typically used to derive robust normalization/clipping ranges from training data.

Parameters

df : pandas.DataFrame or None, optional DataFrame used to compute scaling. If None, self.df is used. write : bool, optional Unused placeholder for interface compatibility with other components. Kept for potential future extensions. Default is True.

Returns

dict Mapping {column_name: [min_value, max_value]}, where values are derived from the 0.05 and 0.95 quantiles.

Notes

If self.debug is True, the method emits debug messages about the data source. Only training data should be used to avoid leakage.

def summary(self):
270    def summary(self):
271        """
272        Print a structured overview of the dataset and configuration.
273
274        The summary includes:
275
276        1. DataFrame information:
277        - Shape
278        - Columns
279        - Head (first rows)
280        - Dtypes
281        - Descriptive statistics
282
283        2. Configuration overview:
284        - Number of nodes
285        - Loader mode (DataLoader vs. raw Dataset)
286        - Precomputation status
287
288        3. Node settings:
289        - Batch size
290        - Shuffle flag
291        - num_workers
292        - pin_memory
293        - return_intercept_shift
294        - debug
295        - transform
296
297        4. DataLoader overview:
298        - Type and length of each loader.
299
300        Parameters
301        ----------
302        None
303
304        Returns
305        -------
306        None
307
308        Notes
309        -----
310        Intended for quick inspection and debugging. Uses `print` statements
311        and does not return structured metadata.
312        """
313        
314        print("\n[TramDagDataset Summary]")
315        print("=" * 60)
316
317        print("\n[DataFrame]")
318        print("Shape:", self.df.shape)
319        print("Columns:", list(self.df.columns))
320        print("\nHead:")
321        print(self.df.head())
322
323        print("\nDtypes:")
324        print(self.df.dtypes)
325
326        print("\nDescribe:")
327        print(self.df.describe(include="all"))
328
329        print("\n[Configuration]")
330        print(f"Nodes: {len(self.nodes_dict)}")
331        print(f"Loader mode: {'DataLoader' if self.use_dataloader else 'Direct dataset'}")
332        print(f"Precomputed: {getattr(self, 'use_precomputed', False)}")
333
334        print("\n[Node Settings]")
335        for node in self.nodes_dict.keys():
336            batch_size = self.batch_size[node] if isinstance(self.batch_size, dict) else self.batch_size
337            shuffle_flag = self.shuffle[node] if isinstance(self.shuffle, dict) else self.shuffle
338            num_workers = self.num_workers[node] if isinstance(self.num_workers, dict) else self.num_workers
339            pin_memory = self.pin_memory[node] if isinstance(self.pin_memory, dict) else self.pin_memory
340            rshift = self.return_intercept_shift[node] if isinstance(self.return_intercept_shift, dict) else self.return_intercept_shift
341            debug_flag = self.debug[node] if isinstance(self.debug, dict) else self.debug
342            transform = self.transform[node] if isinstance(self.transform, dict) else self.transform
343
344            print(
345                f" Node '{node}': "
346                f"batch_size={batch_size}, "
347                f"shuffle={shuffle_flag}, "
348                f"num_workers={num_workers}, "
349                f"pin_memory={pin_memory}, "
350                f"return_intercept_shift={rshift}, "
351                f"debug={debug_flag}, "
352                f"transform={transform}"
353            )
354
355        if hasattr(self, "loaders"):
356            print("\n[DataLoaders]")
357            for node, loader in self.loaders.items():
358                try:
359                    length = len(loader)
360                except Exception:
361                    length = "?"
362                print(f"  {node}: {type(loader).__name__}, len={length}")
363
364        print("=" * 60 + "\n")

Print a structured overview of the dataset and configuration.

The summary includes:

  1. DataFrame information:
  • Shape
  • Columns
  • Head (first rows)
  • Dtypes
  • Descriptive statistics
  1. Configuration overview:
  • Number of nodes
  • Loader mode (DataLoader vs. raw Dataset)
  • Precomputation status
  1. Node settings:
  • Batch size
  • Shuffle flag
  • num_workers
  • pin_memory
  • return_intercept_shift
  • debug
  • transform
  1. DataLoader overview:
  • Type and length of each loader.

Parameters

None

Returns

None

Notes

Intended for quick inspection and debugging. Uses print statements and does not return structured metadata.