diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index f038c0f99c..fa0c0d1f16 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -11,6 +11,7 @@ - Added the `broadcast_distribution_samples` function that helps broadcasting arrays of drawn samples, taking into account the requested `size` and the inferred distribution shape. This sometimes is needed by distributions that call several `rvs` separately within their `random` method, such as the `ZeroInflatedPoisson` (Fix issue #3310). - The `Wald`, `Kumaraswamy`, `LogNormal`, `Pareto`, `Cauchy`, `HalfCauchy`, `Weibull` and `ExGaussian` distributions `random` method used a hidden `_random` function that was written with scalars in mind. This could potentially lead to artificial correlations between random draws. Added shape guards and broadcasting of the distribution samples to prevent this (Similar to issue #3310). - Added a fix to allow the imputation of single missing values of observed data, which previously would fail (Fix issue #3122). +- Fix for #3346. The `draw_values` function was too permissive with what could be grabbed from inside `point`, which lead to an error when sampling posterior predictives of variables that depended on shared variables that had changed their shape after `pm.sample()` had been called. - Fix for #3354. `draw_values` now adds the theano graph descendants of `TensorConstant` or `SharedVariables` to the named relationship nodes stack, only if these descendants are `ObservedRV` or `MultiObservedRV` instances. ### Deprecations diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index d206b75528..16b9cf14d0 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -309,7 +309,8 @@ def draw_values(params, point=None, size=None): # param was drawn in related contexts v = drawn[(p, size)] evaluated[i] = v - elif name is not None and name in point: + # We filter out Deterministics by checking for `model` attribute + elif name is not None and hasattr(p, 'model') and name in point: # param.name is in point v = point[name] evaluated[i] = drawn[(p, size)] = v @@ -495,7 +496,7 @@ def _draw_value(param, point=None, givens=None, size=None): dist_tmp.shape = distshape try: - dist_tmp.random(point=point, size=size) + return dist_tmp.random(point=point, size=size) except (ValueError, TypeError): # reset shape to account for shape changes # with theano.shared inputs diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 1febe0a34c..1d73e55edf 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -302,6 +302,32 @@ def test_model_not_drawable_prior(self): samples = pm.sample_posterior_predictive(trace, 50) assert samples['foo'].shape == (50, 200) + def test_model_shared_variable(self): + x = np.random.randn(100) + y = x > 0 + x_shared = theano.shared(x) + y_shared = theano.shared(y) + with pm.Model() as model: + coeff = pm.Normal('x', mu=0, sd=1) + logistic = pm.Deterministic('p', pm.math.sigmoid(coeff * x_shared)) + + obs = pm.Bernoulli('obs', p=logistic, observed=y_shared) + trace = pm.sample(100) + + x_shared.set_value([-1, 0, 1.]) + y_shared.set_value([0, 0, 0]) + + samples = 100 + with model: + post_pred = pm.sample_posterior_predictive(trace, + samples=samples, + vars=[logistic, obs]) + + expected_p = np.array([logistic.eval({coeff: val}) + for val in trace['x'][:samples]]) + assert post_pred['obs'].shape == (samples, 3) + assert np.allclose(post_pred['p'], expected_p) + def test_deterministic_of_observed(self): meas_in_1 = pm.theanof.floatX(2 + 4 * np.random.randn(100)) meas_in_2 = pm.theanof.floatX(5 + 4 * np.random.randn(100)) diff --git a/pymc3/tests/test_step.py b/pymc3/tests/test_step.py index 2e7766e97a..c2bc15dcd8 100644 --- a/pymc3/tests/test_step.py +++ b/pymc3/tests/test_step.py @@ -467,206 +467,206 @@ class TestStepMethods: # yield test doesn't work subclassing object ), SMC: np.array( [ - 0.53154128, - 1.69014751, - 0.38863621, - 1.36550742, - 0.54937705, - 0.85452502, - 1.34000193, - 0.04963276, - 0.73207585, - -0.45504452, - 0.99087969, - 0.53800418, - 1.69481083, - 0.19015456, - 0.54587521, - 0.51198155, - 0.17514563, - -0.62023686, - 0.73211032, - 1.0269751, - 0.82004582, - 1.07714066, - 1.27243655, - 0.8603388, - -0.96709536, - -0.4963755, - 0.47436677, - 0.34392296, - 0.08501226, - 0.95779747, - 1.21125461, - -0.04609757, - 0.29714065, - 0.89447118, - 0.00472546, - 0.50365803, - 1.73127064, - 1.04164544, - -0.22236077, - 1.33631993, - 0.96357527, - 1.06122196, - 0.12798557, - 0.4665272, - -0.1162582, - 1.62002463, - 1.44557222, - -0.49218985, - 1.2175313, - 0.25761981, - 0.82879531, - 0.16321047, - 1.34260731, - -0.05709803, - 0.18903618, - 0.76825821, - 0.08211472, - 0.53817434, - 0.53379232, - -0.47094362, - 1.14433914, - 0.03770296, - 1.30737805, - 0.39671022, - 1.22440728, - 0.09600212, - -0.49796137, - -0.44963869, - 0.95208986, - -0.04308567, - 0.45937807, - 2.59887219, - 0.36326674, - 1.16659465, - 2.26888158, - -0.64081701, - 0.13085995, - 1.5847621, - 0.29342994, - -0.7802778, - 0.62631291, - 0.56155063, - 0.63017152, - 1.88801376, - 0.32864795, - 0.19722366, - 0.62506725, - -0.04154236, - 0.74923865, - 0.64958051, - 1.05205509, - 1.12818507, - 0.35463018, - 1.49394093, - -1.32280176, - 0.48922758, - -0.67185953, - 0.01282045, - -0.00832875, - 0.60746178, - 1.04869967, - 0.43197615, - 0.14665959, - -0.08117829, - 0.43216574, - 0.87241428, - -0.07985268, - -0.93380876, - 1.73662159, - 0.23926283, - -0.69068641, - 1.17829179, - -0.16332134, - -0.5112194, - -0.43442261, - 0.34948852, - 1.11002685, - 0.42445302, - 0.68379355, - -0.12877628, - 0.59561974, - 0.67230016, - 1.67895815, - 1.51053172, - 1.14415702, - 1.00682996, - 1.09882483, - 0.28820149, - -0.75250142, - -0.66453929, - -0.0991988, - 0.2907921, - 0.04077525, - -0.44036405, - 0.44894708, - 0.68646345, - 0.03986746, - 0.50061203, - 1.18904715, - 0.36231722, - -0.16347099, - 0.35343108, - 1.15870795, - 0.5973369, - 1.50731862, - 0.69983246, - 1.50854043, - 0.97489667, - 0.25267479, - 0.26369507, - 1.59775053, - 1.56383915, - 0.1721522, - -0.96935772, - 1.47191825, - 0.79858327, - 0.69071774, - -0.17667758, - 0.61438524, - 0.99424152, - -0.23558854, - -0.27873225, - 0.16615446, - 0.02589567, - -0.38007309, - 0.24960815, - 1.17127086, - 1.96577002, - 0.83224965, - 0.9386008, - -0.16018964, - 0.25239747, - -0.09936852, - -0.20376629, - 1.39291948, - -0.2083358, - 0.51435573, - 1.38304537, - 0.23272616, - -0.15257289, - 0.77293718, - 0.33558962, - -0.99534345, - -0.03472376, - 0.07169895, - 1.62726823, - 0.08074445, - 0.38765492, - 0.7844363, - 0.89340893, - 0.28605836, - 0.83632054, - 0.54210362, - -0.55168005, - 0.91756515, - 0.16982656, - -0.36404392, - 1.0011945, - -0.2659181, - 0.31691263, + 0.96155575, + 0.12354058, + 1.37843394, + -0.64490769, + 0.56153344, + 0.56949613, + 1.48805567, + 0.86943208, + 1.28672225, + 1.11818376, + 0.99202034, + 0.76951474, + -0.80478442, + 0.25773324, + 1.71676489, + 0.47726659, + 1.41458346, + 0.91086253, + 0.36146025, + -0.05769771, + 0.59441151, + 1.86157896, + 0.84378345, + -0.08420864, + 0.96264496, + 1.18466616, + -0.08301892, + 0.69627504, + 0.52721449, + 0.69209438, + 1.22462443, + 1.3527388, + -0.96055421, + 0.95149676, + 0.64798478, + 0.74435388, + 0.67474009, + 0.16073648, + 0.6660611, + 0.62209787, + 0.72155144, + 0.93515369, + 0.62462479, + 0.93689047, + 0.66827631, + 1.22635052, + 1.40935767, + -0.12069125, + 0.29298952, + 0.32982792, + 0.34051724, + -0.24309767, + -0.29414865, + 0.37741563, + -0.29052578, + -0.36946086, + 0.70680747, + 0.282181, + 0.91349846, + 1.61566788, + 1.35593426, + 0.22115063, + -0.1904221, + -0.53213195, + 0.25832089, + 1.272686, + -0.34608613, + 0.7405656, + 0.86422801, + 0.4133873, + 0.02002896, + 1.52716732, + 1.11700266, + 0.01264456, + 1.28187506, + -0.14362652, + 0.02715474, + 2.23628253, + 1.04628416, + 1.02788034, + -0.16128468, + 0.5390423, + 1.34396764, + 1.53622121, + 0.9573948, + 2.14416466, + 0.71949634, + 0.78759446, + 1.39579606, + -0.03169538, + 1.47822573, + 0.01943775, + 1.31197471, + 0.47475589, + 1.35694485, + 1.10879123, + 1.68316079, + -0.3677566, + 0.01876864, + 1.10056079, + -0.73678253, + 0.41594072, + 1.11037177, + 0.89196759, + 0.06043841, + 0.11064186, + 1.03143181, + 0.64732913, + 0.12394593, + -0.26254331, + 0.36509784, + 0.7893701, + -0.40537559, + -0.78701105, + 1.54832725, + 0.41225113, + -0.72722386, + 1.51618746, + -0.34325617, + -0.37255259, + 0.3190562, + 0.30245186, + 0.61122402, + 0.416849, + 0.56464343, + 1.19103522, + 0.74743207, + 0.8641633, + 1.71882087, + 1.07100044, + 1.33798974, + 1.48935406, + 1.16031688, + 0.37036675, + -0.55450346, + -0.70492211, + -0.03610812, + 0.19853608, + 0.08956581, + -0.2381033, + 0.83095776, + 0.76563511, + -0.07106511, + 0.56329078, + 1.15931015, + 0.04533864, + -0.21308172, + 0.274915, + 1.57040134, + 0.78145705, + 1.92064424, + 0.28787777, + 1.46101294, + 1.08120648, + 0.18144586, + 1.34311923, + 1.46516965, + 1.51600453, + -0.09295297, + -0.85488172, + 1.40737347, + 0.73784988, + 0.06846084, + -0.1996966, + 0.3227957, + 1.20572832, + -0.31104901, + -0.01546281, + 0.0172552, + 0.22116851, + 0.0931431, + -0.22007601, + 0.03757681, + 1.5345242, + 0.88473831, + 1.22550042, + 0.10627061, + 0.23637588, + -0.11944451, + -0.21066274, + 1.33649839, + -1.0085795, + 0.48370976, + 1.38720187, + -0.20363393, + -0.01723264, + 0.33616712, + 0.57274987, + -0.70025566, + -0.05105602, + -0.37324572, + 2.01686044, + 0.15311289, + -0.32750163, + 0.56208031, + 0.90689661, + 0.43908833, + 0.35270695, + 0.37425973, + -0.80283989, ] ), }