Source code for data

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import division
from __future__ import print_function

import os
import time
import numpy as np
from glob import glob
from threading import Thread
from functools import partial

from NumPyNet.image import Image

__author__ = ['Mattia Ceccarelli', 'Nico Curti']
__email__ = ['mattia.ceccarelli3@studio.unibo.it', 'nico.curti2@unibo.it']


[docs]class DataGenerator (object): ''' Data generator in detached thread. Parameters ---------- load_func : function or lambda Function to apply for the preprocessing on a single data/label pair batch_size : int Dimension of batch to load source_path : str (default=None) Path to the source files source_file : str (default=None) Filename in which is stored the list of source files label_path : str (default=None) Path to the label files label_file : str (default=None) Filename in which is stored the list of label files source_extension : str (default='') Extension of the source files label_extension : str (default='') Extension of the label files seed : int Random seed **load_func_kwargs : dict Optional parameters to use in the load_func Example ------- >>> import pylab as plt >>> >>> train_gen = DataGenerator(load_func=load_segmentation, batch_size=2, >>> source_path='/path/to/train/images', >>> label_path='/path/to/mask/images', >>> source_extension='.png', >>> label_extension='.png' >>> ) >>> train_gen.start() >>> >>> fig, ((ax00, ax01), (ax10, ax11)) = plt.subplots(nrows=2, ncols=2) >>> >>> for i in range(10): >>> grabbed = False >>> >>> while not grabbed: >>> >>> (data1, data2), (label1, label2), grabbed = train_gen.load_data() >>> >>> ax00.imshow(data1.get(), cmap='gray') >>> ax00.axis('off') >>> >>> ax01.imshow(label1.get(), cmap='gray') >>> ax01.axis('off') >>> >>> ax10.imshow(data2.get(), cmap='gray') >>> ax10.axis('off') >>> >>> ax11.imshow(label2.get(), cmap='gray') >>> ax11.axis('off') >>> >>> plt.pause(1e-2) >>> >>> plt.show() >>> >>> train_gen.stop() ''' def __init__ (self, load_func, batch_size, source_path=None, source_file=None, label_path=None, label_file=None, source_extension='', label_extension='', seed=123, **load_func_kwargs): np.random.seed(seed) if source_path is None and source_file is None: raise ValueError('Source path and Source file can not be both null. Please give one of them') if source_path is not None: if not os.path.exists(source_path): raise ValueError('Source path does not exist') source_files = sorted(glob(source_path + '/*{}'.format(source_extension))) else: with open(source_file) as fp: source_files = fp.read().splitlines() source_files = np.asarray(source_files) if label_path is not None: if not os.path.exists(label_path): raise ValueError('Labels path does not exist') label_files = sorted(glob(label_path + '/*{}'.format(label_extension))) label_files = np.asarray(label_files) elif label_file is not None: with open(label_file) as fp: label_files = fp.read().splitlines() # convert to unique numbers _, label_files = np.unique(label_files, return_inverse=True) label_files = label_files.astype(float) else: label_files = None self._num_data = source_files.size source_files, label_files = self._randomize(source_files, label_files) load_func = partial(load_func, **load_func_kwargs) self.load_func = load_func self._batch = batch_size self._thread = Thread(target=self._update, args=(source_files, label_files)) self._thread.daemon = True self._current_batch = 0 self._stopped = False self._data, self._label = (None, None) @property def num_data (self): ''' Get the number of data ''' return self._num_data def _randomize (self, source, label=None): ''' Randomize the source and labels arrays Parameters ---------- source : array-like List of source files label : array-like (default = None) List of label files Return ------ source : array-like Array of source shuffled label : array-like Array of labels shuffled ''' if label is not None: random_index = np.arange(0, self._num_data) np.random.shuffle(random_index) source = source[random_index] label = label[random_index] else: np.random.shuffle(source) return (source, label) def _load (self, sources, labels=None): ''' Map the loading function over the sources and labels Parameters ---------- sources : list List of filenames to load labels : list (default=None) List of labels filenames to load Returns ------- data : array-like Data read according to the load_func label : array-like Labels read according to the load_func ''' if labels is not None: try: self._data, self._label = zip(*map(self.load_func, sources, labels)) except Exception as e: self._stopped = True raise e return (self._data, self._label) else: try: self._data = zip(*map(self.load_func, sources)) except Exception as e: self._stopped = True raise e return (self._data, None) def _update (self, source_files, label_files): ''' Infinite loop of batch reading. Each batch is read only if necessary (the previous is already used). Parameters ---------- source_files : list List of source files to load label_files : list List of label files to load ''' start_time = time.time() elapsed = 1. while not self._stopped: if self._data is None: # we reach the end of batch if self._current_batch + self._batch >= self._num_data: source_files, label_files = self._randomize(source_files, label_files) self._current_batch = 0 if label_files is not None: self._data, self._label = self._load(source_files[self._current_batch : self._current_batch + self._batch], label_files[self._current_batch : self._current_batch + self._batch]) else: self._data, self._label = self._load(source_files[self._current_batch : self._current_batch + self._batch]) self._current_batch += self._batch elapsed = time.time() - start_time print('Elapsed {:.3f} sec.'.format(elapsed)) else: time.sleep(elapsed + .05) start_time = time.time()
[docs] def start (self): ''' Start the thread ''' self._thread.start() time.sleep(1.) return self
[docs] def stop (self): ''' Stop the thread ''' self._stopped = True self._thread.join()
[docs] def load_data (self): ''' Get a batch of images and labels Returns ------- data : obj Loaded data label : obj Loaded label stopped : bool Check if the end of the list is achieved ''' data, label = (self._data, self._label) self._data = None if label is not None: return (data, label, not self._stopped) else: return (data, not self._stopped)
[docs]def load_super_resolution (hr_image_filename, patch_size=(48, 48), scale=4): ''' Load Super resolution data. Parameters ---------- hr_image_filename : string Filename of the high resolution image patch_size : tuple (default=(48, 48)) Dimension to cut scale : int (default=4) Downsampling scale factor Returns ------- data : Image obj Loaded Image object label : Image obj Generated Image label Notes ----- .. note:: In SR models the labels are given by the HR image while the input data are obtained from the same image after a downsampling/resizing. The upsample scale factor learned by the SR model will be the same used inside this function. ''' hr_image = Image(hr_image_filename) w, h, _ = hr_image.shape patch_x, patch_y = patch_size dx = np.random.uniform(low=0, high=w - patch_x - 1) dy = np.random.uniform(low=0, high=h - patch_y - 1) hr_image = hr_image.crop(dsize=(dx, dy), size=patch_size) random_flip = np.random.uniform(low=0, high=1.) if random_flip >= .66: hr_image = hr_image.transpose() hr_image = hr_image.flip() elif random_flip >= .33: hr_image = hr_image.flip() else: pass label = hr_image data = hr_image.resize(scale_factor=(scale, scale)) return (data, label)
[docs]def load_segmentation (source_image_filename, mask_image_filename): ''' Load Segmentation data. Parameters ---------- source_image_filename : str Filename of the source image mask_image_filename : str Filename of the corresponding mask image in binary format Returns ------- src_image : Image Loaded Image object mask_image : Image Image label as mask image Notes ----- .. note:: In Segmentation model we have to feed the model with a simple image and the labels will be given by the mask (binary) of the same image in which the segmentation parts are highlight No checks are performed on the compatibility between source image and corresponding mask file. The only checks are given on the image size (channels are excluded) ''' src_image = Image(source_image_filename) mask_image = Image(mask_image_filename) if src_image.shape[:2] != mask_image.shape[:2]: raise ValueError('Incorrect shapes found. The source image and the corresponding mask have different sizes') return (src_image, mask_image)
if __name__ == '__main__': import pylab as plt train_gen = DataGenerator(load_func=load_segmentation, batch_size=2, source_path='/path/to/train/images', label_path='/path/to/mask/images', source_extension='.png', label_extension='.png' ) train_gen.start() fig, ((ax00, ax01), (ax10, ax11)) = plt.subplots(nrows=2, ncols=2) for i in range(10): grabbed = False while not grabbed: (data1, data2), (label1, label2), grabbed = train_gen.load_data() ax00.imshow(data1.get(), cmap='gray') ax00.axis('off') ax01.imshow(label1.get(), cmap='gray') ax01.axis('off') ax10.imshow(data2.get(), cmap='gray') ax10.axis('off') ax11.imshow(label2.get(), cmap='gray') ax11.axis('off') plt.pause(1e-2) plt.show() train_gen.stop()