This example demonstrates how to build a simple GAN on the MNIST dataset using Determined's TensorFlow Keras API. This example is adapted from this TensorFlow Tutorial.
The DCGAN Keras model featured in this example subclasses tf.keras.Model
and defines a custom train_step()
and test_step()
. This functionality was first added in TensorFlow 2.2.
- dc_gan.py: The code code defining the model.
- data.py: The data loading and preparation code for the model.
- model_def.py: Organizes the model into Determined's TensorFlow Keras API.
- export.py: Exports a trained checkpoint and uses it to generate images.
- const.yaml: Train the model with constant hyperparameter values.
- distributed.yaml: Same as const.yaml, but instead uses multiple GPUs (distributed training).
Installation instructions can be found under docs/install-admin.html
or at Determined installation page.
After configuring the settings in const.yaml
, run the following command: det -m <master host:port> experiment create -f const.yaml .
Once the model has been trained, its top checkpoint can be exported and used to generate images by running:
python export.py --experiment-id <experimend_id> --master-url <master:port>