Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid using legacy constructors Tensor.new_*() #1842

Merged
merged 7 commits into from
Apr 30, 2019
Merged

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Apr 30, 2019

The following constructors are deprecated and are unsupported by the jit:

  • Tensor.new_tensor()
  • Tensor.new_empty()
  • Tensor.new_zeros()
  • Tensor.new_ones()
  • Tensor.new_full()

Updated

  • pyro core
  • most of pyro.contrib
  • examples/
  • tests/infer/test_jit.py

Triaged (not updated)

  • pyro.contrib.gp (I'll let @fehiepsi handle pyro.contrib.gp)
  • tutorials
  • tests

Tested

  • refactoring is exercised by existing tests
  • ran CUDA_TEST=1 pytest -vx tests/test_examples.py::test_cuda against torch==1.0.1

@@ -34,7 +34,7 @@ def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None):
batch_shape = v.shape[:batch_dim]
event_shape = v.shape[batch_dim:]
if isinstance(log_density, numbers.Number):
log_density = v.new_empty(batch_shape).fill_(log_density)
log_density = torch.full(batch_shape, log_density, dtype=v.dtype, device=v.device)
Copy link
Member

@neerajprad neerajprad Apr 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not relevant to the cases that I see here, but I just discovered that torch.full doesn't generalize well in JIT when the fill value is scalar. Using empty_like with fill_ seems to work fine in those cases.

e.g.

>>> a = torch.tensor([0, 0])

>>> def fn1(x):
        return torch.full(a.size(), x, dtype=a.dtype, device=a.device)

>>> torch.jit.trace(fn1, torch.tensor(1.))(torch.tensor(2.))
tensor([1, 1])

>>> def fn2(x):
        return torch.empty_like(a).fill_(x)

>>> torch.jit.trace(fn1, torch.tensor(1.))(torch.tensor(2.))
tensor([2, 2])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, it would be good to have this example in tests/infer/test_jit.py

Copy link
Member

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks for updating the codebase! Looks great.

@neerajprad
Copy link
Member

Will merge when tests pass.

@fehiepsi
Copy link
Member

(I'll let @fehiepsi handle pyro.contrib.gp)

Yes yes, I'll do it. :D

@jpchen jpchen deleted the legacy-constructors branch June 9, 2019 23:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants