Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid using legacy constructors Tensor.new_*() #1842

Merged
merged 7 commits into from
Apr 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions examples/air/air.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def __init__(self,
self.baseline_scalar = baseline_scalar
self.likelihood_sd = likelihood_sd
self.use_cuda = use_cuda
self.prototype = torch.tensor(0.).cuda() if use_cuda else torch.tensor(0.)
prototype = torch.tensor(0.).cuda() if use_cuda else torch.tensor(0.)
self.options = dict(dtype=prototype.dtype, device=prototype.device)

self.z_pres_size = 1
self.z_where_size = 3
Expand Down Expand Up @@ -107,8 +108,8 @@ def __init__(self,
def prior(self, n, **kwargs):

state = ModelState(
x=self.prototype.new_zeros([n, self.x_size, self.x_size]),
z_pres=self.prototype.new_ones([n, self.z_pres_size]),
x=torch.zeros(n, self.x_size, self.x_size, **self.options),
z_pres=torch.ones(n, self.z_pres_size, **self.options),
z_where=None)

z_pres = []
Expand Down Expand Up @@ -143,8 +144,8 @@ def prior_step(self, t, n, prev, z_pres_prior_p=default_z_pres_prior_p):

# Sample latent code for contents of the attention window.
z_what = pyro.sample('z_what_{}'.format(t),
dist.Normal(self.prototype.new_zeros([n, self.z_what_size]),
self.prototype.new_ones([n, self.z_what_size]))
dist.Normal(torch.zeros(n, self.z_what_size, **self.options),
torch.ones(n, self.z_what_size, **self.options))
.mask(sample_mask)
.to_event(1))

Expand All @@ -169,7 +170,7 @@ def model(self, data, batch_size, **kwargs):
(z_where, z_pres), x = self.prior(n, **kwargs)
pyro.sample('obs',
dist.Normal(x.view(n, -1),
(self.likelihood_sd * self.prototype.new_ones(n, self.x_size ** 2)))
(self.likelihood_sd * torch.ones(n, self.x_size ** 2, **self.options)))
.to_event(1),
obs=batch.view(n, -1))

Expand Down Expand Up @@ -207,7 +208,7 @@ def guide(self, data, batch_size, **kwargs):
c=batch_expand(self.c_init, n),
bl_h=batch_expand(self.bl_h_init, n),
bl_c=batch_expand(self.bl_c_init, n),
z_pres=self.prototype.new_ones(n, self.z_pres_size),
z_pres=torch.ones(n, self.z_pres_size, **self.options),
z_where=batch_expand(self.z_where_init, n),
z_what=batch_expand(self.z_what_init, n))

Expand Down
14 changes: 7 additions & 7 deletions examples/baseball.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pyro
from pyro.distributions import Beta, Binomial, HalfCauchy, Normal, Pareto, Uniform
from pyro.distributions.util import logsumexp
from pyro.distributions.util import logsumexp, scalar_like
from pyro.infer.abstract_infer import TracePredictive
from pyro.infer.mcmc import MCMC, NUTS

Expand Down Expand Up @@ -67,7 +67,7 @@ def fully_pooled(at_bats, hits):
:param (torch.Tensor) hits: Number of hits for the given at bats.
:return: Number of hits predicted by the model.
"""
phi_prior = Uniform(at_bats.new_tensor(0), at_bats.new_tensor(1))
phi_prior = Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1))
phi = pyro.sample("phi", phi_prior)
return pyro.sample("obs", Binomial(at_bats, phi), obs=hits)

Expand All @@ -83,7 +83,7 @@ def not_pooled(at_bats, hits):
"""
num_players = at_bats.shape[0]
with pyro.plate("num_players", num_players):
phi_prior = Uniform(at_bats.new_tensor(0), at_bats.new_tensor(1))
phi_prior = Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1))
phi = pyro.sample("phi", phi_prior)
return pyro.sample("obs", Binomial(at_bats, phi), obs=hits)

