Skip to content

Commit

Permalink
zluda hijack rfftn
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Nov 14, 2024
1 parent c4f9032 commit b59a21f
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions modules/zluda_hijacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def fft_ifftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: d
return _fft_ifftn(input.cpu(), *args, **kwargs).to(input.device)


_fft_rfftn = torch.fft.rfftn
def fft_rfftn(input: torch.Tensor, *args, **kwargs) -> torch.Tensor: # pylint: disable=redefined-builtin
return _fft_rfftn(input.cpu(), *args, **kwargs).to(input.device)


def jit_script(f, *_, **__): # experiment / provide dummy graph
f.graph = torch._C.Graph() # pylint: disable=protected-access
return f
Expand All @@ -29,4 +34,5 @@ def do_hijack():
torch.topk = topk
torch.fft.fftn = fft_fftn
torch.fft.ifftn = fft_ifftn
torch.fft.rfftn = fft_rfftn
torch.jit.script = jit_script

0 comments on commit b59a21f

Please sign in to comment.