From 5fb714ffdcb51ab004af1b02ccd9281f348c0ddb Mon Sep 17 00:00:00 2001 From: Schaechtle Date: Mon, 8 Jul 2024 12:58:03 -0400 Subject: [PATCH 1/3] feat: Extend CRP to Pitman-Yor Process (co-authored with Joao) This is taking code from Joao's branch - and modifies it so all tests pass. --- src/crosscat/sampling.py | 8 +- src/crosscat/state.py | 136 ++++++++-------------------------- src/mixtures/dim.py | 64 +++++++++++----- src/mixtures/view.py | 24 +++--- src/primitives/categorical.py | 41 ++++++---- src/primitives/crp.py | 71 +++++++++++------- src/primitives/normal.py | 41 +++++----- src/utils/general.py | 72 +++++++++++------- src/utils/grid.py | 98 ++++++++++++++++++++++++ 9 files changed, 333 insertions(+), 222 deletions(-) create mode 100644 src/utils/grid.py diff --git a/src/crosscat/sampling.py b/src/crosscat/sampling.py index 93824b47..34bb9b2e 100644 --- a/src/crosscat/sampling.py +++ b/src/crosscat/sampling.py @@ -74,7 +74,9 @@ def view_logpdf(view, rowid, targets, constraints): Nk = view.Nk() N_rows = len(view.Zr()) K = view.crp.clusters[0].gibbs_tables(-1) - lp_crp = [Crp.calc_predictive_logp(k, N_rows, Nk, view.alpha()) for k in K] + lp_crp = [Crp.calc_predictive_logp( + k, N_rows, Nk, view.crp.hypers["alpha"], view.crp.hypers["discount"]) + for k in K] lp_constraints = [_logpdf_row(view, constraints, k) for k in K] if all(np.isinf(lp_constraints)): raise ValueError('Zero density constraints: %s' % (constraints,)) @@ -89,7 +91,9 @@ def view_simulate(view, rowid, targets, constraints, N): Nk = view.Nk() N_rows = len(view.Zr()) K = view.crp.clusters[0].gibbs_tables(-1) - lp_crp = [Crp.calc_predictive_logp(k, N_rows, Nk, view.alpha()) for k in K] + lp_crp = [Crp.calc_predictive_logp( + k, N_rows, Nk, view.crp.hypers['alpha'], + view.crp.hypers['discount']) for k in K] lp_constraints = [_logpdf_row(view, constraints, k) for k in K] if all(np.isinf(lp_constraints)): raise ValueError('Zero density constraints: %s' % (constraints,)) diff --git a/src/crosscat/state.py b/src/crosscat/state.py index 5590e719..2dac2a86 100644 --- a/src/crosscat/state.py +++ b/src/crosscat/state.py @@ -54,7 +54,7 @@ class State(CGpm): def __init__( self, X, outputs=None, inputs=None, cctypes=None, - distargs=None, Zv=None, Zrv=None, alpha=None, view_alphas=None, + distargs=None, Zv=None, Zrv=None, structure_hypers=None, view_structure_hypers=None, hypers=None, Cd=None, Ci=None, Rd=None, Ri=None, diagnostics=None, loom_path=None, rng=None): """ Construct a State. @@ -136,13 +136,12 @@ def __init__( self.Rd = {} self.Ri = {} # Prepare the GPM for the column crp. - crp_alpha = None if alpha is None else {'alpha': alpha} self.crp_id = 5**8 self.crp = Dim( outputs=[self.crp_id], inputs=[-1], cctype='crp', - hypers=crp_alpha, + hypers=structure_hypers, rng=self.rng ) self.crp.transition_hyper_grids([1]*self.n_cols()) @@ -158,13 +157,19 @@ def __init__( # Independence constraints are specified; simulate # the non-exchangeable version of the constrained crp. Zv = gu.simulate_crp_constrained( - self.n_cols(), self.alpha(), self.Cd, self.Ci, + self.n_cols(), + self.crp.hypers['alpha'], + self.crp.hypers['discount'], + self.Cd, self.Ci, self.Rd, self.Ri, rng=self.rng) else: # Only dependence constraints are specified; simulate # the exchangeable version of the constrained crp. Zv = gu.simulate_crp_constrained_dependent( - self.n_cols(), self.alpha(), self.Cd, self.rng) + self.n_cols(), + self.crp.hypers['alpha'], + self.crp.hypers['discount'], + self.Cd, self.rng) # Incorporate hte the data. for c, z in zip(self.outputs, Zv): self.crp.incorporate(c, {self.crp_id: z}, {-1:0}) @@ -184,7 +189,7 @@ def __init__( cctypes = cctypes or [None] * len(self.outputs) distargs = distargs or [None] * len(self.outputs) hypers = hypers or [None] * len(self.outputs) - view_alphas = view_alphas or {} + view_structure_hypers = view_structure_hypers or {} # If the user specifies Zrv, then the keys of Zrv must match the views # which are values in Zv. @@ -206,7 +211,7 @@ def __init__( outputs=[self.crp_id_view+v] + v_outputs, inputs=None, Zr=Zrv.get(v, None), - alpha=view_alphas.get(v, None), + structure_hypers=view_structure_hypers.get(v, None), cctypes=v_cctypes, distargs=v_distargs, hypers=v_hypers, @@ -477,8 +482,8 @@ def logpdf_score_crp(self): if not self.Cd: return self.crp.logpdf_score() # Compute constrained CRP probability. - Zv, alpha = self.Zv(), self.alpha() - return gu.logp_crp_constrained_dependent(Zv, alpha, self.Cd) + Zv, alpha, discount = self.Zv(), self.crp.hypers['alpha'], self.crp.hypers['discount'] + return gu.logp_crp_constrained_dependent(Zv, alpha, discount, self.Cd) def logpdf_likelihood(self): logp_views = sum(v.logpdf_likelihood() for v in self.views.values()) @@ -903,10 +908,10 @@ def transition( # Default order of crosscat kernels is important. _kernel_lookup = OrderedDict([ - ('alpha', - lambda : self.transition_crp_alpha()), - ('view_alphas', - lambda : self.transition_view_alphas(views=views, cols=cols)), + ('structure_hypers', + lambda : self.transition_structure_hypers()), + ('view_structure_hypers', + lambda : self.transition_view_structure_hypers(views=views, cols=cols)), ('column_params', lambda : self.transition_dim_params(cols=cols)), ('column_hypers', @@ -928,16 +933,16 @@ def transition( self._transition_generic( kernel_funcs, N=N, S=S, progress=progress, checkpoint=checkpoint) - def transition_crp_alpha(self): + def transition_structure_hypers(self): self.crp.transition_hypers() - self._increment_iterations('alpha') + self._increment_iterations('structure_hypers') - def transition_view_alphas(self, views=None, cols=None): + def transition_view_structure_hypers(self, views=None, cols=None): if views is None: views = set(self.Zv(col) for col in cols) if cols else self.views for v in views: - self.views[v].transition_crp_alpha() - self._increment_iterations('view_alphas') + self.views[v].transition_crp_hypers() + self._increment_iterations('view_structure_hypers') def transition_dim_params(self, cols=None): if cols is None: @@ -977,40 +982,6 @@ def transition_dims(self, cols=None, m=1): self._gibbs_transition_dim(c, m) self._increment_iterations('columns') - def transition_lovecat( - self, N=None, S=None, kernels=None, rowids=None, cols=None, - progress=None, checkpoint=None): - # This function in its entirely is one major hack. - # XXX TODO: Temporarily convert all cctypes into normal/categorical. - if any(c not in ['normal','categorical'] for c in self.cctypes()): - raise ValueError( - 'Only normal and categorical cgpms supported by lovecat: %s' - % (self.cctypes())) - if any(d.is_conditional() for d in self.dims()): - raise ValueError('Cannot transition lovecat with conditional dims.') - from cgpm.crosscat import lovecat - seed = self.rng.randint(1, 2**31-1) - lovecat.transition( - self, N=N, S=S, kernels=kernels, rowids=rowids, cols=cols, - seed=seed, progress=progress, checkpoint=checkpoint) - # Transition the non-structural parameters. - self.transition_dim_hypers() - self.transition_crp_alpha() - self.transition_view_alphas() - - def transition_loom( - self, N=None, S=None, kernels=None, progress=None, - checkpoint=None, seed=None): - from cgpm.crosscat import loomcat - loomcat.transition( - self, N=N, S=S, kernels=kernels, progress=progress, - checkpoint=checkpoint, seed=seed) - # Transition the non-structural parameters. - self.transition( - N=1, - kernels=['column_hypers', 'column_params', 'alpha', 'view_alphas' - ]) - def transition_foreign( self, N=None, S=None, cols=None, progress=None): # Build foreign kernels. @@ -1078,7 +1049,7 @@ def _increment_iterations(self, kernel, N=1): def _increment_diagnostics(self): self.diagnostics['logscore'].append(self.logpdf_score()) - self.diagnostics['column_crp_alpha'].append(self.alpha()) + self.diagnostics['structure_hypers'].append(self.crp.hypers) self.diagnostics['column_partition'].append(list(self.Zv().items())) def _progress(self, percentage): @@ -1120,53 +1091,9 @@ def set_outputs(self, outputs): self.outputs = list(outputs) self.outputs_set = set(self.outputs) - # -------------------------------------------------------------------------- - # Plotting - - def plot(self): - """Plots observation histogram and posterior distirbution of Dims.""" - import matplotlib.pyplot as plt - from cgpm.utils import plots as pu - layout = pu.get_state_plot_layout(self.n_cols()) - fig = plt.figure( - num=None, - figsize=(layout['plot_inches_y'], layout['plot_inches_x']), - dpi=75, - facecolor='w', - edgecolor='k', - frameon=False, - tight_layout=True - ) - # Do not plot more than 6 by 4. - if self.n_cols() > 24: - return - fig.clear() - for i, dim in enumerate(self.dims()): - index = dim.index - ax = fig.add_subplot(layout['plots_x'], layout['plots_y'], i+1) - dim.plot_dist(self.X[dim.index], ax=ax) - ax.text( - 1,1, "K: %i " % len(dim.clusters), - transform=ax.transAxes, - fontsize=12, - weight='bold', - color='blue', - horizontalalignment='right', - verticalalignment='top' - ) - ax.grid() - # XXX TODO: Write png to disk rather than slow matplotlib animation. - # plt.draw() - # plt.ion() - # plt.show() - return fig - # -------------------------------------------------------------------------- # Internal CRP utils. - def alpha(self): - return self.crp.hypers['alpha'] - def Nv(self, v=None): Nv = self.crp.clusters[0].counts return Nv[v] if v is not None else Nv.copy() @@ -1332,7 +1259,7 @@ def hypothetical(self, rowid): def _check_partitions(self): if not cu.check_env_debug(): return - assert self.alpha() > 0. + assert self.crp.hypers['alpha'] > 0. assert all(len(self.views[v].dims) == self.crp.clusters[0].counts[v] for v in self.views) # All outputs should be in the dataset keys. @@ -1352,8 +1279,6 @@ def _check_partitions(self): assert vu.validate_crp_constrained_partition( [self.Zv(c) for c in self.outputs], self.Cd, self.Ci, self.Rd, self.Ri) - - # -------------------------------------------------------------------------- # Serialize def to_metadata(self): @@ -1364,7 +1289,7 @@ def to_metadata(self): metadata['outputs'] = self.outputs # View partition data. - metadata['alpha'] = self.alpha() + metadata['structure_hypers'] = self.crp.hypers metadata['Zv'] = list(self.Zv().items()) # Column data. @@ -1387,11 +1312,11 @@ def to_metadata(self): # View data. metadata['Zrv'] = [] - metadata['view_alphas'] = [] + metadata['view_structure_hypers'] = [] for v, view in self.views.items(): rowids = sorted(view.Zr()) metadata['Zrv'].append((v, [view.Zr(i) for i in rowids])) - metadata['view_alphas'].append((v, view.alpha())) + metadata['view_structure_hypers'].append((v, view.crp.hypers)) # Diagnostic data. metadata['diagnostics'] = self.diagnostics @@ -1424,10 +1349,10 @@ def from_metadata(cls, metadata, rng=None): outputs=metadata.get('outputs', None), cctypes=metadata.get('cctypes', None), distargs=metadata.get('distargs', None), - alpha=metadata.get('alpha', None), + structure_hypers=to_dict(metadata.get('structure_hypers', None)), Zv=to_dict(metadata.get('Zv', None)), Zrv=to_dict(metadata.get('Zrv', None)), - view_alphas=to_dict(metadata.get('view_alphas', None)), + view_structure_hypers=to_dict(metadata.get('view_structure_hypers', None)), hypers=metadata.get('hypers', None), Cd=metadata.get('Cd', None), Ci=metadata.get('Ci', None), @@ -1452,4 +1377,3 @@ def from_pickle(cls, fileptr, rng=None): else: metadata = pickle.load(fileptr) return cls.from_metadata(metadata, rng=rng) - diff --git a/src/mixtures/dim.py b/src/mixtures/dim.py index a62c8d98..25c60440 100644 --- a/src/mixtures/dim.py +++ b/src/mixtures/dim.py @@ -151,36 +151,64 @@ def transition_params(self): def transition_hypers(self): """Transitions the hyperparameters of each cluster.""" - hypers = list(self.hypers.keys()) - self.rng.shuffle(hypers) - # For each hyper. - for hyper in hypers: - logps = [] - # For each grid point. - for grid_value in self.hyper_grids[hyper]: - # Compute the probability of the grid point. - self.hypers[hyper] = grid_value - logp_k = 0 - for k in self.clusters: - self.clusters[k].set_hypers(self.hypers) - logp_k += self.clusters[k].logpdf_score() - logps.append(logp_k) - # Sample a new hyperparameter from the grid. - index = gu.log_pflip(logps, rng=self.rng) - self.hypers[hyper] = self.hyper_grids[hyper][index] + if self.model.name() in ['crp']: + self.transition_hypers_joint() + else: + hypers = list(self.hypers.keys()) + self.rng.shuffle(hypers) + # For each hyper. + for hyper in hypers: + logps = [] + # For each grid point. + for grid_value in self.hyper_grids[hyper]: + # Compute the probability of the grid point. + self.hypers[hyper] = grid_value + logp_k = 0 + for k in self.clusters: + self.clusters[k].set_hypers(self.hypers) + logp_k += self.clusters[k].logpdf_score() + logps.append(logp_k) + # Sample a new hyperparameter from the grid. + index = gu.log_pflip(logps, rng=self.rng) + self.hypers[hyper] = self.hyper_grids[hyper][index] # Set the hyperparameters in each cluster. for k in self.clusters: self.clusters[k].set_hypers(self.hypers) self.aux_model = self.create_aux_model() + def transition_hypers_joint(self): + logps = [] + for values in zip(*self.hyper_grids.values()): + for i, hyper in enumerate(self.hyper_grids): + self.hypers[hyper] = values[i] + logp_k = 0 + for k in self.clusters: + self.clusters[k].set_hypers(self.hypers) + logp_k += self.clusters[k].logpdf_score() + logps.append(logp_k) + + index = gu.log_pflip(logps, rng=self.rng) + for hyper, val in self.hyper_grids.items(): + self.hypers[hyper] = val[index] + def transition_hyper_grids(self, X, n_grid=30): """Transitions hyperparameter grids using empirical Bayes.""" self.hyper_grids = self.model.construct_hyper_grids( [x for x in X if not math.isnan(x)], n_grid=n_grid) # Only transition the hypers if previously uninstantiated. + if self.cctype == "categorical": + hyper_list = [f'alpha_{i}' for i in range(int(self.distargs['k']))] + else: + + hyper_list = list(self.hyper_grids.keys()) + if not self.hypers: - for h in self.hyper_grids: + for h in hyper_list: self.hypers[h] = self.rng.choice(self.hyper_grids[h]) + + + # TODO when creating a pitman-yor model, + # initialize from the list of dicts hyper grid self.aux_model = self.create_aux_model() # -------------------------------------------------------------------------- diff --git a/src/mixtures/view.py b/src/mixtures/view.py index ab8f688e..923ea6f2 100644 --- a/src/mixtures/view.py +++ b/src/mixtures/view.py @@ -36,7 +36,7 @@ class View(CGpm): """CGpm represnting a multivariate Dirichlet process mixture of CGpms.""" def __init__( - self, X, outputs=None, inputs=None, alpha=None, + self, X, outputs=None, inputs=None, structure_hypers=None, cctypes=None, distargs=None, hypers=None, Zr=None, rng=None): """View constructor provides a convenience method for bulk incorporate and unincorporate by specifying the data and optional row partition. @@ -53,8 +53,8 @@ def __init__( output variables. inputs : list Currently disabled. - alpha : float, optional. - Concentration parameter for row CRP. + structure_hypers : dict[str, float], optional. + Concentration parameters for row pitman-yor. cctypes : list, optional. A `len(outputs[1:])` list of cctypes, see `utils.config` for names. distargs : list, optional. @@ -95,7 +95,7 @@ def __init__( outputs=[self.outputs[0]], inputs=[-1], cctype='crp', - hypers=None if alpha is None else {'alpha': alpha}, + hypers=structure_hypers, rng=self.rng ) n_rows = len(self.X[list(self.X.keys())[0]]) @@ -228,11 +228,10 @@ def update_cctype(self, col, cctype, distargs=None): def transition(self, N): for _ in range(N): self.transition_rows() - self.transition_crp_alpha() + self.transition_crp_hypers() self.transition_dim_hypers() - def transition_crp_alpha(self): - self.crp.transition_hypers() + def transition_crp_hypers(self): self.crp.transition_hypers() def transition_dim_hypers(self, cols=None): @@ -434,9 +433,6 @@ def _migrate_row(self, rowid, k): # -------------------------------------------------------------------------- # Internal crp utils. - def alpha(self): - return self.crp.hypers['alpha'] - def Nk(self, k=None): Nk = self.crp.clusters[0].counts return Nk[k] if k is not None else Nk @@ -534,7 +530,7 @@ def _check_partitions(self): if not cu.check_env_debug(): return # For debugging only. - assert self.alpha() > 0. + assert self.crp.hypers['alpha'] > 0. # Check that the number of dims actually assigned to the view # matches the count in Nv. Zr = self.Zr() @@ -563,8 +559,6 @@ def _check_partitions(self): data = [[self.X[c][r] for c in cols] for r in rowids_k] rowids_nan = np.any(np.isnan(data), axis=1) if data else [] assert (dim.clusters[k].N + np.sum(rowids_nan) == Nk[k]) - - # -------------------------------------------------------------------------- # Metadata def to_metadata(self): @@ -577,7 +571,7 @@ def to_metadata(self): # View partition data. rowids = sorted(self.Zr().keys()) metadata['Zr'] = [self.Zr(i) for i in rowids] - metadata['alpha'] = self.alpha() + metadata['structure_hypers'] = self.crp.hypers # Column data. metadata['cctypes'] = [] @@ -603,7 +597,7 @@ def from_metadata(cls, metadata, rng=None): metadata.get('X'), outputs=metadata.get('outputs', None), inputs=metadata.get('inputs', None), - alpha=metadata.get('alpha', None), + structure_hypers=metadata.get('structure_hypers', None), cctypes=metadata.get('cctypes', None), distargs=metadata.get('distargs', None), hypers=metadata.get('hypers', None), diff --git a/src/primitives/categorical.py b/src/primitives/categorical.py index 183e78a8..2f48e062 100644 --- a/src/primitives/categorical.py +++ b/src/primitives/categorical.py @@ -16,6 +16,7 @@ from builtins import range from math import log +from collections import defaultdict import numpy as np @@ -23,6 +24,7 @@ from cgpm.primitives.distribution import DistributionGpm from cgpm.utils import general as gu +from cgpm.utils.grid import dd_alpha class Categorical(DistributionGpm): @@ -32,7 +34,7 @@ class Categorical(DistributionGpm): k := distarg v ~ Symmetric-Dirichlet(alpha/k) x ~ Categorical(v) - http://www.cs.berkeley.edu/~stephentu/writeups/dirichlet-conjugate-prior.pdf + https://web.archive.org/web/20170603230204/https://people.eecs.berkeley.edu/~stephentu/writeups/dirichlet-conjugate-prior.pdf """ def __init__(self, outputs, inputs, hypers=None, params=None, @@ -48,8 +50,13 @@ def __init__(self, outputs, inputs, hypers=None, params=None, self.N = 0 self.counts = np.zeros(self.k) # Hyperparameters. - if hypers is None: hypers = {} - self.alpha = hypers.get('alpha', 1.) + if not hypers: + hypers = {} + self.alpha = np.ones(self.k) + else: + self.alpha = np.array([ + hypers[f"alpha_{i}"] + for i in range(self.k)]) def incorporate(self, rowid, observation, inputs=None): DistributionGpm.incorporate(self, rowid, observation, inputs) @@ -79,6 +86,7 @@ def simulate(self, rowid, targets, constraints=None, inputs=None, N=None): DistributionGpm.simulate(self, rowid, targets, constraints, inputs, N) if rowid in self.data: return {self.outputs[0]: self.data[rowid]} + x = gu.pflip(self.counts + self.alpha, rng=self.rng) return {self.outputs[0]: x} @@ -93,11 +101,15 @@ def transition_params(self): return def set_hypers(self, hypers): - assert hypers['alpha'] > 0 - self.alpha = hypers['alpha'] + assert set(hypers.keys()) == set([f'alpha_{i}' for i in range(self.k)]) + for i in range(self.k): + assert hypers[f'alpha_{i}'] > 0 + self.alpha = np.array([ + hypers[f"alpha_{i}"] + for i in range(self.k)]) def get_hypers(self): - return {'alpha': self.alpha} + return {f'alpha_{i}': self.alpha[i] for i in range(self.k)} def get_params(self): return {} @@ -110,8 +122,11 @@ def get_distargs(self): @staticmethod def construct_hyper_grids(X, n_grid=30): - grids = dict() - grids['alpha'] = gu.log_linspace(1., float(len(X)), n_grid) + # grid is a static method, so it can't have access to self.k + # we'll circumvent this by making it a defaultdict + grids = defaultdict( + lambda: dd_alpha + ) return grids @staticmethod @@ -144,13 +159,13 @@ def validate(x, K): @staticmethod def calc_predictive_logp(x, N, counts, alpha): - numer = log(alpha + counts[x]) - denom = log(np.sum(counts) + alpha * len(counts)) + numer = log(alpha[x] + counts[x]) + denom = log(np.sum(counts + alpha)) return numer - denom @staticmethod def calc_logpdf_marginal(N, counts, alpha): K = len(counts) - A = K * alpha - lg = sum(gammaln(counts[k] + alpha) for k in range(K)) - return gammaln(A) - gammaln(A+N) + lg - K * gammaln(alpha) + A = sum(alpha) + lg = sum(gammaln(counts[k] + alpha[k]) for k in range(K)) + return gammaln(A) - gammaln(A+N) + lg - sum(gammaln(alpha[k]) for k in range(K)) diff --git a/src/primitives/crp.py b/src/primitives/crp.py index b0dc1160..1fbe4509 100644 --- a/src/primitives/crp.py +++ b/src/primitives/crp.py @@ -17,11 +17,13 @@ from builtins import range from collections import OrderedDict from math import log +import numpy as np from scipy.special import gammaln from cgpm.primitives.distribution import DistributionGpm from cgpm.utils import general as gu +from cgpm.utils.grid import pitman_yor class Crp(DistributionGpm): @@ -42,6 +44,7 @@ def __init__( # Hyperparameters. if hypers is None: hypers = {} self.alpha = hypers.get('alpha', 1.) + self.discount = hypers.get('discount', 0.) def incorporate(self, rowid, observation, inputs=None): DistributionGpm.incorporate(self, rowid, observation, inputs) @@ -67,7 +70,7 @@ def logpdf(self, rowid, targets, constraints=None, inputs=None): x = int(targets[self.outputs[0]]) if rowid in self.data: return 0 if self.data[rowid] == x else -float('inf') - return Crp.calc_predictive_logp(x, self.N, self.counts, self.alpha) + return Crp.calc_predictive_logp(x, self.N, self.counts, self.alpha, self.discount) @gu.simulate_many def simulate(self, rowid, targets, constraints=None, inputs=None, N=None): @@ -82,7 +85,7 @@ def simulate(self, rowid, targets, constraints=None, inputs=None, N=None): return {self.outputs[0]: x} def logpdf_score(self): - return Crp.calc_logpdf_marginal(self.N, self.counts, self.alpha) + return Crp.calc_logpdf_marginal(self.N, self.counts.values(), self.alpha, self.discount) ################## # NON-GPM METHOD # @@ -92,11 +95,14 @@ def transition_params(self): return def set_hypers(self, hypers): - assert hypers['alpha'] > 0 + assert hypers['discount'] >= 0 + assert hypers['discount'] < 1 + assert hypers['alpha'] > -hypers["discount"] self.alpha = hypers['alpha'] + self.discount = hypers['discount'] def get_hypers(self): - return {'alpha': self.alpha} + return {'alpha': self.alpha, 'discount': self.discount} def get_params(self): return {} @@ -110,19 +116,20 @@ def get_distargs(self): # Some Gibbs utils. def gibbs_logps(self, rowid, m=1): - """Compute the CRP probabilities for a Gibbs transition of rowid, - with table counts Nk, table assignments Z, and m auxiliary tables.""" - assert rowid in self.data - assert 0 < m - singleton = self.singleton(rowid) - p_aux = self.alpha / float(m) - p_rowid = p_aux if singleton else self.counts[self.data[rowid]]-1 - tables = self.gibbs_tables(rowid, m=m) - def p_table(t): - if t == self.data[rowid]: return p_rowid # rowid table. - if t not in self.counts: return p_aux # auxiliary table. - return self.counts[t] # regular table. - return [log(p_table(t)) for t in tables] + total_count = sum(self.counts.values()) - 1 + logdenom = log(total_count + self.alpha) + # Get all counts but don't count current rowid. + current_table_id = self.data[rowid] + counts = {table_id: (count - 1 if table_id == current_table_id else count) for table_id, count in self.counts.items()} + new_table_prob = log(len(counts) * self.discount + self.alpha) - logdenom + probs = [new_table_prob - log(m) if count==0 else + log(count - self.discount) - logdenom + for count in counts.values() + ] + if counts[current_table_id] == 0 : + return probs + [new_table_prob - log(m)]*(m-1) + return probs + [new_table_prob - log(m)]*m + def gibbs_tables(self, rowid, m=1): """Retrieve a list of possible tables for rowid. @@ -148,8 +155,11 @@ def singleton(self, rowid): @staticmethod def construct_hyper_grids(X, n_grid=30): - grids = dict() - grids['alpha'] = gu.log_linspace(1./len(X), len(X), n_grid) + grids = pitman_yor(alpha_count=n_grid, d_count=n_grid) + grids = { + 'alpha': [g['alpha'] for g in grids], + 'discount': [g['d'] for g in grids], + } return grids @staticmethod @@ -177,13 +187,24 @@ def is_numeric(): ################## @staticmethod - def calc_predictive_logp(x, N, counts, alpha): - numerator = counts.get(x, alpha) + def calc_predictive_logp(x, N, counts, alpha, discount): + if x in counts.keys(): + numerator = counts.get(x, alpha) - discount + else: + numerator = len(counts.keys()) * discount + alpha + denominator = N + alpha return log(numerator) - log(denominator) @staticmethod - def calc_logpdf_marginal(N, counts, alpha): - # http://gershmanlab.webfactional.com/pubs/GershmanBlei12.pdf#page=4 (eq 8) - return len(counts) * log(alpha) + sum(gammaln(list(counts.values()))) \ - + gammaln(alpha) - gammaln(N + alpha) + def calc_logpdf_marginal(N, counts, alpha, discount): + # as seen in the PClean implementation: + # https://github.com/probcomp/PClean/blob/cef451a17749a14960f6f46dc6c2b92ed846dc05/src/model/trace.jl#L65 + n_references = 0 + logprob = 0. + for (n_objects, size) in enumerate(counts): + logprob += log(n_objects * discount + alpha) - log(n_references + alpha) + if size > 1: + logprob += sum(log(i - discount) - log(n_references + i + alpha) for i in range(1, size)) + n_references += size + return logprob diff --git a/src/primitives/normal.py b/src/primitives/normal.py index bb10045f..d959bc04 100644 --- a/src/primitives/normal.py +++ b/src/primitives/normal.py @@ -24,6 +24,7 @@ from cgpm.primitives.distribution import DistributionGpm from cgpm.utils import general as gu +from cgpm.utils.grid import DEFAULTS LOG2 = log(2) @@ -53,14 +54,20 @@ def __init__(self, outputs, inputs, hypers=None, params=None, self.sum_x = 0 self.sum_x_sq = 0 # Hyper parameters. - if hypers is None: hypers = {} - self.m = hypers.get('m', 0.) - self.r = hypers.get('r', 1.) - self.s = hypers.get('s', 1.) - self.nu = hypers.get('nu', 1.) - assert self.s > 0. - assert self.r > 0. - assert self.nu > 0. + if not hypers: + self.m = 0. + self.r = 1. + self.s = 1. + self.nu = 1. + else: + assert set(hypers.keys()) == set(['m', 'r', 's', 'nu']) + self.m = hypers['m'] + self.r = hypers['r'] + self.s = hypers['s'] + self.nu = hypers['nu'] + assert self.s > 0. + assert self.r > 0. + assert self.nu > 0. def incorporate(self, rowid, observation, inputs=None): DistributionGpm.incorporate(self, rowid, observation, inputs) @@ -128,15 +135,12 @@ def get_distargs(self): @staticmethod def construct_hyper_grids(X, n_grid=30): - grids = dict() - # Plus 1 for single observation case. - N = len(X) + 1. - ssqdev = np.var(X) * len(X) + 1. - # Data dependent heuristics. - grids['m'] = np.linspace(min(X), max(X) + 5, n_grid) - grids['r'] = gu.log_linspace(1. / N, N, n_grid) - grids['s'] = gu.log_linspace(ssqdev / 100., ssqdev, n_grid) - grids['nu'] = gu.log_linspace(1., N, n_grid) # df >= 1 + loom_grid = DEFAULTS['nich'] + grids = {} + grids['m'] = loom_grid['mu'] + grids['r'] = loom_grid['kappa'] + grids['s'] = loom_grid['sigmasq'] + grids['nu'] = loom_grid['nu'] return grids @staticmethod @@ -187,7 +191,8 @@ def posterior_hypers(N, sum_x, sum_x_sq, m, r, s, nu): nun = nu + float(N) mn = old_div((r*m + sum_x),rn) sn = s + sum_x_sq + r*m*m - rn*mn*mn - if sn == 0: + # XXX: why do I need this??? + if sn <= 0: sn = s return mn, rn, sn, nun diff --git a/src/utils/general.py b/src/utils/general.py index e3b540d0..be854db6 100644 --- a/src/utils/general.py +++ b/src/utils/general.py @@ -91,13 +91,16 @@ def normalize(p): """Normalizes a np array of probabilites.""" return old_div(np.asarray(p, dtype=float), sum(p)) -def logp_crp(N, Nk, alpha): - """Returns the log normalized P(N,K|alpha), where N is the number of - customers and K is the number of tables. - http://gershmanlab.webfactional.com/pubs/GershmanBlei12.pdf#page=4 (eq 8) - """ - return len(Nk)*log(alpha) + np.sum(lgamma(c) for c in Nk) \ - + lgamma(alpha) - lgamma(N+alpha) +def logp_crp(N, Nk, alpha, discount): + # XXX: copy pasta from crp.py (see `def calc_logpdf_marginal`). + n_references = 0 + logprob = 0. + for (n_objects, size) in enumerate(Nk): + logprob += log(n_objects * discount + alpha) - log(n_references + alpha) + if size > 1: + logprob += sum(log(i - discount) - log(n_references + i + alpha) for i in range(1, size)) + n_references += size + return logprob def logp_crp_unorm(N, K, alpha): """Returns the log unnormalized P(N,K|alpha), where N is the number of @@ -106,18 +109,32 @@ def logp_crp_unorm(N, K, alpha): """ return K*log(alpha) + lgamma(alpha) - lgamma(N+alpha) -def logp_crp_gibbs(Nk, Z, i, alpha, m): + +def aux_table_probs(new_table_prob, m): + if m==0: + return [] + return [new_table_prob - log(n)]*m + +def logp_crp_gibbs(Nk, Z, i, alpha, discount, m): """Compute the CRP probabilities for a Gibbs transition of customer i, with table counts Nk, table assignments Z, and m auxiliary tables.""" - # XXX F ME - K = sorted(Nk) if isinstance(Nk, dict) else range(len(Nk)) - singleton = Nk[Z[i]] == 1 - m_aux = m-1 if singleton else m - p_table_aux = alpha/float(m) - p_current = lambda : p_table_aux if singleton else Nk[Z[i]]-1 - p_other = lambda t : Nk[t] - p_table = lambda t: p_current() if t == Z[i] else p_other(t) - return [log(p_table(t)) for t in K] + [log(p_table_aux)]*m_aux + + # XXX F ME (edit: found this great comment 8 years after it was made) + # XXX: this is kinda copy-pasta from crp.py not sure why we have it. + + total_count = sum(Nk) - 1 + logdenom = log(total_count + alpha) + # Get all counts but don't count current rowid. + current_table_id = Z[i] + counts = {table_id: (count - 1 if table_id == current_table_id else count) for table_id, count in enumerate(Nk)} + new_table_prob = log(len(counts) * discount + alpha) - logdenom + probs = [new_table_prob - log(m) if count==0 else + log(count - discount) - logdenom + for count in counts.values() + ] + if counts[current_table_id] == 0 : + return probs + [new_table_prob - log(m)]*(m-1) + return probs + [new_table_prob - log(m)]*m def logp_crp_fresh(N, Nk, alpha, m=1): """Compute the CRP probabilities for a fresh customer i=N+1, with @@ -216,12 +233,13 @@ def log_nCk(n, k): return 0 return log(n) + lgamma(n) - log(k) - lgamma(k) - log(n-k) - lgamma(n-k) -def simulate_crp(N, alpha, rng=None): +def simulate_crp(N, alpha, discount, rng=None): """Generates random N-length partition from the CRP with parameter alpha.""" if rng is None: rng = gen_rng() assert N > 0 and alpha > 0. + assert 0 <= discount < 1 alpha = float(alpha) partition = [0]*N @@ -231,7 +249,9 @@ def simulate_crp(N, alpha, rng=None): ps = np.zeros(K+1) for k in range(K): ps[k] = float(Nk[k]) - ps[K] = alpha + # XXX - edit by ulli - is this correct? + ps[K] = K*discount + alpha + # .. and is the denominator still correct? ps /= (float(i) - 1 + alpha) assignment = pflip(ps, rng=rng) if assignment == K: @@ -251,7 +271,7 @@ def simulate_crp(N, alpha, rng=None): # rng.shuffle(partition) return partition -def simulate_crp_constrained(N, alpha, Cd, Ci, Rd, Ri, rng=None): +def simulate_crp_constrained(N, alpha, discount, Cd, Ci, Rd, Ri, rng=None): """Simulates a CRP with N customers and concentration alpha. Cd is a list, where each entry is a list of friends. Ci is a list of tuples, where each tuple is a pair of enemies.""" @@ -286,7 +306,8 @@ def simulate_crp_constrained(N, alpha, Cd, Ci, Rd, Ri, rng=None): prob_table[t] = 0 break # Choose from valid tables using CRP. - prob_table.append(alpha) + # XXX - edit by ulli - is this correct? + prob_table.append(len(prob_table)*discount + alpha) assignment = pflip(prob_table, rng=rng) for f in friends.get(cust, [cust]): Z[f] = assignment @@ -296,7 +317,7 @@ def simulate_crp_constrained(N, alpha, Cd, Ci, Rd, Ri, rng=None): assert vu.validate_crp_constrained_partition(Z, Cd, Ci, Rd, Ri) return Z -def simulate_crp_constrained_dependent(N, alpha, Cd, rng=None): +def simulate_crp_constrained_dependent(N, alpha, discount, Cd, rng=None): """Simulates a CRP with N customers and concentration alpha. Cd is a list, where each entry is a list of friends. Each clique of friends are effectively treated as one customer, since they are assigned to a table @@ -316,7 +337,7 @@ def simulate_crp_constrained_dependent(N, alpha, Cd, rng=None): num_simulate = get_crp_constrained_num_effective(N, Cd) # Simulate a CRP based on the block structure. - crp_block = simulate_crp(num_simulate, alpha, rng=rng) + crp_block = simulate_crp(num_simulate, alpha, discount=discount, rng=rng) # Prepare the overall partition of length N. partition = [-1] * N @@ -338,13 +359,14 @@ def simulate_crp_constrained_dependent(N, alpha, Cd, rng=None): return partition -def logp_crp_constrained_dependent(Z, alpha, Cd): +def logp_crp_constrained_dependent(Z, alpha, discount, Cd): """Compute logp of CRP simulated by simulate_crp_constrained_dependent. Z is a map from each customer to the table assignment. Cd is a list, where each entry is a list of friends. Each clique of friends are effectively treated as one customer, since they are assigned to a table jointly. """ + if not vu.validate_crp_constrained_partition(Z, Cd, [], [], []): return -float('inf') @@ -358,7 +380,7 @@ def logp_crp_constrained_dependent(Z, alpha, Cd): Nk = list(counts.values()) assert sum(Nk) == num_simulate - return logp_crp(num_simulate, Nk, alpha) + return logp_crp(num_simulate, Nk, alpha, discount) def get_crp_constrained_num_effective(N, Cd): diff --git a/src/utils/grid.py b/src/utils/grid.py new file mode 100644 index 00000000..90fbe7fa --- /dev/null +++ b/src/utils/grid.py @@ -0,0 +1,98 @@ +import numpy + +dd_alpha = numpy.logspace(-1, 2, 12).tolist() + +pos_logspace = numpy.logspace(-8, 8, 100).tolist() +neg_logspace = (-numpy.logspace(-8, 8, 100)).tolist() + +def uniform(min_val, max_val, point_count): + grid = numpy.array(range(point_count)) + 0.5 + grid *= (max_val - min_val) / float(point_count) + grid += min_val + return grid + + +def center_heavy(min_val, max_val, point_count): + grid = uniform(-1, 1, point_count) + grid = numpy.arcsin(grid) / numpy.pi + 0.5 + grid *= max_val - min_val + grid += min_val + return grid + + +def left_heavy(min_val, max_val, point_count): + grid = uniform(0, 1, point_count) + grid = grid ** 2 + grid *= max_val - min_val + grid += min_val + return grid + + +def right_heavy(min_val, max_val, point_count): + grid = left_heavy(max_val, min_val, point_count) + return grid[::-1].copy() + + +def pitman_yor( + min_alpha=0.1, + max_alpha=100, + min_d=0, + max_d=0.5, + alpha_count=20, + d_count=10): + ''' + For d = 0, this degenerates to the CRP, where the expected number of + tables is: + E[table_count] = O(alpha log(customer_count)) + ''' + min_alpha = float(min_alpha) + max_alpha = float(max_alpha) + min_d = float(min_d) + max_d = float(max_d) + + lower_triangle = [ + (x, y) + for x in center_heavy(0, 1, alpha_count) + for y in left_heavy(0, 1, d_count) + if x + y < 1 + ] + alpha = lambda x: min_alpha * (max_alpha / min_alpha) ** x + d = lambda y: min_d + (max_d - min_d) * y + grid = [ + {'alpha': alpha(x), 'd': d(y)} + for (x, y) in lower_triangle + ] + return grid + +pitman_yor_grid = pitman_yor() + +DEFAULTS = { + 'topology': pitman_yor_grid, + 'clustering': pitman_yor_grid, + 'bb': { + 'alpha': dd_alpha, + 'beta': dd_alpha, + }, + 'dd': { + 'alpha': dd_alpha, + }, + 'dpd': { + 'gamma': (10 ** left_heavy(-1, 2, 30)).tolist(), + 'alpha': (10 ** right_heavy(-1, 1, 20)).tolist(), + }, + 'gp': { + 'alpha': numpy.logspace(-1, 5, 100).tolist(), + 'inv_beta': numpy.logspace(-5, 1, 100).tolist(), + }, + 'bnb': { + 'alpha': [2.0 ** p for p in range(-3, 1)], + 'beta': [2.0 ** p for p in range(13)], + 'r': [2 ** p for p in range(13)], + }, + 'nich': { + 'mu': neg_logspace + [0] + pos_logspace, + 'sigmasq': pos_logspace, + 'kappa': numpy.logspace(-2, 2, 30).tolist(), + 'nu': numpy.logspace(0, 2, 30).tolist(), + }, +} \ No newline at end of file From 0c43c23f36b5e6f75b71d25da04e00788b9aefa2 Mon Sep 17 00:00:00 2001 From: Schaechtle Date: Mon, 8 Jul 2024 13:02:07 -0400 Subject: [PATCH 2/3] test: Make crp tests pass for Pitman-Yor with discount = 0 --- src/utils/test.py | 20 +++++------ tests/test_cmi.py | 2 +- tests/test_constr_crp.py | 39 +++++++++++----------- tests/test_crp.py | 50 ++++++++++++++++------------ tests/test_dependence_constraints.py | 2 +- tests/test_iter_counter.py | 12 +++---- tests/test_logpdf_score.py | 2 +- tests/test_relevance.py | 13 ++++---- tests/test_serialize.py | 2 +- tests/test_update_cctype.py | 2 +- tests/test_view_logpdf_cluster.py | 6 ++-- 11 files changed, 79 insertions(+), 71 deletions(-) diff --git a/src/utils/test.py b/src/utils/test.py index a0d13236..2ce676b6 100644 --- a/src/utils/test.py +++ b/src/utils/test.py @@ -315,11 +315,11 @@ def gen_simple_engine(multiprocess=1): rng=gu.gen_rng(1), multiprocess=multiprocess, outputs=outputs, - alpha=1., + structure_hypers=1., cctypes=['bernoulli']*D, distargs={i: {'alpha': 1., 'beta': 1.} for i in outputs}, Zv={0: 0, 1: 0, 2: 1}, - view_alphas=[1.]*D, + view_structure_hypers=[1.]*D, Zrv={0: [0]*R, 1: [0]*R}) return engine @@ -331,11 +331,11 @@ def gen_simple_state(): state = State( X=data, outputs=outputs, - alpha=1., + structure_hypers=1., cctypes=['bernoulli']*D, hypers=[{'alpha': 1., 'beta': 1.} for i in outputs], Zv={0: 0, 1: 0, 2: 1}, - view_alphas=[1.]*D, + view_structure_hypers=[1.]*D, Zrv={0: [0]*R, 1: [0]*R}) return state @@ -349,7 +349,7 @@ def gen_simple_view(): view = View( X, outputs=[1000] + outputs, - alpha=1., + structure_hypers=1., cctypes=['bernoulli']*D, hypers={i: {'alpha': 1., 'beta': 1.} for i in outputs}, Zr=Zr) @@ -399,7 +399,7 @@ def change_column_hyperparameters(cgpm, value): new_hypers[c] = {'alpha': value, 'beta': value} elif cctypes[c] == 'categorical': - new_hypers[c] = {'alpha': value} + new_hypers[c] = {k:value for k,_ in metadata["hypers"][columns.index(c)].items()} elif cctypes[c] == 'normal': new_hypers[c] = {'m': value, 'nu': value, 'r': value, 's': value} @@ -421,13 +421,13 @@ def change_concentration_hyperparameters(cgpm, value): # Retrieve metadata from cgpm metadata = cgpm.to_metadata() - # Alter metadata.alpha to extreme values according to data type + # Alter metadata.structure_hypers to extreme values according to data type new_metadata = metadata if isinstance(cgpm, View): - new_metadata['alpha'] = value + new_metadata['structure_hypers'] = value elif isinstance(cgpm, State): - old_view_alphas = metadata['view_alphas'] - new_metadata['view_alphas'] = [(a[0], value) for a in old_view_alphas] + old_view_structure_hypers = metadata['view_structure_hypers'] + new_metadata['view_structure_hypers'] = [(a[0], value) for a in old_view_structure_hypers] return cgpm.from_metadata(new_metadata) def restrict_evidence_to_query(query, evidence): diff --git a/tests/test_cmi.py b/tests/test_cmi.py index f7994b75..9d2e126d 100644 --- a/tests/test_cmi.py +++ b/tests/test_cmi.py @@ -106,7 +106,7 @@ def test_cmi_different_views__ci_(): rng=rng ) state.transition(N=30, - kernels=['alpha','view_alphas','column_params','column_hypers','rows']) + kernels=['structure_hypers','view_structure_hypers','column_params','column_hypers','rows']) mi01 = state.mutual_information([0], [1]) mi02 = state.mutual_information([0], [2]) diff --git a/tests/test_constr_crp.py b/tests/test_constr_crp.py index 6d08a226..da704e24 100644 --- a/tests/test_constr_crp.py +++ b/tests/test_constr_crp.py @@ -100,20 +100,21 @@ def test_valid_constraints(): assert vu.validate_crp_constrained_input(7, Cd, Ci, Rd, Ri) -# Tests for simulate_crp_constrained and simulate_crp_constrained_dependent. - +# Globallyh set discount parameter. +DISCOUNT = 0. +# Tests for simulate_crp_constrained and simulate_crp_constrained_dependent. def test_no_constraints(): N, alpha = 10, .4 Cd = Ci = [] Rd = Ri = {} Z = gu.simulate_crp_constrained( - N, alpha, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0)) + N, alpha, DISCOUNT, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0)) assert vu.validate_crp_constrained_partition(Z, Cd, Ci, Rd, Ri) Z = gu.simulate_crp_constrained_dependent( - N, alpha, Cd, rng=gu.gen_rng(0)) + N, alpha, DISCOUNT, Cd, rng=gu.gen_rng(0)) assert vu.validate_crp_constrained_partition(Z, Cd, [], [], []) def test_all_friends(): @@ -123,11 +124,11 @@ def test_all_friends(): Rd = Ri = {} Z = gu.simulate_crp_constrained( - N, alpha, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0)) + N, alpha, DISCOUNT, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0)) assert vu.validate_crp_constrained_partition(Z, Cd, Ci, Rd, Ri) Z = gu.simulate_crp_constrained_dependent( - N, alpha, Cd, rng=gu.gen_rng(0)) + N, alpha, DISCOUNT, Cd, rng=gu.gen_rng(0)) assert vu.validate_crp_constrained_partition(Z, Cd, [], [], []) def test_all_enemies(): @@ -136,7 +137,7 @@ def test_all_enemies(): Ci = list(itertools.combinations(list(range(N)), 2)) Rd = Ri = {} Z = gu.simulate_crp_constrained( - N, alpha, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0)) + N, alpha, DISCOUNT, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0)) assert vu.validate_crp_constrained_partition(Z, Cd, Ci, Rd, Ri) def test_all_enemies_rows(): @@ -147,7 +148,7 @@ def test_all_enemies_rows(): Rd = {0:[[0,1]], 1:[[1,2]], 2:[[2,3]]} Ri = {0:[(1,2)], 1:[(2,3)], 2:[(0,1)]} Z = gu.simulate_crp_constrained( - N, alpha, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0)) + N, alpha, DISCOUNT, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0)) assert vu.validate_crp_constrained_partition(Z, Cd, Ci, Rd, Ri) def test_complex_relationships(): @@ -156,7 +157,7 @@ def test_complex_relationships(): Ci = [(2,8), (0,3)] Rd = Ri = {} Z = gu.simulate_crp_constrained( - N, alpha, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0)) + N, alpha, DISCOUNT, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0)) assert vu.validate_crp_constrained_partition(Z, Cd, Ci, Rd, Ri) # Tests for simulate_crp_constrained and simulate_crp_constrained_dependent. @@ -172,42 +173,42 @@ def test_logp_no_dependence_constraints(): Z = {0:0, 1:0, 2:0, 3:1} alpha = 2 Cd = [] - lp0 = gu.logp_crp(len(Z), get_partition_counts(Z), alpha) - lp1 = gu.logp_crp_constrained_dependent(Z, alpha, Cd) + lp0 = gu.logp_crp(len(Z), get_partition_counts(Z), alpha, DISCOUNT) + lp1 = gu.logp_crp_constrained_dependent(Z, alpha, DISCOUNT, Cd) assert np.allclose(lp0, lp1) Z = {0:1, 1:2, 2:3, 3:4} alpha = 2 Cd = [] - lp0 = gu.logp_crp(len(Z), get_partition_counts(Z), alpha) - lp1 = gu.logp_crp_constrained_dependent(Z, alpha, Cd) + lp0 = gu.logp_crp(len(Z), get_partition_counts(Z), alpha, DISCOUNT) + lp1 = gu.logp_crp_constrained_dependent(Z, alpha, DISCOUNT, Cd) assert np.allclose(lp0, lp1) def test_logp_simple_dependence_constraints(): Z = {0:0, 1:0, 2:0, 3:1} alpha = 2 Cd = [[0,1]] - lp0 = gu.logp_crp(len(Z)-1, [2, 1], alpha) - lp1 = gu.logp_crp_constrained_dependent(Z, alpha, Cd) + lp0 = gu.logp_crp(len(Z)-1, [2, 1], alpha, DISCOUNT) + lp1 = gu.logp_crp_constrained_dependent(Z, alpha, DISCOUNT, Cd) assert np.allclose(lp0, lp1) Z = {0:0, 1:1, 2:1, 3:0} alpha = 2 Cd = [[0,3], [1,2]] - lp0 = gu.logp_crp(len(Z)-2, [1,1], alpha) - lp1 = gu.logp_crp_constrained_dependent(Z, alpha, Cd) + lp0 = gu.logp_crp(len(Z)-2, [1,1], alpha, DISCOUNT) + lp1 = gu.logp_crp_constrained_dependent(Z, alpha, DISCOUNT, Cd) assert np.allclose(lp0, lp1) def test_logp_impossible(): Z = {0:0, 1:1, 2:0} alpha = 2 Cd = [[0,1]] - lp = gu.logp_crp_constrained_dependent(Z, alpha, Cd) + lp = gu.logp_crp_constrained_dependent(Z, alpha, DISCOUNT, Cd) assert lp == -float('inf') def test_logp_deterministic(): Z = {0:0, 1:0, 2:0, 3:0} alpha = 2 Cd = [[0,1,2,3]] - lp1 = gu.logp_crp_constrained_dependent(Z, alpha, Cd) + lp1 = gu.logp_crp_constrained_dependent(Z, alpha, DISCOUNT, Cd) assert np.allclose(0, lp1) diff --git a/tests/test_crp.py b/tests/test_crp.py index f14d4459..790ab33d 100644 --- a/tests/test_crp.py +++ b/tests/test_crp.py @@ -27,7 +27,7 @@ def simulate_crp_gpm(N, alpha, rng): - crp = Crp(outputs=[0], inputs=None, hypers={'alpha':alpha}, rng=rng) + crp = Crp(outputs=[0], inputs=None, hypers={'alpha':alpha, 'discount':0}, rng=rng) for i in range(N): s = crp.simulate(i, [0], None) crp.incorporate(i, s, None) @@ -35,6 +35,9 @@ def simulate_crp_gpm(N, alpha, rng): def assert_crp_equality(alpha, Nk, crp): + # Because we're assessing CRP probabilities, + # we want to set the discount to 0. + discount = 0. N = sum(Nk) Z = list(itertools.chain.from_iterable( [i]*n for i, n in enumerate(Nk))) @@ -48,25 +51,25 @@ def assert_crp_equality(alpha, Nk, crp): [crp.logpdf(-1, {0:v}, None) for v in probe_values]) # Data probability. assert np.allclose( - gu.logp_crp(N, Nk, alpha), - crp.logpdf_score()) + gu.logp_crp(N, Nk, alpha, discount), + crp.logpdf_score()), (gu.logp_crp(N, Nk, alpha, discount), crp.logpdf_score()) # Gibbs transition probabilities. Z = list(crp.data.values()) for i, rowid in enumerate(crp.data): assert np.allclose( - gu.logp_crp_gibbs(Nk, Z, i, alpha, 1), + gu.logp_crp_gibbs(Nk, Z, i, alpha, discount, 1), crp.gibbs_logps(rowid)) N = [2**i for i in range(8)] alpha = gu.log_linspace(.001, 100, 10) seed = [5] - +DISCOUNT = 0 @pytest.mark.parametrize('N, alpha, seed', itertools.product(N, alpha, seed)) def test_crp_simple(N, alpha, seed): # Obtain the partitions. - A = gu.simulate_crp(N, alpha, rng=gu.gen_rng(seed)) + A = gu.simulate_crp(N, alpha, DISCOUNT, rng=gu.gen_rng(seed)) Nk = list(np.bincount(A)) crp = simulate_crp_gpm(N, alpha, rng=gu.gen_rng(seed)) @@ -77,7 +80,7 @@ def test_crp_simple(N, alpha, seed): @pytest.mark.parametrize('N, alpha, seed', itertools.product(N, alpha, seed)) def test_crp_decrement(N, alpha, seed): - A = gu.simulate_crp(N, alpha, rng=gu.gen_rng(seed)) + A = gu.simulate_crp(N, alpha, DISCOUNT, rng=gu.gen_rng(seed)) Nk = list(np.bincount(A)) # Decrement all counts by 1. Nk = [n-1 if n > 1 else n for n in Nk] @@ -98,7 +101,7 @@ def test_crp_decrement(N, alpha, seed): @pytest.mark.parametrize('N, alpha, seed', itertools.product(N, alpha, seed)) def test_crp_increment(N, alpha, seed): - A = gu.simulate_crp(N, alpha, rng=gu.gen_rng(seed)) + A = gu.simulate_crp(N, alpha, DISCOUNT, rng=gu.gen_rng(seed)) Nk = list(np.bincount(A)) # Add 3 new classes. Nk.extend([2, 3, 1]) @@ -116,9 +119,12 @@ def test_crp_increment(N, alpha, seed): assert_crp_equality(alpha, Nk, crp) +def norm(l): + return np.array(l)/np.sum(l) + def test_gibbs_tables_logps(): crp = Crp( - outputs=[0], inputs=None, hypers={'alpha': 1.5}, rng=gu.gen_rng(1)) + outputs=[0], inputs=None, hypers={'alpha': 1.5, 'discount':0.}, rng=gu.gen_rng(1)) assignments = [ (0, {0: 0}), @@ -136,53 +142,52 @@ def test_gibbs_tables_logps(): K01 = crp.gibbs_tables(0, m=1) assert K01 == [0, 2, 6, 7] P01 = crp.gibbs_logps(0, m=1) - assert np.allclose(np.exp(P01), [1, 3, 1, 1.5]) + assert np.allclose(np.exp(P01), norm([1, 3, 1, 1.5])) K02 = crp.gibbs_tables(0, m=2) assert K02 == [0, 2, 6, 7, 8] P02 = crp.gibbs_logps(0, m=2) - assert np.allclose(np.exp(P02), [1, 3, 1, 1.5/2, 1.5/2]) + assert np.allclose(np.exp(P02), norm([1, 3, 1, 1.5/2, 1.5/2])) K03 = crp.gibbs_tables(0, m=3) assert K03 == [0, 2, 6, 7, 8, 9] P03 = crp.gibbs_logps(0, m=3) - assert np.allclose(np.exp(P03), [1, 3, 1, 1.5/3, 1.5/3, 1.5/3]) + assert np.allclose(np.exp(P03), norm([1, 3, 1, 1.5/3, 1.5/3, 1.5/3])) K21 = crp.gibbs_tables(2, m=1) assert K21 == [0, 2, 6, 7] P21 = crp.gibbs_logps(2, m=1) - assert np.allclose(np.exp(P21), [2, 2, 1, 1.5]) + assert np.allclose(np.exp(P21), norm([2, 2, 1, 1.5])) K22 = crp.gibbs_tables(2, m=2) assert K22 == [0, 2, 6, 7, 8] P22 = crp.gibbs_logps(2, m=2) - assert np.allclose(np.exp(P22), [2, 2, 1, 1.5/2, 1.5/2]) + assert np.allclose(np.exp(P22), norm([2, 2, 1, 1.5/2, 1.5/2])) K23 = crp.gibbs_tables(2, m=3) P23 = crp.gibbs_logps(2, m=3) assert K23 == [0, 2, 6, 7, 8, 9] - assert np.allclose(np.exp(P23), [2, 2, 1, 1.5/3, 1.5/3, 1.5/3]) + assert np.allclose(np.exp(P23), norm([2, 2, 1, 1.5/3, 1.5/3, 1.5/3])) K51 = crp.gibbs_tables(5, m=1) assert K51 == [0, 2, 6] P51 = crp.gibbs_logps(5, m=1) - assert np.allclose(np.exp(P51), [2, 3, 1.5]) + assert np.allclose(np.exp(P51), norm([2, 3, 1.5])) K52 = crp.gibbs_tables(5, m=2) assert K52 == [0, 2, 6, 7] P52 = crp.gibbs_logps(5, m=2) - assert np.allclose(np.exp(P52), [2, 3, 1.5/2, 1.5/2]) + assert np.allclose(np.exp(P52), norm([2, 3, 1.5/2, 1.5/2])) K53 = crp.gibbs_tables(5, m=3) P53 = crp.gibbs_logps(5, m=3) assert K53 == [0, 2, 6, 7, 8] - assert np.allclose(np.exp(P53), [2, 3, 1.5/3, 1.5/3, 1.5/3]) + assert np.allclose(np.exp(P53), norm([2, 3, 1.5/3, 1.5/3, 1.5/3])) def test_crp_logpdf_score(): """Ensure that logpdf_marginal agrees with sequence of predictives.""" - crp = Crp( - outputs=[0], inputs=None, hypers={'alpha': 1.5}, rng=gu.gen_rng(1)) + crp = Crp(outputs=[0], inputs=None, hypers={'alpha': 1.5, 'discount': 0.}, rng=gu.gen_rng(1)) assignments = [ (0, {0: 0}), @@ -258,7 +263,7 @@ def test_crp_same_table_probability(): where kQ is list of tables in the CRP plus a fresh singleton. """ crp = Crp( - outputs=[0], inputs=None, hypers={'alpha': 1.5}, rng=gu.gen_rng(1)) + outputs=[0], inputs=None, hypers={'alpha': 1.5, 'discount': 0.}, rng=gu.gen_rng(1)) assignments = [ (0, {0: 0}), @@ -393,4 +398,5 @@ def get_tables_different(tables): # Confirm no mutation has occured. assert crp.data == crp_data_full - assert crp.logpdf_score() == logpdf_score_full + # Before, we had equality here, but that now fails on the 15th decimal. + assert np.allclose(crp.logpdf_score(), logpdf_score_full) diff --git a/tests/test_dependence_constraints.py b/tests/test_dependence_constraints.py index ed87cb7c..4fe0aa28 100644 --- a/tests/test_dependence_constraints.py +++ b/tests/test_dependence_constraints.py @@ -59,7 +59,7 @@ def test_simple_dependence_constraint(Ci): state.transition(N=10, kernels=['columns'], progress=0) state.transition( N=10, - kernels=['rows', 'alpha', 'column_hypers', 'alpha', 'view_alphas'], + kernels=['rows', 'structure_hypers', 'column_hypers', 'structure_hypers', 'view_structure_hypers'], progress=False) vu.validate_crp_constrained_partition(state.Zv(), Cd, Ci, {}, {}) diff --git a/tests/test_iter_counter.py b/tests/test_iter_counter.py index 04b4ac04..699df798 100644 --- a/tests/test_iter_counter.py +++ b/tests/test_iter_counter.py @@ -37,19 +37,19 @@ def test_individual_kernels(): rng = gu.gen_rng(0) X = rng.normal(size=(5,5)) state = State(X, cctypes=['normal']*5) - state.transition(N=3, kernels=['alpha', 'rows']) + state.transition(N=3, kernels=['structure_hypers', 'rows']) check_expected_counts( state.diagnostics['iterations'], - {'alpha':3, 'rows':3}) - state.transition(N=5, kernels=['view_alphas', 'column_params']) + {'structure_hypers':3, 'rows':3}) + state.transition(N=5, kernels=['view_structure_hypers', 'column_params']) check_expected_counts( state.to_metadata()['diagnostics']['iterations'], - {'alpha':3, 'rows':3, 'view_alphas':5, 'column_params':5}) + {'structure_hypers':3, 'rows':3, 'view_structure_hypers':5, 'column_params':5}) state.transition( - N=1, kernels=['view_alphas', 'column_params', 'column_hypers']) + N=1, kernels=['view_structure_hypers', 'column_params', 'column_hypers']) check_expected_counts( state.to_metadata()['diagnostics']['iterations'], - {'alpha':3, 'rows':3, 'view_alphas':6, 'column_params':6, + {'structure_hypers':3, 'rows':3, 'view_structure_hypers':6, 'column_params':6, 'column_hypers':1}) diff --git a/tests/test_logpdf_score.py b/tests/test_logpdf_score.py index 3d8f98da..b270bed0 100644 --- a/tests/test_logpdf_score.py +++ b/tests/test_logpdf_score.py @@ -30,7 +30,7 @@ def test_logpdf_score_crash(): assert np.all(logpdf_score_initial < logpdf_likelihood_initial) # assert np.all(logpdf_likelihood_initial < logpdf_score_initial) engine.transition(N=100) - engine.transition(kernels=['column_hypers','view_alphas'], N=10) + engine.transition(kernels=['column_hypers','view_structure_hypers'], N=10) logpdf_likelihood_final = np.asarray(engine.logpdf_likelihood()) logpdf_score_final = np.asarray(engine.logpdf_score()) assert np.all(logpdf_score_final < logpdf_likelihood_final) diff --git a/tests/test_relevance.py b/tests/test_relevance.py index 683a4e62..455ce68d 100644 --- a/tests/test_relevance.py +++ b/tests/test_relevance.py @@ -70,7 +70,7 @@ def gen_view_cgpm(get_data): view = View( outputs=[1000]+outputs, X={output: data[:, i] for i, output in enumerate(outputs)}, - alpha=1.5, + structure_hypers={'alpha': 1.5, 'disocunt': 1.}, cctypes=cctypes, distargs=distargs, Zr=assignments, @@ -90,7 +90,7 @@ def gen_state_cgpm(get_data): distargs=distargs, Zv={output: 0 for output in outputs}, Zrv={0: assignments}, - view_alphas={0: 1.5}, + view_structure_hypers={0: {'alpha': 1.5, 'disocunt': 0.}}, rng=gu.gen_rng(1) ) @@ -179,7 +179,8 @@ def test_relevance_commutative_single_query_row(): def test_relevance_large_concentration_hypers(): """Confirm crp_alpha -> infty, implies rp(target, query) -> 0.""" view = gen_view_cgpm(get_data_separated) - lim_view = tu.change_concentration_hyperparameters(view, 1e5) + structure_hypers = {"alpha": 1e5 ,"disocunt": 0} + lim_view = tu.change_concentration_hyperparameters(view, structure_hypers) rp_view_0 = lim_view.relevance_probability(1, [4, 6, 7], 1) rp_view_1 = lim_view.relevance_probability(3, [8], 1) @@ -187,7 +188,7 @@ def test_relevance_large_concentration_hypers(): assert np.allclose(rp_view_1, 0, atol=1e-5) state = gen_state_cgpm(get_data_separated) - ext_state = tu.change_concentration_hyperparameters(state, 1e5) + ext_state = tu.change_concentration_hyperparameters(state, structure_hypers) rp_state_0 = ext_state.relevance_probability(1, [4, 6, 7], 1) rp_state_1 = ext_state.relevance_probability(3, [8], 1) @@ -227,7 +228,7 @@ def test_relevance_with_itself(): def test_relevance_analytically(): view = gen_view_cgpm(get_data_all_ones) n = view.n_rows() - a = view.alpha() # crp_alpha + a = view.crp.hypers["alpha"] # crp_alpha b1 = view.dims[1].hypers['alpha'] # bernoulli pseudocounts for one b0 = view.dims[1].hypers['beta'] # bernoulli pseudocounts for zero @@ -271,7 +272,7 @@ def test_crash_missing(): def test_missing_analytical(): view = gen_view_cgpm(get_data_missing) n = view.n_rows() - a = view.alpha() # crp_alpha + a = view.crp.hypers["alpha"] # crp_alpha b1 = view.dims[1].hypers['alpha'] # bernoulli pseudocounts for one b0 = view.dims[1].hypers['beta'] # bernoulli pseudocounts for zero diff --git a/tests/test_serialize.py b/tests/test_serialize.py index a71ec313..cc1dc38b 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -101,7 +101,7 @@ def test_view_serialize(): builder = getattr(modname, metadata['factory'][1]) model2 = builder.from_metadata(metadata) # Pick out some data. - assert np.allclose(model.alpha(), model.alpha()) + assert np.allclose(model.crp.hypers["alpha"], model2.crp.hypers["alpha"]) assert dict(model2.Zr()) == dict(model.Zr()) assert np.allclose( model.logpdf(-1, {0:0, 1:1}, {2:0}), diff --git a/tests/test_update_cctype.py b/tests/test_update_cctype.py index 61df3779..3df533bd 100644 --- a/tests/test_update_cctype.py +++ b/tests/test_update_cctype.py @@ -170,7 +170,7 @@ def test_linreg_missing_data_ignore(): # Make sure that missing covariates are handles as missing cell. state.update_cctype(2, 'linear_regression', distargs={'inputs': [0,1]}) assert state.dim_for(2).inputs[1:] == [0,1] - state.transition(N=5, kernels=['rows', 'column_hypers', 'view_alphas']) + state.transition(N=5, kernels=['rows', 'column_hypers', 'view_structure_hypers']) state.update_cctype(2, 'normal', distargs={'inputs': [0,1]}) # Make sure that specified inputs are set correctly. state.update_cctype(2, 'linear_regression', distargs={'inputs': [1]}) diff --git a/tests/test_view_logpdf_cluster.py b/tests/test_view_logpdf_cluster.py index 8180eab6..d549682c 100644 --- a/tests/test_view_logpdf_cluster.py +++ b/tests/test_view_logpdf_cluster.py @@ -38,7 +38,7 @@ def retrieve_view(): return View( {c: data[:,i].tolist() for i, c in enumerate(outputs)}, outputs=[1000] + outputs, - alpha=2., + structure_hypers={'alpha': 2., 'discount': 0}, cctypes=['normal'] * len(outputs), Zr=[0,0,0,1,1,] ) @@ -46,11 +46,11 @@ def retrieve_view(): def test_crp_prior_logpdf(): view = retrieve_view() - crp_normalizer = view.alpha() + 5. + crp_normalizer = view.crp.hypers["alpha"] + 5. cluster_logps = np.log(np.asarray([ old_div(3, crp_normalizer), old_div(2, crp_normalizer), - old_div(view.alpha(), crp_normalizer) + old_div(view.crp.hypers["alpha"], crp_normalizer) ])) # Test the crp probabilities agree for a hypothetical row. for k in [0,1,2]: From 1e5a059666ae24afa24861446ee7e9eaed16c86b Mon Sep 17 00:00:00 2001 From: Schaechtle Date: Tue, 9 Jul 2024 11:14:30 -0400 Subject: [PATCH 3/3] chore: Mark a long-running with continuous integration flag --- tests/test_serialize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_serialize.py b/tests/test_serialize.py index cc1dc38b..4232843c 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -111,7 +111,7 @@ def test_view_serialize(): model2.logpdf(-1, {0:0, 1:1})) -def test_serialize_composite_cgpm(): +def test_serialize_composite_cgpm__ci_(): rng = gu.gen_rng(2) # Generate the data.