Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IPEX] Support SDE samplers #1579

Merged
merged 1 commit into from
Jul 5, 2023
Merged

[IPEX] Support SDE samplers #1579

merged 1 commit into from
Jul 5, 2023

Conversation

Nuullll
Copy link
Contributor

@Nuullll Nuullll commented Jul 5, 2023

Description

This is a W/A since torch.Generator() API doesn't support xpu backend at the moment. So replacing it with torch.xpu.Generator() API provided by IPEX.

Notes

Original error message for IPEX, when DPM++ 2M SDE or DPM++ 2M SDE Karras sampler is used:

08:08:02-165066 ERROR    Please apply this patch to repositories/k-diffusion/k_diffusion/sampling.py:
                         https://github.com/crowsonkb/k-diffusion/pull/68/files
08:08:02-241772 ERROR    Exception: Device type XPU is not supported for torch.Generator() api.
08:08:02-243317 ERROR    Arguments: args=('task(vf0n135fwtjw20x)', 'iron man', '', [], 20, 1, False, False, 1, 1, 6, 1,
                         -1.0, -1.0, 0, 0, 0, False, 512, 512, False, 0.7, 2, 'Latent', 0, 0, 0, [], 0, False,
                         'MultiDiffusion', False, True, 1024, 1024, 96, 96, 48, 4, 'None', 2, False, 10, 1, 1, 64,
                         False, False, False, False, False, 0.4, 0.4, 0.2, 0.2, '', '', 'Background', 0.2, -1.0, False,
                         0.4, 0.4, 0.2, 0.2, '', '', 'Background', 0.2, -1.0, False, 0.4, 0.4, 0.2, 0.2, '', '',
                         'Background', 0.2, -1.0, False, 0.4, 0.4, 0.2, 0.2, '', '', 'Background', 0.2, -1.0, False,
                         0.4, 0.4, 0.2, 0.2, '', '', 'Background', 0.2, -1.0, False, 0.4, 0.4, 0.2, 0.2, '', '',
                         'Background', 0.2, -1.0, False, 0.4, 0.4, 0.2, 0.2, '', '', 'Background', 0.2, -1.0, False,
                         0.4, 0.4, 0.2, 0.2, '', '', 'Background', 0.2, -1.0, False, 512, 64, True, True, True, False,
                         False, 7, 100, 'Constant', 0, 'Constant', 0, 4,
                         <scripts.controlnet_ui.controlnet_ui_group.UiControlNetUnit object at 0x7fe29c1aedd0>,
                         <scripts.controlnet_ui.controlnet_ui_group.UiControlNetUnit object at 0x7fe29c1ae920>,
                         <scripts.controlnet_ui.controlnet_ui_group.UiControlNetUnit object at 0x7fe3d3c02350>, False,
                         False, 'positive', 'comma', 0, False, False, '', 0, '', [], 0, '', [], 0, '', [], True, False,
                         False, False, 0, False, None, None, False, None, None, False, None, None, False, 50) kwargs={}
08:08:02-247736 ERROR    gradio call: RuntimeError
╭───────────────────────────────────────── Traceback (most recent call last) ──────────────────────────────────────────╮
│ /sd-webui/modules/call_queue.py:34 in f                                                                              │
│                                                                                                                      │
│    33 │   │   │   try:                                                                                               │
│ ❱  34 │   │   │   │   res = func(*args, **kwargs)                                                                    │
│    35 │   │   │   │   progress.record_results(id_task, res)                                                          │
│                                                                                                                      │
│ /sd-webui/modules/txt2img.py:56 in txt2img                                                                           │
│                                                                                                                      │
│   55 │   if processed is None:                                                                                       │
│ ❱ 56 │   │   processed = processing.process_images(p)                                                                │
│   57 │   p.close()                                                                                                   │
│                                                                                                                      │
│                                               ... 14 frames hidden ...                                               │
│                                                                                                                      │
│ /deps/venv/lib/python3.10/site-packages/torchsde/_brownian/brownian_interval.py:234 in _randn                        │
│                                                                                                                      │
│   233 │   │   size = self._top._size                                                                                 │
│ ❱ 234 │   │   return _randn(size, self._top._dtype, self._top._device, seed)                                         │
│   235                                                                                                                │
│                                                                                                                      │
│ /deps/venv/lib/python3.10/site-packages/torchsde/_brownian/brownian_interval.py:32 in _randn                         │
│                                                                                                                      │
│    31 def _randn(size, dtype, device, seed):                                                                         │
│ ❱  32 │   generator = torch.Generator(device).manual_seed(int(seed))                                                 │
│    33 │   return torch.randn(size, dtype=dtype, device=device, generator=generator)                                  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Device type XPU is not supported for torch.Generator() api.

After this fix, the SDE samplers can be used for IPEX without manually applying the local patch for k_diffusion. However, the performance seems to be the same as that of k_diffusion patch: crowsonkb/k-diffusion#68

Environment and Testing

OS: Windows 11
Docker: nuullll/ipex-arc-sd:v0.2 (oneAPI 2023.1)

arch: x86_64
cpu: x86_64
system: Linux
release: 5.15.90.1-microsoft-standard-WSL2
python: 3.10.6
torch: 1.13.0a0+gitb1dde16 Autocast  half
device: Intel(R) Graphics [0x56a0] (1)  # Arc A770
ipex: 1.13.120+xpu
  • DPM++ 2M SDE: ~5.9 it/s
    image

  • DPM++ 2M SDE Karras: ~5.9it/s
    image

This is a W/A since `torch.Generator()` API doesn't support `xpu`
backend at the moment. So replacing it with `torch.xpu.Generator()` API
provided by IPEX.
@Nuullll
Copy link
Contributor Author

Nuullll commented Jul 5, 2023

Adding @Disty0 who might be interested in this :-)

@Disty0 Disty0 merged commit 860bf8e into vladmandic:master Jul 5, 2023
@Disty0
Copy link
Collaborator

Disty0 commented Jul 5, 2023

Nice! Less steps for end users.

Also removed .vscode line in the .gitignore file since it already has .vscode/ in it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants