Module openpack_torch.data.datamodule

Utilities for PyTorch Lightning DataModule.

Todo

  • Add usage (Example Section).
  • Add unit-test.
Expand source code
"""Utilities for PyTorch Lightning DataModule.

Todo:
    * Add usage (Example Section).
    * Add unit-test.
"""
import copy
from logging import getLogger
from typing import Dict, List, Optional, Tuple

import pytorch_lightning as pl
import torch
from omegaconf import DictConfig
from torch.utils.data import DataLoader

from openpack_torch.data.utils import assemble_sequence_list_from_cfg

log = getLogger(__name__)


class OpenPackBaseDataModule(pl.LightningDataModule):
    """Base class of PyTorch Lightning DataModule.
    A datamodule is a shareable, reusable class that encapsulates all the steps needed to process
    data:

    Attributes:
        dataset_class (torch.utils.data.Dataset): dataset class. this variable is call to create
            dataset instances.
        cfg (DictConfig): config object. The all parameters used to initialuze dataset class should
            be included in this object.
        batch_size (int): batch size.
        debug (bool): If True, enable debug mode.
    """

    dataset_class: torch.utils.data.Dataset

    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg

        self.debug = cfg.debug
        if cfg.debug:
            self.batch_size = cfg.train.debug.batch_size
        else:
            self.batch_size = cfg.train.batch_size

    def get_kwargs_for_datasets(self, stage: Optional[str] = None) -> Dict:
        """Build a kwargs to initialize dataset class. This method is called in ``setup()``.

        Args:
            stage (str, optional): dataset type. {train, validate, test, submission}.

        Example:

            ::

                def get_kwargs_for_datasets(self) -> Dict:
                    kwargs = {
                        "window": self.cfg.train.window,
                        "debug": self.cfg.debug,
                    }
                    return kwargs

        Returns:
            Dict:
        """
        kwargs = {
            "window": self.cfg.train.window,
            "debug": self.cfg.debug,
        }
        return kwargs

    def _init_datasets(
        self,
        user_session: Tuple[int, int],
        kwargs: Dict,
    ) -> Dict[str, torch.utils.data.Dataset]:
        """Returns list of initialized dataset object.

        Args:
            rootdir (Path): _description_
            user_session (Tuple[int, int]): _description_
            kwargs (Dict): _description_

        Returns:
            Dict[str, torch.utils.data.Dataset]: dataset objects
        """
        datasets = dict()
        for user, session in user_session:
            key = f"{user}-{session}"
            datasets[key] = self.dataset_class(
                copy.deepcopy(self.cfg), [(user, session)], **kwargs
            )
        return datasets

    def setup(self, stage: Optional[str] = None) -> None:
        if hasattr(self.cfg.dataset.split, "spec"):
            split = self.cfg.dataset.split.spec
        else:
            split = self.cfg.dataset.split

        if stage in (None, "fit"):
            kwargs = self.get_kwargs_for_datasets(stage="train")
            self.op_train = self.dataset_class(self.cfg, split.train, **kwargs)
            if self.cfg.train.random_crop:
                self.op_train.random_crop = True
                log.debug(f"enable random_crop in training dataset: {self.op_train}")
        else:
            self.op_train = None

        if stage in (None, "fit", "validate"):
            kwargs = self.get_kwargs_for_datasets(stage="validate")
            self.op_val = self._init_datasets(split.val, kwargs)
        else:
            self.op_val = None

        if stage in (None, "test"):
            kwargs = self.get_kwargs_for_datasets(stage="test")
            self.op_test = self._init_datasets(split.test, kwargs)
        else:
            self.op_test = None

        if stage in (None, "submission"):
            kwargs = self.get_kwargs_for_datasets(stage="submission")
            kwargs.update({"submission": True})
            self.op_submission = self._init_datasets(split.submission, kwargs)
        elif stage == "test-on-submission":
            kwargs = self.get_kwargs_for_datasets(stage="submission")
            self.op_submission = self._init_datasets(split.submission, kwargs)
        else:
            self.op_submission = None

        log.info(f"dataset[train]: {self.op_train}")
        log.info(f"dataset[val]: {self.op_val}")
        log.info(f"dataset[test]: {self.op_test}")
        log.info(f"dataset[submission]: {self.op_submission}")

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.op_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.cfg.train.num_workers,
        )

    def val_dataloader(self) -> List[DataLoader]:
        dataloaders = []
        for key, dataset in self.op_val.items():
            dataloaders.append(
                DataLoader(
                    dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    num_workers=self.cfg.train.num_workers,
                )
            )
        return dataloaders

    def test_dataloader(self) -> List[DataLoader]:
        dataloaders = []
        for key, dataset in self.op_test.items():
            dataloaders.append(
                DataLoader(
                    dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    num_workers=self.cfg.train.num_workers,
                )
            )
        return dataloaders

    def submission_dataloader(self) -> List[DataLoader]:
        dataloaders = []
        for key, dataset in self.op_submission.items():
            dataloaders.append(
                DataLoader(
                    dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    num_workers=self.cfg.train.num_workers,
                )
            )
        return dataloaders


