Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make_fx fails with jacfwd (when used with torch.add(Tensor, Scalar)) #1078

Open
kshitij12345 opened this issue Dec 7, 2022 · 1 comment
Open

Comments

@kshitij12345
Copy link
Collaborator

import torch
import functorch

dtype = torch.float32
device = torch.device('cpu')

def foo(x):
    return x + 1.0

x = torch.tensor([[0.0]], dtype=dtype, device=device)

functorch.make_fx(functorch.vmap(foo))(x)  # Works
functorch.make_fx(functorch.jacrev(foo))(x)  # Works
functorch.make_fx(functorch.jacfwd(foo))(x)  # Fails

Error Message:

Traceback (most recent call last):
  File "/home/kshiteej/Pytorch/pytorch_functorch/test/test_scratch.py", line 31, in <module>
    functorch.make_fx(functorch.jacfwd(foo))(x)  # Fails
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 683, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 441, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py", line 739, in trace
    (self.create_arg(fn(*args)),),
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 457, in wrapped
    out = f(*tensors)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 996, in wrapper_fn
    results = vmap(push_jvp, randomness=randomness)(basis)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 362, in wrapped
    return _flat_vmap(
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 35, in fn
    return f(*args, **kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 489, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 989, in push_jvp
    output = _jvp_with_argnums(func, args, basis, argnums=argnums, has_aux=has_aux)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 35, in fn
    return f(*args, **kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 837, in _jvp_with_argnums
    result_duals = func(*duals)
  File "/home/kshiteej/Pytorch/pytorch_functorch/test/test_scratch.py", line 26, in foo
    return x + 1.0
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 483, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 508, in inner_torch_dispatch
    out = proxy_call(self, func, args, kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 259, in proxy_call
    r = func.decompose(*args, **kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_ops.py", line 307, in decompose
    return self._op_dk(dk, *args, **kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 483, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 508, in inner_torch_dispatch
    out = proxy_call(self, func, args, kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 393, in proxy_call
    track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 206, in track_tensor_tree
    wrap_with_proxy(inner_res, proxy_res, constant)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 185, in wrap_with_proxy
    set_meta(proxy, e)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 149, in set_meta
    proxy.node.meta['val'] = torch.empty_strided(val.shape, val.stride(), device=val.device, dtype=val.dtype)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 878, in __torch_dispatch__
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 325, in constructors
    return FakeTensor(fake_mode, r, out_device)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 560, in __init__
    assert device.type != "meta"
AssertionError
@zou3519
Copy link
Contributor

zou3519 commented Dec 7, 2022

Could be related to pytorch/pytorch#90065

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants