Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPMD] Enable GPU CI for Distributed Tensor #333

Merged
merged 72 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
7e41a7c
Enable GPU test for Distributed Tensor
fduwjj Aug 9, 2022
a70c626
change test specs
fduwjj Aug 9, 2022
a49c086
Fix errors
fduwjj Aug 9, 2022
0c37fbe
Shard test
fduwjj Aug 9, 2022
715771b
Install pytest-shard
fduwjj Aug 9, 2022
03e3f8a
Remove relative import and add new e2e test
fduwjj Aug 9, 2022
8175fd6
Merge with main
fduwjj Aug 9, 2022
3b05a48
Reformat
fduwjj Aug 9, 2022
77b8e99
Patch fix and retest
fduwjj Aug 9, 2022
a42dce2
Fix formart
fduwjj Aug 10, 2022
3b85bdd
Fix CI test
fduwjj Aug 10, 2022
f485d14
Fix CI Python version
fduwjj Aug 10, 2022
d0c54ca
Docker image update
fduwjj Aug 10, 2022
5819ab0
Revert docker change and change code instead
fduwjj Aug 10, 2022
01eb5cb
test
fduwjj Aug 10, 2022
600b108
Merge branch 'main' into enable_gpu_test
fduwjj Aug 10, 2022
2d4e779
fix test
fduwjj Aug 10, 2022
d616ccd
CI test
fduwjj Aug 10, 2022
49e31b8
remove pippy install
fduwjj Aug 10, 2022
6316f8d
Fix linter
fduwjj Aug 10, 2022
1641b03
merge with main
fduwjj Aug 10, 2022
278db1f
format
fduwjj Aug 10, 2022
efd640c
remove unnecessary change
fduwjj Aug 10, 2022
acd047e
revert test chage
fduwjj Aug 10, 2022
302be98
Split commit
fduwjj Aug 10, 2022
560e23f
continue revert
fduwjj Aug 10, 2022
7521575
revert all test related change
fduwjj Aug 10, 2022
5e0b372
Merge branch 'main' into enable_gpu_test
fduwjj Aug 11, 2022
b830340
Merge with main
fduwjj Aug 11, 2022
4be00a2
Merge branch 'main' into enable_gpu_test
fduwjj Aug 14, 2022
4b5d3d5
Format
fduwjj Aug 14, 2022
d2c5ed0
Add back pytest
fduwjj Aug 14, 2022
4022b09
fix
fduwjj Aug 14, 2022
4271245
Merge with main
fduwjj Aug 17, 2022
a75d361
Change name
fduwjj Aug 17, 2022
d6ca983
Merge branch 'main' into enable_gpu_test
fduwjj Aug 20, 2022
d82dd2a
Reformat
fduwjj Aug 22, 2022
e1d95a5
Merge branch 'main' into enable_gpu_test
fduwjj Aug 22, 2022
aeb6630
update nvidia driver version
fduwjj Aug 22, 2022
09fe4e3
Change driver version
fduwjj Aug 22, 2022
2e62a9f
Change cuda version
fduwjj Aug 22, 2022
448f35c
Update docker image and pytorch version
fduwjj Aug 23, 2022
b19c472
Update docker
fduwjj Aug 23, 2022
5d398fe
Merge branch 'main' into enable_gpu_test
fduwjj Aug 23, 2022
9d1cc29
Narrow down to only one test
fduwjj Aug 24, 2022
5b21cd0
debug
fduwjj Aug 24, 2022
a926292
debug 2
fduwjj Aug 24, 2022
3014d4b
debug 3
fduwjj Aug 24, 2022
cd15563
debug 4
fduwjj Aug 24, 2022
7b7b577
debug 5
fduwjj Aug 24, 2022
b2e1364
debug 5
fduwjj Aug 24, 2022
e3555f8
debug 7
fduwjj Aug 24, 2022
6e58e43
debug 8
fduwjj Aug 24, 2022
4a6868b
add ssh to CI machine
fduwjj Aug 25, 2022
49e4dd3
Fix machine cleaning up part
fduwjj Aug 25, 2022
38087d6
fix CI
fduwjj Aug 25, 2022
0ee4cdf
Fix script
fduwjj Aug 26, 2022
3acacd5
Change permission of file
fduwjj Aug 26, 2022
01415a5
Use new CI machines
fduwjj Aug 26, 2022
a6993b9
Use AWS EC2 p4 machine
fduwjj Aug 26, 2022
a3ff713
Merge with main
fduwjj Aug 30, 2022
6b767ba
Update machine
fduwjj Aug 30, 2022
be3f1e4
Add share memory config
fduwjj Aug 30, 2022
6d802eb
Comment out remove program
fduwjj Aug 30, 2022
02c9c02
Update command
fduwjj Aug 31, 2022
bbc9740
Reformat and skip test_dtensor_op
fduwjj Aug 31, 2022
7c734a2
Make Linter happy
fduwjj Aug 31, 2022
43778cb
Merge branch 'main' into enable_gpu_test
fduwjj Aug 31, 2022
a17d2a3
Comment out failing test for CI
fduwjj Aug 31, 2022
c69a4f0
reformat
fduwjj Aug 31, 2022
6366548
Refresh CI
fduwjj Aug 31, 2022
cc00ae4
Fix linter
fduwjj Aug 31, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/workflows/spmd_gpu_tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#!/bin/bash

