Skip to content

Commit

Permalink
s/core/compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Oct 29, 2020
1 parent 549effa commit 2f93d4d
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 22 deletions.
File renamed without changes.
10 changes: 5 additions & 5 deletions mcx/core/compiler.py → mcx/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import networkx as nx

import mcx
from mcx.core.graph import GraphicalModel
from mcx.core.nodes import Argument, RandVar
from mcx.compiler.graph import GraphicalModel
from mcx.compiler.nodes import Argument, RandVar


class Artifact(NamedTuple):
Expand Down Expand Up @@ -335,7 +335,7 @@ def compile_to_sampler(graph, namespace) -> Artifact:
returned_vars = [
ast.Name(id=node.name, ctx=ast.Load())
for node in ordered_nodes
if not isinstance(node, mcx.core.graph.Var)
if not isinstance(node, mcx.compiler.graph.Var)
]

returned = ast.Return(
Expand Down Expand Up @@ -426,7 +426,7 @@ def compile_to_prior_sampler(graph, namespace, jit=False) -> Artifact:
returned_vars = [
ast.Name(id=node.name, ctx=ast.Load())
for node in ordered_nodes
if not isinstance(node, mcx.core.graph.Var) and node.is_returned
if not isinstance(node, mcx.compiler.graph.Var) and node.is_returned
]
if len(returned_vars) == 1:
returned = ast.Return(returned_vars[0])
Expand Down Expand Up @@ -576,7 +576,7 @@ def compile_to_posterior_sampler(graph, namespace, jit=False) -> Artifact:
returned_vars = [
ast.Name(id=node.name, ctx=ast.Load())
for node in ordered_nodes
if not isinstance(node, mcx.core.graph.Var) and node.is_returned
if not isinstance(node, mcx.compiler.graph.Var) and node.is_returned
]
if len(returned_vars) == 1:
returned = ast.Return(returned_vars[0])
Expand Down
2 changes: 1 addition & 1 deletion mcx/core/graph.py → mcx/compiler/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import astor
import networkx as nx

from mcx.core.nodes import Argument, RandVar, Transformation, Var
from mcx.compiler.nodes import Argument, RandVar, Transformation, Var


class GraphicalModel(nx.DiGraph):
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion mcx/core/parser.py → mcx/compiler/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import astor

import mcx
from mcx.core.graph import GraphicalModel
from mcx.compiler.graph import GraphicalModel


def parse_definition(model: Callable, namespace: Dict) -> GraphicalModel:
Expand Down
21 changes: 11 additions & 10 deletions mcx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import jax.numpy as np
import numpy

import mcx.core as core
import mcx.compiler as compiler
from mcx.distributions import Distribution
from mcx.predict import sample_forward

__all__ = ["model", "seed"]

Expand Down Expand Up @@ -173,7 +174,7 @@ class model(Distribution):
def __init__(self, fn: Callable) -> None:
self.model_fn = fn
self.namespace = fn.__globals__ # type: ignore
self.graph = core.parse_definition(fn, self.namespace)
self.graph = compiler.parse_definition(fn, self.namespace)
self.rng_key = jax.random.PRNGKey(53)
functools.update_wrapper(self, fn)

Expand All @@ -191,7 +192,7 @@ def __call__(self, *args, **kwargs) -> numpy.ndarray:
except Exception:
arguments += (arg,)

prior_sampler, _, _, _ = core.compile_to_prior_sampler(
prior_sampler, _, _, _ = compiler.compile_to_prior_sampler(
self.graph, self.namespace
)
samples = prior_sampler(self.rng_key, *arguments, **kwargs)
Expand Down Expand Up @@ -254,51 +255,51 @@ def forward(self, **kwargs):

@property
def forward_src(self) -> str:
artifact = core.compile_to_sampler(self.graph, self.namespace)
artifact = compiler.compile_to_sampler(self.graph, self.namespace)
return artifact.fn_source

def sample(self, *args, sample_shape=(1000,)) -> jax.numpy.DeviceArray:
"""Return forward samples from the distribution."""
sampler, _, _, _ = core.compile_to_sampler(self.graph, self.namespace)
sampler, _, _, _ = compiler.compile_to_sampler(self.graph, self.namespace)
_, self.rng_key = jax.random.split(self.rng_key)
samples = sampler(self.rng_key, sample_shape, *args)
return samples

def logpdf(self, *args, **kwargs) -> float:
"""Compute the value of the distribution's logpdf."""
logpdf, _, _, _ = core.compile_to_logpdf(self.graph, self.namespace)
logpdf, _, _, _ = compiler.compile_to_logpdf(self.graph, self.namespace)
return logpdf(*args, **kwargs)

@property
def logpdf_src(self) -> str:
"""Return the source code of the log-probability density funtion
generated by the compiler.
"""
artifact = core.compile_to_logpdf(self.graph, self.namespace)
artifact = compiler.compile_to_logpdf(self.graph, self.namespace)
return artifact.fn_source

@property
def loglikelihoods_src(self) -> str:
"""Return the source code of the log-probability density funtion
generated by the compiler.
"""
artifact = core.compile_to_loglikelihoods(self.graph, self.namespace)
artifact = compiler.compile_to_loglikelihoods(self.graph, self.namespace)
return artifact.fn_source

@property
def sampler_src(self) -> str:
"""Return the source code of the forward sampling funtion
generated by the compiler.
"""
artifact = core.compile_to_sampler(self.graph, self.namespace)
artifact = compiler.compile_to_sampler(self.graph, self.namespace)
return artifact.fn_source

@property
def posterior_sampler_src(self) -> str:
"""Return the source code of the forward sampling funtion
generated by the compiler.
"""
artifact = core.compile_to_posterior_sampler(self.graph, self.namespace)
artifact = compiler.compile_to_posterior_sampler(self.graph, self.namespace)
return artifact.fn_source

@property
Expand Down
8 changes: 4 additions & 4 deletions mcx/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy
from jax import numpy as np

import mcx.core as core
import mcx.compiler as compiler
from mcx.trace import Trace

__all__ = ["predict", "sample_forward"]
Expand Down Expand Up @@ -55,7 +55,7 @@ def predict(
class posterior_predictive:
def __init__(self, rng_key: jax.random.PRNGKey, model, trace: Trace) -> None:
"""Initialize the posterior predictive sampler."""
artifact = core.compile_to_posterior_sampler(model.graph, model.namespace)
artifact = compiler.compile_to_posterior_sampler(model.graph, model.namespace)
sampler = jax.jit(artifact.compiled_fn)

self.model = model
Expand Down Expand Up @@ -176,7 +176,7 @@ def sample_one_chain(*args):

class prior_predictive:
def __init__(self, rng_key: jax.random.PRNGKey, model) -> None:
artifact = core.compile_to_prior_sampler(model.graph, model.namespace)
artifact = compiler.compile_to_prior_sampler(model.graph, model.namespace)
sampler = jax.jit(artifact.compiled_fn)

self.model = model
Expand Down Expand Up @@ -325,7 +325,7 @@ def sample_forward(
else:
out_axes = (1,) * len(model.variables)

artifact = core.compile_to_sampler(model.graph, model.namespace)
artifact = compiler.compile_to_sampler(model.graph, model.namespace)
sampler = jax.jit(artifact.compiled_fn)
samples = jax.vmap(sampler, in_axes, out_axes)(*sampler_args)

Expand Down
2 changes: 1 addition & 1 deletion mcx/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import mcx
from mcx import sample_forward
from mcx.core import compile_to_loglikelihoods, compile_to_logpdf
from mcx.compiler import compile_to_loglikelihoods, compile_to_logpdf
from mcx.jax import ravel_pytree as mcx_ravel_pytree
from mcx.trace import Trace

Expand Down

0 comments on commit 2f93d4d

Please sign in to comment.