Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] zero optimizer w/ tensor parallel test #167

Merged
merged 5 commits into from
Mar 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,18 @@
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import oslo
from oslo.torch.distributed.parallel_context import ParallelContext
from oslo.torch.utils import get_free_port, set_seed
from oslo.torch.nn.parallel.data_parallel.zero import ZeroRedundancyOptimizer
from torch.testing import assert_close
from oslo.torch.nn.parallel import TensorParallel
from transformers import AutoModelForSequenceClassification, AutoTokenizer

skip_if_dist_unavailable = pytest.mark.skipif(
torch.cuda.device_count() < 2, reason="dist required"
)


class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.linear2 = nn.Linear(256, 512)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x


def assert_shard_close(
tensor: torch.Tensor,
shard: torch.Tensor,
Expand All @@ -40,14 +29,14 @@ def assert_shard_close(
):
assert tensor.ndim == shard.ndim
if tensor.shape == shard.shape:
return assert_close(tensor, shard, rtol=rtol, atol=atol)
return torch.allclose(tensor, shard, rtol=rtol, atol=atol)
else:
dims_not_eq = torch.nonzero(
torch.tensor(tensor.shape) != torch.tensor(shard.shape)
)
if dims_not_eq.numel() == 1:
dim = dims_not_eq.item()
return assert_close(
return torch.allclose(
tensor.chunk(world_size, dim)[rank], shard, rtol=rtol, atol=atol
)
else:
Expand All @@ -58,46 +47,53 @@ def run(parallel_context: ParallelContext):
local_rank = torch.distributed.get_rank()

# create model
model = MlpModel().cuda()
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
hybrid_model = TensorParallel(
copy.deepcopy(model), parallel_context=parallel_context
)
zero_model = model
oslo.ready(hybrid_model, parallel_context)
zero_model = model.cuda()

# create optimizer
hybrid_optimizer = ZeroRedundancyOptimizer(
torch.optim.Adam(hybrid_model.parameters(), lr=1),
torch.optim.Adam(hybrid_model.parameters(), lr=1e-2),
parallel_context=parallel_context,
overlap_communication=True,
partition_grad=True,
)
zero_optimizer = ZeroRedundancyOptimizer(
torch.optim.Adam(zero_model.parameters(), lr=1),
torch.optim.Adam(zero_model.parameters(), lr=1e-2),
parallel_context=parallel_context,
overlap_communication=True,
)

# create tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# create data
set_seed(2021 + local_rank)
input_data = torch.randn(32, 128).cuda()
input_text = ["This is a sample text."] * 32
inputs = tokenizer(
input_text, return_tensors="pt", padding=True, truncation=True
).to("cuda")
labels = torch.randint(0, model.config.num_labels, (32,)).long().cuda()

# zero-dp forward
hybrid_output = hybrid_model(input_data)
zero_output = zero_model(input_data)
hybrid_output = hybrid_model(**inputs, labels=labels).loss
zero_output = zero_model(**inputs, labels=labels).loss

assert torch.allclose(hybrid_output, zero_output)

# zero-dp backward
hybrid_output.sum().float().backward()
zero_output.sum().float().backward()
hybrid_output.backward()
zero_output.backward()

# step
hybrid_optimizer.step()
zero_optimizer.step()

# check updated param
for hp, zp in zip(hybrid_model.parameters(), zero_model.parameters()):
assert torch.allclose(hp.data, zp.data)
assert assert_shard_close(
zp.data, hp.data, local_rank, torch.distributed.get_world_size()
)


def run_dist(rank, world_size):
Expand Down