Expand All @@ -101,8 +101,8 @@ def partially_pooled(at_bats, hits):
:return: Number of hits predicted by the model.
"""
num_players = at_bats.shape[0]
m = pyro.sample("m", Uniform(at_bats.new_tensor(0), at_bats.new_tensor(1)))
kappa = pyro.sample("kappa", Pareto(at_bats.new_tensor(1), at_bats.new_tensor(1.5)))
m = pyro.sample("m", Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1)))
kappa = pyro.sample("kappa", Pareto(scalar_like(at_bats, 1), scalar_like(at_bats, 1.5)))
with pyro.plate("num_players", num_players):
phi_prior = Beta(m * kappa, (1 - m) * kappa)
phi = pyro.sample("phi", phi_prior)
Expand All @@ -120,8 +120,8 @@ def partially_pooled_with_logit(at_bats, hits):
:return: Number of hits predicted by the model.
"""
num_players = at_bats.shape[0]
loc = pyro.sample("loc", Normal(at_bats.new_tensor(-1), at_bats.new_tensor(1)))
scale = pyro.sample("scale", HalfCauchy(scale=at_bats.new_tensor(1)))
loc = pyro.sample("loc", Normal(scalar_like(at_bats, -1), scalar_like(at_bats, 1)))
scale = pyro.sample("scale", HalfCauchy(scale=scalar_like(at_bats, 1)))
with pyro.plate("num_players", num_players):
alpha = pyro.sample("alpha", Normal(loc, scale))
return pyro.sample("obs", Binomial(at_bats, logits=alpha), obs=hits)
Expand Down
9 changes: 5 additions & 4 deletions examples/bayesian_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ def forward(self, x):