class OpenPackBaseFlexSetDataModule(OpenPackBaseDataModule):
    dataset_train = None
    dataset_val = None
    dataset_test = None

    def _init_datasets(
        self,
        user_session: Tuple[int, int],
        kwargs: Dict,
    ) -> Dict[str, torch.utils.data.Dataset]:
        """Returns list of initialized dataset object.
        Args:
            rootdir (Path): _description_
            user_session (Tuple[int, int]): _description_
            kwargs (Dict): _description_
        Returns:
            Dict[str, torch.utils.data.Dataset]: dataset objects
        """
        datasets = dict()
        for user, session in user_session:
            key = f"{user}-{session}"
            datasets[key] = self.dataset_class(
                copy.deepcopy(self.cfg), [(user, session)], **kwargs
            )
        return datasets

    def setup(self, stage: Optional[str] = None) -> None:
        split = self.cfg.dataset.split

        if stage in (None, "fit"):
            kwargs = self.get_kwargs_for_datasets(stage="fit")
            split = assemble_sequence_list_from_cfg(self.cfg, stage)
            self.dataset_train = self.dataset_class(self.cfg, split, **kwargs)
        else:
            self.dataset_train = None

        # TODO: Generate ValDataset from Train.
        if stage in (None, "fit", "validate"):
            self.dataset_train, self.dataset_val = split_dataset(
                self.cfg,
                self.dataset_train,
                val_split_size=self.cfg.train.val_split_siz,
            )
        else:
            self.dataset_val = None

        if stage in (None, "test"):
            kwargs = self.get_kwargs_for_datasets(stage="test")
            split = assemble_sequence_list_from_cfg(self.cfg, stage)
            self.dataset_test = self._init_datasets(split, kwargs)
        else:
            self.dataset_test = None

        log.info(f"dataset[train]: {self.dataset_train}")
        log.info(f"dataset[val]: {self.dataset_val}")
        log.info(f"dataset[test]: {self.dataset_test}")

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.cfg.train.num_workers,
        )

    def val_dataloader(self) -> List[DataLoader]:
        return DataLoader(
            self.dataset_val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.cfg.train.num_workers,
        )

    def test_dataloader(self) -> List[DataLoader]:
        dataloaders = []
        for key, dataset in self.dataset_test.items():
            dataloaders.append(
                DataLoader(
                    dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    num_workers=self.cfg.train.num_workers,
                )
            )
        return dataloaders

    def submission_dataloader(self) -> List[DataLoader]:
        dataloaders = []
        for key, dataset in self.dataset_submission.items():
            dataloaders.append(
                DataLoader(
                    dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    num_workers=self.cfg.train.num_workers,
                )
            )
        return dataloaders

Classes

class OpenPackBaseDataModule (cfg: omegaconf.dictconfig.DictConfig)

Base class of PyTorch Lightning DataModule. A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data:

Attributes

dataset_class : torch.utils.data.Dataset
dataset class. this variable is call to create dataset instances.
cfg : DictConfig
config object. The all parameters used to initialuze dataset class should be included in this object.
batch_size : int
batch size.
debug : bool
If True, enable debug mode.

Attributes

prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data. allow_zero_length_dataloader_with_multiple_devices: If True, dataloader with zero length within local rank is allowed. Default value is False.

