Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

data/benchmarks #422

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Install dependencies

```
pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu113
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
python setup.py develop
```

# Usage instructions


```
usage: run_benchmark.py [-h] [--dataset DATASET] [--model_name MODEL_NAME] [--batch_size BATCH_SIZE]
[--num_epochs NUM_EPOCHS] [--report_location REPORT_LOCATION]
[--num_workers NUM_WORKERS] [--shuffle] [--dataloaderv DATALOADERV]
```

## Available metrics
* [x] Total time
* [x] Time per batch
* [x] Time per epoch
* [x] Precision over time
* [x] CPU Load
* [x] GPU Load
* [x] Memory usage

## Additional profiling

```
pip install scalene
msaroufim marked this conversation as resolved.
Show resolved Hide resolved
pip install torch-tb-profiler
```


`scalene run_benchmark.py`
128 changes: 128 additions & 0 deletions benchmarks/run_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import argparse
import torchvision
import torch
try:
import transformers
except:
pass
from torchvision.prototype.datasets import load
import torch.nn.functional as F
from torchvision import transforms
import time
from statistics import mean
import torch.optim as optim



parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="gtsrb", help="The name of the dataset")
parser.add_argument("--model_name", type=str, default="resnext50_32x4d", help="The name of the model")
parser.add_argument("--batch_size", type=int, default=1, help="")
parser.add_argument("--device", type=str, default="cuda:0", help="Options are are cpu or cuda:0")
parser.add_argument("--num_epochs", type=int, default=2)
parser.add_argument("--report_location", type=str, default="./report.md", help="The location where the generated report will be stored")
parser.add_argument("--num_workers", type=int, default=1, help="Number of dataloader workers")
parser.add_argument("--shuffle", action="store_true")
parser.add_argument("--dataloaderv", type=int, default=1)

args = parser.parse_args()
dataset = args.dataset
model_name = args.model_name
batch_size = args.batch_size
num_epochs = args.num_epochs
report_location = args.report_location
num_workers = args.num_workers
shuffle = args.shuffle
dataloaderv = args.dataloaderv

if dataloaderv == 1:
from torch.utils.data import DataLoader
elif dataloaderv == 2:
from torch.utils.data.dataloader_experimental import DataLoader2 as DataLoader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please avoid using this one, perhaps you need to create your own wrapper, which will use DLv2 from torchdata repo and automatically create MultiProcessingReadingService

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is up to @NivekT , if he wants to have this update as follow-up or within this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NivekT do you have more context, I'm not sure I follow what change is being requested

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think dataloader_experimental.DataLoader2 is bound to be removed eventually. The actual DataLoader2 is in torchdata.dataloader2.

I have an example in my benchmarking PR, if you're interested https://github.com/pytorch/vision/pull/6196/files#diff-32b42103e815b96c670a0b5f0db055fe63f10fc8776ccbb6aa9b61a6940abba0R207-R211

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I think using torchdata.dataloader2 with MultiProcessingReadingService is preferred.

else:
raise(f"dataloaderv{dataloaderv} is not a valid option")

# Util function for multiprocessing
def init_fn(worker_id):
info = torch.utils.data.get_worker_info()
num_workers = info.num_workers
datapipe = info.dataset
torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id)
NivekT marked this conversation as resolved.
Show resolved Hide resolved

# Download model
model_map = {
"resnext50_32x4d": torchvision.models.resnext50_32x4d,
"mobilenet_v3_large" : torchvision.models.mobilenet_v3_large,
"transformerencoder" : torch.nn.TransformerEncoder,
# "bert-base" : transformers.BertModel,

}

model = model_map[model_name]().to(torch.device("cuda:0"))

# setup data pipe
dp = load(dataset, split="train")
print(f"batch size {batch_size}")
print(f"Dataset name {dp}")
print(f"Dataset length {len(dp)}")

# Datapipe format
print(f"data format is {next(iter(dp))}")

# Setup data loader
if num_workers == 1:
dl = DataLoader(dataset=dp, batch_size=batch_size, shuffle=shuffle)

# Shuffle won't work in distributed yet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shuffle and sharding won't work out of the box with DDP. There are some suggestions here pytorch/text#1755, but no definite recommended practices yet

else:
dl = DataLoader(dataset=dp, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=init_fn, multiprocessing_context="spawn")


# TODO: Add measurements time per batch, per epoch and total time here

total_start = time.time()
per_epoch_durations = []
batch_durations = []

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for epoch in range(num_epochs):
epoch_start = time.time()
running_loss = 0
for i, elem in enumerate(dl):
batch_start = time.time()
# Should image preprocessing be done online or offline?
# This is all image specific, need to refactor this out or create a training loop per model/dataset combo
input_image = torch.unsqueeze(elem["image"], 0).to(torch.device("cuda:0"))
input_image = transforms.Resize(size=(96,98))(input_image)
input_image = input_image.reshape(64,3,7,7) / 255

labels = elem["label"].to(torch.device("cuda:0"))

# TODO: remove this is wrong
labels = labels.repeat(64)
optimizer.zero_grad()

outputs = model(input_image)

# ValueError: Expected input batch_size (64) to match target batch_size (1).
loss = criterion(outputs,labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
# if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0

batch_end = time.time()
batch_duration = batch_end - batch_start
batch_durations.append(batch_duration)
epoch_end = time.time()
epoch_duration = epoch_end - epoch_start
per_epoch_durations.append(epoch_duration)
total_end = time.time()
total_duration = total_end - total_start

print(f"Total duration is {total_duration}")
print(f"Per epoch duration {mean(per_epoch_durations)}")
print(f"Per batch duration {mean(batch_durations)}")