Licensed under the Apache License, Version 2.0 (the "License");
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
This tutorial demonstrates how to build and train a conditional generative adversarial network (cGAN) called pix2pix that learns a mapping from input images to output images, as described in Image-to-image translation with conditional adversarial networks by Isola et al. (2017). pix2pix is not application specific—it can be applied to a wide range of tasks, including synthesizing photos from label maps, generating colorized photos from black and white images, turning Google Maps photos into aerial images, and even transforming sketches into photos.
In this example, your network will generate images of building facades using the CMP Facade Database provided by the Center for Machine Perception at the Czech Technical University in Prague. To keep it short, you will use a preprocessed copy) of this dataset created by the pix2pix authors.
In the pix2pix cGAN, you condition on input images and generate corresponding output images. cGANs were first proposed in Conditional Generative Adversarial Nets (Mirza and Osindero, 2014)
The architecture of your network will contain:
Note that each epoch can take around 15 seconds on a single V100 GPU.
Below are some examples of the output generated by the pix2pix cGAN after training for 200 epochs on the facades dataset (80k steps).
import tensorflow as tf
import os
import pathlib
import time
import datetime
from matplotlib import pyplot as plt
from IPython import display
dataset_name = "night2day" #@param ["cityscapes", "edges2handbags", "edges2shoes", "facades", "maps", "night2day"]
# _URL = f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{dataset_name}.tar.gz'
# path_to_zip = tf.keras.utils.get_file(
# fname=f"{dataset_name}.tar.gz",
# origin=_URL,
# extract=True)
# path_to_zip = pathlib.Path(path_to_zip)
# PATH = path_to_zip.parent/dataset_name
# list(PATH.parent.iterdir())
Each original image is of size 256 x 512
containing two 256 x 256
images:
sample_image = tf.io.read_file(str('./real-paired/1-0.1.jpg'))
sample_image = tf.io.decode_jpeg(sample_image)
print(sample_image.shape)
(256, 512, 3)
plt.figure()
plt.imshow(sample_image)
<matplotlib.image.AxesImage at 0x7fce528d3460>
You need to separate real building facade images from the architecture label images—all of which will be of size 256 x 256
.
Define a function that loads image files and outputs two image tensors:
def load(image_file):
# Read and decode an image file to a uint8 tensor
image = tf.io.read_file(image_file)
image = tf.io.decode_jpeg(image)
# Split each image tensor into two tensors:
# - one with a real building facade image
# - one with an architecture label image
w = tf.shape(image)[1]
w = w // 2
real_image = image[:, w:, :]
input_image = image[:, :w, :]
# Convert both images to float32 tensors
input_image = tf.cast(input_image, tf.float32)
real_image = tf.cast(real_image, tf.float32)
return input_image, real_image
Plot a sample of the input (architecture label image) and real (building facade photo) images:
inp, re = load(str('./real-paired/1-0.1.jpg'))
# Casting to int for matplotlib to display the images
plt.figure()
plt.imshow(inp/255)
plt.figure()
plt.imshow(re/255)
<matplotlib.image.AxesImage at 0x7fce52a7ed60>
As described in the pix2pix paper, you need to apply random jittering and mirroring to preprocess the training set.
Define several functions that:
256 x 256
image to a larger height and width—286 x 286
.256 x 256
.[-1, 1]
range.# The facade training set consist of 400 images
BUFFER_SIZE = 400
# The batch size of 1 produced better results for the U-Net in the original pix2pix experiment
BATCH_SIZE = 1
# Each image is 256x256 in size
IMG_WIDTH = 256
IMG_HEIGHT = 256
def resize(input_image, real_image, height, width):
input_image = tf.image.resize(input_image, [height, width],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
real_image = tf.image.resize(real_image, [height, width],
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return input_image, real_image
def random_crop(input_image, real_image):
stacked_image = tf.stack([input_image, real_image], axis=0)
cropped_image = tf.image.random_crop(
stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])
return cropped_image[0], cropped_image[1]
# Normalizing the images to [-1, 1]
def normalize(input_image, real_image):
input_image = (input_image / 127.5) - 1
real_image = (real_image / 127.5) - 1
return input_image, real_image
@tf.function()
def random_jitter(input_image, real_image):
# Resizing to 286x286
input_image, real_image = resize(input_image, real_image, 286, 286)
# Random cropping back to 256x256
input_image, real_image = random_crop(input_image, real_image)
if tf.random.uniform(()) > 0.5:
# Random mirroring
input_image = tf.image.flip_left_right(input_image)
real_image = tf.image.flip_left_right(real_image)
return input_image, real_image
You can inspect some of the preprocessed output:
plt.figure(figsize=(6, 6))
for i in range(4):
rj_inp, rj_re = random_jitter(inp, re)
plt.subplot(2, 2, i + 1)
plt.imshow(rj_inp / 255.0)
plt.axis('off')
plt.show()
Having checked that the loading and preprocessing works, let's define a couple of helper functions that load and preprocess the training and test sets:
def load_image_train(image_file):
input_image, real_image = load(image_file)
input_image, real_image = random_jitter(input_image, real_image)
input_image, real_image = normalize(input_image, real_image)
return input_image, real_image
def load_image_test(image_file):
input_image, real_image = load(image_file)
input_image, real_image = resize(input_image, real_image,
IMG_HEIGHT, IMG_WIDTH)
input_image, real_image = normalize(input_image, real_image)
return input_image, real_image
tf.data
¶train_dataset = tf.data.Dataset.list_files(str('./real-paired/*.jpg'))
train_dataset = train_dataset.map(load_image_train,
num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.list_files(str('./test/*.jpg'))
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)
The generator of your pix2pix cGAN is a modified U-Net. A U-Net consists of an encoder (downsampler) and decoder (upsampler). (You can find out more about it in the Image segmentation tutorial and on the U-Net project website.)
Define the downsampler (encoder):
OUTPUT_CHANNELS = 3
def downsample(filters, size, apply_batchnorm=True):
initializer = tf.random_normal_initializer(0., 0.02)
result = tf.keras.Sequential()
result.add(
tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
kernel_initializer=initializer, use_bias=False))
if apply_batchnorm:
result.add(tf.keras.layers.BatchNormalization())
result.add(tf.keras.layers.LeakyReLU())
return result
down_model = downsample(3, 4)
down_result = down_model(tf.expand_dims(inp, 0))
print (down_result.shape)
(1, 128, 128, 3)
Define the upsampler (decoder):
def upsample(filters, size, apply_dropout=False):
initializer = tf.random_normal_initializer(0., 0.02)
result = tf.keras.Sequential()
result.add(
tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
padding='same',
kernel_initializer=initializer,
use_bias=False))
result.add(tf.keras.layers.BatchNormalization())
if apply_dropout:
result.add(tf.keras.layers.Dropout(0.5))
result.add(tf.keras.layers.ReLU())
return result
up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)
(1, 256, 256, 3)
Define the generator with the downsampler and the upsampler:
def Generator():
#inputs = tf.keras.layers.Input(shape=[256, 256, 3])
inputs = tf.keras.layers.Input(shape=[256, 256, 3])
down_stack = [
downsample(64, 4, apply_batchnorm=False), # (batch_size, 128, 128, 64)
downsample(128, 4), # (batch_size, 64, 64, 128)
downsample(256, 4), # (batch_size, 32, 32, 256)
downsample(512, 4), # (batch_size, 16, 16, 512)
downsample(512, 4), # (batch_size, 8, 8, 512)
downsample(512, 4), # (batch_size, 4, 4, 512)
downsample(512, 4), # (batch_size, 2, 2, 512)
downsample(512, 4), # (batch_size, 1, 1, 512)
]
up_stack = [
upsample(512, 4), # (batch_size, 16, 16, 1024)
upsample(512, 4), # (batch_size, 16, 16, 1024)
upsample(512, 4), # (batch_size, 16, 16, 1024)
upsample(512, 4), # (batch_size, 16, 16, 1024)
upsample(256, 4), # (batch_size, 16, 16, 1024)
upsample(128, 4), # (batch_size, 32, 32, 512)
upsample(64, 4), # (batch_size, 64, 64, 256)
]
# down_stack = [
# downsample(32, 4, apply_batchnorm=False), # (batch_size, 128, 128, 64)
# downsample(64, 4), # (batch_size, 32, 32, 256)
# downsample(128, 4), # (batch_size, 16, 16, 512)
# downsample(128, 4) ]
# up_stack = [
# upsample(128, 4), # (batch_size, 16, 16, 1024)
# upsample(64, 4), # (batch_size, 64, 64, 256)
# upsample(32, 4), # (batch_size, 128, 128, 128)
# ]
initializer = tf.random_normal_initializer(0., 0.02)
last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
strides=2,
padding='same',
kernel_initializer=initializer,
activation='tanh') # (batch_size, 256, 256, 3)
x = inputs
# Downsampling through the model
skips = []
for down in down_stack:
x = down(x)
skips.append(x)
skips = reversed(skips[:-1])
# Upsampling and establishing the skip connections
for up, skip in zip(up_stack, skips):
x = up(x)
x = tf.keras.layers.Concatenate()([x, skip])
x = last(x)
return tf.keras.Model(inputs=inputs, outputs=x)
Visualize the generator model architecture:
generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)
Test the generator:
gen_output = generator(inp[tf.newaxis, ...], training=False)
plt.imshow(gen_output[0, ...])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x7fce16a1f7c0>
GANs learn a loss that adapts to the data, while cGANs learn a structured loss that penalizes a possible structure that differs from the network output and the target image, as described in the pix2pix paper.
gan_loss + LAMBDA * l1_loss
, where LAMBDA = 100
. This value was decided by the authors of the paper.LAMBDA = 100
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(disc_generated_output, gen_output, target):
gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
# Mean absolute error
l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
total_gen_loss = gan_loss + (LAMBDA * l1_loss)
return total_gen_loss, gan_loss, l1_loss
The training procedure for the generator is as follows:
The discriminator in the pix2pix cGAN is a convolutional PatchGAN classifier—it tries to classify if each image patch is real or not real, as described in the pix2pix paper.
(batch_size, 30, 30, 1)
.30 x 30
image patch of the output classifies a 70 x 70
portion of the input image.tf.concat([inp, tar], axis=-1)
to concatenate these 2 inputs together.Let's define the discriminator:
def Discriminator():
initializer = tf.random_normal_initializer(0., 0.02)
inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')
x = tf.keras.layers.concatenate([inp, tar]) # (batch_size, 256, 256, channels*2)
down1 = downsample(64, 4, False)(x) # (batch_size, 128, 128, 64)
down2 = downsample(128, 4)(down1) # (batch_size, 64, 64, 128)
down3 = downsample(256, 4)(down2) # (batch_size, 32, 32, 256)
zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (batch_size, 34, 34, 256)
conv = tf.keras.layers.Conv2D(512, 4, strides=1,
kernel_initializer=initializer,
use_bias=False)(zero_pad1) # (batch_size, 31, 31, 512)
batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (batch_size, 33, 33, 512)
last = tf.keras.layers.Conv2D(1, 4, strides=1,
kernel_initializer=initializer)(zero_pad2) # (batch_size, 30, 30, 1)
return tf.keras.Model(inputs=[inp, tar], outputs=last)
Visualize the discriminator model architecture:
discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)
Test the discriminator:
disc_out = discriminator([inp[tf.newaxis, ...], gen_output], training=False)
plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7fcddb93fa30>
discriminator_loss
function takes 2 inputs: real images and generated images.real_loss
is a sigmoid cross-entropy loss of the real images and an array of ones(since these are the real images).generated_loss
is a sigmoid cross-entropy loss of the generated images and an array of zeros (since these are the fake images).total_loss
is the sum of real_loss
and generated_loss
.def discriminator_loss(disc_real_output, disc_generated_output):
real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)
generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
total_disc_loss = real_loss + generated_loss
return total_disc_loss
The training procedure for the discriminator is shown below.
To learn more about the architecture and the hyperparameters you can refer to the pix2pix paper.
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
Write a function to plot some images during training.
Note: The training=True
is intentional here since
you want the batch statistics, while running the model on the test dataset. If you use training=False
, you get the accumulated statistics learned from the training dataset (which you don't want).
def generate_images(model, test_input, tar):
prediction = model(test_input, training=True)
plt.figure(figsize=(15, 15))
display_list = [test_input[0], tar[0], prediction[0]]
title = ['Input Image', 'Ground Truth', 'Predicted Image']
for i in range(3):
plt.subplot(1, 3, i+1)
plt.title(title[i])
# Getting the pixel values in the [0, 1] range to plot.
plt.imshow(display_list[i] * 0.5 + 0.5)
plt.axis('off')
plt.show()
Test the function:
for example_input, example_target in test_dataset.take(1):
generate_images(generator, example_input, example_target)
input_image
and the generated image as the first input. The second input is the input_image
and the target_image
.log_dir="logs/"
summary_writer = tf.summary.create_file_writer(
log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
@tf.function
def train_step(input_image, target, step):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
gen_output = generator(input_image, training=True)
disc_real_output = discriminator([input_image, target], training=True)
disc_generated_output = discriminator([input_image, gen_output], training=True)
gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
generator_gradients = gen_tape.gradient(gen_total_loss,
generator.trainable_variables)
discriminator_gradients = disc_tape.gradient(disc_loss,
discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(generator_gradients,
generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
discriminator.trainable_variables))
with summary_writer.as_default():
tf.summary.scalar('gen_total_loss', gen_total_loss, step=step//1000)
tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=step//1000)
tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=step//1000)
tf.summary.scalar('disc_loss', disc_loss, step=step//1000)
The actual training loop. Since this tutorial can run of more than one dataset, and the datasets vary greatly in size the training loop is setup to work in steps instead of epochs.
.
).generate_images
to show the progress.def fit(train_ds, test_ds, steps):
example_input, example_target = next(iter(test_ds.take(1)))
start = time.time()
for step, (input_image, target) in train_ds.repeat().take(steps).enumerate():
if (step) % 1000 == 0:
display.clear_output(wait=True)
if step != 0:
print(f'Time taken for 1000 steps: {time.time()-start:.2f} sec\n')
start = time.time()
generate_images(generator, example_input, example_target)
print(f"Step: {step//1000}k")
train_step(input_image, target, step)
# Training step
if (step+1) % 10 == 0:
print('.', end='', flush=True)
# Save (checkpoint) the model every 5k steps
if (step + 1) % 5000 == 0:
checkpoint.save(file_prefix=checkpoint_prefix)
This training loop saves logs that you can view in TensorBoard to monitor the training progress.
If you work on a local machine, you would launch a separate TensorBoard process. When working in a notebook, launch the viewer before starting the training to monitor with TensorBoard.
To launch the viewer paste the following into a code-cell:
%load_ext tensorboard
%tensorboard --logdir {log_dir}
The tensorboard extension is already loaded. To reload it, use: %reload_ext tensorboard
Reusing TensorBoard on port 6006 (pid 31815), started 0:49:30 ago. (Use '!kill 31815' to kill it.)
Finally, run the training loop:
fit(train_dataset, test_dataset, steps=300000)
Time taken for 1000 steps: 41.01 sec
Step: 299k ....................................................................................................
If you want to share the TensorBoard results publicly, you can upload the logs to TensorBoard.dev by copying the following into a code-cell.
Note: This requires a Google account.
!tensorboard dev upload --logdir {log_dir}
Caution: This command does not terminate. It's designed to continuously upload the results of long-running experiments. Once your data is uploaded you need to stop it using the "interrupt execution" option in your notebook tool.
You can view the results of a previous run of this notebook on TensorBoard.dev.
TensorBoard.dev is a managed experience for hosting, tracking, and sharing ML experiments with everyone.
It can also included inline using an <iframe>
:
display.IFrame(
src="https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw",
width="100%",
height="1000px")
Interpreting the logs is more subtle when training a GAN (or a cGAN like pix2pix) compared to a simple classification or regression model. Things to look for:
gen_gan_loss
or the disc_loss
gets very low, it's an indicator that this model is dominating the other, and you are not successfully training the combined model.log(2) = 0.69
is a good reference point for these losses, as it indicates a perplexity of 2 - the discriminator is, on average, equally uncertain about the two options.disc_loss
, a value below 0.69
means the discriminator is doing better than random on the combined set of real and generated images.gen_gan_loss
, a value below 0.69
means the generator is doing better than random at fooling the discriminator.gen_l1_loss
should go down.!ls {checkpoint_dir}
checkpoint ckpt-40.data-00000-of-00001 ckpt-1.data-00000-of-00001 ckpt-40.index ckpt-1.index ckpt-41.data-00000-of-00001 ckpt-10.data-00000-of-00001 ckpt-41.index ckpt-10.index ckpt-42.data-00000-of-00001 ckpt-11.data-00000-of-00001 ckpt-42.index ckpt-11.index ckpt-43.data-00000-of-00001 ckpt-12.data-00000-of-00001 ckpt-43.index ckpt-12.index ckpt-44.data-00000-of-00001 ckpt-13.data-00000-of-00001 ckpt-44.index ckpt-13.index ckpt-45.data-00000-of-00001 ckpt-14.data-00000-of-00001 ckpt-45.index ckpt-14.index ckpt-46.data-00000-of-00001 ckpt-15.data-00000-of-00001 ckpt-46.index ckpt-15.index ckpt-47.data-00000-of-00001 ckpt-16.data-00000-of-00001 ckpt-47.index ckpt-16.index ckpt-48.data-00000-of-00001 ckpt-17.data-00000-of-00001 ckpt-48.index ckpt-17.index ckpt-49.data-00000-of-00001 ckpt-18.data-00000-of-00001 ckpt-49.index ckpt-18.index ckpt-5.data-00000-of-00001 ckpt-19.data-00000-of-00001 ckpt-5.index ckpt-19.index ckpt-50.data-00000-of-00001 ckpt-2.data-00000-of-00001 ckpt-50.index ckpt-2.index ckpt-51.data-00000-of-00001 ckpt-20.data-00000-of-00001 ckpt-51.index ckpt-20.index ckpt-52.data-00000-of-00001 ckpt-21.data-00000-of-00001 ckpt-52.index ckpt-21.index ckpt-53.data-00000-of-00001 ckpt-22.data-00000-of-00001 ckpt-53.index ckpt-22.index ckpt-54.data-00000-of-00001 ckpt-23.data-00000-of-00001 ckpt-54.index ckpt-23.index ckpt-55.data-00000-of-00001 ckpt-24.data-00000-of-00001 ckpt-55.index ckpt-24.index ckpt-56.data-00000-of-00001 ckpt-25.data-00000-of-00001 ckpt-56.index ckpt-25.index ckpt-57.data-00000-of-00001 ckpt-26.data-00000-of-00001 ckpt-57.index ckpt-26.index ckpt-58.data-00000-of-00001 ckpt-27.data-00000-of-00001 ckpt-58.index ckpt-27.index ckpt-59.data-00000-of-00001 ckpt-28.data-00000-of-00001 ckpt-59.index ckpt-28.index ckpt-6.data-00000-of-00001 ckpt-29.data-00000-of-00001 ckpt-6.index ckpt-29.index ckpt-60.data-00000-of-00001 ckpt-3.data-00000-of-00001 ckpt-60.index ckpt-3.index ckpt-61.data-00000-of-00001 ckpt-30.data-00000-of-00001 ckpt-61.index ckpt-30.index ckpt-62.data-00000-of-00001 ckpt-31.data-00000-of-00001 ckpt-62.index ckpt-31.index ckpt-63.data-00000-of-00001 ckpt-32.data-00000-of-00001 ckpt-63.index ckpt-32.index ckpt-64.data-00000-of-00001 ckpt-33.data-00000-of-00001 ckpt-64.index ckpt-33.index ckpt-65.data-00000-of-00001 ckpt-34.data-00000-of-00001 ckpt-65.index ckpt-34.index ckpt-66.data-00000-of-00001 ckpt-35.data-00000-of-00001 ckpt-66.index ckpt-35.index ckpt-67.data-00000-of-00001 ckpt-36.data-00000-of-00001 ckpt-67.index ckpt-36.index ckpt-68.data-00000-of-00001 ckpt-37.data-00000-of-00001 ckpt-68.index ckpt-37.index ckpt-7.data-00000-of-00001 ckpt-38.data-00000-of-00001 ckpt-7.index ckpt-38.index ckpt-8.data-00000-of-00001 ckpt-39.data-00000-of-00001 ckpt-8.index ckpt-39.index ckpt-9.data-00000-of-00001 ckpt-4.data-00000-of-00001 ckpt-9.index ckpt-4.index
# Restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fcd8f4ac5e0>
# Run the trained model on a few examples from the test set
for inp, tar in test_dataset.take(5):
generate_images(generator, inp, tar)
키키
checkpoint_test = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
# Restoring the latest checkpoint in checkpoint_dir
checkpoint_test.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fcd8f21a940>
# Run the trained model on a few examples from the test set
for inp, tar in test_dataset.take(5):
generate_images(generator, inp, tar)