Skip to content

Commit

Permalink
Test RandomWalk change size and fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 9, 2022
1 parent 6c5e022 commit 0e159a6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def rv_op(cls, init_dist, innovation_dist, steps, size=None):

# If not explicit, size is determined by the shapes of the input distributions
if size is None:
size = at.broadcast_shape(init_dist, innovation_dist[..., 0])
size = at.broadcast_shape(init_dist, at.atleast_1d(innovation_dist)[..., 0])
innovation_size = tuple(size) + (steps,)

# Resize input distributions
Expand Down
13 changes: 13 additions & 0 deletions pymc/tests/distributions/test_timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,19 @@ def test_dists_not_registered_check(self):
):
RandomWalk("rw", init_dist=init_dist, innovation_dist=innovation, steps=5)

def test_change_size(self):
init_dist = Normal.dist()
innovation_dist = Normal.dist()

# size = 5
rw = RandomWalk.dist(init_dist=init_dist, innovation_dist=innovation_dist, shape=(5, 100))

new_rw = change_dist_size(rw, new_size=(7,))
assert tuple(new_rw.shape.eval()) == (7, 100)

new_rw = change_dist_size(rw, new_size=(4, 3), expand=True)
assert tuple(new_rw.shape.eval()) == (4, 3, 5, 100)


class TestGaussianRandomWalk:
def test_logp(self):
Expand Down

0 comments on commit 0e159a6

Please sign in to comment.