diff --git a/nemo/collections/diffusion/models/dit/dit_model.py b/nemo/collections/diffusion/models/dit/dit_model.py index 4393eca71983..d653cdfb3a48 100644 --- a/nemo/collections/diffusion/models/dit/dit_model.py +++ b/nemo/collections/diffusion/models/dit/dit_model.py @@ -224,7 +224,7 @@ def forward( * B, dtype=torch.bfloat16, ), - ).view(-1) + ).view(-1).cuda() if self.pre_process: # transpose to match x_B_S_D = self.x_embedder(x)