Expand source code
class OpenPackBaseDataModule(pl.LightningDataModule):
    """Base class of PyTorch Lightning DataModule.
    A datamodule is a shareable, reusable class that encapsulates all the steps needed to process
    data:

    Attributes:
        dataset_class (torch.utils.data.Dataset): dataset class. this variable is call to create
            dataset instances.
        cfg (DictConfig): config object. The all parameters used to initialuze dataset class should
            be included in this object.
        batch_size (int): batch size.
        debug (bool): If True, enable debug mode.
    """

    dataset_class: torch.utils.data.Dataset

    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg

        self.debug = cfg.debug
        if cfg.debug:
            self.batch_size = cfg.train.debug.batch_size
        else:
            self.batch_size = cfg.train.batch_size

    def get_kwargs_for_datasets(self, stage: Optional[str] = None) -> Dict:
        """Build a kwargs to initialize dataset class. This method is called in ``setup()``.

        Args:
            stage (str, optional): dataset type. {train, validate, test, submission}.

        Example:

            ::

                def get_kwargs_for_datasets(self) -> Dict:
                    kwargs = {
                        "window": self.cfg.train.window,
                        "debug": self.cfg.debug,
                    }
                    return kwargs

        Returns:
            Dict:
        """
        kwargs = {
            "window": self.cfg.train.window,
            "debug": self.cfg.debug,
        }
        return kwargs

    def _init_datasets(
        self,
        user_session: Tuple[int, int],
        kwargs: Dict,
    ) -> Dict[str, torch.utils.data.Dataset]:
        """Returns list of initialized dataset object.

        Args:
            rootdir (Path): _description_
            user_session (Tuple[int, int]): _description_
            kwargs (Dict): _description_

        Returns:
            Dict[str, torch.utils.data.Dataset]: dataset objects
        """
        datasets = dict()
        for user, session in user_session:
            key = f"{user}-{session}"
            datasets[key] = self.dataset_class(
                copy.deepcopy(self.cfg), [(user, session)], **kwargs
            )
        return datasets

    def setup(self, stage: Optional[str] = None) -> None:
        if hasattr(self.cfg.dataset.split, "spec"):
            split = self.cfg.dataset.split.spec
        else:
            split = self.cfg.dataset.split

        if stage in (None, "fit"):
            kwargs = self.get_kwargs_for_datasets(stage="train")
            self.op_train = self.dataset_class(self.cfg, split.train, **kwargs)
            if self.cfg.train.random_crop:
                self.op_train.random_crop = True
                log.debug(f"enable random_crop in training dataset: {self.op_train}")
        else:
            self.op_train = None

        if stage in (None, "fit", "validate"):
            kwargs = self.get_kwargs_for_datasets(stage="validate")
            self.op_val = self._init_datasets(split.val, kwargs)
        else:
            self.op_val = None

        if stage in (None, "test"):
            kwargs = self.get_kwargs_for_datasets(stage="test")
            self.op_test = self._init_datasets(split.test, kwargs)
        else:
            self.op_test = None

        if stage in (None, "submission"):
            kwargs = self.get_kwargs_for_datasets(stage="submission")
            kwargs.update({"submission": True})
            self.op_submission = self._init_datasets(split.submission, kwargs)
        elif stage == "test-on-submission":
            kwargs = self.get_kwargs_for_datasets(stage="submission")
            self.op_submission = self._init_datasets(split.submission, kwargs)
        else:
            self.op_submission = None

        log.info(f"dataset[train]: {self.op_train}")
        log.info(f"dataset[val]: {self.op_val}")
        log.info(f"dataset[test]: {self.op_test}")
        log.info(f"dataset[submission]: {self.op_submission}")

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.op_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.cfg.train.num_workers,
        )

    def val_dataloader(self) -> List[DataLoader]:
        dataloaders = []
        for key, dataset in self.op_val.items():
            dataloaders.append(
                DataLoader(
                    dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    num_workers=self.cfg.train.num_workers,
                )
            )
        return dataloaders

    def test_dataloader(self) -> List[DataLoader]:
        dataloaders = []
        for key, dataset in self.op_test.items():
            dataloaders.append(
                DataLoader(
                    dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    num_workers=self.cfg.train.num_workers,
                )
            )
        return dataloaders

    def submission_dataloader(self) -> List[DataLoader]:
        dataloaders = []
        for key, dataset in self.op_submission.items():
            dataloaders.append(
                DataLoader(
                    dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    num_workers=self.cfg.train.num_workers,
                )
            )
        return dataloaders

Ancestors

  • pytorch_lightning.core.datamodule.LightningDataModule
  • pytorch_lightning.core.hooks.DataHooks
  • pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Subclasses

Class variables

var dataset_class : torch.utils.data.dataset.Dataset

Methods

def get_kwargs_for_datasets(self, stage: Optional[str] = None) ‑> Dict

Build a kwargs to initialize dataset class. This method is called in setup().

Args

stage : str, optional
dataset type. {train, validate, test, submission}.

Example

::

def get_kwargs_for_datasets(self) -> Dict:
    kwargs = {
        "window": self.cfg.train.window,
        "debug": self.cfg.debug,
    }
    return kwargs

Returns

Dict:

