Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend CRP to Pitman Yor and make hyper-parameters Loom compatible #256

Open
wants to merge 3 commits into
base: zane/hardcode-version
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/crosscat/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def view_logpdf(view, rowid, targets, constraints):
Nk = view.Nk()
N_rows = len(view.Zr())
K = view.crp.clusters[0].gibbs_tables(-1)
lp_crp = [Crp.calc_predictive_logp(k, N_rows, Nk, view.alpha()) for k in K]
lp_crp = [Crp.calc_predictive_logp(
k, N_rows, Nk, view.crp.hypers["alpha"], view.crp.hypers["discount"])
for k in K]
lp_constraints = [_logpdf_row(view, constraints, k) for k in K]
if all(np.isinf(lp_constraints)):
raise ValueError('Zero density constraints: %s' % (constraints,))
Expand All @@ -89,7 +91,9 @@ def view_simulate(view, rowid, targets, constraints, N):
Nk = view.Nk()
N_rows = len(view.Zr())
K = view.crp.clusters[0].gibbs_tables(-1)
lp_crp = [Crp.calc_predictive_logp(k, N_rows, Nk, view.alpha()) for k in K]
lp_crp = [Crp.calc_predictive_logp(
k, N_rows, Nk, view.crp.hypers['alpha'],
view.crp.hypers['discount']) for k in K]
lp_constraints = [_logpdf_row(view, constraints, k) for k in K]
if all(np.isinf(lp_constraints)):
raise ValueError('Zero density constraints: %s' % (constraints,))
Expand Down
136 changes: 30 additions & 106 deletions src/crosscat/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class State(CGpm):

