-
Notifications
You must be signed in to change notification settings - Fork 356
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
docs: add example for Pytorch flexible primitives [DET-3202]
- Loading branch information
Showing
3 changed files
with
299 additions
and
0 deletions.
There are no files selected for viewing
47 changes: 47 additions & 0 deletions
47
examples/experimental/native/flexible_primitives_mnist_pytorch/data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import logging | ||
import os | ||
import shutil | ||
import urllib.parse | ||
from typing import Any, Dict | ||
|
||
import requests | ||
|
||
from torchvision import datasets, transforms | ||
|
||
|
||
def get_dataset(data_dir: str, train: bool) -> Any: | ||
return datasets.MNIST( | ||
data_dir, | ||
train=train, | ||
transform=transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
# These are the precomputed mean and standard deviation of the | ||
# MNIST data; this normalizes the data to have zero mean and unit | ||
# standard deviation. | ||
transforms.Normalize((0.1307,), (0.3081,)), | ||
] | ||
), | ||
) | ||
|
||
|
||
def download_dataset(download_directory: str, data_config: Dict[str, Any]) -> str: | ||
url = data_config["url"] | ||
url_path = urllib.parse.urlparse(url).path | ||
basename = url_path.rsplit("/", 1)[1] | ||
|
||
download_directory = os.path.join(download_directory, "MNIST") | ||
os.makedirs(download_directory, exist_ok=True) | ||
filepath = os.path.join(download_directory, basename) | ||
if not os.path.exists(filepath): | ||
logging.info("Downloading {} to {}".format(url, filepath)) | ||
|
||
r = requests.get(url, stream=True) | ||
with open(filepath, "wb") as f: | ||
for chunk in r.iter_content(chunk_size=8192): | ||
if chunk: | ||
f.write(chunk) | ||
|
||
shutil.unpack_archive(filepath, download_directory) | ||
|
||
return os.path.dirname(download_directory) |
199 changes: 199 additions & 0 deletions
199
examples/experimental/native/flexible_primitives_mnist_pytorch/model_def.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
""" | ||
This example shows how to interact with the Determined PyTorch interface to | ||
build a basic MNIST network. | ||
""" | ||
|
||
from typing import Any, Dict, Union, Sequence | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.optim.lr_scheduler import LambdaLR | ||
import torchvision | ||
from determined.tensorboard.metric_writers.pytorch import TorchWriter | ||
from determined.pytorch import PyTorchTrial, PyTorchTrialContext, DataLoader, LRScheduler | ||
|
||
import data | ||
|
||
TorchData = Union[Dict[str, torch.Tensor], Sequence[torch.Tensor], torch.Tensor] | ||
|
||
|
||
class Generator(nn.Module): | ||
def __init__(self, latent_dim, img_shape): | ||
super(Generator, self).__init__() | ||
self.img_shape = img_shape | ||
|
||
def block(in_feat, out_feat, normalize=True): | ||
layers = [nn.Linear(in_feat, out_feat)] | ||
if normalize: | ||
layers.append(nn.BatchNorm1d(out_feat, 0.8)) | ||
layers.append(nn.LeakyReLU(0.2, inplace=True)) | ||
return layers | ||
|
||
self.model = nn.Sequential( | ||
*block(latent_dim, 128, normalize=False), | ||
*block(128, 256), | ||
*block(256, 512), | ||
*block(512, 1024), | ||
nn.Linear(1024, int(np.prod(img_shape))), | ||
nn.Tanh() | ||
) | ||
|
||
def forward(self, z): | ||
img = self.model(z) | ||
img = img.view(img.size(0), *self.img_shape) | ||
return img | ||
|
||
|
||
class Discriminator(nn.Module): | ||
def __init__(self, img_shape): | ||
super(Discriminator, self).__init__() | ||
|
||
self.model = nn.Sequential( | ||
nn.Linear(int(np.prod(img_shape)), 512), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Linear(512, 256), | ||
nn.LeakyReLU(0.2, inplace=True), | ||
nn.Linear(256, 1), | ||
nn.Sigmoid(), | ||
) | ||
|
||
def forward(self, img): | ||
img_flat = img.view(img.size(0), -1) | ||
validity = self.model(img_flat) | ||
|
||
return validity | ||
|
||
|
||
class GAN(nn.Module): | ||
def __init__(self, trial_context: PyTorchTrialContext) -> None: | ||
super(GAN, self).__init__() | ||
|
||
self.context = trial_context | ||
|
||
mnist_shape = (1, 28, 28) | ||
self.generator = Generator(latent_dim=self.context.get_hparam("latent_dim"), img_shape=mnist_shape) | ||
self.discriminator = Discriminator(img_shape=mnist_shape) | ||
|
||
def forward(self, z): | ||
return self.generator(z) | ||
|
||
|
||
class MNistTrial(PyTorchTrial): | ||
def __init__(self, trial_context: PyTorchTrialContext) -> None: | ||
self.context = trial_context | ||
self.logger = TorchWriter() | ||
|
||
# Create a unique download directory for each rank so they don't overwrite each other. | ||
self.download_directory = f"/tmp/data-rank{self.context.distributed.get_rank()}" | ||
self.data_downloaded = False | ||
|
||
self.model = self.context._Model(GAN(trial_context)) | ||
|
||
lr = self.context.get_hparam("lr") | ||
b1 = self.context.get_hparam("b1") | ||
b2 = self.context.get_hparam("b2") | ||
|
||
self.opt_g = self.context._Optimizer(torch.optim.Adam(self.model.generator.parameters(), lr=lr, betas=(b1, b2))) | ||
self.opt_d = self.context._Optimizer(torch.optim.Adam(self.model.discriminator.parameters(), lr=lr, betas=(b1, b2))) | ||
|
||
# self.model, (self.opt_g, self.opt_d) = self.context._configure_apex_amp( | ||
# model=self.model, | ||
# optimizers=[self.opt_g, self.opt_d], | ||
# opt_level="O1", | ||
# ) | ||
|
||
self.lr_g = self.context._LRScheduler( | ||
lr_scheduler=LambdaLR(self.opt_g, lr_lambda=lambda epoch: 0.95 ** epoch), | ||
step_mode=LRScheduler.StepMode.STEP_EVERY_EPOCH, | ||
) | ||
|
||
def build_training_data_loader(self) -> DataLoader: | ||
if not self.data_downloaded: | ||
self.download_directory = data.download_dataset( | ||
download_directory=self.download_directory, | ||
data_config=self.context.get_data_config(), | ||
) | ||
self.data_downloaded = True | ||
|
||
train_data = data.get_dataset(self.download_directory, train=True) | ||
return DataLoader(train_data, batch_size=self.context.get_per_slot_batch_size()) | ||
|
||
def build_validation_data_loader(self) -> DataLoader: | ||
if not self.data_downloaded: | ||
self.download_directory = data.download_dataset( | ||
download_directory=self.download_directory, | ||
data_config=self.context.get_data_config(), | ||
) | ||
self.data_downloaded = True | ||
|
||
validation_data = data.get_dataset(self.download_directory, train=False) | ||
return DataLoader(validation_data, batch_size=self.context.get_per_slot_batch_size()) | ||
|
||
def train_batch( | ||
self, batch: TorchData, model: nn.Module, epoch_idx: int, batch_idx: int | ||
) -> Dict[str, torch.Tensor]: | ||
imgs, _ = batch | ||
|
||
# train generator | ||
# sample noise and match gpu device (or keep as cpu) | ||
z = torch.randn(imgs.shape[0], self.context.get_hparam("latent_dim")) | ||
z = self.context._to_device(z) | ||
|
||
# generate images | ||
generated_imgs = self.model.generator(z) | ||
|
||
# log sampled images to current directory | ||
sample_imgs = generated_imgs[:6] | ||
grid = torchvision.utils.make_grid(sample_imgs) | ||
self.logger.writer.add_image(f'generated_images_epoch_{epoch_idx}', grid, batch_idx) | ||
|
||
# ground truth result (ie: all fake) | ||
# put on GPU because we created this tensor inside training_loop | ||
valid = torch.ones(imgs.size(0), 1) | ||
valid = self.context._to_device(valid) | ||
|
||
# adversarial loss is binary cross-entropy | ||
g_loss = F.binary_cross_entropy(self.model.discriminator(generated_imgs), valid) | ||
|
||
self.context._backward(g_loss) | ||
self.context._step_optimizer(self.opt_g) | ||
|
||
|
||
# train discriminator | ||
# Measure discriminator's ability to classify real from generated samples | ||
|
||
# how well can it label as real? | ||
valid = torch.ones(imgs.size(0), 1) | ||
valid = self.context._to_device(valid) | ||
|
||
real_loss = F.binary_cross_entropy(self.model.discriminator(imgs), valid) | ||
|
||
# how well can it label as fake? | ||
fake = torch.zeros(imgs.size(0), 1) | ||
fake = self.context._to_device(fake) | ||
|
||
fake_loss = F.binary_cross_entropy( | ||
self.model.discriminator(generated_imgs.detach()), fake) | ||
|
||
# discriminator loss is the average of these | ||
d_loss = (real_loss + fake_loss) / 2 | ||
|
||
self.context._backward(d_loss) | ||
self.context._step_optimizer(self.opt_d) | ||
|
||
output = { | ||
'loss': d_loss, | ||
'g_loss': g_loss, | ||
'd_loss': d_loss, | ||
} | ||
return output | ||
|
||
def evaluate_batch(self, batch: TorchData, model: nn.Module) -> Dict[str, Any]: | ||
imgs, _ = batch | ||
valid = torch.ones(imgs.size(0), 1) | ||
valid = self.context._to_device(valid) | ||
loss = F.binary_cross_entropy(self.model.discriminator(imgs), valid) | ||
return {"loss": loss} | ||
|
53 changes: 53 additions & 0 deletions
53
examples/experimental/native/flexible_primitives_mnist_pytorch/trial_impl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
""" | ||
This example demonstrates training a simple DNN with pytorch using the Determined | ||
Native API. | ||
""" | ||
import argparse | ||
import json | ||
import pathlib | ||
|
||
from determined import experimental | ||
import determined as det | ||
|
||
import model_def | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--config", | ||
dest="config", | ||
help="Specifies Determined Experiment configuration.", | ||
default="{}", | ||
) | ||
parser.add_argument("--local", action="store_true", help="Specifies local mode") | ||
parser.add_argument("--test", action="store_true", help="Specifies test mode") | ||
args = parser.parse_args() | ||
|
||
config = { | ||
"data": { | ||
"url": "https://s3-us-west-2.amazonaws.com/determined-ai-test-data/pytorch_mnist.tar.gz" | ||
}, | ||
"hyperparameters": { | ||
"global_batch_size": 32, | ||
"lr": 0.0002, | ||
"b1": 0.5, | ||
"b2": 0.999, | ||
"latent_dim": 100 | ||
}, | ||
"searcher": { | ||
"name": "single", | ||
"metric": "loss", | ||
"max_steps": 40, | ||
"smaller_is_better": True, | ||
}, | ||
} | ||
config.update(json.loads(args.config)) | ||
|
||
experimental.create( | ||
trial_def=model_def.MNistTrial, | ||
config=config, | ||
local=args.local, | ||
test=args.test, | ||
context_dir=str(pathlib.Path.cwd()), | ||
) |