Skip to content

Commit

Permalink
call to model accepts kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Oct 29, 2020
1 parent d9466b7 commit 555b374
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions mcx/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import ast
import functools
from typing import Callable, List, Union
from typing import Any, Callable, List, Tuple, Union

import jax
import jax.numpy as np
import numpy

import mcx.core as core
from mcx.distributions import Distribution
from mcx.predict import sample_forward

__all__ = ["model", "seed"]

Expand Down Expand Up @@ -177,15 +177,24 @@ def __init__(self, fn: Callable) -> None:
self.rng_key = jax.random.PRNGKey(53)
functools.update_wrapper(self, fn)

def __call__(self, *args) -> numpy.ndarray:
def __call__(self, *args, **kwargs) -> numpy.ndarray:
"""Return a sample from the prior predictive distribution. A different
value is returned at each all.
"""
_, self.rng_key = jax.random.split(self.rng_key)

forward_sampler, _, _, _ = core.compile_to_sampler(self.graph, self.namespace)
samples = forward_sampler(self.rng_key, *args)
print("Forward pass through the model")
# convert all numbers to arrays
arguments: Tuple[Any, ...] = ()
for arg in args:
try:
arguments += (np.atleast_1d(arg),)
except Exception:
arguments += (arg,)

prior_sampler, _, _, _ = core.compile_to_prior_sampler(
self.graph, self.namespace
)
samples = prior_sampler(self.rng_key, *arguments, **kwargs)
return numpy.asarray(samples).squeeze()

def __getitem__(self, name: str):
Expand Down

0 comments on commit 555b374

Please sign in to comment.