Published on

Keras ImageDataGenerator and tf.Data.Dataset in TensorFlow 2.0

Authors

For this case, I used the TensorFlow documentation here: https://www.tensorflow.org/guide/data.

I'm going to use the dataset flowers as they are used in the doc. The code starts as usual:

import tensorflow as tf

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)

# Define some global parameters
BATCH_SIZE = 32
IMG_DIM = 224
NB_CLASSES = 5

# Define the ImageDataGenerator and the data augmentations
img_gen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    rotation_range=20
)


...

From this point, you can wrap the ImageDataGenerator as a tf.data.Data as follows:

...

ds = tf.data.Dataset.from_generator(
    img_gen.flow_from_directory, args=[flowers],
    output_types=(tf.float32, tf.float32),
    output_shapes=([32, 256, 256, 3], [32, 5])
)

# Then just to try
it = iter(ds)
batch = next(it)


But in my case, I wanted to use more arguments than just flowers. By looking at the docstrings, you can read:


Signature:
train_gen.flow_from_directory(
    directory,
    target_size=(256, 256),
    color_mode='rgb',
    classes=None,
    class_mode='categorical',
    batch_size=32,
    shuffle=True,
    seed=None,
    save_to_dir=None,
    save_prefix='',
    save_format='png',
    follow_links=False,
    subset=None,
    interpolation='nearest',
)

But when passing those arguments into args[] inside tf.data.Dataset.from_generator(), I had this error:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-7-5cabd7ffcaf2> in <module>

...

ValueError: ('Invalid color mode:', b'rgb', '; expected "rgb", "rgba", or "grayscale".')


     [[{{node PyFunc}}]] [Op:IteratorGetNextSync]


Finally, on Github (https://github.com/tensorflow/tensorflow/issues/33133#issuecomment-539418486) a very helpful person suggested me:


import tensorflow as tf

flowers = tf.keras.utils.get_file(
    'flower_photos',
    'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
    untar=True)

# Define some global parameters
BATCH_SIZE = 32
IMG_DIM = 224
NB_CLASSES = 5

# Define the ImageDataGenerator and the data augmentations
img_gen = tf.keras.preprocessing.image.ImageDataGenerator()
gen = img_gen.flow_from_directory(
    flowers,
    (IMG_DIM, IMG_DIM),
    'rgb',
    class_mode='categorical',
    batch_size=BATCH_SIZE,
    shuffle=False
)

ds = tf.data.Dataset.from_generator(lambda: gen,
                     output_types=(tf.float32, tf.float32),
                     output_shapes=([BATCH_SIZE, IMG_DIM, IMG_DIM, 3],
                                    [BATCH_SIZE, NB_CLASSES])
                     )

it = iter(ds)
batch = next(it)
print(batch)


If you didn't notice, the generator is created directly with its arguments:

...
img_gen = tf.keras.preprocessing.image.ImageDataGenerator()
gen = img_gen.flow_from_directory(
    flowers,
    (IMG_DIM, IMG_DIM),
    'rgb',
    class_mode='categorical',
    batch_size=BATCH_SIZE,
    shuffle=False
)

...

Then it is passed by a lambda function to the tf.data.Dataset:

...
ds = tf.data.Dataset.from_generator(lambda: gen,
                     output_types=(tf.float32, tf.float32),
                     output_shapes=([BATCH_SIZE, IMG_DIM, IMG_DIM, 3],
                                    [BATCH_SIZE, NB_CLASSES])
                     )

Without using args[].

Now let's say your case is a binary classification and you're loading your images from jpg. You have a directory tree structure as follows:

├── data
│   └── trainset
│       ├── Good
│       │   ├── good_0001.jpg
│       │   ├── ...
│       ├── Not_Good
│       │   ├── defect_0001.jpg
│       │   ├── defect_0002.jpg
│       │   ├── defect_0003.jpg
│       │   ├── ...

Then your case use:

import pathlib

import tensorflow as tf

# Define where your dataset is
data_dir = "data/trainset/"
data_dir = pathlib.Path(data_dir)

# Get the class names
CLASS_NAMES = np.array([item.name for item in data_dir.glob('*')])
print(f"There is 2 classes: {CLASS_NAMES[0]} and {CLASS_NAMES[1]}.")

# Define general parameters
SEED = 42
BATCH_SIZE = 128
IMG_HEIGHT = 640
IMG_WIDTH = 640

# ImageDataGenerator and Data Augmentations by using keras
train_gen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    horizontal_flip=True,
    vertical_flip=True)

# Here we can define easily the parameters
train_data_gen = train_gen.flow_from_directory(
                    batch_size=BATCH_SIZE,
                                        directory=train_dir,
                                        shuffle=True,
                                        target_size=(IMG_HEIGHT, IMG_WIDTH),
                                        class_mode='binary'
                                        )

# And then wrapping the keras generator
train_ds = tf.data.Dataset.from_generator(
    lambda: train_data_gen,
    output_types=(tf.float32, tf.float32),
    output_shapes = ([BATCH_SIZE,IMG_HEIGHT,IMG_WIDTH,3],
                     [BATCH_SIZE,len(CLASS_NAMES)]))