DataGenerator
- class data.DataGenerator(load_func, batch_size, source_path=None, source_file=None, label_path=None, label_file=None, source_extension='', label_extension='', seed=123, **load_func_kwargs)[source]
Bases:
object
Data generator in detached thread.
- Parameters
load_func (function or lambda) – Function to apply for the preprocessing on a single data/label pair
batch_size (int) – Dimension of batch to load
source_path (str (default=None)) – Path to the source files
source_file (str (default=None)) – Filename in which is stored the list of source files
label_path (str (default=None)) – Path to the label files
label_file (str (default=None)) – Filename in which is stored the list of label files
source_extension (str (default='')) – Extension of the source files
label_extension (str (default='')) – Extension of the label files
seed (int) – Random seed
**load_func_kwargs (dict) – Optional parameters to use in the load_func
Example
>>> import pylab as plt >>> >>> train_gen = DataGenerator(load_func=load_segmentation, batch_size=2, >>> source_path='/path/to/train/images', >>> label_path='/path/to/mask/images', >>> source_extension='.png', >>> label_extension='.png' >>> ) >>> train_gen.start() >>> >>> fig, ((ax00, ax01), (ax10, ax11)) = plt.subplots(nrows=2, ncols=2) >>> >>> for i in range(10): >>> grabbed = False >>> >>> while not grabbed: >>> >>> (data1, data2), (label1, label2), grabbed = train_gen.load_data() >>> >>> ax00.imshow(data1.get(), cmap='gray') >>> ax00.axis('off') >>> >>> ax01.imshow(label1.get(), cmap='gray') >>> ax01.axis('off') >>> >>> ax10.imshow(data2.get(), cmap='gray') >>> ax10.axis('off') >>> >>> ax11.imshow(label2.get(), cmap='gray') >>> ax11.axis('off') >>> >>> plt.pause(1e-2) >>> >>> plt.show() >>> >>> train_gen.stop()
- load_data()[source]
Get a batch of images and labels
- Returns
data (obj) – Loaded data
label (obj) – Loaded label
stopped (bool) – Check if the end of the list is achieved
- property num_data
Get the number of data
- data.load_segmentation(source_image_filename, mask_image_filename)[source]
Load Segmentation data.
- Parameters
source_image_filename (str) – Filename of the source image
mask_image_filename (str) – Filename of the corresponding mask image in binary format
- Returns
src_image (Image) – Loaded Image object
mask_image (Image) – Image label as mask image
Notes
Note
In Segmentation model we have to feed the model with a simple image and the labels will be given by the mask (binary) of the same image in which the segmentation parts are highlight No checks are performed on the compatibility between source image and corresponding mask file. The only checks are given on the image size (channels are excluded)
- data.load_super_resolution(hr_image_filename, patch_size=(48, 48), scale=4)[source]
Load Super resolution data.
- Parameters
hr_image_filename (string) – Filename of the high resolution image
patch_size (tuple (default=(48, 48))) – Dimension to cut
scale (int (default=4)) – Downsampling scale factor
- Returns
data (Image obj) – Loaded Image object
label (Image obj) – Generated Image label
Notes
Note
In SR models the labels are given by the HR image while the input data are obtained from the same image after a downsampling/resizing. The upsample scale factor learned by the SR model will be the same used inside this function.