Skip to content

Commit

Permalink
[air] DreamBooth example: Fix code for batch size > 1 (ray-project#34398
Browse files Browse the repository at this point in the history
)

The DreamBooth finetuning example currently throws an error when batch size > 1, even when the GPU memory is large enough. This is because the training batches are currently not created correctly.

This PR fixes the batch format and includes in-line comments to explain the new behavior.

Signed-off-by: Kai Fricke <[email protected]>
  • Loading branch information
krfricke authored and vitsai committed Apr 17, 2023
1 parent e73c0f3 commit b2c750f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
17 changes: 15 additions & 2 deletions python/ray/air/examples/dreambooth/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,23 @@ def collate(batch, device, dtype):
# of the batch.
# During training, a batch will be chunked into 2 sub-batches for prior
# preserving loss calculation.
images = torch.squeeze(torch.stack([batch["image"], batch["image_1"]]))

# batch["image"] = image1, image2
# batch["image_1"] = reg1, reg2
# After cat, we will have [image1, reg1, image2, reg]

images = torch.cat([batch["image"], batch["image_1"]], dim=0)
images = images.to(memory_format=torch.contiguous_format).float()

prompt_ids = torch.cat([batch["prompt_ids"], batch["prompt_ids_1"]], dim=0)
batch_size = len(batch["prompt_ids"])

# batch["prompt_ids"] = pr1, pr2
# batch["prompt_ids_1"] = rr1, rr2
# After stack+reshape, we will have [pr1, rr1, pr2, rr2]

prompt_ids = torch.stack(
[batch["prompt_ids"], batch["prompt_ids_1"]], dim=1
).reshape(batch_size * 2, -1)

return {
"prompt_ids": prompt_ids.to(device), # token ids should stay int.
Expand Down
1 change: 1 addition & 0 deletions release/release_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@
run:
timeout: 1800
script: bash dreambooth_run.sh
artifact_path: /tmp/artifacts/example_out.jpg


- name: air_example_gptj_deepspeed_fine_tuning
Expand Down

0 comments on commit b2c750f

Please sign in to comment.