Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] zero optimizer w/ tensor parallel test #167

Merged
merged 5 commits into from
Mar 28, 2023

Conversation

yhna940
Copy link
Contributor

@yhna940 yhna940 commented Mar 27, 2023

Title

  • [Fix] zero optimizer w/ tensor parallel test

Description

  • ZeRO was not running in tensor parallel mode, so I fixed this by switching to a model from transformers.

Linked Issues

  • N/A

@yhna940 yhna940 added the ZeRO ZeroRedundancyOptimizer label Mar 27, 2023
@yhna940 yhna940 requested a review from KKIEEK March 27, 2023 14:31
@yhna940 yhna940 self-assigned this Mar 27, 2023
@KKIEEK
Copy link
Contributor

KKIEEK commented Mar 27, 2023

I think you need to invoke parallelize function or oslo.ready for using TensorParallel.

@hyunwoongko
Copy link
Member

You don't need to move full model to GPU.
Just use oslo.ready(model, pc).

@KKIEEK
Copy link
Contributor

KKIEEK commented Mar 28, 2023

Unfortunately, it still doesn't work because of these line.

E       torch.multiprocessing.spawn.ProcessRaisedException: 
E       
E       -- Process 0 terminated with the following error:
E       Traceback (most recent call last):
E         File "/admin/home/.local/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
E           fn(i, *args)
E         File "/admin/home/vit/tests_deprecated/torch/nn/parallel/data_parallel/zero/test_hybrid.py", line 103, in run_dist
E           run(parallel_context)
E         File "/admin/home/vit/tests_deprecated/torch/nn/parallel/data_parallel/zero/test_hybrid.py", line 89, in run
E           hybrid_optimizer.step()
E         File "/admin/home/vit/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/sharded_optim.py", line 581, in step
E           norm_group = compute_norm(
E         File "/admin/home/vit/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/_utils.py", line 230, in compute_norm
E           if is_model_parallel_parameter(p) or mp_rank == 0:
E         File "/admin/home/vit/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/_utils.py", line 39, in is_model_parallel_parameter
E           return ParallelMode.PIPELINE in parallel_mode or any(
E         File "/admin/home/vit/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/_utils.py", line 40, in <genexpr>
E           key.startswith("tensor") for key in parallel_mode
E       AttributeError: 'ParallelMode' object has no attribute 'startswith'
>>> from oslo.torch.distributed import ParallelMode
>>> ParallelMode.TENSOR_1D.startswith('tensor')
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: 'ParallelMode' object has no attribute 'startswith'

@yhna940
Copy link
Contributor Author

yhna940 commented Mar 28, 2023

Unfortunately, it still doesn't work because of these line.

E       torch.multiprocessing.spawn.ProcessRaisedException: 
E       
E       -- Process 0 terminated with the following error:
E       Traceback (most recent call last):
E         File "/admin/home/.local/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
E           fn(i, *args)
E         File "/admin/home/vit/tests_deprecated/torch/nn/parallel/data_parallel/zero/test_hybrid.py", line 103, in run_dist
E           run(parallel_context)
E         File "/admin/home/vit/tests_deprecated/torch/nn/parallel/data_parallel/zero/test_hybrid.py", line 89, in run
E           hybrid_optimizer.step()
E         File "/admin/home/vit/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/sharded_optim.py", line 581, in step
E           norm_group = compute_norm(
E         File "/admin/home/vit/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/_utils.py", line 230, in compute_norm
E           if is_model_parallel_parameter(p) or mp_rank == 0:
E         File "/admin/home/vit/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/_utils.py", line 39, in is_model_parallel_parameter
E           return ParallelMode.PIPELINE in parallel_mode or any(
E         File "/admin/home/vit/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/_utils.py", line 40, in <genexpr>
E           key.startswith("tensor") for key in parallel_mode
E       AttributeError: 'ParallelMode' object has no attribute 'startswith'
>>> from oslo.torch.distributed import ParallelMode
>>> ParallelMode.TENSOR_1D.startswith('tensor')
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
AttributeError: 'ParallelMode' object has no attribute 'startswith'

There was a small bug and it was fixed.
https://github.com/EleutherAI/oslo/pull/164/files

@yhna940 yhna940 merged commit 3357dac into EleutherAI:main Mar 28, 2023
Copy link

@l4d2boomer l4d2boomer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

yhna940 added a commit to yhna940/oslo that referenced this pull request Apr 18, 2023
* import ParallelMode (EleutherAI#166)

## fix typo on tensor parallel tutorial

- `from oslo import ParallelContext, ParallelMode`

* [Fix] zero param check (EleutherAI#164)

## Title

- [Fix] zero param check

## Description

- ZeRO checks the redundancy of parameters to calculate the norm. There
is a minor bug in checking the TP and needs to be fixed.

## Linked Issues

- N/A

* [Fix] zero optimizer w/ tensor parallel test (EleutherAI#167)

## Title

- [Fix] zero optimizer w/ tensor parallel test

## Description

- ZeRO was not running in tensor parallel mode, so I fixed this by
switching to a model from `transformers`.

## Linked Issues

- N/A

* Add restarting model from saved model and fix bug (EleutherAI#171)

## Description

- load a model
- start training again from a saved point
- fix bug that training_arg not saved with nccl error. It was because of
parallel_context, and it was removed before saving training_arg and
re-attached again
- test load and restart with oslo TP

* Make decoder-only models to be able to generate with `inputs_embeds` (EleutherAI#172)

## Title
Make decoder-only models to be able to generate with `inputs_embeds`

## Description
Synchronize GPT2 code with Hugging Face transformers—GPT2 can generate
with `input_embeds`.

>Accepting `.generate()` calls with `inputs_embeds` on decoder-only
models is a long-standing request
(huggingface/transformers#6535) -- see
huggingface/transformers#6535 (comment)
particular and its reacts.
>
>It has to be added on a per-model basis, and this PR adds the necessary
changes for GPT2. Other models will throw an informative exception if
the user passes `inputs_embeds`, asking them to check this PR and
implement the same pattern on the model they want to use it with 🤗
>
>Please note that it is still expected that the user passes `input_ids`,
i.e.

```python
outputs = model.generate(input_ids, inputs_embeds=inputs_embeds)
```

>This is because decoder-only models expect the prompt to be present in
the output, and this is the only way to preserve it! input_ids can also
be omitted and, in that case, the output won't contain the prompt.

For more details, please check out [this
PR](huggingface/transformers#21405).

* Wrong import in zero (EleutherAI#169)

## Title

Prevent from using torch 2.0

## Description

- Some of feature have changed in torch 2.0. and oslo has dependency on
torch._six which no longer support by torch 2.0.

olso Dependency
-
https://github.com/EleutherAI/oslo/blob/910c789e7f46d2876b964c221d31984b7924974f/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/_utils.py#L19

other issues
- microsoft/DeepSpeed#2845

## Linked Issues

- resolved #00

* [Fix] Support gradient accumulation for DDP (EleutherAI#173)

## Description

In order to support gradient accumulation, I removed `free_storage`
function that can cause `CUDA error: an illegal memory access was
encountered` in many case. (but this change may lead to an increase in
memory consumption)
What do you guys think about this PR? @nijkah @jinwonkim93

* [Fix] minor bug for single output in _DistributedDataParallel (EleutherAI#177)

## Title

- Fix minor bug for single output in _DistributedDataParallel

## Description

- This PR addresses a minor bug in the `_DistributedDataParallel` class
when handling single output tensors. The changes include:

1. Update the `forward` method in `_DistributedDataParallel` to
correctly handle single output tensors.
2. Add new test cases in
`tests_deprecated/torch/nn/parallel/data_parallel/data_parallel.py` to
ensure the correct behavior for models with various output types (single
tensor, multiple tensors, and dictionary of tensors).

These updates will ensure that the `_DistributedDataParallel` class
works correctly with various output types, providing a more robust
solution for users.

## Linked Issues

- N/A

* [Enhance] Support ViT for TensorParallel (EleutherAI#155)

## Description

I added support for ViT in TensorParallel by appending config to
`_TensorParallelMapping`.
`PatchEmbed` layer in ViT does not have the `weight` parameter unlike
`Embedding` layer, so I replaced the `weight` parameter with a dummy
value to prevent an `AttributeError`.

Any feedback is welcome.

### Memory usage
mode | world_size=1 | world_size=2 | world_size=4 | world_size=8
-|-|-|-|-
1D | 1760MiB | 1126MiB | 789MiB |
2D | | | 589MiB |
2.5D (d=1) | | | 589MiB |
2.5D (d=2) | | | | 586MiB
3D | | | |

### TODO
- [ ] Benchmark with `world_size=8`
- [ ] Refactor slicing patch embedding
- [ ] Fix slicing logic to return the same value as `TensorParallel1D`

<details><summary>code for testing</summary>
<p>

```python
import os
import torch.multiprocessing as mp

import torch
from torch import nn
from torch import optim
import torch.distributed as dist
from transformers import ViTModel, ViTForImageClassification, ViTConfig

import oslo
from oslo.torch.distributed.parallel_context import ParallelContext
from oslo.torch.distributed.parallel_mode import ParallelMode
from oslo.torch.nn.parallel import TensorParallel


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12340"
    os.environ["RANK"] = str(rank)
    os.environ["LOCAL_RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["LOCAL_WORLD_SIZE"] = str(world_size)


def cleanup():
    dist.destroy_process_group()


def train(rank, world_size):
    print(f"Running oslo TP example on rank {rank}.")
    setup(rank, world_size)
    parallel_context = ParallelContext.from_torch(
        tensor_parallel_size=world_size,
        tensor_parallel_mode=ParallelMode.TENSOR_1D,
    )  # TENSOR2D or TENSOR_2P5D

    model = ViTForImageClassification(ViTConfig(num_labels=1000)).to(rank)
    model = TensorParallel(model, parallel_context)
    optimizer = optim.SGD(model.parameters(), lr=1e-4)
    loss_fn = nn.MSELoss()

    oslo.ready(model, parallel_context)

    for _ in range(100):
        model.zero_grad()
        logits = model(pixel_values=torch.ones(8, 3, 224, 224).to(rank)).logits
        labels = torch.ones(8, 1000).to(rank) * 100
        loss = loss_fn(logits, labels)
        loss.backward()
        optimizer.step()
        print(logits)
        print(torch.cuda.max_memory_allocated() / 1024**2)  # MB

    cleanup()


def main(world_size):
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main(4)
```

</p>
</details> 

## Linked Issues

Related to EleutherAI#152

---------

Co-authored-by: Minho Ryu <[email protected]>
Co-authored-by: Hansol Park <[email protected]>
Co-authored-by: Ingyu Seong <[email protected]>
Co-authored-by: whooray <[email protected]>
Co-authored-by: Junhwa Song <[email protected]>
dyanos pushed a commit that referenced this pull request Jun 8, 2023
## Title

- [Fix] zero optimizer w/ tensor parallel test

## Description

- ZeRO was not running in tensor parallel mode, so I fixed this by
switching to a model from `transformers`.

## Linked Issues

- N/A
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ZeRO ZeroRedundancyOptimizer
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants