Skip to content

Commit

Permalink
Merge pull request #14562 from Nuullll/fix-ipex-xpu-generator
Browse files Browse the repository at this point in the history
[IPEX] Fix xpu generator
  • Loading branch information
AUTOMATIC1111 authored Jan 7, 2024
2 parents b00b429 + 818d6a1 commit 71e0057
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions modules/xpu_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention(
return torch.reshape(result, (*N, L, Ev))


def is_xpu_device(device: str | torch.device = None):
if device is None:
return False
if isinstance(device, str):
return device.startswith("xpu")
return device.type == "xpu"


if has_xpu:
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: device is not None and device.type == "xpu")
try:
# torch.Generator supports "xpu" device since 2.1
torch.Generator("xpu")
except RuntimeError:
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: is_xpu_device(device))

# W/A for some OPs that could not handle different input dtypes
CondFunc('torch.nn.functional.layer_norm',
Expand Down

0 comments on commit 71e0057

Please sign in to comment.