Module openpack_torch.data.utils
Expand source code
import copy
from logging import getLogger
from typing import Dict, List, Optional, Tuple
import torch
from omegaconf import DictConfig
from torch.utils.data import DataLoader
import openpack_torch as optorch
log = getLogger(__name__)
def assemble_sequence_list_data_volume_flexible_cv(cfg: DictConfig, stage: str):
src_set = cfg.metadata.labels.src_set
if stage in ("fit", "train"):
seq_list = cfg.dataset.split.spec.pool[src_set].train
elif stage == "test":
# Predict on all the other sessions.
seq_list = []
for key, d in cfg.dataset.split.spec.pool.items():
seq_list += list(d.train)
seq_list += list(d.test)
elif stage == "test-b2":
# Predict on all the other sessions.
seq_list = cfg.dataset.split.spec.pool[src_set].test
elif stage == "test-b3":
# Predict on all the other sessions.
seq_list = []
exclude_list = [
f"{user}-{sess}"
for user, sess in cfg.dataset.split.spec.pool[src_set].get(
"exclude_on_test", []
)
]
for key, d in cfg.dataset.split.spec.pool.items():
if key == src_set:
continue
for user, sess in d.train:
if f"{user}-{sess}" not in exclude_list:
seq_list.append([user, sess])
for user, sess in d.test:
if f"{user}-{sess}" not in exclude_list:
seq_list.append([user, sess])
else:
raise NotImplementedError(f"stage={stage} is not supported.")
return seq_list
def assemble_sequence_list_flexible_train_data_volume_setting(
cfg: DictConfig, stage: str
):
assert cfg.dataset.split.kind == "dataset/split/flexible-train-data-volume"
train_set = cfg.metadata.labels.train_set
if stage in ("fit", "train"):
seq_list = []
pool_keys = sorted(list(cfg.dataset.split.spec.pool.keys()))
log.debug(f"pool_keys={pool_keys}")
ind = pool_keys.index(train_set)
for key in pool_keys[: (ind + 1)]:
seq_list += cfg.dataset.split.spec.pool[key].train
elif stage == "test":
seq_list = cfg.dataset.split.spec.test
else:
raise NotImplementedError(f"stage={stage} is not supported.")
return seq_list
def assemble_sequence_list_leave_one_out_setting(cfg: DictConfig, stage: str):
assert cfg.dataset.split.kind == "dataset/split/leave-one-out-cv"
test_set = cfg.metadata.labels.test_set
if stage in ("fit", "train"):
seq_list = []
pool_keys = sorted(list(cfg.dataset.split.spec.pool.keys()))
log.debug(f"pool_keys={pool_keys}")
seq_list = []
for key, _seq_list in cfg.dataset.split.spec.pool.items():
if key != test_set:
seq_list += list(_seq_list)
elif stage == "test":
seq_list = cfg.dataset.split.spec.pool[test_set]
else:
raise NotImplementedError(f"stage={stage} is not supported.")
return seq_list
def assemble_sequence_list_from_cfg(cfg: DictConfig, stage: str):
split_kind = cfg.dataset.split.kind
if split_kind == "dataset/split/data-volume-flexible-cv":
seq_list = assemble_sequence_list_data_volume_flexible_cv(cfg, stage)
elif split_kind == "dataset/split/flexible-train-data-volume":
seq_list = assemble_sequence_list_flexible_train_data_volume_setting(cfg, stage)
elif split_kind == "dataset/split/leave-one-out-cv":
seq_list = assemble_sequence_list_leave_one_out_setting(cfg, stage)
else:
raise ValueError(f"unknown split type: {cfg.dataset.split.kind}")
return seq_list
# -----------------------------------------------------------------------------
def split_dataset(
cfg: DictConfig,
dataset: torch.utils.data.Dataset,
val_split_size: float = 0.2,
):
"""Split this instance into train and val dataset class.
Note:
Some leakage occurs between train and test with ``random_crop=True``.
"""
assert val_split_size > 0.0
# Split Index
original_index = dict()
for win in dataset.index:
if win.sequence_idx in original_index.keys():
original_index[win.sequence_idx] += [win]
else:
original_index[win.sequence_idx] = [win]
train_index, val_index = [], []
keys = sorted(original_index.keys())
for key in keys:
num_sample = len(original_index[key])
num_train = int(num_sample * (1.0 - val_split_size))
train_index += original_index[key][:num_train]
val_index += original_index[key][num_train:]
# Split
dataset_train = dataset.__class__(
cfg,
user_session_list=None,
classes=dataset.classes,
window=dataset.window,
random_crop=dataset.random_crop,
submission=dataset.submission,
debug=dataset.debug,
)
dataset_train.data = dataset.data.copy()
dataset_train.index = train_index
dataset_val = dataset.__class__(
cfg,
user_session_list=None,
classes=dataset.classes,
window=dataset.window,
random_crop=False,
submission=dataset.submission,
debug=dataset.debug,
)
dataset_val.data = dataset.data.copy()
dataset_val.index = val_index
return dataset_train, dataset_val
Functions
def assemble_sequence_list_data_volume_flexible_cv(cfg: omegaconf.dictconfig.DictConfig, stage: str)
-
Expand source code
def assemble_sequence_list_data_volume_flexible_cv(cfg: DictConfig, stage: str): src_set = cfg.metadata.labels.src_set if stage in ("fit", "train"): seq_list = cfg.dataset.split.spec.pool[src_set].train elif stage == "test": # Predict on all the other sessions. seq_list = [] for key, d in cfg.dataset.split.spec.pool.items(): seq_list += list(d.train) seq_list += list(d.test) elif stage == "test-b2": # Predict on all the other sessions. seq_list = cfg.dataset.split.spec.pool[src_set].test elif stage == "test-b3": # Predict on all the other sessions. seq_list = [] exclude_list = [ f"{user}-{sess}" for user, sess in cfg.dataset.split.spec.pool[src_set].get( "exclude_on_test", [] ) ] for key, d in cfg.dataset.split.spec.pool.items(): if key == src_set: continue for user, sess in d.train: if f"{user}-{sess}" not in exclude_list: seq_list.append([user, sess]) for user, sess in d.test: if f"{user}-{sess}" not in exclude_list: seq_list.append([user, sess]) else: raise NotImplementedError(f"stage={stage} is not supported.") return seq_list
def assemble_sequence_list_flexible_train_data_volume_setting(cfg: omegaconf.dictconfig.DictConfig, stage: str)
-
Expand source code
def assemble_sequence_list_flexible_train_data_volume_setting( cfg: DictConfig, stage: str ): assert cfg.dataset.split.kind == "dataset/split/flexible-train-data-volume" train_set = cfg.metadata.labels.train_set if stage in ("fit", "train"): seq_list = [] pool_keys = sorted(list(cfg.dataset.split.spec.pool.keys())) log.debug(f"pool_keys={pool_keys}") ind = pool_keys.index(train_set) for key in pool_keys[: (ind + 1)]: seq_list += cfg.dataset.split.spec.pool[key].train elif stage == "test": seq_list = cfg.dataset.split.spec.test else: raise NotImplementedError(f"stage={stage} is not supported.") return seq_list
def assemble_sequence_list_from_cfg(cfg: omegaconf.dictconfig.DictConfig, stage: str)
-
Expand source code
def assemble_sequence_list_from_cfg(cfg: DictConfig, stage: str): split_kind = cfg.dataset.split.kind if split_kind == "dataset/split/data-volume-flexible-cv": seq_list = assemble_sequence_list_data_volume_flexible_cv(cfg, stage) elif split_kind == "dataset/split/flexible-train-data-volume": seq_list = assemble_sequence_list_flexible_train_data_volume_setting(cfg, stage) elif split_kind == "dataset/split/leave-one-out-cv": seq_list = assemble_sequence_list_leave_one_out_setting(cfg, stage) else: raise ValueError(f"unknown split type: {cfg.dataset.split.kind}") return seq_list
def assemble_sequence_list_leave_one_out_setting(cfg: omegaconf.dictconfig.DictConfig, stage: str)
-
Expand source code
def assemble_sequence_list_leave_one_out_setting(cfg: DictConfig, stage: str): assert cfg.dataset.split.kind == "dataset/split/leave-one-out-cv" test_set = cfg.metadata.labels.test_set if stage in ("fit", "train"): seq_list = [] pool_keys = sorted(list(cfg.dataset.split.spec.pool.keys())) log.debug(f"pool_keys={pool_keys}") seq_list = [] for key, _seq_list in cfg.dataset.split.spec.pool.items(): if key != test_set: seq_list += list(_seq_list) elif stage == "test": seq_list = cfg.dataset.split.spec.pool[test_set] else: raise NotImplementedError(f"stage={stage} is not supported.") return seq_list
def split_dataset(cfg: omegaconf.dictconfig.DictConfig, dataset: torch.utils.data.dataset.Dataset, val_split_size: float = 0.2)
-
Split this instance into train and val dataset class.
Note
Some leakage occurs between train and test with
random_crop=True
.Expand source code
def split_dataset( cfg: DictConfig, dataset: torch.utils.data.Dataset, val_split_size: float = 0.2, ): """Split this instance into train and val dataset class. Note: Some leakage occurs between train and test with ``random_crop=True``. """ assert val_split_size > 0.0 # Split Index original_index = dict() for win in dataset.index: if win.sequence_idx in original_index.keys(): original_index[win.sequence_idx] += [win] else: original_index[win.sequence_idx] = [win] train_index, val_index = [], [] keys = sorted(original_index.keys()) for key in keys: num_sample = len(original_index[key]) num_train = int(num_sample * (1.0 - val_split_size)) train_index += original_index[key][:num_train] val_index += original_index[key][num_train:] # Split dataset_train = dataset.__class__( cfg, user_session_list=None, classes=dataset.classes, window=dataset.window, random_crop=dataset.random_crop, submission=dataset.submission, debug=dataset.debug, ) dataset_train.data = dataset.data.copy() dataset_train.index = train_index dataset_val = dataset.__class__( cfg, user_session_list=None, classes=dataset.classes, window=dataset.window, random_crop=False, submission=dataset.submission, debug=dataset.debug, ) dataset_val.data = dataset.data.copy() dataset_val.index = val_index return dataset_train, dataset_val