Skip to content

Commit

Permalink
add pt distrib test
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed May 16, 2020
1 parent 251cdf8 commit a8fc661
Showing 1 changed file with 53 additions and 24 deletions.
77 changes: 53 additions & 24 deletions tests/pytorch/test_distributed_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
import shutil
import time
from pathlib import Path

# Third Party
import numpy as nn
Expand Down Expand Up @@ -64,7 +65,15 @@ def train(model, device, optimizer, num_steps=10):
optimizer.step()


def run(rank, size, include_workers="one", num_epochs=10, batch_size=128, num_batches=10):
def run(
rank,
size,
include_workers="one",
num_epochs=10,
batch_size=128,
num_batches=10,
test_timeline_writer=False,
):
"""Distributed function to be implemented later."""
torch.manual_seed(1234)
device = torch.device("cpu")
Expand All @@ -84,15 +93,16 @@ def run(rank, size, include_workers="one", num_epochs=10, batch_size=128, num_ba

for epoch in range(num_epochs):
epoch_loss = 0.0
start_time = time.time()
hook._write_trace_event_summary(
training_phase="Training",
op_name="TrainingEpochStart",
phase="B",
timestamp=start_time,
rank=rank,
epoch=epoch,
)
if test_timeline_writer:
start_time = time.time()
hook._write_trace_event_summary(
training_phase="Training",
op_name="TrainingEpochStart",
phase="B",
timestamp=start_time,
rank=rank,
epoch=epoch,
)
for _ in range(num_batches):
optimizer.zero_grad()
data, target = dataset(batch_size)
Expand All @@ -102,16 +112,17 @@ def run(rank, size, include_workers="one", num_epochs=10, batch_size=128, num_ba
loss.backward()
average_gradients(model)
optimizer.step()
end_time = time.time()
hook._write_trace_event_summary(
training_phase="Training",
op_name="TrainingEpochEnd",
phase="E",
timestamp=end_time,
rank=rank,
duration=end_time - start_time,
epoch=epoch,
)
if test_timeline_writer:
end_time = time.time()
hook._write_trace_event_summary(
training_phase="Training",
op_name="TrainingEpochEnd",
phase="E",
timestamp=end_time,
rank=rank,
duration=end_time - start_time,
epoch=epoch,
)
# print(f"Rank {dist.get_rank()}, epoch {epoch}: {epoch_loss / num_batches}")

assert hook._get_worker_name() == f"worker_{dist.get_rank()}"
Expand All @@ -131,15 +142,15 @@ def average_gradients(model):
param.grad.data /= size


def init_processes(rank, size, include_workers, fn, backend="gloo"):
def init_processes(rank, size, include_workers, test_timeline_writer, fn, backend="gloo"):
"""Initialize the distributed environment."""
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size, include_workers)
fn(rank, size, include_workers, test_timeline_writer)


def _run_net_distributed(include_workers="one"):
def _run_net_distributed(include_workers="one", test_timeline_writer=False):
"""Runs a single linear layer on 2 processes."""
# torch.distributed is empty on Mac on Torch <= 1.2
if not hasattr(dist, "is_initialized"):
Expand All @@ -148,7 +159,9 @@ def _run_net_distributed(include_workers="one"):
size = 2
processes = []
for rank in range(size):
p = Process(target=init_processes, args=(rank, size, include_workers, run))
p = Process(
target=init_processes, args=(rank, size, include_workers, test_timeline_writer, run)
)
p.start()
processes.append(p)

Expand Down Expand Up @@ -199,3 +212,19 @@ def test_run_net_distributed_save_one_worker():
trial = _run_net_distributed(include_workers="one")
assert len(trial.workers()) == 1, f"trial.workers() = {trial.workers()}"
assert len(trial.steps()) == 3, f"trial.steps() = {trial.steps()}"


@pytest.mark.parametrize("include_workers", ["one", "all"])
def test_run_net_distributed_timeline_file_writer(include_workers):
num_workers = {"one": 1, "all": 2}
trial = _run_net_distributed(include_workers=include_workers, test_timeline_writer=True)
assert (
len(trial.workers()) == num_workers[include_workers]
), f"trial.workers() = {trial.workers()}"
assert len(trial.steps()) == 3, f"trial.steps() = {trial.steps()}"

files = []
for path in Path(out_dir + "/framework/pevents").rglob("*.json"):
files.append(path)

assert len(files) == num_workers[include_workers]

0 comments on commit a8fc661

Please sign in to comment.