From 2f93d4df8261368e7731ef4417c383124f10b4fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 29 Oct 2020 12:47:53 +0100 Subject: [PATCH] s/core/compiler --- mcx/{core => compiler}/__init__.py | 0 mcx/{core => compiler}/compiler.py | 10 +++++----- mcx/{core => compiler}/graph.py | 2 +- mcx/{core => compiler}/nodes.py | 0 mcx/{core => compiler}/parser.py | 2 +- mcx/model.py | 21 +++++++++++---------- mcx/predict.py | 8 ++++---- mcx/sample.py | 2 +- 8 files changed, 23 insertions(+), 22 deletions(-) rename mcx/{core => compiler}/__init__.py (100%) rename mcx/{core => compiler}/compiler.py (98%) rename mcx/{core => compiler}/graph.py (99%) rename mcx/{core => compiler}/nodes.py (100%) rename mcx/{core => compiler}/parser.py (99%) diff --git a/mcx/core/__init__.py b/mcx/compiler/__init__.py similarity index 100% rename from mcx/core/__init__.py rename to mcx/compiler/__init__.py diff --git a/mcx/core/compiler.py b/mcx/compiler/compiler.py similarity index 98% rename from mcx/core/compiler.py rename to mcx/compiler/compiler.py index 77e485f6..20767453 100644 --- a/mcx/core/compiler.py +++ b/mcx/compiler/compiler.py @@ -6,8 +6,8 @@ import networkx as nx import mcx -from mcx.core.graph import GraphicalModel -from mcx.core.nodes import Argument, RandVar +from mcx.compiler.graph import GraphicalModel +from mcx.compiler.nodes import Argument, RandVar class Artifact(NamedTuple): @@ -335,7 +335,7 @@ def compile_to_sampler(graph, namespace) -> Artifact: returned_vars = [ ast.Name(id=node.name, ctx=ast.Load()) for node in ordered_nodes - if not isinstance(node, mcx.core.graph.Var) + if not isinstance(node, mcx.compiler.graph.Var) ] returned = ast.Return( @@ -426,7 +426,7 @@ def compile_to_prior_sampler(graph, namespace, jit=False) -> Artifact: returned_vars = [ ast.Name(id=node.name, ctx=ast.Load()) for node in ordered_nodes - if not isinstance(node, mcx.core.graph.Var) and node.is_returned + if not isinstance(node, mcx.compiler.graph.Var) and node.is_returned ] if len(returned_vars) == 1: returned = ast.Return(returned_vars[0]) @@ -576,7 +576,7 @@ def compile_to_posterior_sampler(graph, namespace, jit=False) -> Artifact: returned_vars = [ ast.Name(id=node.name, ctx=ast.Load()) for node in ordered_nodes - if not isinstance(node, mcx.core.graph.Var) and node.is_returned + if not isinstance(node, mcx.compiler.graph.Var) and node.is_returned ] if len(returned_vars) == 1: returned = ast.Return(returned_vars[0]) diff --git a/mcx/core/graph.py b/mcx/compiler/graph.py similarity index 99% rename from mcx/core/graph.py rename to mcx/compiler/graph.py index 878fd3d1..64660dbc 100644 --- a/mcx/core/graph.py +++ b/mcx/compiler/graph.py @@ -4,7 +4,7 @@ import astor import networkx as nx -from mcx.core.nodes import Argument, RandVar, Transformation, Var +from mcx.compiler.nodes import Argument, RandVar, Transformation, Var class GraphicalModel(nx.DiGraph): diff --git a/mcx/core/nodes.py b/mcx/compiler/nodes.py similarity index 100% rename from mcx/core/nodes.py rename to mcx/compiler/nodes.py diff --git a/mcx/core/parser.py b/mcx/compiler/parser.py similarity index 99% rename from mcx/core/parser.py rename to mcx/compiler/parser.py index 95fb6643..0e259680 100644 --- a/mcx/core/parser.py +++ b/mcx/compiler/parser.py @@ -7,7 +7,7 @@ import astor import mcx -from mcx.core.graph import GraphicalModel +from mcx.compiler.graph import GraphicalModel def parse_definition(model: Callable, namespace: Dict) -> GraphicalModel: diff --git a/mcx/model.py b/mcx/model.py index d59fab3e..b6773dda 100644 --- a/mcx/model.py +++ b/mcx/model.py @@ -6,8 +6,9 @@ import jax.numpy as np import numpy -import mcx.core as core +import mcx.compiler as compiler from mcx.distributions import Distribution +from mcx.predict import sample_forward __all__ = ["model", "seed"] @@ -173,7 +174,7 @@ class model(Distribution): def __init__(self, fn: Callable) -> None: self.model_fn = fn self.namespace = fn.__globals__ # type: ignore - self.graph = core.parse_definition(fn, self.namespace) + self.graph = compiler.parse_definition(fn, self.namespace) self.rng_key = jax.random.PRNGKey(53) functools.update_wrapper(self, fn) @@ -191,7 +192,7 @@ def __call__(self, *args, **kwargs) -> numpy.ndarray: except Exception: arguments += (arg,) - prior_sampler, _, _, _ = core.compile_to_prior_sampler( + prior_sampler, _, _, _ = compiler.compile_to_prior_sampler( self.graph, self.namespace ) samples = prior_sampler(self.rng_key, *arguments, **kwargs) @@ -254,19 +255,19 @@ def forward(self, **kwargs): @property def forward_src(self) -> str: - artifact = core.compile_to_sampler(self.graph, self.namespace) + artifact = compiler.compile_to_sampler(self.graph, self.namespace) return artifact.fn_source def sample(self, *args, sample_shape=(1000,)) -> jax.numpy.DeviceArray: """Return forward samples from the distribution.""" - sampler, _, _, _ = core.compile_to_sampler(self.graph, self.namespace) + sampler, _, _, _ = compiler.compile_to_sampler(self.graph, self.namespace) _, self.rng_key = jax.random.split(self.rng_key) samples = sampler(self.rng_key, sample_shape, *args) return samples def logpdf(self, *args, **kwargs) -> float: """Compute the value of the distribution's logpdf.""" - logpdf, _, _, _ = core.compile_to_logpdf(self.graph, self.namespace) + logpdf, _, _, _ = compiler.compile_to_logpdf(self.graph, self.namespace) return logpdf(*args, **kwargs) @property @@ -274,7 +275,7 @@ def logpdf_src(self) -> str: """Return the source code of the log-probability density funtion generated by the compiler. """ - artifact = core.compile_to_logpdf(self.graph, self.namespace) + artifact = compiler.compile_to_logpdf(self.graph, self.namespace) return artifact.fn_source @property @@ -282,7 +283,7 @@ def loglikelihoods_src(self) -> str: """Return the source code of the log-probability density funtion generated by the compiler. """ - artifact = core.compile_to_loglikelihoods(self.graph, self.namespace) + artifact = compiler.compile_to_loglikelihoods(self.graph, self.namespace) return artifact.fn_source @property @@ -290,7 +291,7 @@ def sampler_src(self) -> str: """Return the source code of the forward sampling funtion generated by the compiler. """ - artifact = core.compile_to_sampler(self.graph, self.namespace) + artifact = compiler.compile_to_sampler(self.graph, self.namespace) return artifact.fn_source @property @@ -298,7 +299,7 @@ def posterior_sampler_src(self) -> str: """Return the source code of the forward sampling funtion generated by the compiler. """ - artifact = core.compile_to_posterior_sampler(self.graph, self.namespace) + artifact = compiler.compile_to_posterior_sampler(self.graph, self.namespace) return artifact.fn_source @property diff --git a/mcx/predict.py b/mcx/predict.py index cdcf4e93..e70eebdf 100644 --- a/mcx/predict.py +++ b/mcx/predict.py @@ -4,7 +4,7 @@ import numpy from jax import numpy as np -import mcx.core as core +import mcx.compiler as compiler from mcx.trace import Trace __all__ = ["predict", "sample_forward"] @@ -55,7 +55,7 @@ def predict( class posterior_predictive: def __init__(self, rng_key: jax.random.PRNGKey, model, trace: Trace) -> None: """Initialize the posterior predictive sampler.""" - artifact = core.compile_to_posterior_sampler(model.graph, model.namespace) + artifact = compiler.compile_to_posterior_sampler(model.graph, model.namespace) sampler = jax.jit(artifact.compiled_fn) self.model = model @@ -176,7 +176,7 @@ def sample_one_chain(*args): class prior_predictive: def __init__(self, rng_key: jax.random.PRNGKey, model) -> None: - artifact = core.compile_to_prior_sampler(model.graph, model.namespace) + artifact = compiler.compile_to_prior_sampler(model.graph, model.namespace) sampler = jax.jit(artifact.compiled_fn) self.model = model @@ -325,7 +325,7 @@ def sample_forward( else: out_axes = (1,) * len(model.variables) - artifact = core.compile_to_sampler(model.graph, model.namespace) + artifact = compiler.compile_to_sampler(model.graph, model.namespace) sampler = jax.jit(artifact.compiled_fn) samples = jax.vmap(sampler, in_axes, out_axes)(*sampler_args) diff --git a/mcx/sample.py b/mcx/sample.py index 400145bc..b4a2c5af 100644 --- a/mcx/sample.py +++ b/mcx/sample.py @@ -8,7 +8,7 @@ import mcx from mcx import sample_forward -from mcx.core import compile_to_loglikelihoods, compile_to_logpdf +from mcx.compiler import compile_to_loglikelihoods, compile_to_logpdf from mcx.jax import ravel_pytree as mcx_ravel_pytree from mcx.trace import Trace