Home

Awesome

Tensorflow-DatasetAPI

Simple Tensorflow DatasetAPI Tutorial for reading image

Usage

1. glob images

ex) trainA_dataset = glob('./dataset/{}/*.*'.format(dataset_name + '/trainA'))

trainA_dataset = ['./dataset/cat/trainA/a.jpg', 
                  './dataset/cat/trainA/b.png', 
                  './dataset/cat/trainA/c.jpeg', 
                  ...]

2. Use from_tensor_slices

trainA = tf.data.Dataset.from_tensor_slices(trainA_dataset)

3. Use map for preprocessing


    def image_processing(filename):
        x = tf.read_file(filename) # file read 
        x_decode = tf.image.decode_jpeg(x, channels=3) # for RGB

        # DO NOT USE decode_image
        # will be error

        img = tf.image.resize_images(x_decode, [256, 256])
        img = tf.cast(img, tf.float32) / 127.5 - 1

        return img
        

trainA = trainA.map(image_processing, num_parallel_calls=8)

class ImageData:

    def __init__(self, batch_size, load_size, channels, augment_flag):
        self.batch_size = batch_size
        self.load_size = load_size
        self.channels = channels
        self.augment_flag = augment_flag
        self.augment_size = load_size + (30 if load_size == 256 else 15)

    def image_processing(self, filename):
        x = tf.read_file(filename)
        x_decode = tf.image.decode_jpeg(x, channels=self.channels)
        
        # DO NOT USE decode_image
        # will be error
        
        img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
        img = tf.cast(img, tf.float32) / 127.5 - 1

        if self.augment_flag :
            p = random.random()
            if p > 0.5:
                img = self.augmentation(img)

        return img
        
    def augmentation(self, image):
        seed = random.randint(0, 2 ** 31 - 1)
    
        ori_image_shape = tf.shape(image)
        image = tf.image.random_flip_left_right(image, seed=seed)
        image = tf.image.resize_images(image, [self.augment_size, self.augment_size])
        image = tf.random_crop(image, ori_image_shape, seed=seed)
    
        return image
    
    
Image_Data_Class = ImageData(batch_size, img_size, img_ch, augment_flag)
trainA = trainA.map(Image_Data_Class.image_processing, num_parallel_calls=8)


4. Set prefetch & batch_size


trainA = trainA.shuffle(buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat()

trainA = trainA.shuffle(10000).prefetch(batch_size).apply(batch_and_drop_remainder(batch_size)).repeat()
# hyper-parameter examples
gpu_device = '/gpu:0'
dataset_num = 10000
batch_size = 8

trainA = trainA.apply(shuffle_and_repeat(dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, batch_size))

5. Set Iterator


trainA_iterator = trainA.make_one_shot_iterator()

data_A = trainA_iterator.get_next()
logit = network(data_A)
...


6. Run Model


def train() :
    for epoch ...
        for iteration ...


7. See Code

Author

Junho Kim