DataGenerator

class data.DataGenerator(load_func, batch_size, source_path=None, source_file=None, label_path=None, label_file=None, source_extension='', label_extension='', seed=123, **load_func_kwargs)[source]

Bases: object

Data generator in detached thread.

Parameters
  • load_func (function or lambda) – Function to apply for the preprocessing on a single data/label pair

  • batch_size (int) – Dimension of batch to load

  • source_path (str (default=None)) – Path to the source files

  • source_file (str (default=None)) – Filename in which is stored the list of source files

  • label_path (str (default=None)) – Path to the label files

  • label_file (str (default=None)) – Filename in which is stored the list of label files

  • source_extension (str (default='')) – Extension of the source files

  • label_extension (str (default='')) – Extension of the label files

  • seed (int) – Random seed

  • **load_func_kwargs (dict) – Optional parameters to use in the load_func

Example

>>>  import pylab as plt
>>>
>>>  train_gen = DataGenerator(load_func=load_segmentation, batch_size=2,
>>>                            source_path='/path/to/train/images',
>>>                            label_path='/path/to/mask/images',
>>>                            source_extension='.png',
>>>                            label_extension='.png'
>>>                            )
>>>  train_gen.start()
>>>
>>>  fig, ((ax00, ax01), (ax10, ax11)) = plt.subplots(nrows=2, ncols=2)
>>>
>>>  for i in range(10):
>>>    grabbed = False
>>>
>>>    while not grabbed:
>>>
>>>      (data1, data2), (label1, label2), grabbed = train_gen.load_data()
>>>
>>>    ax00.imshow(data1.get(), cmap='gray')
>>>    ax00.axis('off')
>>>
>>>    ax01.imshow(label1.get(), cmap='gray')
>>>    ax01.axis('off')
>>>
>>>    ax10.imshow(data2.get(), cmap='gray')
>>>    ax10.axis('off')
>>>
>>>    ax11.imshow(label2.get(), cmap='gray')
>>>    ax11.axis('off')
>>>
>>>    plt.pause(1e-2)
>>>
>>>  plt.show()
>>>
>>>  train_gen.stop()
load_data()[source]

Get a batch of images and labels

Returns

  • data (obj) – Loaded data

  • label (obj) – Loaded label

  • stopped (bool) – Check if the end of the list is achieved

property num_data

Get the number of data

start()[source]

Start the thread

stop()[source]

Stop the thread

data.load_segmentation(source_image_filename, mask_image_filename)[source]

Load Segmentation data.

Parameters
  • source_image_filename (str) – Filename of the source image

  • mask_image_filename (str) – Filename of the corresponding mask image in binary format

Returns

  • src_image (Image) – Loaded Image object

  • mask_image (Image) – Image label as mask image

Notes

Note

In Segmentation model we have to feed the model with a simple image and the labels will be given by the mask (binary) of the same image in which the segmentation parts are highlight No checks are performed on the compatibility between source image and corresponding mask file. The only checks are given on the image size (channels are excluded)

data.load_super_resolution(hr_image_filename, patch_size=(48, 48), scale=4)[source]

Load Super resolution data.

Parameters
  • hr_image_filename (string) – Filename of the high resolution image

  • patch_size (tuple (default=(48, 48))) – Dimension to cut

  • scale (int (default=4)) – Downsampling scale factor

Returns

  • data (Image obj) – Loaded Image object

  • label (Image obj) – Generated Image label

Notes

Note

In SR models the labels are given by the HR image while the input data are obtained from the same image after a downsampling/resizing. The upsample scale factor learned by the SR model will be the same used inside this function.