-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Joint into multivariate Delta and Contraction #169
Conversation
Can you confirm that I forget the original set of terms for which |
Yes, via the definition of |
All tests pass, so I'm going to merge this into #157 to keep the git history from getting any uglier. |
* Add Contraction normal form for multilinear ops and normalize interpretation * switch to variadic patterns and remove contraction subclasses * add tests and push down unary reduce * add joint smoke tests * add plated einsum test cases * update negation test cases * add joint-gaussian smoke tests * make all tests pass * nit * add moment matching pattern * nits * add FUNSOR_NORMALIZE flag to interpreter * pull in changes * Update eager patterns in contraction and interpretation resolution order in eager * nit * Add unit propagation * remove dead code * remove __future__ impotrs * skip adjoint tests * Fix tests and lint * ignore annoying test case * lint * fix examples * Refactor optimizer to use Contraction and normalize (#165) * Add a type-aware eager contraction evaluator * Update optimizer to use Contraction * Remove funsor.contract * Make tests pass and resolve recursion issues * revert file removal * Refactor Affine into Contraction (#173) * Remove Affine * remove duplicated test * nit * generalize is_affine * register Funsor reciprocal * fix recursion error and joint test * pull in some upstream changes to cnf * lint * cons-cache result of normalize interpretation * update naive_contract_einsum * Add pass and fail cases to test_reduce_subset * tweak sequential test * Pull in changes from cnf-joint that are not joint-related * Pull in cnf-joint changes to cnf.py that are unrelated to joint * Refactor Joint into multivariate Delta and Contraction (#169) * start removing some joint patterns * Update joint smoke tests * add a rule for permuting joint inputs * make more joint and gaussian tests pass * Add multidelta term * remove Joint, integrator, dead code * remove remaining Joint appearances * remove duplicate test cases * fix smoke tests * lint * remove duplicate moment matching test * make commutativity pattern less of a hack * fix bug in delta * move joint patterns to joint.py * remove redundant pattern * remove Delta entirely in favor of MultiDelta * refactor MultiDelta to have a single log-density tensor * have Tensor.unscaled_sample return a single MultiDelta * revert Tensor.unscaled_sample to Delta * fix moment matching * lint * remove incorrect tensor contraction * another attempt at scaling * removed faulty pattern that was causing gaussian integration tests to fail * fix one bug in minipyro and expose another * fix minipyro.Distribution.expand_inputs * increase tolerance in sequential_sum_product test * fix a couple more tests * fix wrong log pattern * sketch independent? * All integrate tests pass * Add basic align method to Contraction * nit * remove inplace op in reciprocal * fix advanced indexing tests * fix independent * Add normalize patterns for eliminating log(exp), exp(log), neg(neg), and remove cons-caching in normalize interpreter * fix smoke tests * fix adjoint * use normalize when computing adjoints * Squashed commit of the following: commit c8b851615cbf7c3da9526cc42383805fda34464b Author: Eli <[email protected]> Date: Sat Aug 31 19:39:32 2019 -0700 fix fusion condition commit 211379337db7912f70ce09e9ad4903f27b2e1c0e Author: Eli <[email protected]> Date: Sat Aug 31 19:12:42 2019 -0700 fix affine commit 31b2e4dc7680e428bcbb6cb422ebbf3cc8483086 Merge: 7d18851 6fca2ba Author: Eli <[email protected]> Date: Mon Aug 19 14:14:12 2019 -0700 Merge branch 'contraction-normal-form' into separate-unfold-from-cnf commit 7d1885129d262fbb0c6c7d27360987a09c0da78f Merge: 1edfa18 4a6cc0c Author: Eli <[email protected]> Date: Fri Aug 16 21:23:09 2019 -0700 Merge branch 'contraction-normal-form' into separate-unfold-from-cnf commit 1edfa1880ff9361b909d5ee6c906dc69ad344560 Author: Eli <[email protected]> Date: Fri Aug 16 21:17:57 2019 -0700 change semiring commit 354277b3249a8e0d6a93ecc5d2c954bd8f43fd10 Merge: 2038544 2575e97 Author: Eli <[email protected]> Date: Fri Aug 16 18:50:36 2019 -0700 Merge branch 'contraction-normal-form' into separate-unfold-from-cnf commit 2038544ac0314153b8b3a4527b8aa7c33b358366 Author: Eli <[email protected]> Date: Fri Aug 16 18:46:28 2019 -0700 change to new api commit 69c5f639b1c24cde651d310110f8967eb80da12d Merge: 6e33dd4 6cc4a1f Author: Eli <[email protected]> Date: Fri Aug 16 18:09:16 2019 -0700 Merge branch 'contraction-normal-form' into separate-unfold-from-cnf commit 6e33dd42a76d1b0aef6cd58f8d3740974e58f652 Author: Eli <[email protected]> Date: Thu Aug 15 14:42:46 2019 -0700 separate unfolding from normalization * fix merge errors * update moment matching to match master * use pattern matching in cnf * add number-gaussian mixture * simplify some patterns * remove some obsolete patterns * remove obsolete sample patterns * remove obsolete joint pattern * switch MultiDelta to one density per part * remove dead code * tweak persistently failing tests * use nested union types * make moment matching pattern more specific * fixed infinite loop in Integrate evaluation, now seeing different errors in adjoint and einsum * fix regressions in adjoint and einsum * nit * patch test_reduce_subset * xfail test_joint.py::test_reduce_moment_matching_moment for missing patterns, all other tests and examples now pass * normalize by default * move all joint-specific patterns out of cnf and into joint * revert code movement, is there an import missing somewhere? * rename AnyOp to NullOp * remove Delta alias, rename MultiDelta to Delta * remove FUNSOR_NORMALIZE flag and fix error type in cat * fix error in Integrate(Gaussian, Variable) * add missing Independent patterns, vae example and bart example pass smoke tests * fix test_normal_independent * address review comments * remove commented code * fix assertion * remove comment
This PR removes
Joint
in favor of a normalized multivariateDelta
which is closed under addition andContraction
from #156 and internally represents all log-densities with a single tensor. It also deletes lots of duplicate test code made redundant by the removal ofJoint
, while preserving behavior in all existing tests that exercisedJoint
.This is quite a large PR with lots of changes since
Joint
was used throughout the codebase. I'll add a review/merge plan for this series of PRs to the master issue #156.It's almost complete, except for these remaining todos:
moment_matching
so that tests passMultiDelta
Independent
forMultiDelta
,Contraction