def model(data):
# Create unit normal priors over the parameters
loc = data.new_zeros(torch.Size((1, p)))
scale = 2 * data.new_ones(torch.Size((1, p)))
bias_loc = data.new_zeros(torch.Size((1,)))
bias_scale = 2 * data.new_ones(torch.Size((1,)))
options = dict(dtype=data.dtype, device=data.device)
loc = torch.zeros(1, p, **options)
scale = 2 * torch.ones(1, p, **options)
bias_loc = torch.zeros(1, **options)
bias_scale = 2 * torch.ones(1, **options)
w_prior = Normal(loc, scale).to_event(1)
b_prior = Normal(bias_loc, bias_scale).to_event(1)
priors = {'linear.weight': w_prior, 'linear.bias': b_prior}
Expand Down
7 changes: 4 additions & 3 deletions examples/contrib/oed/gp_bayes_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def opt_differentiable(self, differentiable, num_candidates=5):
candidates = []
values = []
for j in range(num_candidates):
x_init = self.gpmodel.X.new_empty(1).uniform_(
self.constraints.lower_bound, self.constraints.upper_bound)
x_init = (torch.empty(1, dtype=self.gpmodel.X.dtype, device=self.gpmodel.X.device)
.uniform_(self.constraints.lower_bound, self.constraints.upper_bound))
x, y = self.find_a_candidate(differentiable, x_init)
if torch.isnan(y):
continue
Expand All @@ -109,7 +109,8 @@ def acquire_thompson(self, num_acquisitions=1, **opt_params):
"""

# Initialize the return tensor
X = self.gpmodel.X.new_empty(num_acquisitions, *self.gpmodel.X.shape[1:])
X = self.gpmodel.X
X = torch.empty(num_acquisitions, *X.shape[1:], dtype=X.dtype, device=X.device)

for i in range(num_acquisitions):
sampler = self.gpmodel.iter_sample(noiseless=False)
Expand Down
2 changes: 1 addition & 1 deletion examples/dmm/polyphonic_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def load_data(dataset):
# this function takes a torch mini-batch and reverses each sequence
# (w.r.t. the temporal axis, i.e. axis=1).
def reverse_sequences(mini_batch, seq_lengths):
reversed_mini_batch = mini_batch.new_zeros(mini_batch.size())
reversed_mini_batch = torch.zeros_like(mini_batch)
for b in range(mini_batch.size(0)):
T = seq_lengths[b]
time_slice = torch.arange(T - 1, -1, -1, device=mini_batch.device)
Expand Down
3 changes: 2 additions & 1 deletion examples/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ def forward(self, x, y):
# a bernoulli variable y. Whereas x will typically be enumerated, y will be observed.
# We apply x_to_hidden independently from y_to_hidden, then broadcast the non-enumerated
# y part up to the enumerated x part in the + operation.
x_onehot = y.new_zeros(x.shape[:-1] + (self.args.hidden_dim,)).scatter_(-1, x, 1)
x_onehot = (torch.zeros(x.shape[:-1] + (self.args.hidden_dim,), dtype=y.dtype, device=y.device)
.scatter_(-1, x, 1))
y_conv = self.relu(self.conv(y.unsqueeze(-2))).reshape(y.shape[:-1] + (-1,))
h = self.relu(self.x_to_hidden(x_onehot) + self.y_to_hidden(y_conv))
return self.hidden_to_logits(h)
Expand Down
7 changes: 4 additions & 3 deletions examples/lkj.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
def model(y):
d = y.shape[1]
N = y.shape[0]
options = dict(dtype=y.dtype, device=y.device)
# Vector of variances for each of the d variables
theta = pyro.sample("theta", dist.HalfCauchy(y.new_ones(d)))
theta = pyro.sample("theta", dist.HalfCauchy(torch.ones(d, **options)))
# Lower cholesky factor of a correlation matrix
eta = y.new_ones(1) # Implies a uniform distribution over correlation matrices
eta = torch.ones(1, **options) # Implies a uniform distribution over correlation matrices
L_omega = pyro.sample("L_omega", dist.LKJCorrCholesky(d, eta))
# Lower cholesky factor of the covariance matrix
L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega)
# For inference with SVI, one might prefer to use torch.bmm(theta.sqrt().diag_embed(), L_omega)

# Vector of expectations
mu = y.new_zeros(d)
mu = torch.zeros(d, **options)

with pyro.plate("observations", N):
obs = pyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)
Expand Down
10 changes: 5 additions & 5 deletions examples/vae/ss_vae_M2.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,17 @@ def model(self, xs, ys=None):
pyro.module("ss_vae", self)

batch_size = xs.size(0)
options = dict(dtype=xs.dtype, device=xs.device)
with pyro.plate("data"):

# sample the handwriting style from the constant prior distribution
prior_loc = xs.new_zeros([batch_size, self.z_dim])
prior_scale = xs.new_ones([batch_size, self.z_dim])
prior_loc = torch.zeros(batch_size, self.z_dim, **options)
prior_scale = torch.ones(batch_size, self.z_dim, **options)
zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))

# if the label y (which digit to write) is supervised, sample from the
# constant prior, otherwise, observe the value (i.e. score it against the constant prior)
alpha_prior = xs.new_ones([batch_size, self.output_size]) / (1.0 * self.output_size)
alpha_prior = torch.ones(batch_size, self.output_size, **options) / (1.0 * self.output_size)
ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)

# finally, score the image (x) using the handwriting style (z) and
Expand Down Expand Up @@ -167,8 +168,7 @@ def classifier(self, xs):
res, ind = torch.topk(alpha, 1)

# convert the digit(s) to one-hot tensor(s)
ys = xs.new_zeros(alpha.size())
ys = ys.scatter_(1, ind, 1.0)
ys = torch.zeros_like(alpha).scatter_(1, ind, 1.0)
return ys

def model_classify(self, xs, ys=None):
Expand Down
4 changes: 2 additions & 2 deletions examples/vae/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def model(self, x):
pyro.module("decoder", self.decoder)
with pyro.plate("data", x.shape[0]):
# setup hyperparameters for prior p(z)
z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device)
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# decode the latent code z
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/autoguide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def quantiles(self, quantiles, *args, **kwargs):
:rtype: dict
"""
loc, scale = self._loc_scale(*args, **kwargs)
quantiles = loc.new_tensor(quantiles).unsqueeze(-1)
quantiles = torch.tensor(quantiles, dtype=loc.dtype, device=loc.device).unsqueeze(-1)
latents = dist.Normal(loc, scale).icdf(quantiles)
result = {}
for latent in latents:
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def step(self, *args, **kwargs):
self.optim(params)
# Zero out the gradients so that they don't accumulate.
for p in params:
p.grad = p.new_zeros(p.shape)
p.grad = torch.zeros_like(p)
return loss.item()