def __init__(
self, X, outputs=None, inputs=None, cctypes=None,
distargs=None, Zv=None, Zrv=None, alpha=None, view_alphas=None,
distargs=None, Zv=None, Zrv=None, structure_hypers=None, view_structure_hypers=None,
hypers=None, Cd=None, Ci=None, Rd=None, Ri=None, diagnostics=None,
loom_path=None, rng=None):
""" Construct a State.
Expand Down Expand Up @@ -136,13 +136,12 @@ def __init__(
self.Rd = {}
self.Ri = {}
# Prepare the GPM for the column crp.
crp_alpha = None if alpha is None else {'alpha': alpha}
self.crp_id = 5**8
self.crp = Dim(
outputs=[self.crp_id],
inputs=[-1],
cctype='crp',
hypers=crp_alpha,
hypers=structure_hypers,
rng=self.rng
)
self.crp.transition_hyper_grids([1]*self.n_cols())
Expand All @@ -158,13 +157,19 @@ def __init__(
# Independence constraints are specified; simulate
# the non-exchangeable version of the constrained crp.
Zv = gu.simulate_crp_constrained(
self.n_cols(), self.alpha(), self.Cd, self.Ci,
self.n_cols(),
self.crp.hypers['alpha'],
self.crp.hypers['discount'],
self.Cd, self.Ci,
self.Rd, self.Ri, rng=self.rng)
else:
# Only dependence constraints are specified; simulate
# the exchangeable version of the constrained crp.
Zv = gu.simulate_crp_constrained_dependent(
self.n_cols(), self.alpha(), self.Cd, self.rng)
self.n_cols(),
self.crp.hypers['alpha'],
self.crp.hypers['discount'],
self.Cd, self.rng)
# Incorporate hte the data.
for c, z in zip(self.outputs, Zv):
self.crp.incorporate(c, {self.crp_id: z}, {-1:0})
Expand All @@ -184,7 +189,7 @@ def __init__(
cctypes = cctypes or [None] * len(self.outputs)
distargs = distargs or [None] * len(self.outputs)
hypers = hypers or [None] * len(self.outputs)
view_alphas = view_alphas or {}
view_structure_hypers = view_structure_hypers or {}

# If the user specifies Zrv, then the keys of Zrv must match the views
# which are values in Zv.
Expand All @@ -206,7 +211,7 @@ def __init__(
outputs=[self.crp_id_view+v] + v_outputs,
inputs=None,
Zr=Zrv.get(v, None),
alpha=view_alphas.get(v, None),
structure_hypers=view_structure_hypers.get(v, None),
cctypes=v_cctypes,
distargs=v_distargs,
hypers=v_hypers,
Expand Down Expand Up @@ -477,8 +482,8 @@ def logpdf_score_crp(self):
if not self.Cd:
return self.crp.logpdf_score()
# Compute constrained CRP probability.
Zv, alpha = self.Zv(), self.alpha()
return gu.logp_crp_constrained_dependent(Zv, alpha, self.Cd)
Zv, alpha, discount = self.Zv(), self.crp.hypers['alpha'], self.crp.hypers['discount']
return gu.logp_crp_constrained_dependent(Zv, alpha, discount, self.Cd)

def logpdf_likelihood(self):
logp_views = sum(v.logpdf_likelihood() for v in self.views.values())
Expand Down Expand Up @@ -903,10 +908,10 @@ def transition(

# Default order of crosscat kernels is important.
_kernel_lookup = OrderedDict([
('alpha',
lambda : self.transition_crp_alpha()),
('view_alphas',
lambda : self.transition_view_alphas(views=views, cols=cols)),
('structure_hypers',
lambda : self.transition_structure_hypers()),
('view_structure_hypers',
lambda : self.transition_view_structure_hypers(views=views, cols=cols)),
('column_params',
lambda : self.transition_dim_params(cols=cols)),
('column_hypers',
Expand All @@ -928,16 +933,16 @@ def transition(
self._transition_generic(
kernel_funcs, N=N, S=S, progress=progress, checkpoint=checkpoint)

def transition_crp_alpha(self):
def transition_structure_hypers(self):
self.crp.transition_hypers()
self._increment_iterations('alpha')
self._increment_iterations('structure_hypers')

def transition_view_alphas(self, views=None, cols=None):
def transition_view_structure_hypers(self, views=None, cols=None):
if views is None:
views = set(self.Zv(col) for col in cols) if cols else self.views
for v in views:
self.views[v].transition_crp_alpha()
self._increment_iterations('view_alphas')
self.views[v].transition_crp_hypers()
self._increment_iterations('view_structure_hypers')

def transition_dim_params(self, cols=None):
if cols is None:
Expand Down Expand Up @@ -977,40 +982,6 @@ def transition_dims(self, cols=None, m=1):
self._gibbs_transition_dim(c, m)
self._increment_iterations('columns')

def transition_lovecat(
self, N=None, S=None, kernels=None, rowids=None, cols=None,
progress=None, checkpoint=None):
# This function in its entirely is one major hack.
# XXX TODO: Temporarily convert all cctypes into normal/categorical.
if any(c not in ['normal','categorical'] for c in self.cctypes()):
raise ValueError(
'Only normal and categorical cgpms supported by lovecat: %s'
% (self.cctypes()))
if any(d.is_conditional() for d in self.dims()):
raise ValueError('Cannot transition lovecat with conditional dims.')
from cgpm.crosscat import lovecat
seed = self.rng.randint(1, 2**31-1)
lovecat.transition(
self, N=N, S=S, kernels=kernels, rowids=rowids, cols=cols,
seed=seed, progress=progress, checkpoint=checkpoint)
# Transition the non-structural parameters.
self.transition_dim_hypers()
self.transition_crp_alpha()
self.transition_view_alphas()

def transition_loom(
self, N=None, S=None, kernels=None, progress=None,
checkpoint=None, seed=None):
from cgpm.crosscat import loomcat
loomcat.transition(
self, N=N, S=S, kernels=kernels, progress=progress,
checkpoint=checkpoint, seed=seed)
# Transition the non-structural parameters.
self.transition(
N=1,
kernels=['column_hypers', 'column_params', 'alpha', 'view_alphas'
])

def transition_foreign(
self, N=None, S=None, cols=None, progress=None):
# Build foreign kernels.
Expand Down Expand Up @@ -1078,7 +1049,7 @@ def _increment_iterations(self, kernel, N=1):

def _increment_diagnostics(self):
self.diagnostics['logscore'].append(self.logpdf_score())
self.diagnostics['column_crp_alpha'].append(self.alpha())
self.diagnostics['structure_hypers'].append(self.crp.hypers)
self.diagnostics['column_partition'].append(list(self.Zv().items()))

def _progress(self, percentage):
Expand Down Expand Up @@ -1120,53 +1091,9 @@ def set_outputs(self, outputs):
self.outputs = list(outputs)
self.outputs_set = set(self.outputs)

# --------------------------------------------------------------------------
# Plotting

def plot(self):
"""Plots observation histogram and posterior distirbution of Dims."""
import matplotlib.pyplot as plt
from cgpm.utils import plots as pu
layout = pu.get_state_plot_layout(self.n_cols())
fig = plt.figure(
num=None,
figsize=(layout['plot_inches_y'], layout['plot_inches_x']),
dpi=75,
facecolor='w',
edgecolor='k',
frameon=False,
tight_layout=True
)
# Do not plot more than 6 by 4.
if self.n_cols() > 24:
return
fig.clear()
for i, dim in enumerate(self.dims()):
index = dim.index
ax = fig.add_subplot(layout['plots_x'], layout['plots_y'], i+1)
dim.plot_dist(self.X[dim.index], ax=ax)
ax.text(
1,1, "K: %i " % len(dim.clusters),
transform=ax.transAxes,
fontsize=12,
weight='bold',
color='blue',
horizontalalignment='right',
verticalalignment='top'
)
ax.grid()
# XXX TODO: Write png to disk rather than slow matplotlib animation.
# plt.draw()
# plt.ion()
# plt.show()
return fig

# --------------------------------------------------------------------------
# Internal CRP utils.

def alpha(self):
return self.crp.hypers['alpha']

def Nv(self, v=None):
Nv = self.crp.clusters[0].counts
return Nv[v] if v is not None else Nv.copy()
Expand Down Expand Up @@ -1332,7 +1259,7 @@ def hypothetical(self, rowid):
def _check_partitions(self):
if not cu.check_env_debug():
return
assert self.alpha() > 0.
assert self.crp.hypers['alpha'] > 0.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want to also check that the discount is between 0 and 1 (optionally, you could change the alpha check to the more permissive > -discount, even though given cgpm and loom's grid, we'll never enter the negative regime.

assert all(len(self.views[v].dims) == self.crp.clusters[0].counts[v]
for v in self.views)
# All outputs should be in the dataset keys.
Expand All @@ -1352,8 +1279,6 @@ def _check_partitions(self):
assert vu.validate_crp_constrained_partition(
[self.Zv(c) for c in self.outputs], self.Cd, self.Ci,
self.Rd, self.Ri)

# --------------------------------------------------------------------------
# Serialize

def to_metadata(self):
Expand All @@ -1364,7 +1289,7 @@ def to_metadata(self):
metadata['outputs'] = self.outputs

# View partition data.
metadata['alpha'] = self.alpha()
metadata['structure_hypers'] = self.crp.hypers
metadata['Zv'] = list(self.Zv().items())

# Column data.
Expand All @@ -1387,11 +1312,11 @@ def to_metadata(self):

# View data.
metadata['Zrv'] = []
metadata['view_alphas'] = []
metadata['view_structure_hypers'] = []
for v, view in self.views.items():
rowids = sorted(view.Zr())
metadata['Zrv'].append((v, [view.Zr(i) for i in rowids]))
metadata['view_alphas'].append((v, view.alpha()))
metadata['view_structure_hypers'].append((v, view.crp.hypers))

# Diagnostic data.
metadata['diagnostics'] = self.diagnostics
Expand Down Expand Up @@ -1424,10 +1349,10 @@ def from_metadata(cls, metadata, rng=None):
outputs=metadata.get('outputs', None),
cctypes=metadata.get('cctypes', None),
distargs=metadata.get('distargs', None),
alpha=metadata.get('alpha', None),
structure_hypers=to_dict(metadata.get('structure_hypers', None)),
Zv=to_dict(metadata.get('Zv', None)),
Zrv=to_dict(metadata.get('Zrv', None)),
view_alphas=to_dict(metadata.get('view_alphas', None)),
view_structure_hypers=to_dict(metadata.get('view_structure_hypers', None)),
hypers=metadata.get('hypers', None),
Cd=metadata.get('Cd', None),
Ci=metadata.get('Ci', None),
Expand All @@ -1452,4 +1377,3 @@ def from_pickle(cls, fileptr, rng=None):
else:
metadata = pickle.load(fileptr)
return cls.from_metadata(metadata, rng=rng)

64 changes: 46 additions & 18 deletions src/mixtures/dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,36 +151,64 @@ def transition_params(self):

def transition_hypers(self):
"""Transitions the hyperparameters of each cluster."""
hypers = list(self.hypers.keys())
self.rng.shuffle(hypers)
# For each hyper.
for hyper in hypers:
logps = []
# For each grid point.
for grid_value in self.hyper_grids[hyper]:
# Compute the probability of the grid point.
self.hypers[hyper] = grid_value
logp_k = 0
for k in self.clusters:
self.clusters[k].set_hypers(self.hypers)
logp_k += self.clusters[k].logpdf_score()
logps.append(logp_k)
# Sample a new hyperparameter from the grid.
index = gu.log_pflip(logps, rng=self.rng)
self.hypers[hyper] = self.hyper_grids[hyper][index]
if self.model.name() in ['crp']:
self.transition_hypers_joint()
else:
hypers = list(self.hypers.keys())
self.rng.shuffle(hypers)
# For each hyper.
for hyper in hypers:
logps = []
# For each grid point.
for grid_value in self.hyper_grids[hyper]:
# Compute the probability of the grid point.
self.hypers[hyper] = grid_value
logp_k = 0
for k in self.clusters:
self.clusters[k].set_hypers(self.hypers)
logp_k += self.clusters[k].logpdf_score()
logps.append(logp_k)
# Sample a new hyperparameter from the grid.
index = gu.log_pflip(logps, rng=self.rng)
self.hypers[hyper] = self.hyper_grids[hyper][index]
# Set the hyperparameters in each cluster.
for k in self.clusters:
self.clusters[k].set_hypers(self.hypers)
self.aux_model = self.create_aux_model()

def transition_hypers_joint(self):
logps = []
for values in zip(*self.hyper_grids.values()):
for i, hyper in enumerate(self.hyper_grids):
self.hypers[hyper] = values[i]
logp_k = 0
for k in self.clusters:
self.clusters[k].set_hypers(self.hypers)
logp_k += self.clusters[k].logpdf_score()
logps.append(logp_k)

index = gu.log_pflip(logps, rng=self.rng)
for hyper, val in self.hyper_grids.items():
self.hypers[hyper] = val[index]

def transition_hyper_grids(self, X, n_grid=30):
"""Transitions hyperparameter grids using empirical Bayes."""
self.hyper_grids = self.model.construct_hyper_grids(
[x for x in X if not math.isnan(x)], n_grid=n_grid)
# Only transition the hypers if previously uninstantiated.
if self.cctype == "categorical":
hyper_list = [f'alpha_{i}' for i in range(int(self.distargs['k']))]
else:

hyper_list = list(self.hyper_grids.keys())

if not self.hypers:
for h in self.hyper_grids:
for h in hyper_list:
self.hypers[h] = self.rng.choice(self.hyper_grids[h])


# TODO when creating a pitman-yor model,
# initialize from the list of dicts hyper grid
self.aux_model = self.create_aux_model()

# --------------------------------------------------------------------------
Expand Down
Loading