set -x

# Print test options
echo "VERBOSE: ${VERBOSE}"

nvidia-smi
nvcc --version
which python3
python3 --version
which pip3
pip3 --version

# Install git
apt-get update
apt-get install git -y

# Install dependencies
# Turn off progress bar to save logs
pip3 install --upgrade pip
pip3 config set global.progress_bar off
pip3 install flake8 pytest pytest-cov pytest-shard numpy expecttest
if [ -f requirements.txt ]; then pip3 install -r requirements.txt --find-links https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html; fi

# Install pippy
python3 smpd/setup.py install
fduwjj marked this conversation as resolved.
Show resolved Hide resolved

# Run all integration tests
python3 test/spmd/tensor/test_megatron_example.py
python3 test/spmd/tensor/test_ddp.py
python3 test/spmd/tensor/test_tp_sharding_ops.py
66 changes: 66 additions & 0 deletions .github/workflows/spmd_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,69 @@ jobs:
- name: Test with pytest
run: |
pytest --cov=spmd test/spmd/

pytest_tests_gpu:
runs-on: linux.16xlarge.nvidia.gpu
strategy:
matrix:
num-gpus: ["4"]
env:
DOCKER_IMAGE: qts8n/cuda-python:devel
PIPPY_ROOT: /PiPPy
VERBOSE: "0"
OMP_NUM_THREADS: "1"