Expand Down
36 changes: 22 additions & 14 deletions pyro/contrib/tracking/assignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def __init__(self, num_objects, num_detections, edges, exists_logits, assign_log

# Wrap the results in Distribution objects.
# This adds a final logit=0 element denoting spurious detection.
padded_assign = assign.new_empty(num_detections, num_objects + 1).fill_(-float('inf'))
padded_assign = torch.full((num_detections, num_objects + 1), -float('inf'),
dtype=assign.dtype, device=assign.device)
padded_assign[:, -1] = 0
padded_assign[edges[0], edges[1]] = assign
self.assign_dist = dist.Categorical(logits=padded_assign)
Expand Down Expand Up @@ -198,9 +199,11 @@ def compute_marginals(exists_logits, assign_logits):
"""
num_detections, num_objects = assign_logits.shape
assert exists_logits.shape == (num_objects,)
dtype = exists_logits.dtype
device = exists_logits.device

exists_probs = exists_logits.new_zeros(2, num_objects) # [not exist, exist]
assign_probs = assign_logits.new_zeros(num_detections, num_objects + 1)
exists_probs = torch.zeros(2, num_objects, dtype=dtype, device=device) # [not exist, exist]
assign_probs = torch.zeros(num_detections, num_objects + 1, dtype=dtype, device=device)
for assign in itertools.product(range(num_objects + 1), repeat=num_detections):
assign_part = sum(assign_logits[j, i] for j, i in enumerate(assign) if i < num_objects)
for exists in itertools.product(*[[1] if i in assign else [0, 1] for i in range(num_objects)]):
Expand Down Expand Up @@ -233,8 +236,8 @@ def compute_marginals_bp(exists_logits, assign_logits, bp_iters):
belief propagation
https://arxiv.org/abs/1209.6299
"""
message_e_to_a = exists_logits.new_zeros(assign_logits.shape)
message_a_to_e = exists_logits.new_zeros(assign_logits.shape)
message_e_to_a = torch.zeros_like(assign_logits)
message_a_to_e = torch.zeros_like(assign_logits)
for i in range(bp_iters):
message_e_to_a = -(message_a_to_e - message_a_to_e.sum(0, True) - exists_logits).exp().log1p()
joint = (assign_logits + message_e_to_a).exp()
Expand Down Expand Up @@ -267,13 +270,14 @@ def compute_marginals_sparse_bp(num_objects, num_detections, edges,

def sparse_sum(x, dim, keepdim=False):
assert dim in (0, 1)
x = x.new_zeros([num_objects, num_detections][dim]).scatter_add_(0, edges[1 - dim], x)
x = (torch.zeros([num_objects, num_detections][dim], dtype=x.dtype, device=x.device)
.scatter_add_(0, edges[1 - dim], x))
if keepdim:
x = x[edges[1 - dim]]
return x

message_e_to_a = exists_logits.new_zeros(assign_logits.shape)
message_a_to_e = exists_logits.new_zeros(assign_logits.shape)
message_e_to_a = torch.zeros_like(assign_logits)
message_a_to_e = torch.zeros_like(assign_logits)
for i in range(bp_iters):
message_e_to_a = -(message_a_to_e - sparse_sum(message_a_to_e, 0, True) - exists_factor).exp().log1p()
joint = (assign_logits + message_e_to_a).exp()
Expand All @@ -298,10 +302,12 @@ def compute_marginals_persistent(exists_logits, assign_logits):
"""
num_frames, num_detections, num_objects = assign_logits.shape
assert exists_logits.shape == (num_objects,)
dtype = exists_logits.dtype
device = exists_logits.device

total = 0
exists_probs = exists_logits.new_zeros(num_objects)
assign_probs = assign_logits.new_zeros(num_frames, num_detections, num_objects)
exists_probs = torch.zeros(num_objects, dtype=dtype, device=device)
assign_probs = torch.zeros(num_frames, num_detections, num_objects, dtype=dtype, device=device)
for exists in itertools.product([0, 1], repeat=num_objects):
exists = [i for i, e in enumerate(exists) if e]
exists_part = _exp(sum(exists_logits[i] for i in exists))
Expand Down Expand Up @@ -363,10 +369,12 @@ def compute_marginals_persistent_bp(exists_logits, assign_logits, bp_iters, bp_m
assert 0 <= bp_momentum < 1, bp_momentum
old, new = bp_momentum, 1 - bp_momentum
num_frames, num_detections, num_objects = assign_logits.shape
message_b_to_a = assign_logits.new_zeros(num_frames, num_detections, num_objects)
message_a_to_b = assign_logits.new_zeros(num_frames, num_detections, num_objects)
message_b_to_e = assign_logits.new_zeros(num_frames, num_objects)
message_e_to_b = assign_logits.new_zeros(num_frames, num_objects)
dtype = assign_logits.dtype
device = assign_logits.device
message_b_to_a = torch.zeros(num_frames, num_detections, num_objects, dtype=dtype, device=device)
message_a_to_b = torch.zeros(num_frames, num_detections, num_objects, dtype=dtype, device=device)
message_b_to_e = torch.zeros(num_frames, num_objects, dtype=dtype, device=device)
message_e_to_b = torch.zeros(num_frames, num_objects, dtype=dtype, device=device)

for i in range(bp_iters):
odds_a = (assign_logits + message_b_to_a).exp()
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/tracking/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def log_prob(self, value):
state = EKFState(self.dynamic_model, self.x0, self.P0, time=0.)
result = 0.
assert value.shape == self.event_shape
zero = value.new_zeros(self.event_shape[-1])
zero = torch.zeros(self.event_shape[-1], dtype=value.dtype, device=value.device)
for i, measurement_mean in enumerate(value):
if i:
state = state.predict(self.dt)
Expand Down
10 changes: 5 additions & 5 deletions pyro/contrib/tracking/dynamic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def process_noise_dist(self, dt=0.):
:return: :class:`~pyro.distributions.torch.MultivariateNormal`.
'''
Q = self.process_noise_cov(dt)
return dist.MultivariateNormal(Q.new_zeros(Q.shape[-1]), Q)
return dist.MultivariateNormal(torch.zeros(Q.shape[-1], dtype=Q.dtype, device=Q.device), Q)


class DifferentiableDynamicModel(DynamicModel):
Expand Down Expand Up @@ -182,7 +182,7 @@ def mean2pv(self, x):
:return: PV state estimate mean.
'''
with torch.no_grad():
x_pv = x.new_zeros(2*self._dimension)
x_pv = torch.zeros(2 * self._dimension, dtype=x.dtype, device=x.device)
x_pv[:self._dimension] = x
return x_pv

Expand All @@ -197,7 +197,7 @@ def cov2pv(self, P):
'''
d = 2*self._dimension
with torch.no_grad():
P_pv = P.new_zeros((d, d))
P_pv = torch.zeros(d, d, dtype=P.dtype, device=P.device)
P_pv[:self._dimension, :self._dimension] = P
return P_pv

Expand Down Expand Up @@ -372,7 +372,7 @@ def process_noise_cov(self, dt=0.):
d = self._dimension
dt2 = dt * dt
dt3 = dt2 * dt
Q = self.sa2.new_zeros(d, d)
Q = torch.zeros(d, d, dtype=self.sa2.dtype, device=self.sa2.device)
eye = eye_like(self.sa2, d//2)
Q[:d//2, :d//2] = dt3 * eye / 3.0
Q[:d//2, d//2:] = dt2 * eye / 2.0
Expand Down Expand Up @@ -445,7 +445,7 @@ def process_noise_cov(self, dt=0.):
dt2 = dt*dt
dt3 = dt2*dt
dt4 = dt2*dt2
Q = self.sa2.new_zeros(d, d)
Q = torch.zeros(d, d, dtype=self.sa2.dtype, device=self.sa2.device)
Q[:d//2, :d//2] = 0.25 * dt4 * eye_like(self.sa2, d//2)
Q[:d//2, d//2:] = 0.5 * dt3 * eye_like(self.sa2, d//2)
Q[d//2:, :d//2] = 0.5 * dt3 * eye_like(self.sa2, d//2)
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/tracking/extended_kalman_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def log_likelihood_of_update(self, measurement):
:return: Likelihood of hypothetical update.
'''
dz, S = self.innovation(measurement)
return dist.MultivariateNormal(S.new_zeros(S.shape[-1]),
return dist.MultivariateNormal(torch.zeros(S.size(-1), dtype=S.dtype, device=S.device),
S).log_prob(dz)

def update(self, measurement):
Expand Down
Loading