Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend CRP to Pitman Yor and make hyper-parameters Loom compatible #256

Open
wants to merge 3 commits into
base: zane/hardcode-version
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/utils/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,11 @@ def gen_simple_engine(multiprocess=1):
rng=gu.gen_rng(1),
multiprocess=multiprocess,
outputs=outputs,
alpha=1.,
structure_hypers=1.,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be a dict with alpha and discount as keys?

cctypes=['bernoulli']*D,
distargs={i: {'alpha': 1., 'beta': 1.} for i in outputs},
Zv={0: 0, 1: 0, 2: 1},
view_alphas=[1.]*D,
view_structure_hypers=[1.]*D,
Zrv={0: [0]*R, 1: [0]*R})
return engine

Expand All @@ -331,11 +331,11 @@ def gen_simple_state():
state = State(
X=data,
outputs=outputs,
alpha=1.,
structure_hypers=1.,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too

cctypes=['bernoulli']*D,
hypers=[{'alpha': 1., 'beta': 1.} for i in outputs],
Zv={0: 0, 1: 0, 2: 1},
view_alphas=[1.]*D,
view_structure_hypers=[1.]*D,
Zrv={0: [0]*R, 1: [0]*R})
return state

Expand All @@ -349,7 +349,7 @@ def gen_simple_view():
view = View(
X,
outputs=[1000] + outputs,
alpha=1.,
structure_hypers=1.,
cctypes=['bernoulli']*D,
hypers={i: {'alpha': 1., 'beta': 1.} for i in outputs},
Zr=Zr)
Expand Down Expand Up @@ -399,7 +399,7 @@ def change_column_hyperparameters(cgpm, value):
new_hypers[c] = {'alpha': value, 'beta': value}

elif cctypes[c] == 'categorical':
new_hypers[c] = {'alpha': value}
new_hypers[c] = {k:value for k,_ in metadata["hypers"][columns.index(c)].items()}

elif cctypes[c] == 'normal':
new_hypers[c] = {'m': value, 'nu': value, 'r': value, 's': value}
Expand All @@ -421,13 +421,13 @@ def change_concentration_hyperparameters(cgpm, value):
# Retrieve metadata from cgpm
metadata = cgpm.to_metadata()

# Alter metadata.alpha to extreme values according to data type
# Alter metadata.structure_hypers to extreme values according to data type
new_metadata = metadata
if isinstance(cgpm, View):
new_metadata['alpha'] = value
new_metadata['structure_hypers'] = value
elif isinstance(cgpm, State):
old_view_alphas = metadata['view_alphas']
new_metadata['view_alphas'] = [(a[0], value) for a in old_view_alphas]
old_view_structure_hypers = metadata['view_structure_hypers']
new_metadata['view_structure_hypers'] = [(a[0], value) for a in old_view_structure_hypers]
return cgpm.from_metadata(new_metadata)

def restrict_evidence_to_query(query, evidence):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_cmi_different_views__ci_():
rng=rng
)
state.transition(N=30,
kernels=['alpha','view_alphas','column_params','column_hypers','rows'])
kernels=['structure_hypers','view_structure_hypers','column_params','column_hypers','rows'])

mi01 = state.mutual_information([0], [1])
mi02 = state.mutual_information([0], [2])
Expand Down
39 changes: 20 additions & 19 deletions tests/test_constr_crp.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,21 @@ def test_valid_constraints():
assert vu.validate_crp_constrained_input(7, Cd, Ci, Rd, Ri)


# Tests for simulate_crp_constrained and simulate_crp_constrained_dependent.

# Globallyh set discount parameter.
DISCOUNT = 0.

# Tests for simulate_crp_constrained and simulate_crp_constrained_dependent.
def test_no_constraints():
N, alpha = 10, .4
Cd = Ci = []
Rd = Ri = {}

Z = gu.simulate_crp_constrained(
N, alpha, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0))
N, alpha, DISCOUNT, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0))
assert vu.validate_crp_constrained_partition(Z, Cd, Ci, Rd, Ri)