steps:
- name: Clean working directory
shell: bash
run: |
sudo rm -rf /home/ec2-user/actions-runner/_work/PiPPy/PiPPy/* || true
- uses: actions/checkout@v2
- name: Clean up previous CUDA driver installations
shell: bash
run: |
set -x
yum list installed | grep nvidia || true
yum list installed | grep cuda || true
sudo yum remove -y cuda || true
sudo yum remove -y cuda-drivers || true
sudo yum remove -y "*nvidia*" || true
- name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG
run: |
bash .github/workflows/install_nvidia_utils_linux.sh || true
echo "GPU_FLAG=--gpus all" >> "${GITHUB_ENV}"
- name: Pull Docker image
run: |
retry () {
"$@" || (sleep 1 && "$@") || (sleep 2 && "$@")
}
retry docker pull "${DOCKER_IMAGE}"
- name: Test docker run
run: |
set -x
# shellcheck disable=SC2086,SC2090
container_name=$(docker run \
--gpus all \
-e VERBOSE \
-e OMP_NUM_THREADS \
--tty \
--detach \
-v "$(pwd):${PIPPY_ROOT}" \
-w "${PIPPY_ROOT}" \
"${DOCKER_IMAGE}"
)
# Run GPU tests and return error signal from docker
docker exec -t -w "${PIPPY_ROOT}" "${container_name}" bash -c "bash .github/workflows/spmd_gpu_tests.sh; exit $?"
fduwjj marked this conversation as resolved.
Show resolved Hide resolved
- name: Chown workspace
if: always()
run: |
# Ensure the working directory gets chowned back to the current user
docker run --rm -v "$(pwd):${PIPPY_ROOT}" -w "${PIPPY_ROOT}" "${DOCKER_IMAGE}" chown -R "$(id -u):$(id -g)" .
- name: Kill containers, clean up images
if: always()
run: |
# ignore expansion of "docker ps -q" since it could be empty
# shellcheck disable=SC2046
docker stop $(docker ps -q) || true
# Prune all of the docker images
docker system prune -af
4 changes: 4 additions & 0 deletions spmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ def distribute_tensor(
scatter_shape = list(tensor.size())
scatter_shape[shard_dim] = chunk_size
local_tensor = device_mesh.scatter(tensor_list, mesh_dim=idx)
# scatter call could not return a tensor with correct requires_grad
# field, as ProcessGroupNCCL refuse to take a tensor with requires_grad
# to do inplace update! So we manually set it here
local_tensor.requires_grad_(tensor.requires_grad)
dist_tensor = DTensor(
local_tensor,
device_mesh,
Expand Down
4 changes: 3 additions & 1 deletion spmd/tensor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def backward(ctx, grad_output: torch.Tensor): # type: ignore

class FromTorchTensor(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, device_mesh, placements, run_check): # type: ignore
def forward(
ctx, input: torch.Tensor, device_mesh, placements, run_check
): # type: ignore
ctx.previous_placement = placements
ctx.previous_device_mesh = device_mesh

Expand Down
7 changes: 2 additions & 5 deletions spmd/tensor/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ def __init__(
# pg or not, it's required that all ranks participate
# in subgroup construction
new_subgroup = new_group(
ranks=subgroup_ranks,
backend=backend_name,
ranks=subgroup_ranks, backend=backend_name
)
# only add to dim_groups if the current rank in the subgroup
if self.get_rank() in subgroup_ranks:
Expand Down Expand Up @@ -240,9 +239,7 @@ def scatter(
src_for_dim = 0
if dim_group is not GroupMember.WORLD:
src_for_dim = _get_global_rank(dim_group, 0)
tensor = torch.empty_like(
to_scatter[0], requires_grad=to_scatter[0].requires_grad
)
tensor = torch.empty_like(to_scatter[0])
if src_for_dim == get_rank():
scatter(
tensor,
Expand Down
2 changes: 1 addition & 1 deletion spmd/tensor/ops/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _dist_dropout(
raise RuntimeError("Not supported!")
else:
local_tensor, mask = torch.ops.aten.native_dropout(
self.to_local(), p=p, train=train
self._local_tensor, p=p, train=train
)
return (
DTensor(
Expand Down
8 changes: 2 additions & 6 deletions spmd/tensor/ops/matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@
# implement matrix related ops for distributed tensor
from typing import Optional
from spmd.tensor.dispatch import OpSchema
from spmd.tensor.placement_types import (
PlacementSpec,
)
from spmd.tensor.placement_types import PlacementSpec
from spmd.tensor.ops.prop_rules import einop_prop, mm_prop, pointwise_prop
from spmd.tensor.ops.utils import (
register_prop_rule,
)
from spmd.tensor.ops.utils import register_prop_rule


@register_prop_rule("aten.mm.default")
Expand Down
4 changes: 1 addition & 3 deletions spmd/tensor/ops/pointwise_ops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Optional
from spmd.tensor.dispatch import OpSchema
from spmd.tensor.placement_types import (
PlacementSpec,
)
from spmd.tensor.placement_types import PlacementSpec
from spmd.tensor.ops.prop_rules import pointwise_prop

# leave the pointwise_ops list here for convenience,
Expand Down
4 changes: 1 addition & 3 deletions spmd/tensor/ops/tp_sharding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from typing import List
from spmd.tensor.api import DTensor
from spmd.tensor.placement_types import Shard
from spmd.tensor.utils import (
unwrap_local_tensor,
)
from spmd.tensor.utils import unwrap_local_tensor
from spmd.tensor.ops.utils import unwrap_single_placement, register_impl

"""
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion test/spmd/tensor/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.nn as nn
from torch.testing._internal.common_utils import run_tests
from ..test_utils import DistTensorTestBase, with_comms
from spmd.test._utils import DistTensorTestBase, with_comms
from spmd import (
distribute_tensor,
distribute_module,
Expand Down
2 changes: 1 addition & 1 deletion test/spmd/tensor/test_device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
_get_global_rank,
)
from torch.testing._internal.common_utils import run_tests
from ..test_utils import DistTensorTestBase, with_comms
from spmd.test._utils import DistTensorTestBase, with_comms
from spmd.tensor import DeviceMesh, DTensor, Shard, Replicate


Expand Down
2 changes: 1 addition & 1 deletion test/spmd/tensor/test_math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from spmd.tensor.ops.prop_rules import einop_prop
from spmd.tensor.placement_types import PlacementSpec
from ..test_utils import DistTensorTestBase, with_comms
from spmd.test._utils import DistTensorTestBase, with_comms
from spmd import distribute_tensor, DeviceMesh, Shard, Replicate


Expand Down
2 changes: 1 addition & 1 deletion test/spmd/tensor/test_matrix_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from torch.testing._internal.common_utils import run_tests
from ..test_utils import DistTensorTestBase, with_comms
from spmd.test._utils import DistTensorTestBase, with_comms
from spmd import distribute_tensor, DeviceMesh, Shard, Replicate


Expand Down
114 changes: 114 additions & 0 deletions test/spmd/tensor/test_megatron_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
import torch.distributed as dist
import functools
from torch.testing._internal.common_utils import run_tests
from spmd.test._utils import DistTensorTestBase, with_comms
from spmd import distribute_tensor, DeviceMesh, DTensor, Shard, Replicate


class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.net1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(16, 12)

def forward(self, x):
return self.net2(self.relu(self.net1(x)))


def _aggregate_local_tensor(module: torch.nn.Module) -> torch.nn.Module:
def hook_func(_module, _input, output):
if isinstance(output, DTensor):
replica_placement = [Replicate()]
return (
output.redistribute(output.device_mesh, replica_placement)
.contiguous()
.local_tensor()
)

module.register_forward_hook(hook_func)
return module


def _replicate_input_tensor(
module: torch.nn.Module, device_mesh, replica_placement
) -> torch.nn.Module:
def hook_func(_, input):
if not isinstance(input[0], DTensor):
return DTensor.from_local(input[0], device_mesh, replica_placement)

module.register_forward_pre_hook(hook_func)
return module


def _gradient_hook(param, grad):
param._local_tensor.grad = grad._local_tensor


def shard_module(m, device_type):
pg = dist.distributed_c10d._get_default_group()
start_idx = 0
device_mesh = DeviceMesh(
device_type,
list(range(start_idx, start_idx + pg.size())),
dim_groups=[pg],
)
col_wise_sharding = [Shard(0)]
row_wise_sharding = [Shard(1)]
replicate = [Replicate()]
m.net1.weight = torch.nn.Parameter(
distribute_tensor(m.net1.weight, device_mesh, col_wise_sharding)
)
m.net2.weight = torch.nn.Parameter(
distribute_tensor(m.net2.weight, device_mesh, row_wise_sharding)
)
m.net1.bias = torch.nn.Parameter(
distribute_tensor(m.net1.bias, device_mesh, col_wise_sharding)
)
m.net2.bias = torch.nn.Parameter(
distribute_tensor(m.net2.bias, device_mesh, replicate)
)
m = _replicate_input_tensor(m, device_mesh, replicate)
m.net2 = _aggregate_local_tensor(m.net2)
m.net1.weight.register_hook(
functools.partial(_gradient_hook, m.net1.weight)
)


class DistTensorMegatronTest(DistTensorTestBase):
@with_comms
def test_simple_megatron_e2e(self):
LR = 0.5
inp_size = [5, 10]
torch.manual_seed(0)
inp = torch.rand(*inp_size, device=self.device_type)
torch.manual_seed(5)
model = SimpleModel()
torch.manual_seed(5)
model_tp = SimpleModel()
shard_module(model_tp, self.device_type)

output = model(inp)
output_tp = model_tp(inp)
self.assertEqual(output, output_tp)

output.sum().backward()
output_tp.sum().backward()
# self.assertTrue(model_tp.net1.weight.local_tensor().grad is not None)

optim = torch.optim.SGD(model.parameters(), lr=LR)
optim.step()
optim = torch.optim.SGD(model_tp.parameters(), lr=LR)
optim.step()

torch.manual_seed(3)
inp = torch.rand(*inp_size).cuda(self.rank)
output = model(inp)
output_tp = model_tp(inp)
self.assertEqual(output, output_tp)


if __name__ == "__main__":
run_tests()
8 changes: 5 additions & 3 deletions test/spmd/tensor/test_pointwise_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from torch.testing._internal.common_utils import run_tests
from ..test_utils import DistTensorTestBase, with_comms, TEST_GPU_NUM
from spmd.test._utils import DistTensorTestBase, with_comms, TEST_GPU_NUM
from spmd import DeviceMesh, DTensor, Shard, Replicate, _Partial
from torch.distributed.distributed_c10d import ReduceOp
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
Expand All @@ -13,8 +13,10 @@ def _run_sharded_elementwise_ops(
self, mesh, spec, input_size, op, reset_seed=None, **kwargs
):
torch.manual_seed(self.rank)
input_tensor = torch.randn(*input_size, requires_grad=True)
dist_tensor = DTensor.from_local(input_tensor, mesh, spec)
input_tensor = torch.randn(
*input_size, device=self.device_type, requires_grad=True
)
dist_tensor = DTensor(input_tensor, mesh, spec)
reset_seed() if reset_seed else None
dt = op(dist_tensor, **kwargs)
reset_seed() if reset_seed else None
Expand Down
2 changes: 1 addition & 1 deletion test/spmd/tensor/test_redistribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from torch.testing._internal.common_utils import run_tests

from ..test_utils import DistTensorTestBase, with_comms
from spmd.test._utils import DistTensorTestBase, with_comms
from spmd.tensor import DeviceMesh, DTensor, Replicate, Shard, _Partial


Expand Down
2 changes: 1 addition & 1 deletion test/spmd/tensor/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.distributed.distributed_c10d import ReduceOp

from torch.testing._internal.common_utils import run_tests
from ..test_utils import DistTensorTestBase, with_comms
from spmd.test._utils import DistTensorTestBase, with_comms
from spmd.tensor import DeviceMesh, DTensor, Replicate, Shard, _Partial


Expand Down
2 changes: 1 addition & 1 deletion test/spmd/tensor/test_tensor_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import torch
from torch.testing._internal.common_utils import run_tests
from ..test_utils import DistTensorTestBase, with_comms
from spmd.test._utils import DistTensorTestBase, with_comms
from spmd import distribute_tensor, DeviceMesh, DTensor, Shard, Replicate


Expand Down
Loading