Expand source code
def get_kwargs_for_datasets(self, stage: Optional[str] = None) -> Dict:
    """Build a kwargs to initialize dataset class. This method is called in ``setup()``.

    Args:
        stage (str, optional): dataset type. {train, validate, test, submission}.

    Example:

        ::

            def get_kwargs_for_datasets(self) -> Dict:
                kwargs = {
                    "window": self.cfg.train.window,
                    "debug": self.cfg.debug,
                }
                return kwargs

    Returns:
        Dict:
    """
    kwargs = {
        "window": self.cfg.train.window,
        "debug": self.cfg.debug,
    }
    return kwargs
def setup(self, stage: Optional[str] = None) ‑> None

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Args

stage
either 'fit', 'validate', 'test', or 'predict'

Example::

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
Expand source code
def setup(self, stage: Optional[str] = None) -> None:
    if hasattr(self.cfg.dataset.split, "spec"):
        split = self.cfg.dataset.split.spec
    else:
        split = self.cfg.dataset.split

    if stage in (None, "fit"):
        kwargs = self.get_kwargs_for_datasets(stage="train")
        self.op_train = self.dataset_class(self.cfg, split.train, **kwargs)
        if self.cfg.train.random_crop:
            self.op_train.random_crop = True
            log.debug(f"enable random_crop in training dataset: {self.op_train}")
    else:
        self.op_train = None

    if stage in (None, "fit", "validate"):
        kwargs = self.get_kwargs_for_datasets(stage="validate")
        self.op_val = self._init_datasets(split.val, kwargs)
    else:
        self.op_val = None

    if stage in (None, "test"):
        kwargs = self.get_kwargs_for_datasets(stage="test")
        self.op_test = self._init_datasets(split.test, kwargs)
    else:
        self.op_test = None

    if stage in (None, "submission"):
        kwargs = self.get_kwargs_for_datasets(stage="submission")
        kwargs.update({"submission": True})
        self.op_submission = self._init_datasets(split.submission, kwargs)
    elif stage == "test-on-submission":
        kwargs = self.get_kwargs_for_datasets(stage="submission")
        self.op_submission = self._init_datasets(split.submission, kwargs)
    else:
        self.op_submission = None

    log.info(f"dataset[train]: {self.op_train}")
    log.info(f"dataset[val]: {self.op_val}")
    log.info(f"dataset[test]: {self.op_test}")
    log.info(f"dataset[submission]: {self.op_submission}")
def submission_dataloader(self) ‑> List[torch.utils.data.dataloader.DataLoader]
Expand source code
def submission_dataloader(self) -> List[DataLoader]:
    dataloaders = []
    for key, dataset in self.op_submission.items():
        dataloaders.append(
            DataLoader(
                dataset,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.cfg.train.num_workers,
            )
        )
    return dataloaders
def test_dataloader(self) ‑> List[torch.utils.data.dataloader.DataLoader]

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>.

For data processing use the following pattern:

- download in :meth:<code>prepare\_data</code>
- process and split in :meth:<code>setup</code>

However, the above are only necessary for distributed processing.

Warning: do not assign state in prepare_data

  • :meth:~pytorch_lightning.trainer.trainer.Trainer.test
  • :meth:prepare_data
  • :meth:setup

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don't need a test dataset and a :meth:test_step, you don't need to implement this method.

Expand source code
def test_dataloader(self) -> List[DataLoader]:
    dataloaders = []
    for key, dataset in self.op_test.items():
        dataloaders.append(
            DataLoader(
                dataset,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.cfg.train.num_workers,
            )
        )
    return dataloaders
def train_dataloader(self) ‑> torch.utils.data.dataloader.DataLoader

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>.

The dataloader you return will not be reloaded unless you set :paramref:~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs to a positive integer.

For data processing use the following pattern:

- download in :meth:<code>prepare\_data</code>
- process and split in :meth:<code>setup</code>

However, the above are only necessary for distributed processing.

Warning: do not assign state in prepare_data

  • :meth:~pytorch_lightning.trainer.trainer.Trainer.fit
  • :meth:prepare_data
  • :meth:setup

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Expand source code
def train_dataloader(self) -> DataLoader:
    return DataLoader(
        self.op_train,
        batch_size=self.batch_size,
        shuffle=True,
        num_workers=self.cfg.train.num_workers,
    )
def val_dataloader(self) ‑> List[torch.utils.data.dataloader.DataLoader]

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this :ref:section <multiple-dataloaders>.

The dataloader you return will not be reloaded unless you set :paramref:~pytorch_lightning.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs to a positive integer.