Z = gu.simulate_crp_constrained_dependent(
N, alpha, Cd, rng=gu.gen_rng(0))
N, alpha, DISCOUNT, Cd, rng=gu.gen_rng(0))
assert vu.validate_crp_constrained_partition(Z, Cd, [], [], [])

def test_all_friends():
Expand All @@ -123,11 +124,11 @@ def test_all_friends():
Rd = Ri = {}

Z = gu.simulate_crp_constrained(
N, alpha, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0))
N, alpha, DISCOUNT, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0))
assert vu.validate_crp_constrained_partition(Z, Cd, Ci, Rd, Ri)

Z = gu.simulate_crp_constrained_dependent(
N, alpha, Cd, rng=gu.gen_rng(0))
N, alpha, DISCOUNT, Cd, rng=gu.gen_rng(0))
assert vu.validate_crp_constrained_partition(Z, Cd, [], [], [])

def test_all_enemies():
Expand All @@ -136,7 +137,7 @@ def test_all_enemies():
Ci = list(itertools.combinations(list(range(N)), 2))
Rd = Ri = {}
Z = gu.simulate_crp_constrained(
N, alpha, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0))
N, alpha, DISCOUNT, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0))
assert vu.validate_crp_constrained_partition(Z, Cd, Ci, Rd, Ri)

def test_all_enemies_rows():
Expand All @@ -147,7 +148,7 @@ def test_all_enemies_rows():
Rd = {0:[[0,1]], 1:[[1,2]], 2:[[2,3]]}
Ri = {0:[(1,2)], 1:[(2,3)], 2:[(0,1)]}
Z = gu.simulate_crp_constrained(
N, alpha, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0))
N, alpha, DISCOUNT, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0))
assert vu.validate_crp_constrained_partition(Z, Cd, Ci, Rd, Ri)

def test_complex_relationships():
Expand All @@ -156,7 +157,7 @@ def test_complex_relationships():
Ci = [(2,8), (0,3)]
Rd = Ri = {}
Z = gu.simulate_crp_constrained(
N, alpha, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0))
N, alpha, DISCOUNT, Cd, Ci, Rd, Ri, rng=gu.gen_rng(0))
assert vu.validate_crp_constrained_partition(Z, Cd, Ci, Rd, Ri)

# Tests for simulate_crp_constrained and simulate_crp_constrained_dependent.
Expand All @@ -172,42 +173,42 @@ def test_logp_no_dependence_constraints():
Z = {0:0, 1:0, 2:0, 3:1}
alpha = 2
Cd = []
lp0 = gu.logp_crp(len(Z), get_partition_counts(Z), alpha)
lp1 = gu.logp_crp_constrained_dependent(Z, alpha, Cd)
lp0 = gu.logp_crp(len(Z), get_partition_counts(Z), alpha, DISCOUNT)
lp1 = gu.logp_crp_constrained_dependent(Z, alpha, DISCOUNT, Cd)
assert np.allclose(lp0, lp1)

Z = {0:1, 1:2, 2:3, 3:4}
alpha = 2
Cd = []
lp0 = gu.logp_crp(len(Z), get_partition_counts(Z), alpha)
lp1 = gu.logp_crp_constrained_dependent(Z, alpha, Cd)
lp0 = gu.logp_crp(len(Z), get_partition_counts(Z), alpha, DISCOUNT)
lp1 = gu.logp_crp_constrained_dependent(Z, alpha, DISCOUNT, Cd)
assert np.allclose(lp0, lp1)

def test_logp_simple_dependence_constraints():
Z = {0:0, 1:0, 2:0, 3:1}
alpha = 2
Cd = [[0,1]]
lp0 = gu.logp_crp(len(Z)-1, [2, 1], alpha)
lp1 = gu.logp_crp_constrained_dependent(Z, alpha, Cd)
lp0 = gu.logp_crp(len(Z)-1, [2, 1], alpha, DISCOUNT)
lp1 = gu.logp_crp_constrained_dependent(Z, alpha, DISCOUNT, Cd)
assert np.allclose(lp0, lp1)

