Skip to content

Commit

Permalink
docs: add example for Pytorch flexible primitives [DET-3202]
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyuann committed Jun 25, 2020
1 parent e8167a1 commit cc528df
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
import os
import shutil
import urllib.parse
from typing import Any, Dict

import requests

from torchvision import datasets, transforms


def get_dataset(data_dir: str, train: bool) -> Any:
return datasets.MNIST(
data_dir,
train=train,
transform=transforms.Compose(
[
transforms.ToTensor(),
# These are the precomputed mean and standard deviation of the
# MNIST data; this normalizes the data to have zero mean and unit
# standard deviation.
transforms.Normalize((0.1307,), (0.3081,)),
]
),
)


def download_dataset(download_directory: str, data_config: Dict[str, Any]) -> str:
url = data_config["url"]
url_path = urllib.parse.urlparse(url).path
basename = url_path.rsplit("/", 1)[1]

download_directory = os.path.join(download_directory, "MNIST")
os.makedirs(download_directory, exist_ok=True)
filepath = os.path.join(download_directory, basename)
if not os.path.exists(filepath):
logging.info("Downloading {} to {}".format(url, filepath))

r = requests.get(url, stream=True)
with open(filepath, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)

shutil.unpack_archive(filepath, download_directory)

return os.path.dirname(download_directory)
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""
This example shows how to interact with the Determined PyTorch interface to
build a basic MNIST network.
"""

from typing import Any, Dict, Union, Sequence

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
import torchvision
from determined.tensorboard.metric_writers.pytorch import TorchWriter
from determined.pytorch import PyTorchTrial, PyTorchTrialContext, DataLoader, LRScheduler

import data

TorchData = Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor]


class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
super(Generator, self).__init__()
self.img_shape = img_shape

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)

def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img


class Discriminator(nn.Module):
def __init__(self, img_shape):
super(Discriminator, self).__init__()

self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)

def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)

return validity


class GAN(nn.Module):
def __init__(self, trial_context: PyTorchTrialContext) -> None:
super(GAN, self).__init__()

self.context = trial_context

mnist_shape = (1, 28, 28)
self.generator = Generator(latent_dim=self.context.get_hparam("latent_dim"), img_shape=mnist_shape)
self.discriminator = Discriminator(img_shape=mnist_shape)

def forward(self, z):
return self.generator(z)


class MNistTrial(PyTorchTrial):
def __init__(self, trial_context: PyTorchTrialContext) -> None:
self.context = trial_context
self.logger = TorchWriter()

# Create a unique download directory for each rank so they don't overwrite each other.
self.download_directory = f"/tmp/data-rank{self.context.distributed.get_rank()}"
self.data_downloaded = False

self.model = self.context._Model(GAN(trial_context))

lr = self.context.get_hparam("lr")
b1 = self.context.get_hparam("b1")
b2 = self.context.get_hparam("b2")

self.opt_g = self.context._Optimizer(torch.optim.Adam(self.model.generator.parameters(), lr=lr, betas=(b1, b2)))
self.opt_d = self.context._Optimizer(torch.optim.Adam(self.model.discriminator.parameters(), lr=lr, betas=(b1, b2)))

# self.model, (self.opt_g, self.opt_d) = self.context._configure_apex_amp(
# model=self.model,
# optimizers=[self.opt_g, self.opt_d],
# opt_level="O1",
# )

self.lr_g = self.context._LRScheduler(
lr_scheduler=LambdaLR(self.opt_g, lr_lambda=lambda epoch: 0.95 ** epoch),
step_mode=LRScheduler.StepMode.STEP_EVERY_EPOCH,
)

def build_training_data_loader(self) -> DataLoader:
if not self.data_downloaded:
self.download_directory = data.download_dataset(
download_directory=self.download_directory,
data_config=self.context.get_data_config(),
)
self.data_downloaded = True

train_data = data.get_dataset(self.download_directory, train=True)
return DataLoader(train_data, batch_size=self.context.get_per_slot_batch_size())

def build_validation_data_loader(self) -> DataLoader:
if not self.data_downloaded:
self.download_directory = data.download_dataset(
download_directory=self.download_directory,
data_config=self.context.get_data_config(),
)
self.data_downloaded = True

validation_data = data.get_dataset(self.download_directory, train=False)
return DataLoader(validation_data, batch_size=self.context.get_per_slot_batch_size())

def train_batch(
self, batch: TorchData, model: nn.Module, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
imgs, _ = batch

# train generator
# sample noise and match gpu device (or keep as cpu)
z = torch.randn(imgs.shape[0], self.context.get_hparam("latent_dim"))
z = self.context._to_device(z)

# generate images
generated_imgs = self.model.generator(z)

# log sampled images to current directory
sample_imgs = generated_imgs[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.writer.add_image(f'generated_images_epoch_{epoch_idx}', grid, batch_idx)

# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)
valid = self.context._to_device(valid)

# adversarial loss is binary cross-entropy
g_loss = F.binary_cross_entropy(self.model.discriminator(generated_imgs), valid)

self.context._backward(g_loss)
self.context._step_optimizer(self.opt_g)


# train discriminator
# Measure discriminator's ability to classify real from generated samples

# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
valid = self.context._to_device(valid)

real_loss = F.binary_cross_entropy(self.model.discriminator(imgs), valid)

# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
fake = self.context._to_device(fake)

fake_loss = F.binary_cross_entropy(
self.model.discriminator(generated_imgs.detach()), fake)

# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2

self.context._backward(d_loss)
self.context._step_optimizer(self.opt_d)

output = {
'loss': d_loss,
'g_loss': g_loss,
'd_loss': d_loss,
}
return output

def evaluate_batch(self, batch: TorchData, model: nn.Module) -> Dict[str, Any]:
imgs, _ = batch
valid = torch.ones(imgs.size(0), 1)
valid = self.context._to_device(valid)
loss = F.binary_cross_entropy(self.model.discriminator(imgs), valid)
return {"loss": loss}

Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
This example demonstrates training a simple DNN with pytorch using the Determined
Native API.
"""
import argparse
import json
import pathlib

from determined import experimental
import determined as det

import model_def


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
dest="config",
help="Specifies Determined Experiment configuration.",
default="{}",
)
parser.add_argument("--local", action="store_true", help="Specifies local mode")
parser.add_argument("--test", action="store_true", help="Specifies test mode")
args = parser.parse_args()

config = {
"data": {
"url": "https://s3-us-west-2.amazonaws.com/determined-ai-test-data/pytorch_mnist.tar.gz"
},
"hyperparameters": {
"global_batch_size": 32,
"lr": 0.0002,
"b1": 0.5,
"b2": 0.999,
"latent_dim": 100
},
"searcher": {
"name": "single",
"metric": "loss",
"max_steps": 40,
"smaller_is_better": True,
},
}
config.update(json.loads(args.config))

experimental.create(
trial_def=model_def.MNistTrial,
config=config,
local=args.local,
test=args.test,
context_dir=str(pathlib.Path.cwd()),
)

0 comments on commit cc528df

Please sign in to comment.