It's recommended that all data downloads and preparation happen in :meth:prepare_data.

  • :meth:~pytorch_lightning.trainer.trainer.Trainer.fit
  • :meth:~pytorch_lightning.trainer.trainer.Trainer.validate
  • :meth:prepare_data
  • :meth:setup

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don't need a validation dataset and a :meth:validation_step, you don't need to implement this method.

Expand source code
def val_dataloader(self) -> List[DataLoader]:
    dataloaders = []
    for key, dataset in self.op_val.items():
        dataloaders.append(
            DataLoader(
                dataset,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.cfg.train.num_workers,
            )
        )
    return dataloaders
class OpenPackBaseFlexSetDataModule (cfg: omegaconf.dictconfig.DictConfig)

Base class of PyTorch Lightning DataModule. A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data:

Attributes

dataset_class : torch.utils.data.Dataset
dataset class. this variable is call to create dataset instances.
cfg : DictConfig
config object. The all parameters used to initialuze dataset class should be included in this object.
batch_size : int
batch size.
debug : bool
If True, enable debug mode.

Attributes

prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data. allow_zero_length_dataloader_with_multiple_devices: If True, dataloader with zero length within local rank is allowed. Default value is False.

Expand source code
class OpenPackBaseFlexSetDataModule(OpenPackBaseDataModule):
    dataset_train = None
    dataset_val = None
    dataset_test = None

    def _init_datasets(
        self,
        user_session: Tuple[int, int],
        kwargs: Dict,
    ) -> Dict[str, torch.utils.data.Dataset]:
        """Returns list of initialized dataset object.
        Args:
            rootdir (Path): _description_
            user_session (Tuple[int, int]): _description_
            kwargs (Dict): _description_
        Returns:
            Dict[str, torch.utils.data.Dataset]: dataset objects
        """
        datasets = dict()
        for user, session in user_session:
            key = f"{user}-{session}"
            datasets[key] = self.dataset_class(
                copy.deepcopy(self.cfg), [(user, session)], **kwargs
            )
        return datasets

    def setup(self, stage: Optional[str] = None) -> None:
        split = self.cfg.dataset.split

        if stage in (None, "fit"):
            kwargs = self.get_kwargs_for_datasets(stage="fit")
            split = assemble_sequence_list_from_cfg(self.cfg, stage)
            self.dataset_train = self.dataset_class(self.cfg, split, **kwargs)
        else:
            self.dataset_train = None

        # TODO: Generate ValDataset from Train.
        if stage in (None, "fit", "validate"):
            self.dataset_train, self.dataset_val = split_dataset(
                self.cfg,
                self.dataset_train,
                val_split_size=self.cfg.train.val_split_siz,
            )
        else:
            self.dataset_val = None

        if stage in (None, "test"):
            kwargs = self.get_kwargs_for_datasets(stage="test")
            split = assemble_sequence_list_from_cfg(self.cfg, stage)
            self.dataset_test = self._init_datasets(split, kwargs)
        else:
            self.dataset_test = None

        log.info(f"dataset[train]: {self.dataset_train}")
        log.info(f"dataset[val]: {self.dataset_val}")
        log.info(f"dataset[test]: {self.dataset_test}")

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.cfg.train.num_workers,
        )

    def val_dataloader(self) -> List[DataLoader]:
        return DataLoader(
            self.dataset_val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.cfg.train.num_workers,
        )

    def test_dataloader(self) -> List[DataLoader]:
        dataloaders = []
        for key, dataset in self.dataset_test.items():
            dataloaders.append(
                DataLoader(
                    dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    num_workers=self.cfg.train.num_workers,
                )
            )
        return dataloaders

    def submission_dataloader(self) -> List[DataLoader]:
        dataloaders = []
        for key, dataset in self.dataset_submission.items():
            dataloaders.append(
                DataLoader(
                    dataset,
                    batch_size=self.batch_size,
                    shuffle=False,
                    num_workers=self.cfg.train.num_workers,
                )
            )
        return dataloaders

Ancestors

  • OpenPackBaseDataModule
  • pytorch_lightning.core.datamodule.LightningDataModule
  • pytorch_lightning.core.hooks.DataHooks
  • pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

Class variables

var dataset_test
var dataset_train
var dataset_val

Methods

def submission_dataloader(self) ‑> List[torch.utils.data.dataloader.DataLoader]
Expand source code
def submission_dataloader(self) -> List[DataLoader]:
    dataloaders = []
    for key, dataset in self.dataset_submission.items():
        dataloaders.append(
            DataLoader(
                dataset,
                batch_size=self.batch_size,
                shuffle=False,
                num_workers=self.cfg.train.num_workers,
            )
        )
    return dataloaders

Inherited members