Z = {0:0, 1:1, 2:1, 3:0}
alpha = 2
Cd = [[0,3], [1,2]]
lp0 = gu.logp_crp(len(Z)-2, [1,1], alpha)
lp1 = gu.logp_crp_constrained_dependent(Z, alpha, Cd)
lp0 = gu.logp_crp(len(Z)-2, [1,1], alpha, DISCOUNT)
lp1 = gu.logp_crp_constrained_dependent(Z, alpha, DISCOUNT, Cd)
assert np.allclose(lp0, lp1)

def test_logp_impossible():
Z = {0:0, 1:1, 2:0}
alpha = 2
Cd = [[0,1]]
lp = gu.logp_crp_constrained_dependent(Z, alpha, Cd)
lp = gu.logp_crp_constrained_dependent(Z, alpha, DISCOUNT, Cd)
assert lp == -float('inf')

def test_logp_deterministic():
Z = {0:0, 1:0, 2:0, 3:0}
alpha = 2
Cd = [[0,1,2,3]]
lp1 = gu.logp_crp_constrained_dependent(Z, alpha, Cd)
lp1 = gu.logp_crp_constrained_dependent(Z, alpha, DISCOUNT, Cd)
assert np.allclose(0, lp1)
50 changes: 28 additions & 22 deletions tests/test_crp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@


def simulate_crp_gpm(N, alpha, rng):
crp = Crp(outputs=[0], inputs=None, hypers={'alpha':alpha}, rng=rng)
crp = Crp(outputs=[0], inputs=None, hypers={'alpha':alpha, 'discount':0}, rng=rng)
for i in range(N):
s = crp.simulate(i, [0], None)
crp.incorporate(i, s, None)
return crp


def assert_crp_equality(alpha, Nk, crp):
# Because we're assessing CRP probabilities,
# we want to set the discount to 0.
discount = 0.
N = sum(Nk)
Z = list(itertools.chain.from_iterable(
[i]*n for i, n in enumerate(Nk)))
Expand All @@ -48,25 +51,25 @@ def assert_crp_equality(alpha, Nk, crp):
[crp.logpdf(-1, {0:v}, None) for v in probe_values])
# Data probability.
assert np.allclose(
gu.logp_crp(N, Nk, alpha),
crp.logpdf_score())
gu.logp_crp(N, Nk, alpha, discount),
crp.logpdf_score()), (gu.logp_crp(N, Nk, alpha, discount), crp.logpdf_score())
# Gibbs transition probabilities.
Z = list(crp.data.values())
for i, rowid in enumerate(crp.data):
assert np.allclose(
gu.logp_crp_gibbs(Nk, Z, i, alpha, 1),
gu.logp_crp_gibbs(Nk, Z, i, alpha, discount, 1),
crp.gibbs_logps(rowid))


N = [2**i for i in range(8)]
alpha = gu.log_linspace(.001, 100, 10)
seed = [5]

DISCOUNT = 0

@pytest.mark.parametrize('N, alpha, seed', itertools.product(N, alpha, seed))
def test_crp_simple(N, alpha, seed):
# Obtain the partitions.
A = gu.simulate_crp(N, alpha, rng=gu.gen_rng(seed))
A = gu.simulate_crp(N, alpha, DISCOUNT, rng=gu.gen_rng(seed))
Nk = list(np.bincount(A))

crp = simulate_crp_gpm(N, alpha, rng=gu.gen_rng(seed))
Expand All @@ -77,7 +80,7 @@ def test_crp_simple(N, alpha, seed):

@pytest.mark.parametrize('N, alpha, seed', itertools.product(N, alpha, seed))
def test_crp_decrement(N, alpha, seed):
A = gu.simulate_crp(N, alpha, rng=gu.gen_rng(seed))
A = gu.simulate_crp(N, alpha, DISCOUNT, rng=gu.gen_rng(seed))
Nk = list(np.bincount(A))
# Decrement all counts by 1.
Nk = [n-1 if n > 1 else n for n in Nk]
Expand All @@ -98,7 +101,7 @@ def test_crp_decrement(N, alpha, seed):

