Source code for skqulacs.dataloader
from dataclasses import dataclass, field
from typing import List, Optional
import numpy as np
from numpy.random import Generator
from numpy.typing import NDArray
[docs]@dataclass
class DataLoader:
"""
Data loader. This class is an iterator that yields mini-batches.
Their size is specified by `batch_size` argument.
You can specify whether to shuffle the data or not by `shuffle` argument.
"""
x: NDArray[np.float_]
y: NDArray[np.float_]
batch_size: int = 1
shuffle: bool = False
seed: Optional[int] = None
def __post_init__(self):
if self.x.shape[0] != self.y.shape[0]:
raise ValueError("x and y must have the same length.")
def __iter__(self) -> "_DataLoaderIterator":
return _DataLoaderIterator(self)
def __len__(self) -> int:
return (len(self.x) + self.batch_size - 1) // self.batch_size
@dataclass
class _DataLoaderIterator:
loader: DataLoader
rng: Generator = field(init=False)
indices: List[int] = field(init=False, default_factory=list)
current_index: int = field(init=False, default=0)
def __post_init__(self):
self.rng = np.random.default_rng(self.loader.seed)
self.indices = list(range(len(self.loader.x)))
if self.loader.shuffle:
self.rng.shuffle(self.indices)
def __next__(self):
if self.current_index >= len(self.indices):
raise StopIteration
selected = self.indices[
self.current_index : (self.current_index + self.loader.batch_size)
]
self.current_index += self.loader.batch_size
return self.loader.x[selected], self.loader.y[selected]