Source: texcla/utils/generators.py#L0
ProcessingSequence
Base object for fitting to a sequence of data, such as a dataset.
Every Sequence
must implement the __getitem__
and the __len__
methods.
If you want to modify your dataset between epochs you may implement on_epoch_end
.
The method __getitem__
should return a complete batch.
Notes
Sequence
are a safer way to do multiprocessing. This structure guarantees that the network will only train once
on each sample per epoch which is not the case with generators.
Examples
from skimage.io import imread
from skimage.transform import resize
import numpy as np
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10Sequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
ProcessingSequence.__init__
__init__(self, X, y, batch_size, process_fn=None)
A Sequence
implementation that can pre-process a mini-batch via process_fn
Args:
- X: The numpy array of inputs.
- y: The numpy array of targets.
- batch_size: The generator mini-batch size.
- process_fn: The preprocessing function to apply on
X
BalancedSequence
Base object for fitting to a sequence of data, such as a dataset.
Every Sequence
must implement the __getitem__
and the __len__
methods.
If you want to modify your dataset between epochs you may implement on_epoch_end
.
The method __getitem__
should return a complete batch.
Notes
Sequence
are a safer way to do multiprocessing. This structure guarantees that the network will only train once
on each sample per epoch which is not the case with generators.
Examples
from skimage.io import imread
from skimage.transform import resize
import numpy as np
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10Sequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
BalancedSequence.__init__
__init__(self, X, y, batch_size, process_fn=None)
A Sequence
implementation that returns balanced y
by undersampling majority class.
Args:
- X: The numpy array of inputs.
- y: The numpy array of targets.
- batch_size: The generator mini-batch size.
- process_fn: The preprocessing function to apply on
X