@pytest.mark.parametrize('N, alpha, seed', itertools.product(N, alpha, seed))
def test_crp_increment(N, alpha, seed):
A = gu.simulate_crp(N, alpha, rng=gu.gen_rng(seed))
A = gu.simulate_crp(N, alpha, DISCOUNT, rng=gu.gen_rng(seed))
Nk = list(np.bincount(A))
# Add 3 new classes.
Nk.extend([2, 3, 1])
Expand All @@ -116,9 +119,12 @@ def test_crp_increment(N, alpha, seed):

assert_crp_equality(alpha, Nk, crp)

def norm(l):
return np.array(l)/np.sum(l)

def test_gibbs_tables_logps():
crp = Crp(
outputs=[0], inputs=None, hypers={'alpha': 1.5}, rng=gu.gen_rng(1))
outputs=[0], inputs=None, hypers={'alpha': 1.5, 'discount':0.}, rng=gu.gen_rng(1))

assignments = [
(0, {0: 0}),
Expand All @@ -136,53 +142,52 @@ def test_gibbs_tables_logps():
K01 = crp.gibbs_tables(0, m=1)
assert K01 == [0, 2, 6, 7]
P01 = crp.gibbs_logps(0, m=1)
assert np.allclose(np.exp(P01), [1, 3, 1, 1.5])
assert np.allclose(np.exp(P01), norm([1, 3, 1, 1.5]))

K02 = crp.gibbs_tables(0, m=2)
assert K02 == [0, 2, 6, 7, 8]
P02 = crp.gibbs_logps(0, m=2)
assert np.allclose(np.exp(P02), [1, 3, 1, 1.5/2, 1.5/2])
assert np.allclose(np.exp(P02), norm([1, 3, 1, 1.5/2, 1.5/2]))

K03 = crp.gibbs_tables(0, m=3)
assert K03 == [0, 2, 6, 7, 8, 9]
P03 = crp.gibbs_logps(0, m=3)
assert np.allclose(np.exp(P03), [1, 3, 1, 1.5/3, 1.5/3, 1.5/3])
assert np.allclose(np.exp(P03), norm([1, 3, 1, 1.5/3, 1.5/3, 1.5/3]))

K21 = crp.gibbs_tables(2, m=1)
assert K21 == [0, 2, 6, 7]
P21 = crp.gibbs_logps(2, m=1)
assert np.allclose(np.exp(P21), [2, 2, 1, 1.5])
assert np.allclose(np.exp(P21), norm([2, 2, 1, 1.5]))

K22 = crp.gibbs_tables(2, m=2)
assert K22 == [0, 2, 6, 7, 8]
P22 = crp.gibbs_logps(2, m=2)
assert np.allclose(np.exp(P22), [2, 2, 1, 1.5/2, 1.5/2])
assert np.allclose(np.exp(P22), norm([2, 2, 1, 1.5/2, 1.5/2]))

K23 = crp.gibbs_tables(2, m=3)
P23 = crp.gibbs_logps(2, m=3)
assert K23 == [0, 2, 6, 7, 8, 9]
assert np.allclose(np.exp(P23), [2, 2, 1, 1.5/3, 1.5/3, 1.5/3])
assert np.allclose(np.exp(P23), norm([2, 2, 1, 1.5/3, 1.5/3, 1.5/3]))

K51 = crp.gibbs_tables(5, m=1)
assert K51 == [0, 2, 6]
P51 = crp.gibbs_logps(5, m=1)
assert np.allclose(np.exp(P51), [2, 3, 1.5])
assert np.allclose(np.exp(P51), norm([2, 3, 1.5]))

K52 = crp.gibbs_tables(5, m=2)
assert K52 == [0, 2, 6, 7]
P52 = crp.gibbs_logps(5, m=2)
assert np.allclose(np.exp(P52), [2, 3, 1.5/2, 1.5/2])
assert np.allclose(np.exp(P52), norm([2, 3, 1.5/2, 1.5/2]))

K53 = crp.gibbs_tables(5, m=3)
P53 = crp.gibbs_logps(5, m=3)
assert K53 == [0, 2, 6, 7, 8]
assert np.allclose(np.exp(P53), [2, 3, 1.5/3, 1.5/3, 1.5/3])
assert np.allclose(np.exp(P53), norm([2, 3, 1.5/3, 1.5/3, 1.5/3]))


def test_crp_logpdf_score():
"""Ensure that logpdf_marginal agrees with sequence of predictives."""
crp = Crp(
outputs=[0], inputs=None, hypers={'alpha': 1.5}, rng=gu.gen_rng(1))
crp = Crp(outputs=[0], inputs=None, hypers={'alpha': 1.5, 'discount': 0.}, rng=gu.gen_rng(1))

assignments = [
(0, {0: 0}),
Expand Down Expand Up @@ -258,7 +263,7 @@ def test_crp_same_table_probability():
where kQ is list of tables in the CRP plus a fresh singleton.
"""
crp = Crp(
outputs=[0], inputs=None, hypers={'alpha': 1.5}, rng=gu.gen_rng(1))
outputs=[0], inputs=None, hypers={'alpha': 1.5, 'discount': 0.}, rng=gu.gen_rng(1))

assignments = [
(0, {0: 0}),
Expand Down Expand Up @@ -393,4 +398,5 @@ def get_tables_different(tables):

# Confirm no mutation has occured.
assert crp.data == crp_data_full
assert crp.logpdf_score() == logpdf_score_full
# Before, we had equality here, but that now fails on the 15th decimal.
assert np.allclose(crp.logpdf_score(), logpdf_score_full)
2 changes: 1 addition & 1 deletion tests/test_dependence_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_simple_dependence_constraint(Ci):
state.transition(N=10, kernels=['columns'], progress=0)
state.transition(
N=10,
kernels=['rows', 'alpha', 'column_hypers', 'alpha', 'view_alphas'],
kernels=['rows', 'structure_hypers', 'column_hypers', 'structure_hypers', 'view_structure_hypers'],
progress=False)
vu.validate_crp_constrained_partition(state.Zv(), Cd, Ci, {}, {})

Expand Down
12 changes: 6 additions & 6 deletions tests/test_iter_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ def test_individual_kernels():
rng = gu.gen_rng(0)
X = rng.normal(size=(5,5))
state = State(X, cctypes=['normal']*5)
state.transition(N=3, kernels=['alpha', 'rows'])
state.transition(N=3, kernels=['structure_hypers', 'rows'])
check_expected_counts(
state.diagnostics['iterations'],
{'alpha':3, 'rows':3})
state.transition(N=5, kernels=['view_alphas', 'column_params'])
{'structure_hypers':3, 'rows':3})
state.transition(N=5, kernels=['view_structure_hypers', 'column_params'])
check_expected_counts(
state.to_metadata()['diagnostics']['iterations'],
{'alpha':3, 'rows':3, 'view_alphas':5, 'column_params':5})
{'structure_hypers':3, 'rows':3, 'view_structure_hypers':5, 'column_params':5})
state.transition(
N=1, kernels=['view_alphas', 'column_params', 'column_hypers'])
N=1, kernels=['view_structure_hypers', 'column_params', 'column_hypers'])
check_expected_counts(
state.to_metadata()['diagnostics']['iterations'],
{'alpha':3, 'rows':3, 'view_alphas':6, 'column_params':6,
{'structure_hypers':3, 'rows':3, 'view_structure_hypers':6, 'column_params':6,
'column_hypers':1})


Expand Down
2 changes: 1 addition & 1 deletion tests/test_logpdf_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_logpdf_score_crash():
assert np.all(logpdf_score_initial < logpdf_likelihood_initial)
# assert np.all(logpdf_likelihood_initial < logpdf_score_initial)
engine.transition(N=100)
engine.transition(kernels=['column_hypers','view_alphas'], N=10)
engine.transition(kernels=['column_hypers','view_structure_hypers'], N=10)
logpdf_likelihood_final = np.asarray(engine.logpdf_likelihood())
logpdf_score_final = np.asarray(engine.logpdf_score())
assert np.all(logpdf_score_final < logpdf_likelihood_final)
Expand Down
Loading