Skip to content

Commit

Permalink
Add FP32 fallback support on sd_vae_approx
Browse files Browse the repository at this point in the history
This tries to execute interpolate with FP32 if it failed.

Background is that
on some environment such as Mx chip MacOS devices, we get error as follows:

```
"torch/nn/functional.py", line 3931, in interpolate
        return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    RuntimeError: "upsample_nearest2d_channels_last" not implemented for 'Half'
```

In this case, ```--no-half``` doesn't help to solve. Therefore this commits add the FP32 fallback execution to solve it.

Note that the submodule may require additional modifications. The following is the example modification on the other submodule.

```repositories/stable-diffusion-stability-ai/ldm/modules/diffusionmodules/openaimodel.py

class Upsample(nn.Module):
..snip..
    def forward(self, x):
        assert x.shape[1] == self.channels
        if self.dims == 3:
            x = F.interpolate(
                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
            )
        else:
            try:
                x = F.interpolate(x, scale_factor=2, mode="nearest")
            except:
                x = F.interpolate(x.to(th.float32), scale_factor=2, mode="nearest").to(x.dtype)
        if self.use_conv:
            x = self.conv(x)
        return x
..snip..
```

You can see the FP32 fallback execution as same as sd_vae_approx.py.
  • Loading branch information
hidenorly committed Nov 20, 2023
1 parent 5f36f6a commit 58c1954
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion modules/sd_vae_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ def __init__(self):

def forward(self, x):
extra = 11
x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
try:
x = nn.functional.interpolate(x, (x.shape[2] * 2, x.shape[3] * 2))
except RuntimeError as e:
if "not implemented for" in str(e) and "Half" in str(e):
x = nn.functional.interpolate(x.to(torch.float32), (x.shape[2] * 2, x.shape[3] * 2)).to(x.dtype)
else:
print(f"An unexpected RuntimeError occurred: {str(e)}")
x = nn.functional.pad(x, (extra, extra, extra, extra))

for layer in [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7, self.conv8, ]:
Expand Down

0 comments on commit 58c1954

Please sign in to comment.