Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve same global rank in DDP with Lightning Trainer #250

Closed
csy1204 opened this issue Jul 22, 2024 · 2 comments
Closed

Resolve same global rank in DDP with Lightning Trainer #250

csy1204 opened this issue Jul 22, 2024 · 2 comments
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@csy1204
Copy link
Contributor

csy1204 commented Jul 22, 2024

🐛 Bug

When conducting distributed learning with DDP, the data is being duplicated. Even though different ranks are confirmed when checking _DistributedEnv, the same data is being injected repeatedly into the learning process.
This issue is occurring with the Lightning Trainer, and I would like to seek a solution for this problem.

To Reproduce

Steps to reproduce the behavior:

### Generate dummy data
import litdata as ld


def random_images(index):
    return {"index": index}


if __name__ == "__main__":
    ld.optimize(
        fn=random_images,
        inputs=list(range(100)),
        output_dir="./dummy_dataset",
        num_workers=3,
        chunk_size=25,
        mode="overwrite",
    )
import lightning as L
import litdata as ld


dataset = ld.StreamingDataset("./dummy_dataset", drop_last=True, shuffle=True)


class LDataModule(L.LightningDataModule):
    def __init__(self):
        super().__init__()

    def setup(self, stage: str):
        pass

    def train_dataloader(self):
        return ld.StreamingDataLoader(
            dataset,
            num_workers=2,
            shuffle=True,
            batch_size=25,
            drop_last=True,
        )

import torch
from litdata.utilities.env import _DistributedEnv, _WorkerEnv


class LitModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Sequential(torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1))
        self.register_buffer("tensor_memory", torch.tensor([]))

    def training_step(self, batch, batch_idx):
        distributed_env = _DistributedEnv.detect()
        worker_env = _WorkerEnv.detect()

        data_id = batch["index"].view(-1, 1).to(self.device, dtype=torch.float)

        rank = self.global_rank
        data_indices = batch["index"].tolist()
        print(f"Global rank: {rank}, Processed data indices: {data_indices}\n\n")
        print(f"Global rank: {rank}: {worker_env=} {distributed_env=}")

        x_hat = data_id
        y = self.model(data_id)
        loss = torch.nn.functional.mse_loss(x_hat, y)

        self.log("train_loss", loss)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


dm = LDataModule()
lit_module = LitModule()

trainer = L.Trainer(
    max_epochs=3, log_every_n_steps=1, strategy="ddp_notebook", devices=4, profiler="simple", accelerator="cpu"
)

trainer.fit(
    model=lit_module,
    datamodule=dm,
)

Result

Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
----------------------------------------------------------------------------------------------------
distributed_backend=gloo
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------


  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | model | Sequential | 31     | train
---------------------------------------------
31        Trainable params
0         Non-trainable params
31        Total params
0.000     Total estimated model params size (MB)
[/Users/user/litdata-env/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:122](https://file+.vscode-resource.vscode-cdn.net/Users/user/litdata-env/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:122): Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.
Global rank: 3, Processed data indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]

Global rank: 1, Processed data indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]

Global rank: 0, Processed data indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]

Global rank: 2, Processed data indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]





Global rank: 2: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 2
)Global rank: 0: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 0
)Global rank: 1: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 1
)Global rank: 3: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 3
)



Global rank: 2, Processed data indices: [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]

Global rank: 0, Processed data indices: [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]

Global rank: 3, Processed data indices: [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]

Global rank: 1, Processed data indices: [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74]





Global rank: 3: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 3
)Global rank: 0: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 0
)Global rank: 1: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 1
)Global rank: 2: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 2
)



Global rank: 0, Processed data indices: [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]

Global rank: 3, Processed data indices: [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]

Global rank: 1, Processed data indices: [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]

Global rank: 2, Processed data indices: [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]





Global rank: 1: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 1
)Global rank: 0: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 0
)Global rank: 2: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 2
)Global rank: 3: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 3
)



Global rank: 3, Processed data indices: [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]

Global rank: 0, Processed data indices: [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]

Global rank: 2, Processed data indices: [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]

Global rank: 1, Processed data indices: [75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]





Global rank: 1: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 1
)Global rank: 0: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 0
)Global rank: 2: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 2
)Global rank: 3: worker_env=_WorkerEnv(world_size: 1, rank: 0) distributed_env=_DistributedEnv(world_size: 4, global_rank: 3
)

Code sample

Expected behavior

Environment

  • PyTorch Version (e.g., 1.0)
pytorch-lightning           2.3.3
torch                       2.3.1
torch-tb-profiler           0.4.3
torchmetrics                1.4.0.post0
torchvision                 0.18.1
  • OS (e.g., Linux): MacOSX
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.10.13
  • CUDA/cuDNN version: None
  • GPU models and configuration: None
  • Any other relevant information: install litdata from github (pip install git+https://github.com/lightning-AI/litdata)

Additional context

@csy1204 csy1204 added bug Something isn't working help wanted Extra attention is needed labels Jul 22, 2024
@csy1204
Copy link
Contributor Author

csy1204 commented Jul 22, 2024

I think that _DistributedEnv was not properly injected in the __iter__ step of Dataset.
It always has global rank: 0.

@csy1204 csy1204 changed the title Resolve data duplication issues in DDP with Lightning Trainer Resolve same global rank in DDP with Lightning Trainer Jul 22, 2024
@csy1204
Copy link
Contributor Author

csy1204 commented Jul 22, 2024

I solved it..

class LDataModule(L.LightningDataModule):
    def __init__(self):
        super().__init__()

    def setup(self, stage: str):
        pass

    def train_dataloader(self):
        dataset = ld.StreamingDataset("./dummy_dataset", drop_last=True, shuffle=True)
        return ld.StreamingDataLoader(
            dataset,
            num_workers=2,
            shuffle=True,
            batch_size=25,
            drop_last=True,
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant