From f29e6d76f9e7664f827a993952cd8a62e1e60963 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 30 Apr 2019 11:34:11 -0700 Subject: [PATCH 1/7] Avoid legacy constructors in distributions --- pyro/distributions/avf_mvn.py | 2 +- pyro/distributions/conjugate.py | 2 +- pyro/distributions/delta.py | 2 +- pyro/distributions/diag_normal_mixture.py | 6 +++--- pyro/distributions/diag_normal_mixture_shared_cov.py | 4 ++-- pyro/distributions/lkj.py | 10 +++++----- pyro/distributions/omt_mvn.py | 2 +- pyro/distributions/relaxed_straight_through.py | 2 +- pyro/distributions/spanning_tree.py | 10 +++++----- pyro/distributions/testing/rejection_gamma.py | 9 ++++++--- pyro/distributions/util.py | 4 ++-- 11 files changed, 28 insertions(+), 25 deletions(-) diff --git a/pyro/distributions/avf_mvn.py b/pyro/distributions/avf_mvn.py index 5c91eeb4e3..ce5c15ad9a 100644 --- a/pyro/distributions/avf_mvn.py +++ b/pyro/distributions/avf_mvn.py @@ -57,7 +57,7 @@ def rsample(self, sample_shape=torch.Size()): class _AVFMVNSample(Function): @staticmethod def forward(ctx, loc, scale_tril, control_var, shape): - white = loc.new_empty(shape).normal_() + white = torch.randn(shape, dtype=loc.dtype, device=loc.device) z = torch.matmul(white, scale_tril.t()) ctx.save_for_backward(scale_tril, control_var, white) return loc + z diff --git a/pyro/distributions/conjugate.py b/pyro/distributions/conjugate.py index 2f82015c38..d46449ed63 100644 --- a/pyro/distributions/conjugate.py +++ b/pyro/distributions/conjugate.py @@ -121,7 +121,7 @@ class DirichletMultinomial(TorchDistribution): def __init__(self, concentration, total_count=1, is_sparse=False, validate_args=None): if isinstance(total_count, numbers.Number): - total_count = concentration.new_tensor(total_count) + total_count = torch.tensor(total_count, dtype=concentration.dtype, device=concentration.device) total_count_1 = total_count.unsqueeze(-1) concentration, total_count = torch.broadcast_tensors(concentration, total_count_1) total_count = total_count_1.squeeze(-1) diff --git a/pyro/distributions/delta.py b/pyro/distributions/delta.py index 0e6d0826f6..ad81f723c1 100644 --- a/pyro/distributions/delta.py +++ b/pyro/distributions/delta.py @@ -34,7 +34,7 @@ def __init__(self, v, log_density=0.0, event_dim=0, validate_args=None): batch_shape = v.shape[:batch_dim] event_shape = v.shape[batch_dim:] if isinstance(log_density, numbers.Number): - log_density = v.new_empty(batch_shape).fill_(log_density) + log_density = torch.full(batch_shape, log_density, dtype=v.dtype, device=v.device) elif validate_args and log_density.shape != batch_shape: raise ValueError('Expected log_density.shape = {}, actual {}'.format( log_density.shape, batch_shape)) diff --git a/pyro/distributions/diag_normal_mixture.py b/pyro/distributions/diag_normal_mixture.py index 521905a679..f87c6933c3 100644 --- a/pyro/distributions/diag_normal_mixture.py +++ b/pyro/distributions/diag_normal_mixture.py @@ -136,7 +136,7 @@ def backward(ctx, grad_output): mu_cd = locs.unsqueeze(-2) - locs.unsqueeze(-3) # b c d i mu_cd_norm = torch.pow(mu_cd, 2.0).sum(-1).sqrt() # b c d mu_cd /= mu_cd_norm.unsqueeze(-1) # b c d i - diagonals = z.new_empty((K,), dtype=torch.long) + diagonals = torch.empty((K,), dtype=torch.long, device=z.device) torch.arange(K, out=diagonals) mu_cd[..., diagonals, diagonals, :] = 0.0 @@ -145,7 +145,7 @@ def backward(ctx, grad_output): z_perp_cd = z.unsqueeze(-2).unsqueeze(-2) - z_ll_cd.unsqueeze(-1) * mu_cd # l b c d i z_perp_cd_sqr = torch.pow(z_perp_cd, 2.0).sum(-1) # l b c d - shift_indices = z.new_empty((dim,), dtype=torch.long) + shift_indices = torch.empty((dim,), dtype=torch.long, device=z.device) torch.arange(dim, out=shift_indices) shift_indices = shift_indices - 1 shift_indices[0] = 0 @@ -170,7 +170,7 @@ def backward(ctx, grad_output): shift_log_scales[..., 0] = 0.0 sigma_products = torch.cumsum(shift_log_scales, dim=-1).exp() # b j i - reverse_indices = z.new_tensor(range(dim - 1, -1, -1), dtype=torch.long) + reverse_indices = torch.tensor(range(dim - 1, -1, -1), dtype=torch.long, device=z.device) reverse_log_sigma_0 = sigma_0.log()[..., reverse_indices] # b 1 i sigma_0_products = torch.cumsum(reverse_log_sigma_0, dim=-1).exp()[..., reverse_indices - 1] # b 1 i sigma_0_products[..., -1] = 1.0 diff --git a/pyro/distributions/diag_normal_mixture_shared_cov.py b/pyro/distributions/diag_normal_mixture_shared_cov.py index fd1b44a8b9..16f305249d 100644 --- a/pyro/distributions/diag_normal_mixture_shared_cov.py +++ b/pyro/distributions/diag_normal_mixture_shared_cov.py @@ -106,7 +106,7 @@ class _MixDiagNormalSharedCovarianceSample(Function): @staticmethod def forward(ctx, locs, coord_scale, component_logits, pis, which, noise_shape): dim = coord_scale.size(-1) - white = locs.new(noise_shape).normal_() + white = torch.randn(noise_shape, dtype=locs.dtype, device=locs.device) n_unsqueezes = locs.dim() - which.dim() for _ in range(n_unsqueezes): which = which.unsqueeze(-1) @@ -130,7 +130,7 @@ def backward(ctx, grad_output): mu_ab = locs_tilde.unsqueeze(-2) - locs_tilde.unsqueeze(-3) # b k j i mu_ab_norm = torch.pow(mu_ab, 2.0).sum(-1).sqrt() # b k j mu_ab /= mu_ab_norm.unsqueeze(-1) # b k j i - diagonals = z.new_empty((K,), dtype=torch.long) + diagonals = torch.empty((K,), dtype=torch.long, device=z.device) torch.arange(K, out=diagonals) mu_ab[..., diagonals, diagonals, :] = 0.0 diff --git a/pyro/distributions/lkj.py b/pyro/distributions/lkj.py index 4b1c6ec3d5..8d9a814aef 100644 --- a/pyro/distributions/lkj.py +++ b/pyro/distributions/lkj.py @@ -39,12 +39,12 @@ def _vector_to_l_cholesky(z): if D % 1 != 0: raise ValueError("Correlation matrix transformation requires d choose 2 inputs") D = int(D) - x = z.new_zeros(list(z.shape[:-1]) + [D, D]) + x = torch.zeros(z.shape[:-1] + (D, D), dtype=z.dtype, device=z.device) x[..., 0, 0] = 1 x[..., 1:, 0] = z[..., :(D - 1)] i = D - 1 - last_squared_x = z.new_zeros(list(z.shape[:-1]) + [D]) + last_squared_x = torch.zeros(z.shape[:-1] + (D,), dtype=z.dtype, device=z.device) for j in range(1, D): distance_to_copy = D - 1 - j last_squared_x = last_squared_x[..., 1:] + x[..., j:, (j - 1)].clone()**2 @@ -83,7 +83,7 @@ def _inverse(self, y): raise ValueError("A matrix that isn't square can't be a Cholesky factor of a correlation matrix") D = y.shape[-1] - z_tri = y.new_zeros(y.shape[:-2] + (D - 2, D - 2)) + z_tri = torch.zeros(y.shape[:-2] + (D - 2, D - 2), dtype=y.dtype, device=y.device) z_stack = [ y[..., 1:, 0] ] @@ -149,7 +149,7 @@ def __init__(self, d, eta, validate_args=None): vector_size = (d * (d - 1)) // 2 alpha = eta.add(0.5 * (d - 1.0)) - concentrations = eta.new_empty(vector_size,) + concentrations = torch.empty(vector_size, dtype=eta.dtype, device=eta.device) i = 0 for k in range(d - 1): alpha -= .5 @@ -210,5 +210,5 @@ def log_prob(self, x): values += log_diagonals.mul(eta.mul(2).add(-2.0)) values = values.sum(-1) + lp - values, _ = torch.broadcast_tensors(values, values.new_empty(self.batch_shape)) + values, _ = torch.broadcast_tensors(values, torch.empty(self.batch_shape, dtype=values.dtype, device=values.device)) return values diff --git a/pyro/distributions/omt_mvn.py b/pyro/distributions/omt_mvn.py index c3a393bc85..2aee5cb753 100644 --- a/pyro/distributions/omt_mvn.py +++ b/pyro/distributions/omt_mvn.py @@ -36,7 +36,7 @@ def rsample(self, sample_shape=torch.Size()): class _OMTMVNSample(Function): @staticmethod def forward(ctx, loc, scale_tril, shape): - white = loc.new_empty(shape).normal_() + white = torch.randn(shape, dtype=loc.dtype, device=loc.device) z = torch.matmul(white, scale_tril.t()) ctx.save_for_backward(z, white, scale_tril) return loc + z diff --git a/pyro/distributions/relaxed_straight_through.py b/pyro/distributions/relaxed_straight_through.py index 5d57cdf1f5..f72f5da82b 100644 --- a/pyro/distributions/relaxed_straight_through.py +++ b/pyro/distributions/relaxed_straight_through.py @@ -44,7 +44,7 @@ class QuantizeCategorical(torch.autograd.Function): @staticmethod def forward(ctx, soft_value): argmax = soft_value.max(-1)[1] - hard_value = soft_value.new_zeros(soft_value.shape) + hard_value = torch.zeros_like(soft_value) hard_value._unquantize = soft_value if argmax.dim() < hard_value.dim(): argmax = argmax.unsqueeze(-1) diff --git a/pyro/distributions/spanning_tree.py b/pyro/distributions/spanning_tree.py index cefdb0ef93..07dab7495f 100644 --- a/pyro/distributions/spanning_tree.py +++ b/pyro/distributions/spanning_tree.py @@ -114,7 +114,7 @@ def log_partition_function(self): grid = make_complete_graph(V) shift = self.edge_logits.max() edge_probs = (self.edge_logits - shift).exp() - adjacency = edge_probs.new_zeros(V, V) + adjacency = torch.zeros(V, V, dtype=edge_probs.dtype) adjacency[grid[0], grid[1]] = edge_probs adjacency[grid[1], grid[0]] = edge_probs laplacian = adjacency.sum(-1).diag() - adjacency @@ -319,7 +319,7 @@ def _sample_tree_mcmc(edge_logits, edges): # Convert edge ids to a canonical list of pairs. edge_ids = edge_ids.sort()[0] - edges = edge_logits.new_empty((E, 2), dtype=torch.long) + edges = torch.empty((E, 2), dtype=torch.long) edges[:, 0] = grid[0, edge_ids] edges[:, 1] = grid[1, edge_ids] return edges @@ -368,9 +368,9 @@ def _sample_tree_approx(edge_logits): # Each of E edges in the tree is stored as an id k in [0, K) indexing into # the complete graph. The id of an edge (v1,v2) is k = v1+v2*(v2-1)/2. - edge_ids = edge_logits.new_empty((E,), dtype=torch.long) + edge_ids = torch.empty((E,), dtype=torch.long) # This maps each vertex to whether it is a member of the cumulative tree. - components = edge_logits.new_zeros(V, dtype=torch.uint8) + components = torch.zeros(V, dtype=torch.uint8) # Sample the first edge at random. probs = (edge_logits - edge_logits.max()).exp() @@ -390,7 +390,7 @@ def _sample_tree_approx(edge_logits): # Convert edge ids to a canonical list of pairs. edge_ids = edge_ids.sort()[0] - edges = edge_logits.new_empty((E, 2), dtype=torch.long) + edges = torch.empty((E, 2), dtype=torch.long) edges[:, 0] = grid[0, edge_ids] edges[:, 1] = grid[1, edge_ids] return edges diff --git a/pyro/distributions/testing/rejection_gamma.py b/pyro/distributions/testing/rejection_gamma.py index 137c39334a..db5775fe00 100644 --- a/pyro/distributions/testing/rejection_gamma.py +++ b/pyro/distributions/testing/rejection_gamma.py @@ -42,8 +42,10 @@ def expand(self, batch_shape, _instance=None): return new def propose(self, sample_shape=torch.Size()): - # Marsaglia & Tsang's x == Naesseth's epsilon - x = self.concentration.new_empty(sample_shape + self.concentration.shape).normal_() + # Marsaglia & Tsang's x == Naesseth's epsilon` + x = torch.randn(sample_shape + self.concentration.shape, + dtype=self.concentration.dtype, + device=self.concentration.device) y = 1.0 + self._c * x v = y * y * y return (self._d * v).clamp_(1e-30, 1e30) @@ -129,7 +131,8 @@ def rsample(self, sample_shape=torch.Size()): x = self._rejection_gamma.rsample(sample_shape) boosted_x = x.clone() for i in range(self._boost): - boosted_x *= (1 - x.new_empty(x.shape).uniform_()) ** (1 / (i + self.concentration)) + u = torch.rand(x.shape, dtype=x.dtype, device=x.device) + boosted_x *= (1 - u) ** (1 / (i + self.concentration)) self._unboost_x_cache = boosted_x, x return boosted_x diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index 0c73c7d112..a27a007919 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -106,7 +106,7 @@ def gather(value, index, dim): Broadcasted gather of indexed values along a named dim. """ value, index = broadcast_all(value, index) - index = index.index_select(dim, index.new_tensor([0])) + index = index.index_select(dim, torch.tensor([0], device=index.device)) return value.gather(dim, index) @@ -194,7 +194,7 @@ def scale_and_mask(tensor, scale=1.0, mask=None): def eye_like(value, m, n=None): if n is None: n = m - eye = value.new_zeros(m, n) + eye = torch.zeros(m, n, dtype=value.dtype, device=value.device) eye.view(-1)[:min(m, n) * n:n + 1] = 1 return eye From 43571c300d39974a69c06fa33c81ebed0e325963 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 30 Apr 2019 12:32:26 -0700 Subject: [PATCH 2/7] Avoid legacy constructors in pyro.infer --- pyro/distributions/lkj.py | 2 +- pyro/distributions/util.py | 4 ++++ pyro/infer/abstract_infer.py | 5 +++-- pyro/infer/mcmc/adaptation.py | 4 +++- pyro/infer/mcmc/hmc.py | 11 ++++++----- pyro/infer/mcmc/mcmc.py | 3 ++- pyro/infer/mcmc/nuts.py | 21 +++++++++------------ pyro/infer/renyi_elbo.py | 10 +++++----- pyro/infer/tracegraph_elbo.py | 2 +- pyro/infer/util.py | 6 +++--- 10 files changed, 37 insertions(+), 31 deletions(-) diff --git a/pyro/distributions/lkj.py b/pyro/distributions/lkj.py index 8d9a814aef..b50aa700bb 100644 --- a/pyro/distributions/lkj.py +++ b/pyro/distributions/lkj.py @@ -210,5 +210,5 @@ def log_prob(self, x): values += log_diagonals.mul(eta.mul(2).add(-2.0)) values = values.sum(-1) + lp - values, _ = torch.broadcast_tensors(values, torch.empty(self.batch_shape, dtype=values.dtype, device=values.device)) + values, _ = torch.broadcast_tensors(values, torch.empty(self.batch_shape)) return values diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index a27a007919..42fb5c227e 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -190,6 +190,10 @@ def scale_and_mask(tensor, scale=1.0, mask=None): return tensor +def scalar_like(prototype, fill_value): + return torch.tensor(fill_value, dtype=prototype.dtype, device=prototype.device) + + # work around lack of jit support for torch.eye(..., out=value) def eye_like(value, m, n=None): if n is None: diff --git a/pyro/infer/abstract_infer.py b/pyro/infer/abstract_infer.py index 14a5dd40ba..db7fad6fd0 100644 --- a/pyro/infer/abstract_infer.py +++ b/pyro/infer/abstract_infer.py @@ -270,7 +270,7 @@ def information_criterion(self, pointwise=False): .log_prob(trace.nodes[obs_node]["value"])) ll = torch.stack(log_likelihoods, dim=0) - waic_value, p_waic = waic(ll, ll.new_tensor(self.log_weights), pointwise) + waic_value, p_waic = waic(ll, torch.tensor(self.log_weights, device=ll.device), pointwise) return OrderedDict([("waic", waic_value), ("p_waic", p_waic)]) @@ -317,7 +317,8 @@ def _adjust_to_data(self, trace, data_trace): # Select random sub-indices to replay values under conditionally independent stacks. # Otherwise, we assume there is an dependence of indexes between training data # and prediction data. - subidxs = Categorical(logits=site["value"].new_ones(site["value"].size(cis.dim))).sample([cis.size]) + logits = torch.ones(site["value"].size(cis.dim), device=site["value"].device) + subidxs = Categorical(logits=logits).sample([cis.size]) site["value"] = site["value"].index_select(cis.dim, subidxs) except KeyError: pass diff --git a/pyro/infer/mcmc/adaptation.py b/pyro/infer/mcmc/adaptation.py index ae16ae07f2..3e010b8209 100644 --- a/pyro/infer/mcmc/adaptation.py +++ b/pyro/infer/mcmc/adaptation.py @@ -104,7 +104,9 @@ def _update_step_size(self, accept_prob): self.step_size = math.exp(log_step_size) def _update_r_dist(self): - loc = self._inverse_mass_matrix.new_zeros(self._inverse_mass_matrix.size(0)) + loc = torch.zeros(self._inverse_mass_matrix.size(0), + dtype=self._inverse_mass_matrix.dtype, + device=self._inverse_mass_matrix.device) if self.is_diag_mass: self._r_dist = dist.Normal(loc, self._inverse_mass_matrix.rsqrt()) else: diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 326fb4f27a..ed81aeb7ab 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -10,13 +10,14 @@ import pyro.distributions as dist import pyro.poutine as poutine from pyro.distributions.util import eye_like +from pyro.distributions.utils import scalar_like from pyro.infer import config_enumerate from pyro.infer.mcmc.adaptation import WarmupAdapter from pyro.infer.mcmc.trace_kernel import TraceKernel from pyro.infer.mcmc.util import TraceEinsumEvaluator from pyro.ops.integrator import velocity_verlet from pyro.poutine.subsample_messenger import _Subsample -from pyro.util import optional, torch_isinf, torch_isnan, ignore_jit_warnings +from pyro.util import ignore_jit_warnings, optional, torch_isinf, torch_isnan class HMC(TraceKernel): @@ -356,7 +357,7 @@ def _initialize_model_properties(self): if site_value is not None: mass_matrix_size = sum(self._r_numels.values()) if self._adapter.is_diag_mass: - initial_mass_matrix = site_value.new_ones(mass_matrix_size) + initial_mass_matrix = torch.full(mass_matrix_size, dtype=site_value.dtype, device=site_value.device) else: initial_mass_matrix = eye_like(site_value, mass_matrix_size) self._adapter.configure(self._warmup_steps, @@ -416,11 +417,11 @@ def sample(self, trace): # Set accept prob to 0.0 if delta_energy is `NaN` which may be # the case for a diverging trajectory when using a large step size. if torch_isnan(delta_energy): - accept_prob = delta_energy.new_tensor(0.0) + accept_prob = scalar_like(delta_energy, 0.) else: accept_prob = (-delta_energy).exp().clamp(max=1.) - rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(accept_prob.new_tensor(0.), - accept_prob.new_tensor(1.))) + rand = pyro.sample("rand_t={}".format(self._t), dist.Uniform(scalar_like(accept_prob, 0.), + scalar_like(accept_prob, 1.))) if rand < accept_prob: self._accept_cnt += 1 z = z_new diff --git a/pyro/infer/mcmc/mcmc.py b/pyro/infer/mcmc/mcmc.py index 8dd903bc43..9844cc62c8 100644 --- a/pyro/infer/mcmc/mcmc.py +++ b/pyro/infer/mcmc/mcmc.py @@ -297,7 +297,8 @@ def diagnostics(self): try: site_stats["n_eff"] = stats.effective_sample_size(site_support) except NotImplementedError: - site_stats["n_eff"] = site_support.new_full(site_support.shape[2:], float("nan")) + site_stats["n_eff"] = torch.full(site_support.shape[2:], float("nan"), + dtype=site_support.dtype, device=site_support) site_stats["r_hat"] = stats.split_gelman_rubin(site_support) self._diagnostics[site] = site_stats return self._diagnostics diff --git a/pyro/infer/mcmc/nuts.py b/pyro/infer/mcmc/nuts.py index 976f8c21d5..ef7bb1be5c 100644 --- a/pyro/infer/mcmc/nuts.py +++ b/pyro/infer/mcmc/nuts.py @@ -7,6 +7,7 @@ import pyro import pyro.distributions as dist from pyro.distributions.util import logsumexp +from pyro.distributions.utils import scalar_like from pyro.infer.mcmc.hmc import HMC from pyro.ops.integrator import velocity_verlet from pyro.util import optional, torch_isnan @@ -167,7 +168,7 @@ def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): r_new_flat = torch.cat([r_new[site_name].reshape(-1) for site_name in sorted(r_new)]) energy_new = potential_energy + self._kinetic_energy(r_new) # handle the NaN case - energy_new = energy_new.new_tensor(float("inf")) if torch_isnan(energy_new) else energy_new + energy_new = scalar_like(energy_new, float("inf")) if torch_isnan(energy_new) else energy_new sliced_energy = energy_new + log_slice diverging = (sliced_energy > self._max_sliced_energy) delta_energy = energy_new - energy_current @@ -180,8 +181,7 @@ def _build_basetree(self, z, r, z_grads, log_slice, direction, energy_current): # we eliminate states which p(z, r) < u, or dE > 0. # Due to this elimination (and stop doubling conditions), # the weight of binary tree might not equal to 2^tree_depth. - tree_weight = (sliced_energy.new_ones(()) if sliced_energy <= 0 - else sliced_energy.new_zeros(())) + tree_weight = scalar_like(sliced_energy, 1. if sliced_energy <= 0 else 0.) return _TreeInfo(z_new, r_new, z_grads, z_new, r_new, z_grads, z_new, potential_energy, z_grads, r_new_flat, tree_weight, False, diverging, accept_prob, 1) @@ -232,7 +232,7 @@ def _build_tree(self, z, r, z_grads, log_slice, direction, tree_depth, energy_cu # we choose the proposal from the first half # (any is fine, because the probability of picking it at the end is 0!). other_half_tree_prob = (other_half_tree.weight / tree_weight if tree_weight > 0 - else tree_weight.new_zeros(())) + else scalar_like(tree_weight, 0.)) is_other_half_tree = pyro.sample("is_other_half_tree", dist.Bernoulli(probs=other_half_tree_prob)) @@ -300,7 +300,7 @@ def sample(self, trace): # sample log_slice directly using `energy`, so as to avoid potential underflow or # overflow issues ([2]). slice_exp_term = pyro.sample("slicevar_exp_t={}".format(self._t), - dist.Exponential(energy_current.new_tensor(1.))) + dist.Exponential(scalar_like(energy_current, 1.))) log_slice = -energy_current - slice_exp_term z_left = z_right = z @@ -310,10 +310,7 @@ def sample(self, trace): r_sum = r_flat sum_accept_probs = 0. num_proposals = 0 - if self.use_multinomial_sampling: - tree_weight = energy_current.new_zeros(()) - else: - tree_weight = energy_current.new_ones(()) + tree_weight = scalar_like(energy_current, 0. if self.use_multinomial_sampling else 1.) # Temporarily disable distributions args checking as # NaNs are expected during step size adaptation. @@ -322,7 +319,7 @@ def sample(self, trace): tree_depth = 0 while tree_depth < self._max_tree_depth: direction = pyro.sample("direction_t={}_treedepth={}".format(self._t, tree_depth), - dist.Bernoulli(probs=tree_weight.new_tensor(0.5))) + dist.Bernoulli(probs=scalar_like(tree_weight, 0.5))) direction = int(direction.item()) if direction == 1: # go to the right, start from the right leaf of current tree new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice, @@ -351,8 +348,8 @@ def sample(self, trace): else: new_tree_prob = new_tree.weight / tree_weight rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth), - dist.Uniform(new_tree_prob.new_tensor(0.), - new_tree_prob.new_tensor(1.))) + dist.Uniform(scalar_like(new_tree_prob, 0.), + scalar_like(new_tree_prob, 1.))) if rand < new_tree_prob: accepted = True z = new_tree.z_proposal diff --git a/pyro/infer/renyi_elbo.py b/pyro/infer/renyi_elbo.py index bc9dc6a127..441166a45e 100644 --- a/pyro/infer/renyi_elbo.py +++ b/pyro/infer/renyi_elbo.py @@ -182,15 +182,15 @@ def loss_and_grads(self, model, guide, *args, **kwargs): if is_identically_zero(elbo_particle): if tensor_holder is not None: - elbo_particle = tensor_holder.new_zeros(tensor_holder.shape) - surrogate_elbo_particle = tensor_holder.new_zeros(tensor_holder.shape) + elbo_particle = torch.zeros_like(tensor_holder) + surrogate_elbo_particle = torch.zeros_like(tensor_holder) else: # elbo_particle is not None if tensor_holder is None: - tensor_holder = elbo_particle.new_empty(elbo_particle.shape) + tensor_holder = torch.zeros_like(elbo_particle) # change types of previous `elbo_particle`s for i in range(len(elbo_particles)): - elbo_particles[i] = tensor_holder.new_zeros(tensor_holder.shape) - surrogate_elbo_particles[i] = tensor_holder.new_zeros(tensor_holder.shape) + elbo_particles[i] = torch.zeros_like(tensor_holder) + surrogate_elbo_particles[i] = torch.zeros_like(tensor_holder) elbo_particles.append(elbo_particle) surrogate_elbo_particles.append(surrogate_elbo_particle) diff --git a/pyro/infer/tracegraph_elbo.py b/pyro/infer/tracegraph_elbo.py index ef1b07eb54..a45c1570f2 100644 --- a/pyro/infer/tracegraph_elbo.py +++ b/pyro/infer/tracegraph_elbo.py @@ -134,7 +134,7 @@ def _compute_elbo_non_reparam(guide_trace, non_reparam_nodes, downstream_costs): param_name = "__baseline_avg_downstream_cost_" + node with torch.no_grad(): avg_downstream_cost_old = pyro.param(param_name, - guide_site['value'].new_zeros(dc_shape)) + torch.zeros(dc_shape, device=guide_site['value'].device)) avg_downstream_cost_new = (1 - baseline_beta) * downstream_cost + \ baseline_beta * avg_downstream_cost_old pyro.get_param_store()[param_name] = avg_downstream_cost_new diff --git a/pyro/infer/util.py b/pyro/infer/util.py index b564c63199..f9d3eea8f8 100644 --- a/pyro/infer/util.py +++ b/pyro/infer/util.py @@ -67,7 +67,7 @@ def zero_grads(tensors): """ for p in tensors: if p.grad is not None: - p.grad = p.grad.new_zeros(p.shape) + p.grad = torch.zeros_like(p.grad) def get_plate_stacks(trace): @@ -177,7 +177,7 @@ def __init__(self, guide_trace, ordering): log_prob = log_prob - log_prob.detach() log_prob = log_prob - math.log(num_samples) if not isinstance(log_prob, torch.Tensor): - log_prob = site["value"].new_tensor(log_prob) + log_prob = torch.tensor(float(log_prob), device=site["value"].device) log_prob._pyro_dims = dims # I don't know why the following broadcast is needed, but it makes tests pass: log_prob, _ = packed.broadcast_all(log_prob, site["packed"]["log_prob"]) @@ -236,7 +236,7 @@ def compute_expectation(self, costs): for cost in cost_terms: key = frozenset(cost._pyro_dims) if queries[key] is None: - query = cost.new_zeros(cost.shape) + query = torch.zeros_like(cost) query._pyro_dims = cost._pyro_dims log_factors.append(query) queries[key] = query From 33db24d8b2108e54038da8ff3aa5266983b259e4 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 30 Apr 2019 13:02:09 -0700 Subject: [PATCH 3/7] Avoid legacy constructors in ops and most of contrib --- pyro/contrib/autoguide/__init__.py | 2 +- pyro/contrib/minipyro.py | 2 +- pyro/contrib/tracking/assignment.py | 36 +++++++++++-------- pyro/contrib/tracking/distributions.py | 2 +- pyro/contrib/tracking/dynamic_models.py | 10 +++--- .../tracking/extended_kalman_filter.py | 2 +- pyro/contrib/tracking/measurements.py | 2 +- pyro/ops/contract.py | 3 +- pyro/ops/linalg.py | 4 +-- pyro/ops/packed.py | 2 +- pyro/ops/stats.py | 20 ++++++----- pyro/optim/clipped_adam.py | 5 +-- 12 files changed, 51 insertions(+), 39 deletions(-) diff --git a/pyro/contrib/autoguide/__init__.py b/pyro/contrib/autoguide/__init__.py index e8a02bc30f..db687e5ce2 100644 --- a/pyro/contrib/autoguide/__init__.py +++ b/pyro/contrib/autoguide/__init__.py @@ -431,7 +431,7 @@ def quantiles(self, quantiles, *args, **kwargs): :rtype: dict """ loc, scale = self._loc_scale(*args, **kwargs) - quantiles = loc.new_tensor(quantiles).unsqueeze(-1) + quantiles = torch.tensor(quantiles, dtype=loc.dtype, device=loc.device).unsqueeze(-1) latents = dist.Normal(loc, scale).icdf(quantiles) result = {} for latent in latents: diff --git a/pyro/contrib/minipyro.py b/pyro/contrib/minipyro.py index ea87e8a087..c55eee54ac 100644 --- a/pyro/contrib/minipyro.py +++ b/pyro/contrib/minipyro.py @@ -270,7 +270,7 @@ def step(self, *args, **kwargs): self.optim(params) # Zero out the gradients so that they don't accumulate. for p in params: - p.grad = p.new_zeros(p.shape) + p.grad = torch.zeros_like(p) return loss.item() diff --git a/pyro/contrib/tracking/assignment.py b/pyro/contrib/tracking/assignment.py index 013500cd87..eca16f0e63 100644 --- a/pyro/contrib/tracking/assignment.py +++ b/pyro/contrib/tracking/assignment.py @@ -119,7 +119,8 @@ def __init__(self, num_objects, num_detections, edges, exists_logits, assign_log # Wrap the results in Distribution objects. # This adds a final logit=0 element denoting spurious detection. - padded_assign = assign.new_empty(num_detections, num_objects + 1).fill_(-float('inf')) + padded_assign = torch.full((num_detections, num_objects + 1), -float('inf'), + dtype=assign.dtype, device=assign.device) padded_assign[:, -1] = 0 padded_assign[edges[0], edges[1]] = assign self.assign_dist = dist.Categorical(logits=padded_assign) @@ -198,9 +199,11 @@ def compute_marginals(exists_logits, assign_logits): """ num_detections, num_objects = assign_logits.shape assert exists_logits.shape == (num_objects,) + dtype = exists_logits.dtype + device = exists_logits.device - exists_probs = exists_logits.new_zeros(2, num_objects) # [not exist, exist] - assign_probs = assign_logits.new_zeros(num_detections, num_objects + 1) + exists_probs = torch.zeros(2, num_objects, dtype=dtype, device=device) # [not exist, exist] + assign_probs = torch.zeros(num_detections, num_objects + 1, dtype=dtype, device=device) for assign in itertools.product(range(num_objects + 1), repeat=num_detections): assign_part = sum(assign_logits[j, i] for j, i in enumerate(assign) if i < num_objects) for exists in itertools.product(*[[1] if i in assign else [0, 1] for i in range(num_objects)]): @@ -233,8 +236,8 @@ def compute_marginals_bp(exists_logits, assign_logits, bp_iters): belief propagation https://arxiv.org/abs/1209.6299 """ - message_e_to_a = exists_logits.new_zeros(assign_logits.shape) - message_a_to_e = exists_logits.new_zeros(assign_logits.shape) + message_e_to_a = torch.zeros_like(assign_logits) + message_a_to_e = torch.zeros_like(assign_logits) for i in range(bp_iters): message_e_to_a = -(message_a_to_e - message_a_to_e.sum(0, True) - exists_logits).exp().log1p() joint = (assign_logits + message_e_to_a).exp() @@ -267,13 +270,14 @@ def compute_marginals_sparse_bp(num_objects, num_detections, edges, def sparse_sum(x, dim, keepdim=False): assert dim in (0, 1) - x = x.new_zeros([num_objects, num_detections][dim]).scatter_add_(0, edges[1 - dim], x) + x = (torch.zeros([num_objects, num_detections][dim], dtype=x.dtype, device=x.device) + .scatter_add_(0, edges[1 - dim], x)) if keepdim: x = x[edges[1 - dim]] return x - message_e_to_a = exists_logits.new_zeros(assign_logits.shape) - message_a_to_e = exists_logits.new_zeros(assign_logits.shape) + message_e_to_a = torch.zeros_like(assign_logits) + message_a_to_e = torch.zeros_like(assign_logits) for i in range(bp_iters): message_e_to_a = -(message_a_to_e - sparse_sum(message_a_to_e, 0, True) - exists_factor).exp().log1p() joint = (assign_logits + message_e_to_a).exp() @@ -298,10 +302,12 @@ def compute_marginals_persistent(exists_logits, assign_logits): """ num_frames, num_detections, num_objects = assign_logits.shape assert exists_logits.shape == (num_objects,) + dtype = exists_logits.dtype + device = exists_logits.device total = 0 - exists_probs = exists_logits.new_zeros(num_objects) - assign_probs = assign_logits.new_zeros(num_frames, num_detections, num_objects) + exists_probs = torch.zeros(num_objects, dtype=dtype, device=device) + assign_probs = torch.zeros(num_frames, num_detections, num_objects, dtype=dtype, device=device) for exists in itertools.product([0, 1], repeat=num_objects): exists = [i for i, e in enumerate(exists) if e] exists_part = _exp(sum(exists_logits[i] for i in exists)) @@ -363,10 +369,12 @@ def compute_marginals_persistent_bp(exists_logits, assign_logits, bp_iters, bp_m assert 0 <= bp_momentum < 1, bp_momentum old, new = bp_momentum, 1 - bp_momentum num_frames, num_detections, num_objects = assign_logits.shape - message_b_to_a = assign_logits.new_zeros(num_frames, num_detections, num_objects) - message_a_to_b = assign_logits.new_zeros(num_frames, num_detections, num_objects) - message_b_to_e = assign_logits.new_zeros(num_frames, num_objects) - message_e_to_b = assign_logits.new_zeros(num_frames, num_objects) + dtype = assign_logits.dtype + device = assign_logits.device + message_b_to_a = torch.zeros(num_frames, num_detections, num_objects, dtype=dtype, device=device) + message_a_to_b = torch.zeros(num_frames, num_detections, num_objects, dtype=dtype, device=device) + message_b_to_e = torch.zeros(num_frames, num_objects, dtype=dtype, device=device) + message_e_to_b = torch.zeros(num_frames, num_objects, dtype=dtype, device=device) for i in range(bp_iters): odds_a = (assign_logits + message_b_to_a).exp() diff --git a/pyro/contrib/tracking/distributions.py b/pyro/contrib/tracking/distributions.py index 1643033640..6115091e19 100644 --- a/pyro/contrib/tracking/distributions.py +++ b/pyro/contrib/tracking/distributions.py @@ -75,7 +75,7 @@ def log_prob(self, value): state = EKFState(self.dynamic_model, self.x0, self.P0, time=0.) result = 0. assert value.shape == self.event_shape - zero = value.new_zeros(self.event_shape[-1]) + zero = torch.zeros(self.event_shape[-1], dtype=value.dtype, device=value.device) for i, measurement_mean in enumerate(value): if i: state = state.predict(self.dt) diff --git a/pyro/contrib/tracking/dynamic_models.py b/pyro/contrib/tracking/dynamic_models.py index 9b506404b4..bc35fe3696 100644 --- a/pyro/contrib/tracking/dynamic_models.py +++ b/pyro/contrib/tracking/dynamic_models.py @@ -118,7 +118,7 @@ def process_noise_dist(self, dt=0.): :return: :class:`~pyro.distributions.torch.MultivariateNormal`. ''' Q = self.process_noise_cov(dt) - return dist.MultivariateNormal(Q.new_zeros(Q.shape[-1]), Q) + return dist.MultivariateNormal(torch.zeros(Q.shape[-1], dtype=Q.dtype, device=Q.device), Q) class DifferentiableDynamicModel(DynamicModel): @@ -182,7 +182,7 @@ def mean2pv(self, x): :return: PV state estimate mean. ''' with torch.no_grad(): - x_pv = x.new_zeros(2*self._dimension) + x_pv = torch.zeros(2 * self._dimension, dtype=x.dtype, device=x.device) x_pv[:self._dimension] = x return x_pv @@ -197,7 +197,7 @@ def cov2pv(self, P): ''' d = 2*self._dimension with torch.no_grad(): - P_pv = P.new_zeros((d, d)) + P_pv = torch.zeros(d, d, dtype=P.dtype, device=P.device) P_pv[:self._dimension, :self._dimension] = P return P_pv @@ -372,7 +372,7 @@ def process_noise_cov(self, dt=0.): d = self._dimension dt2 = dt * dt dt3 = dt2 * dt - Q = self.sa2.new_zeros(d, d) + Q = torch.zeros(d, d, dtype=self.sa2.dtype, device=self.sa2.device) eye = eye_like(self.sa2, d//2) Q[:d//2, :d//2] = dt3 * eye / 3.0 Q[:d//2, d//2:] = dt2 * eye / 2.0 @@ -445,7 +445,7 @@ def process_noise_cov(self, dt=0.): dt2 = dt*dt dt3 = dt2*dt dt4 = dt2*dt2 - Q = self.sa2.new_zeros(d, d) + Q = torch.zeros(d, d, dtype=self.sa2.dtype, device=self.sa2.device) Q[:d//2, :d//2] = 0.25 * dt4 * eye_like(self.sa2, d//2) Q[:d//2, d//2:] = 0.5 * dt3 * eye_like(self.sa2, d//2) Q[d//2:, :d//2] = 0.5 * dt3 * eye_like(self.sa2, d//2) diff --git a/pyro/contrib/tracking/extended_kalman_filter.py b/pyro/contrib/tracking/extended_kalman_filter.py index 1e200de3e6..b2dacbcff2 100644 --- a/pyro/contrib/tracking/extended_kalman_filter.py +++ b/pyro/contrib/tracking/extended_kalman_filter.py @@ -161,7 +161,7 @@ def log_likelihood_of_update(self, measurement): :return: Likelihood of hypothetical update. ''' dz, S = self.innovation(measurement) - return dist.MultivariateNormal(S.new_zeros(S.shape[-1]), + return dist.MultivariateNormal(torch.zeros(S.size(-1), dtype=S.dtype, device=S.device), S).log_prob(dz) def update(self, measurement): diff --git a/pyro/contrib/tracking/measurements.py b/pyro/contrib/tracking/measurements.py index 890c6b856a..39b07be7e7 100644 --- a/pyro/contrib/tracking/measurements.py +++ b/pyro/contrib/tracking/measurements.py @@ -116,7 +116,7 @@ def __init__(self, mean, cov, time=None, frame_num=None): super(PositionMeasurement, self).__init__(mean, cov, time=time, frame_num=frame_num) self._jacobian = torch.cat([ eye_like(mean, self.dimension), - mean.new_zeros((self.dimension, self.dimension))], dim=1) + torch.zeros(self.dimension, self.dimension, dtype=mean.dtype, device=mean.device)], dim=1) def __call__(self, x, do_normalization=True): ''' diff --git a/pyro/ops/contract.py b/pyro/ops/contract.py index 58797288d1..17d0dd52cb 100644 --- a/pyro/ops/contract.py +++ b/pyro/ops/contract.py @@ -511,7 +511,8 @@ def naive_ubersum(equation, *operands, **kwargs): flat_operands.append(_select(operand, offsets, index)) # Defer to unplated einsum. - result = operands[0].new_empty(torch.Size(sizes[d] for d in output)) + result = torch.empty(torch.Size(sizes[d] for d in output), + dtype=operands[0].dtype, device=operands[0].device) local_dims = [d for d in output if d in plates] offsets = [output.index(d) - len(output) for d in local_dims] for index in itertools.product(*(range(sizes[d]) for d in local_dims)): diff --git a/pyro/ops/linalg.py b/pyro/ops/linalg.py index 3a64a99253..4b26cbbb20 100644 --- a/pyro/ops/linalg.py +++ b/pyro/ops/linalg.py @@ -17,7 +17,7 @@ def rinverse(M, sym=False): return 1./M elif M.shape[-1] == 2: det = M[..., 0, 0]*M[..., 1, 1] - M[..., 1, 0]*M[..., 0, 1] - inv = M.new_empty(M.shape) + inv = torch.empty_like(M) inv[..., 0, 0] = M[..., 1, 1] inv[..., 1, 1] = M[..., 0, 0] inv[..., 0, 1] = -M[..., 0, 1] @@ -65,7 +65,7 @@ def inv3d(H, sym=False): Calculates the inverse of a batched 3-D matrix """ detH = determinant_3d(H) - Hinv = H.new_empty(H.shape) + Hinv = torch.empty_like(H) Hinv[..., 0, 0] = H[..., 1, 1] * H[..., 2, 2] - H[..., 1, 2] * H[..., 2, 1] Hinv[..., 1, 1] = H[..., 0, 0] * H[..., 2, 2] - H[..., 0, 2] * H[..., 2, 0] Hinv[..., 2, 2] = H[..., 0, 0] * H[..., 1, 1] - H[..., 0, 1] * H[..., 1, 0] diff --git a/pyro/ops/packed.py b/pyro/ops/packed.py index 712dc0b21e..bfcf5b60cc 100644 --- a/pyro/ops/packed.py +++ b/pyro/ops/packed.py @@ -91,7 +91,7 @@ def gather(value, index, dim): value, index = broadcast_all(value, index) dims = value._pyro_dims.replace(dim, '') pos = value._pyro_dims.index(dim) - index = index.index_select(pos, index.new_tensor([0])) + index = index.index_select(pos, torch.tensor([0], device=index.device)) value = value.gather(pos, index).squeeze(pos) value._pyro_dims = dims assert value.dim() == len(value._pyro_dims) diff --git a/pyro/ops/stats.py b/pyro/ops/stats.py index 4134f62a08..3414389735 100644 --- a/pyro/ops/stats.py +++ b/pyro/ops/stats.py @@ -115,20 +115,20 @@ def autocorrelation(input, dim=0): # centering and padding x centered_signal = input - input.mean(dim=-1, keepdim=True) - pad = input.new_zeros(input.shape[:-1] + (M2 - N,)) + pad = torch.zeros(input.shape[:-1] + (M2 - N,), dtype=input.dtype) centered_signal = torch.cat([centered_signal, pad], dim=-1) # Fourier transform freqvec = torch.rfft(centered_signal, signal_ndim=1, onesided=False) # take square of magnitude of freqvec (or freqvec x freqvec*) freqvec_gram = freqvec.pow(2).sum(-1, keepdim=True) - freqvec_gram = torch.cat([freqvec_gram, input.new_zeros(freqvec_gram.shape)], dim=-1) + freqvec_gram = torch.cat([freqvec_gram, torch.zeros(freqvec_gram.shape, dtype=input.dtype)], dim=-1) # inverse Fourier transform autocorr = torch.irfft(freqvec_gram, signal_ndim=1, onesided=False) # truncate and normalize the result, then transpose back to original shape autocorr = autocorr[..., :N] - autocorr = autocorr / input.new_tensor(range(N, 0, -1)) + autocorr = autocorr / torch.tensor(range(N, 0, -1), dtype=input.dtype) autocorr = autocorr / autocorr[..., :1] return autocorr.transpose(dim, -1) @@ -154,7 +154,8 @@ def _cummin(input): # FIXME: is there a better trick to find accumulate min of a sequence? N = input.size(0) input_tril = input.unsqueeze(0).repeat((N,) + (1,) * input.dim()) - triu_mask = input.new_ones(N, N).triu(diagonal=1).reshape((N, N) + (1,) * (input.dim() - 1)) + triu_mask = (torch.ones(N, N, dtype=input.dtype, device=input.device) + .triu(diagonal=1).reshape((N, N) + (1,) * (input.dim() - 1))) triu_mask = triu_mask.expand((N, N) + input.shape[1:]) > 0.5 input_tril.masked_fill_(triu_mask, input.max()) return input_tril.min(dim=1)[0] @@ -229,7 +230,7 @@ def resample(input, num_samples, dim=0, replacement=False): :param int dim: dimension to draw from ``input``. :returns torch.Tensor: samples drawn randomly from ``input``. """ - weights = input.new_ones(input.size(dim)) + weights = torch.ones(input.size(dim), dtype=input.dtype, device=input.device) indices = torch.multinomial(weights, num_samples, replacement) return input.index_select(dim, indices) @@ -245,7 +246,7 @@ def quantile(input, probs, dim=0): :returns torch.Tensor: quantiles of ``input`` at ``probs``. """ if isinstance(probs, (numbers.Number, list, tuple)): - probs = input.new_tensor(probs) + probs = torch.tensor(probs, dtype=input.dtype, device=input.device) sorted_input = input.sort(dim)[0] max_index = input.size(dim) - 1 indices = probs * max_index @@ -290,9 +291,9 @@ def hpdi(input, prob, dim=0): mass = input.size(dim) index_length = int(prob * mass) intervals_left = sorted_input.index_select( - dim, input.new_tensor(range(mass - index_length), dtype=torch.long)) + dim, torch.tensor(range(mass - index_length), dtype=torch.long, device=input.device)) intervals_right = sorted_input.index_select( - dim, input.new_tensor(range(index_length, mass), dtype=torch.long)) + dim, torch.tensor(range(index_length, mass), dtype=torch.long, device=input.device)) intervals_length = intervals_right - intervals_left index_start = intervals_length.argmin(dim) indices = torch.stack([index_start, index_start + index_length], dim) @@ -329,7 +330,8 @@ def waic(input, log_weights=None, pointwise=False, dim=0): :param int dim: the sample dimension of ``input``. :returns tuple: tuple of WAIC and effective number of parameters. """ - log_weights = input.new_zeros(input.size(dim)) if log_weights is None else log_weights + if log_weights is None: + log_weights = torch.zeros(input.size(dim), dtype=input.dtype, device=input.device) # computes log pointwise predictive density: formula (3) of [1] dim = input.dim() + dim if dim < 0 else dim diff --git a/pyro/optim/clipped_adam.py b/pyro/optim/clipped_adam.py index 530c39d945..60cb6ab098 100644 --- a/pyro/optim/clipped_adam.py +++ b/pyro/optim/clipped_adam.py @@ -2,6 +2,7 @@ import math +import torch from torch.optim.optimizer import Optimizer @@ -56,9 +57,9 @@ def step(self, closure=None): if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values - state['exp_avg'] = grad.new_zeros(grad.shape) + state['exp_avg'] = torch.zeros_like(grad) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = grad.new_zeros(grad.shape) + state['exp_avg_sq'] = torch.zeros_like(grad) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] From b698e2c8b47556e5ffd2921168765c359005b2a6 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 30 Apr 2019 13:29:53 -0700 Subject: [PATCH 4/7] Avoid legacy constructors in examples --- examples/air/air.py | 15 ++++++++------- examples/baseball.py | 14 +++++++------- examples/bayesian_regression.py | 9 +++++---- examples/contrib/oed/gp_bayes_opt.py | 7 ++++--- examples/dmm/polyphonic_data_loader.py | 2 +- examples/hmm.py | 3 ++- examples/lkj.py | 7 ++++--- examples/vae/ss_vae_M2.py | 10 +++++----- examples/vae/vae.py | 4 ++-- pyro/infer/mcmc/hmc.py | 3 +-- pyro/infer/mcmc/nuts.py | 3 +-- 11 files changed, 40 insertions(+), 37 deletions(-) diff --git a/examples/air/air.py b/examples/air/air.py index d5c10bbd21..226fad1df8 100644 --- a/examples/air/air.py +++ b/examples/air/air.py @@ -63,7 +63,8 @@ def __init__(self, self.baseline_scalar = baseline_scalar self.likelihood_sd = likelihood_sd self.use_cuda = use_cuda - self.prototype = torch.tensor(0.).cuda() if use_cuda else torch.tensor(0.) + prototype = torch.tensor(0.).cuda() if use_cuda else torch.tensor(0.) + self.options = dict(dtype=prototype.dtype, device=prototype.device) self.z_pres_size = 1 self.z_where_size = 3 @@ -107,8 +108,8 @@ def __init__(self, def prior(self, n, **kwargs): state = ModelState( - x=self.prototype.new_zeros([n, self.x_size, self.x_size]), - z_pres=self.prototype.new_ones([n, self.z_pres_size]), + x=torch.zeros(n, self.x_size, self.x_size, **self.options), + z_pres=torch.ones(n, self.z_pres_size, **self.options), z_where=None) z_pres = [] @@ -143,8 +144,8 @@ def prior_step(self, t, n, prev, z_pres_prior_p=default_z_pres_prior_p): # Sample latent code for contents of the attention window. z_what = pyro.sample('z_what_{}'.format(t), - dist.Normal(self.prototype.new_zeros([n, self.z_what_size]), - self.prototype.new_ones([n, self.z_what_size])) + dist.Normal(torch.zeros(n, self.z_what_size, **self.options), + torch.ones(n, self.z_what_size, **self.options)) .mask(sample_mask) .to_event(1)) @@ -169,7 +170,7 @@ def model(self, data, batch_size, **kwargs): (z_where, z_pres), x = self.prior(n, **kwargs) pyro.sample('obs', dist.Normal(x.view(n, -1), - (self.likelihood_sd * self.prototype.new_ones(n, self.x_size ** 2))) + (self.likelihood_sd * torch.ones(n, self.x_size ** 2, **self.options))) .to_event(1), obs=batch.view(n, -1)) @@ -207,7 +208,7 @@ def guide(self, data, batch_size, **kwargs): c=batch_expand(self.c_init, n), bl_h=batch_expand(self.bl_h_init, n), bl_c=batch_expand(self.bl_c_init, n), - z_pres=self.prototype.new_ones(n, self.z_pres_size), + z_pres=torch.ones(n, self.z_pres_size, **self.options), z_where=batch_expand(self.z_where_init, n), z_what=batch_expand(self.z_what_init, n)) diff --git a/examples/baseball.py b/examples/baseball.py index 0bac3c7b62..c6813ae3d4 100644 --- a/examples/baseball.py +++ b/examples/baseball.py @@ -9,7 +9,7 @@ import pyro from pyro.distributions import Beta, Binomial, HalfCauchy, Normal, Pareto, Uniform -from pyro.distributions.util import logsumexp +from pyro.distributions.util import logsumexp, scalar_like from pyro.infer.abstract_infer import TracePredictive from pyro.infer.mcmc import MCMC, NUTS @@ -67,7 +67,7 @@ def fully_pooled(at_bats, hits): :param (torch.Tensor) hits: Number of hits for the given at bats. :return: Number of hits predicted by the model. """ - phi_prior = Uniform(at_bats.new_tensor(0), at_bats.new_tensor(1)) + phi_prior = Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1)) phi = pyro.sample("phi", phi_prior) return pyro.sample("obs", Binomial(at_bats, phi), obs=hits) @@ -83,7 +83,7 @@ def not_pooled(at_bats, hits): """ num_players = at_bats.shape[0] with pyro.plate("num_players", num_players): - phi_prior = Uniform(at_bats.new_tensor(0), at_bats.new_tensor(1)) + phi_prior = Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1)) phi = pyro.sample("phi", phi_prior) return pyro.sample("obs", Binomial(at_bats, phi), obs=hits) @@ -101,8 +101,8 @@ def partially_pooled(at_bats, hits): :return: Number of hits predicted by the model. """ num_players = at_bats.shape[0] - m = pyro.sample("m", Uniform(at_bats.new_tensor(0), at_bats.new_tensor(1))) - kappa = pyro.sample("kappa", Pareto(at_bats.new_tensor(1), at_bats.new_tensor(1.5))) + m = pyro.sample("m", Uniform(scalar_like(at_bats, 0), scalar_like(at_bats, 1))) + kappa = pyro.sample("kappa", Pareto(scalar_like(at_bats, 1), scalar_like(at_bats, 1.5))) with pyro.plate("num_players", num_players): phi_prior = Beta(m * kappa, (1 - m) * kappa) phi = pyro.sample("phi", phi_prior) @@ -120,8 +120,8 @@ def partially_pooled_with_logit(at_bats, hits): :return: Number of hits predicted by the model. """ num_players = at_bats.shape[0] - loc = pyro.sample("loc", Normal(at_bats.new_tensor(-1), at_bats.new_tensor(1))) - scale = pyro.sample("scale", HalfCauchy(scale=at_bats.new_tensor(1))) + loc = pyro.sample("loc", Normal(scalar_like(at_bats, -1), scalar_like(at_bats, 1))) + scale = pyro.sample("scale", HalfCauchy(scale=scalar_like(at_bats, 1))) with pyro.plate("num_players", num_players): alpha = pyro.sample("alpha", Normal(loc, scale)) return pyro.sample("obs", Binomial(at_bats, logits=alpha), obs=hits) diff --git a/examples/bayesian_regression.py b/examples/bayesian_regression.py index bdf9cf094c..2e0d8908b7 100644 --- a/examples/bayesian_regression.py +++ b/examples/bayesian_regression.py @@ -53,10 +53,11 @@ def forward(self, x): def model(data): # Create unit normal priors over the parameters - loc = data.new_zeros(torch.Size((1, p))) - scale = 2 * data.new_ones(torch.Size((1, p))) - bias_loc = data.new_zeros(torch.Size((1,))) - bias_scale = 2 * data.new_ones(torch.Size((1,))) + options = dict(dtype=data.dtype, device=data.device) + loc = torch.zeros(1, p, **options) + scale = 2 * torch.ones(1, p, **options) + bias_loc = torch.zeros(1, **options) + bias_scale = 2 * torch.ones(1, **options) w_prior = Normal(loc, scale).to_event(1) b_prior = Normal(bias_loc, bias_scale).to_event(1) priors = {'linear.weight': w_prior, 'linear.bias': b_prior} diff --git a/examples/contrib/oed/gp_bayes_opt.py b/examples/contrib/oed/gp_bayes_opt.py index c1a27629b5..39812dccab 100644 --- a/examples/contrib/oed/gp_bayes_opt.py +++ b/examples/contrib/oed/gp_bayes_opt.py @@ -86,8 +86,8 @@ def opt_differentiable(self, differentiable, num_candidates=5): candidates = [] values = [] for j in range(num_candidates): - x_init = self.gpmodel.X.new_empty(1).uniform_( - self.constraints.lower_bound, self.constraints.upper_bound) + x_init = (torch.empty(1, dtype=self.gpmodel.X.dtype, device=self.gpmodel.X.device) + .uniform_(self.constraints.lower_bound, self.constraints.upper_bound)) x, y = self.find_a_candidate(differentiable, x_init) if torch.isnan(y): continue @@ -109,7 +109,8 @@ def acquire_thompson(self, num_acquisitions=1, **opt_params): """ # Initialize the return tensor - X = self.gpmodel.X.new_empty(num_acquisitions, *self.gpmodel.X.shape[1:]) + X = self.gpmodel.X + X = torch.empty(num_acquisitions, *X.shape[1:], dtype=X.dtype, device=X.device) for i in range(num_acquisitions): sampler = self.gpmodel.iter_sample(noiseless=False) diff --git a/examples/dmm/polyphonic_data_loader.py b/examples/dmm/polyphonic_data_loader.py index 74f9dacfbb..67a4874b66 100644 --- a/examples/dmm/polyphonic_data_loader.py +++ b/examples/dmm/polyphonic_data_loader.py @@ -101,7 +101,7 @@ def load_data(dataset): # this function takes a torch mini-batch and reverses each sequence # (w.r.t. the temporal axis, i.e. axis=1). def reverse_sequences(mini_batch, seq_lengths): - reversed_mini_batch = mini_batch.new_zeros(mini_batch.size()) + reversed_mini_batch = torch.zeros_like(mini_batch) for b in range(mini_batch.size(0)): T = seq_lengths[b] time_slice = torch.arange(T - 1, -1, -1, device=mini_batch.device) diff --git a/examples/hmm.py b/examples/hmm.py index be71c2537e..3346c11d94 100644 --- a/examples/hmm.py +++ b/examples/hmm.py @@ -388,7 +388,8 @@ def forward(self, x, y): # a bernoulli variable y. Whereas x will typically be enumerated, y will be observed. # We apply x_to_hidden independently from y_to_hidden, then broadcast the non-enumerated # y part up to the enumerated x part in the + operation. - x_onehot = y.new_zeros(x.shape[:-1] + (self.args.hidden_dim,)).scatter_(-1, x, 1) + x_onehot = (torch.zeros(x.shape[:-1] + (self.args.hidden_dim,), dtype=y.dtype, device=y.device) + .scatter_(-1, x, 1)) y_conv = self.relu(self.conv(y.unsqueeze(-2))).reshape(y.shape[:-1] + (-1,)) h = self.relu(self.x_to_hidden(x_onehot) + self.y_to_hidden(y_conv)) return self.hidden_to_logits(h) diff --git a/examples/lkj.py b/examples/lkj.py index c9c95f7e27..5164acfcf6 100644 --- a/examples/lkj.py +++ b/examples/lkj.py @@ -19,17 +19,18 @@ def model(y): d = y.shape[1] N = y.shape[0] + options = dict(dtype=y.dtype, device=y.device) # Vector of variances for each of the d variables - theta = pyro.sample("theta", dist.HalfCauchy(y.new_ones(d))) + theta = pyro.sample("theta", dist.HalfCauchy(torch.ones(d, **options))) # Lower cholesky factor of a correlation matrix - eta = y.new_ones(1) # Implies a uniform distribution over correlation matrices + eta = torch.ones(1, **options) # Implies a uniform distribution over correlation matrices L_omega = pyro.sample("L_omega", dist.LKJCorrCholesky(d, eta)) # Lower cholesky factor of the covariance matrix L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega) # For inference with SVI, one might prefer to use torch.bmm(theta.sqrt().diag_embed(), L_omega) # Vector of expectations - mu = y.new_zeros(d) + mu = torch.zeros(d, **options) with pyro.plate("observations", N): obs = pyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y) diff --git a/examples/vae/ss_vae_M2.py b/examples/vae/ss_vae_M2.py index ccad14ff2f..d9de1c1fb5 100644 --- a/examples/vae/ss_vae_M2.py +++ b/examples/vae/ss_vae_M2.py @@ -102,16 +102,17 @@ def model(self, xs, ys=None): pyro.module("ss_vae", self) batch_size = xs.size(0) + options = dict(dtype=xs.dtype, device=xs.device) with pyro.plate("data"): # sample the handwriting style from the constant prior distribution - prior_loc = xs.new_zeros([batch_size, self.z_dim]) - prior_scale = xs.new_ones([batch_size, self.z_dim]) + prior_loc = torch.zeros(batch_size, self.z_dim, **options) + prior_scale = torch.ones(batch_size, self.z_dim, **options) zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1)) # if the label y (which digit to write) is supervised, sample from the # constant prior, otherwise, observe the value (i.e. score it against the constant prior) - alpha_prior = xs.new_ones([batch_size, self.output_size]) / (1.0 * self.output_size) + alpha_prior = torch.ones(batch_size, self.output_size, **options) / (1.0 * self.output_size) ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys) # finally, score the image (x) using the handwriting style (z) and @@ -167,8 +168,7 @@ def classifier(self, xs): res, ind = torch.topk(alpha, 1) # convert the digit(s) to one-hot tensor(s) - ys = xs.new_zeros(alpha.size()) - ys = ys.scatter_(1, ind, 1.0) + ys = torch.zeros_like(alpha).scatter_(1, ind, 1.0) return ys def model_classify(self, xs, ys=None): diff --git a/examples/vae/vae.py b/examples/vae/vae.py index 1d81f3635f..b07d4e9f5a 100644 --- a/examples/vae/vae.py +++ b/examples/vae/vae.py @@ -83,8 +83,8 @@ def model(self, x): pyro.module("decoder", self.decoder) with pyro.plate("data", x.shape[0]): # setup hyperparameters for prior p(z) - z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim))) - z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim))) + z_loc = torch.zeros(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device) + z_scale = torch.ones(x.shape[0], self.z_dim, dtype=x.dtype, device=x.device) # sample from prior (value will be sampled by guide when computing the ELBO) z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) # decode the latent code z diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index ed81aeb7ab..f32add3c48 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -9,8 +9,7 @@ import pyro import pyro.distributions as dist import pyro.poutine as poutine -from pyro.distributions.util import eye_like -from pyro.distributions.utils import scalar_like +from pyro.distributions.util import eye_like, scalar_like from pyro.infer import config_enumerate from pyro.infer.mcmc.adaptation import WarmupAdapter from pyro.infer.mcmc.trace_kernel import TraceKernel diff --git a/pyro/infer/mcmc/nuts.py b/pyro/infer/mcmc/nuts.py index ef7bb1be5c..a17cce8efc 100644 --- a/pyro/infer/mcmc/nuts.py +++ b/pyro/infer/mcmc/nuts.py @@ -6,8 +6,7 @@ import pyro import pyro.distributions as dist -from pyro.distributions.util import logsumexp -from pyro.distributions.utils import scalar_like +from pyro.distributions.util import logsumexp, scalar_like from pyro.infer.mcmc.hmc import HMC from pyro.ops.integrator import velocity_verlet from pyro.util import optional, torch_isnan From b92e53df0a5a8a9f31248434d8bd9f605f51f65a Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 30 Apr 2019 14:07:58 -0700 Subject: [PATCH 5/7] Fix typos --- pyro/infer/mcmc/hmc.py | 2 +- pyro/infer/mcmc/mcmc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index f32add3c48..7bd9e3af14 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -356,7 +356,7 @@ def _initialize_model_properties(self): if site_value is not None: mass_matrix_size = sum(self._r_numels.values()) if self._adapter.is_diag_mass: - initial_mass_matrix = torch.full(mass_matrix_size, dtype=site_value.dtype, device=site_value.device) + initial_mass_matrix = torch.ones(mass_matrix_size, dtype=site_value.dtype, device=site_value.device) else: initial_mass_matrix = eye_like(site_value, mass_matrix_size) self._adapter.configure(self._warmup_steps, diff --git a/pyro/infer/mcmc/mcmc.py b/pyro/infer/mcmc/mcmc.py index 9844cc62c8..18547aac5a 100644 --- a/pyro/infer/mcmc/mcmc.py +++ b/pyro/infer/mcmc/mcmc.py @@ -298,7 +298,7 @@ def diagnostics(self): site_stats["n_eff"] = stats.effective_sample_size(site_support) except NotImplementedError: site_stats["n_eff"] = torch.full(site_support.shape[2:], float("nan"), - dtype=site_support.dtype, device=site_support) + dtype=site_support.dtype, device=site_support.device) site_stats["r_hat"] = stats.split_gelman_rubin(site_support) self._diagnostics[site] = site_stats return self._diagnostics From 8616f7a13657b756cb0c24fa32ad6653d1bcc524 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 30 Apr 2019 14:14:13 -0700 Subject: [PATCH 6/7] Update test_jit.py --- tests/infer/test_jit.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/infer/test_jit.py b/tests/infer/test_jit.py index c6327d0b21..71ae158e05 100644 --- a/tests/infer/test_jit.py +++ b/tests/infer/test_jit.py @@ -43,13 +43,13 @@ def f(x): logger.debug('Compiling f') f = torch.jit.trace(f, (y,), check_trace=False) logger.debug('Calling f(y)') - assert_equal(f(y), y.new_tensor([2., 2.])) + assert_equal(f(y), torch.tensor([2., 2.])) logger.debug('Calling f(y)') - assert_equal(f(y), y.new_tensor([2., 2.])) + assert_equal(f(y), torch.tensor([2., 2.])) logger.debug('Calling f(torch.zeros(2))') - assert_equal(f(torch.zeros(2)), y.new_tensor([1., 1.])) + assert_equal(f(torch.zeros(2)), torch.tensor([1., 1.])) logger.debug('Calling f(torch.zeros(5))') - assert_equal(f(torch.ones(5)), y.new_tensor([2., 2., 2., 2., 2.])) + assert_equal(f(torch.ones(5)), torch.tensor([2., 2., 2., 2., 2.])) def test_multi_output(): @@ -65,13 +65,13 @@ def f(x): logger.debug('Compiling f') f = torch.jit.trace(f, (y,), check_trace=False) logger.debug('Calling f(y)') - assert_equal(f(y)[1], y.new_tensor([2., 2.])) + assert_equal(f(y)[1], torch.tensor([2., 2.])) logger.debug('Calling f(y)') - assert_equal(f(y)[1], y.new_tensor([2., 2.])) + assert_equal(f(y)[1], torch.tensor([2., 2.])) logger.debug('Calling f(torch.zeros(2))') - assert_equal(f(torch.zeros(2))[1], y.new_tensor([1., 1.])) + assert_equal(f(torch.zeros(2))[1], torch.tensor([1., 1.])) logger.debug('Calling f(torch.zeros(5))') - assert_equal(f(torch.ones(5))[1], y.new_tensor([2., 2., 2., 2., 2.])) + assert_equal(f(torch.ones(5))[1], torch.tensor([2., 2., 2., 2., 2.])) def test_backward(): @@ -164,7 +164,7 @@ def f(y, mask): def test_scatter(): def make_one_hot(x, i): - return x.new_zeros(x.shape).scatter(-1, i.unsqueeze(-1), 1.0) + return torch.zeros_like(x).scatter(-1, i.unsqueeze(-1), 1.0) x = torch.randn(5, 4, 3) i = torch.randint(0, 3, torch.Size((5, 4))) @@ -175,7 +175,7 @@ def make_one_hot(x, i): def test_scatter_workaround(): def make_one_hot_expected(x, i): - return x.new_zeros(x.shape).scatter(-1, i.unsqueeze(-1), 1.0) + return torch.zeros_like(x).scatter(-1, i.unsqueeze(-1), 1.0) def make_one_hot_actual(x, i): eye = torch.eye(x.shape[-1], dtype=x.dtype, device=x.device) From e1da677dcbbe3a771bdd92d65631d33371e4f80d Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 30 Apr 2019 14:20:33 -0700 Subject: [PATCH 7/7] Make gather op jit compatible --- pyro/distributions/util.py | 5 ++++- pyro/ops/packed.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pyro/distributions/util.py b/pyro/distributions/util.py index 42fb5c227e..1b292d7f20 100644 --- a/pyro/distributions/util.py +++ b/pyro/distributions/util.py @@ -8,6 +8,7 @@ from torch import logsumexp from torch.distributions.utils import broadcast_all +from pyro.util import ignore_jit_warnings _VALIDATION_ENABLED = False @@ -106,7 +107,9 @@ def gather(value, index, dim): Broadcasted gather of indexed values along a named dim. """ value, index = broadcast_all(value, index) - index = index.index_select(dim, torch.tensor([0], device=index.device)) + with ignore_jit_warnings(): + zero = torch.zeros(1, dtype=torch.long, device=index.device) + index = index.index_select(dim, zero) return value.gather(dim, index) diff --git a/pyro/ops/packed.py b/pyro/ops/packed.py index bfcf5b60cc..13dd759041 100644 --- a/pyro/ops/packed.py +++ b/pyro/ops/packed.py @@ -91,7 +91,9 @@ def gather(value, index, dim): value, index = broadcast_all(value, index) dims = value._pyro_dims.replace(dim, '') pos = value._pyro_dims.index(dim) - index = index.index_select(pos, torch.tensor([0], device=index.device)) + with ignore_jit_warnings(): + zero = torch.zeros(1, dtype=torch.long, device=index.device) + index = index.index_select(pos, zero) value = value.gather(pos, index).squeeze(pos) value._pyro_dims = dims assert value.dim() == len(value._pyro_dims)