Skip to content

Commit

Permalink
Even better shortcut
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Mar 27, 2024
1 parent 9edd3cf commit 6a2c386
Showing 1 changed file with 31 additions and 18 deletions.
49 changes: 31 additions & 18 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,42 +363,55 @@ def invert(x) -> tuple[np.ndarray, ...]:
logger.info("find_group_cohorts: cohorts is preferred, chunking is perfect.")
return "cohorts", chunks_cohorts

# Containment = |Q & S| / |Q|
# We'll use containment to measure degree of overlap between labels.
# Containment C = |Q & S| / |Q|
# - |X| is the cardinality of set X
# - Q is the query set being tested
# - S is the existing set
# We'll use containment to measure degree of overlap between labels. The bitmask
# matrix allows us to calculate this pretty efficiently.
asfloat = bitmask.astype(float)
# Note: While A.T @ A is a symmetric matrix, the division by chunks_per_label
# makes it non-symmetric.
# Note: We haven't normalized this yet, since we will first check sparsity.
# We don't need the actual containment value for this check.
dotproduct = asfloat.T @ asfloat

# The containment matrix is a measure of how much the labels overlap
# with each other. We treat the sparsity = (nnz/size) as a summary measure of the net overlap.
# The bitmask matrix S allows us to calculate this pretty efficiently using a dot product.
# S.T @ S / chunks_per_label
#
# We treat the sparsity(C) = (nnz/size) as a summary measure of the net overlap.
# 1. For high enough sparsity, there is a lot of overlap and we should use "map-reduce".
# 2. When labels are uniformly distributed amongst all chunks
# (and number of labels < chunk size), sparsity is 1.
# 3. Time grouping cohorts (e.g. dayofyear) appear as lines in this matrix.
# 4. When there are no overlaps at all between labels, containment is a block diagonal matrix
# (approximately).
MAX_SPARSITY_FOR_COHORTS = 0.6 # arbitrary
sparsity = dotproduct.nnz / math.prod(dotproduct.shape)
#
# However computing S.T @ S can still be the slowest step, especially if S
# is not particularly sparse. Empirically the sparsity( S.T @ S ) > min(1, 2 x sparsity(S)).
# So we use sparsity(S) as a shortcut.
MAX_SPARSITY_FOR_COHORTS = 0.4 # arbitrary
sparsity = bitmask.nnz / math.prod(bitmask.shape)
preferred_method: Literal["map-reduce"] | Literal["cohorts"]
logger.debug(
"sparsity of bitmask is {}, threshold is {}".format( # noqa
sparsity, MAX_SPARSITY_FOR_COHORTS
)
)
if sparsity > MAX_SPARSITY_FOR_COHORTS:
logger.info("sparsity is {}".format(sparsity)) # noqa
if not merge:
logger.info("find_group_cohorts: merge=False, choosing 'map-reduce'")
logger.info(
"find_group_cohorts: bitmask sparsity={}, merge=False, choosing 'map-reduce'".format( # noqa
sparsity
)
)
return "map-reduce", {}
preferred_method = "map-reduce"
else:
preferred_method = "cohorts"

# Now normalize the dotproduct to get containment
containment = csr_array(dotproduct / chunks_per_label)
# Note: While A.T @ A is a symmetric matrix, the division by chunks_per_label
# makes it non-symmetric.
asfloat = bitmask.astype(float)
containment = csr_array(asfloat.T @ asfloat / chunks_per_label)

logger.debug(
"sparsity of containment matrix is {}".format( # noqa
containment.nnz / math.prod(containment.shape)
)
)
# Use a threshold to force some merging. We do not use the filtered
# containment matrix for estimating "sparsity" because it is a bit
# hard to reason about.
Expand Down

0 comments on commit 6a2c386

Please sign in to comment.