- Published on
Keras ImageDataGenerator and tf.Data.Dataset in TensorFlow 2.0
- Authors
- Name
- Yann Le Guilly
- @yannlg_
For this case, I used the TensorFlow documentation here: https://www.tensorflow.org/guide/data.
I'm going to use the dataset flowers
as they are used in the doc. The code starts as usual:
import tensorflow as tf
flowers = tf.keras.utils.get_file(
'flower_photos',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
untar=True)
# Define some global parameters
BATCH_SIZE = 32
IMG_DIM = 224
NB_CLASSES = 5
# Define the ImageDataGenerator and the data augmentations
img_gen = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1./255,
rotation_range=20
)
...
From this point, you can wrap the ImageDataGenerator as a tf.data.Data as follows:
...
ds = tf.data.Dataset.from_generator(
img_gen.flow_from_directory, args=[flowers],
output_types=(tf.float32, tf.float32),
output_shapes=([32, 256, 256, 3], [32, 5])
)
# Then just to try
it = iter(ds)
batch = next(it)
But in my case, I wanted to use more arguments than just flowers
. By looking at the docstrings, you can read:
Signature:
train_gen.flow_from_directory(
directory,
target_size=(256, 256),
color_mode='rgb',
classes=None,
class_mode='categorical',
batch_size=32,
shuffle=True,
seed=None,
save_to_dir=None,
save_prefix='',
save_format='png',
follow_links=False,
subset=None,
interpolation='nearest',
)
But when passing those arguments into args[]
inside tf.data.Dataset.from_generator()
, I had this error:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-7-5cabd7ffcaf2> in <module>
...
ValueError: ('Invalid color mode:', b'rgb', '; expected "rgb", "rgba", or "grayscale".')
[[{{node PyFunc}}]] [Op:IteratorGetNextSync]
Finally, on Github (https://github.com/tensorflow/tensorflow/issues/33133#issuecomment-539418486) a very helpful person suggested me:
import tensorflow as tf
flowers = tf.keras.utils.get_file(
'flower_photos',
'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
untar=True)
# Define some global parameters
BATCH_SIZE = 32
IMG_DIM = 224
NB_CLASSES = 5
# Define the ImageDataGenerator and the data augmentations
img_gen = tf.keras.preprocessing.image.ImageDataGenerator()
gen = img_gen.flow_from_directory(
flowers,
(IMG_DIM, IMG_DIM),
'rgb',
class_mode='categorical',
batch_size=BATCH_SIZE,
shuffle=False
)
ds = tf.data.Dataset.from_generator(lambda: gen,
output_types=(tf.float32, tf.float32),
output_shapes=([BATCH_SIZE, IMG_DIM, IMG_DIM, 3],
[BATCH_SIZE, NB_CLASSES])
)
it = iter(ds)
batch = next(it)
print(batch)
If you didn't notice, the generator is created directly with its arguments:
...
img_gen = tf.keras.preprocessing.image.ImageDataGenerator()
gen = img_gen.flow_from_directory(
flowers,
(IMG_DIM, IMG_DIM),
'rgb',
class_mode='categorical',
batch_size=BATCH_SIZE,
shuffle=False
)
...
Then it is passed by a lambda function to the tf.data.Dataset:
...
ds = tf.data.Dataset.from_generator(lambda: gen,
output_types=(tf.float32, tf.float32),
output_shapes=([BATCH_SIZE, IMG_DIM, IMG_DIM, 3],
[BATCH_SIZE, NB_CLASSES])
)
Without using args[]
.
Now let's say your case is a binary classification and you're loading your images from jpg. You have a directory tree structure as follows:
├── data
│ └── trainset
│ ├── Good
│ │ ├── good_0001.jpg
│ │ ├── ...
│ ├── Not_Good
│ │ ├── defect_0001.jpg
│ │ ├── defect_0002.jpg
│ │ ├── defect_0003.jpg
│ │ ├── ...
Then your case use:
import pathlib
import tensorflow as tf
# Define where your dataset is
data_dir = "data/trainset/"
data_dir = pathlib.Path(data_dir)
# Get the class names
CLASS_NAMES = np.array([item.name for item in data_dir.glob('*')])
print(f"There is 2 classes: {CLASS_NAMES[0]} and {CLASS_NAMES[1]}.")
# Define general parameters
SEED = 42
BATCH_SIZE = 128
IMG_HEIGHT = 640
IMG_WIDTH = 640
# ImageDataGenerator and Data Augmentations by using keras
train_gen = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1./255,
horizontal_flip=True,
vertical_flip=True)
# Here we can define easily the parameters
train_data_gen = train_gen.flow_from_directory(
batch_size=BATCH_SIZE,
directory=train_dir,
shuffle=True,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary'
)
# And then wrapping the keras generator
train_ds = tf.data.Dataset.from_generator(
lambda: train_data_gen,
output_types=(tf.float32, tf.float32),
output_shapes = ([BATCH_SIZE,IMG_HEIGHT,IMG_WIDTH,3],
[BATCH_SIZE,len(CLASS_NAMES)]))