Skip to content

Commit

Permalink
Add multisets
Browse files Browse the repository at this point in the history
  • Loading branch information
duvenaud committed Aug 22, 2023
1 parent 3cbde4c commit 4302fdd
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 8 deletions.
70 changes: 62 additions & 8 deletions lib/set.dx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
'# Sets and Set-Indexed Arrays
'# Sets, Set-Indexed Arrays, and Multisets

import sort

Expand All @@ -21,25 +21,42 @@ def all_except_last(xs:n=>a) -> List a given (n|Ix, a) =
allButLast = for i:shortSize. xs[unsafe_from_ordinal (ordinal i)]
AsList _ allButLast

def merge_unique_sorted_lists(xlist:List a, ylist:List a) -> List a given (a|Eq) =
def all_except_first(xs:n=>a) -> List a given (n|Ix, a) =
shortSize = Fin (size n -| 1)
allButFirst = for i:shortSize. xs[unsafe_from_ordinal (1 + ordinal i)]
(AsList _ allButFirst)

def merge_unique_sorted_lists_with_aux(
combine_side_info : (side, side)->side,
xlist:List (a, side),
ylist:List (a, side)) -> List (a, side) given (a|Eq, side|Data) =
-- This function is associative, for use in a monoidal reduction.
-- Assumes all xs are <= all ys.
-- The element at the end of xs might equal the
-- element at the beginning of ys. If so, this
-- function removes the duplicate when concatenating the lists.
AsList(nx, xs) = xlist
AsList(_ , ys) = ylist
AsList(_, ys) = ylist
case last xs of
Nothing -> ylist
Just last_x -> case first ys of
Nothing -> xlist
Just first_y -> case last_x == first_y of
False -> concat [xlist, ylist]
True -> concat [all_except_last xs, ylist]
Just first_y ->
(last_x_inner, last_x_side) = last_x
(first_y_inner, first_y_side) = first_y
case last_x_inner == first_y_inner of
False -> xlist <> ylist
True ->
combined = AsList 1 [(first_y_inner, combine_side_info last_x_side first_y_side)]
all_except_last xs <> combined <> all_except_first ys

def remove_duplicates_from_sorted(xs:n=>a) -> List a given (n|Ix, a|Eq) =
xlists = for i:n. (AsList 1 [xs[i]])
reduce (AsList 0 []) merge_unique_sorted_lists xlists
-- Special case for ordinary sets, which don't have any side information.
xlist = for i. AsList(_, [(xs[i], ())])
ignore = \a b. ()
AsList(_, set_with_aux) =
reduce AsList(0, []) (\x y. merge_unique_sorted_lists_with_aux(ignore, x, y)) xlist
AsList(_, for i. fst set_with_aux[i])


'## Sets
Expand Down Expand Up @@ -119,3 +136,40 @@ instance Eq(Element set) given (a|Ord, set:Set a)
instance Ord(Element set) given (a|Ord, set:Set a)
def (<)(ix1, ix2) = ordinal ix1 < ordinal ix2
def (>)(ix1, ix2) = ordinal ix1 > ordinal ix2


'## Multisets

def remove_duplicates_from_sorted_with_counts(xs:n=>(a, Nat)) ->
List (a, Nat) given (n|Ix, a|Eq|Data) =
xlists = for i. AsList(1, [xs[i]])
reduce AsList(_, []) (\x y. merge_unique_sorted_lists_with_aux(\a b. a + b, x, y)) xlists

data Multiset(a|Ord) =
-- Guaranteed to be in sorted order,
-- as long as no one else uses this constructor.
-- Instead use the "to_multiset" function below.
UnsafeAsMultiset(n:Nat, elements:(Fin n => (a, Nat)))

def to_multiset(xs:n=>a) -> Multiset a given (n|Ix, a|Ord) =
sorted_xs = sort xs
sorted_xs_with_1s = for i. (sorted_xs[i], 1)
AsList(n', unique_xs) = remove_duplicates_from_sorted_with_counts sorted_xs_with_1s
UnsafeAsMultiset n' unique_xs

instance Eq(Multiset a) given (a|Ord)
def (==)(sx, sy) =
UnsafeAsMultiset(_, xs) = sx
UnsafeAsMultiset(_, ys) = sy
(AsList _ xs) == (AsList _ ys)

def multiset_add(sx:Multiset a, sy:Multiset a) -> Multiset a given (a|Ord) =
UnsafeAsMultiset(nx, xs) = sx
UnsafeAsMultiset(ny, ys) = sy
combined = merge_sorted_tables xs ys
AsList(_, unique_xs) = remove_duplicates_from_sorted_with_counts combined
UnsafeAsMultiset(_, unique_xs)

instance Add(Multiset a) given (a|Ord)
def (+)(sx, sy) = multiset_add sx sy
zero = to_multiset []
48 changes: 48 additions & 0 deletions tests/set-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,51 @@ setix : Person = from_just $ member "Bob" names2
setix2 : Person = from_just $ member "Charlie" names2
:p setix2
> Element(2)


#### Multiset tests

-- check order invariance.
:p (to_multiset ["Bob", "Alice", "Bob", "Charlie"]) == (to_multiset ["Charlie", "Bob", "Alice", "Bob"])
> True

-- check counts matter for equality.
:p (to_multiset ["Bob", "Alice", "Bob", "Charlie"]) == (to_multiset ["Charlie", "Bob", "Alice"])
> False

multiset1 = to_multiset ["Xeno", "Alice", "Bob", "Bob"]
multiset2 = to_multiset ["Bob", "Xeno", "Charlie", "Xeno", "Alice"]

:p multiset1 == multiset2
> False

:p multiset1 + multiset2
> (UnsafeAsMultiset 4 [("Alice", 2), ("Bob", 3), ("Charlie", 1), ("Xeno", 3)])

:p multiset_intersect multiset1 multiset2
> (UnsafeAsSet 2 ["Bob", "Xeno"])

:p multiset1 == (multiset1 + multiset1)
> False

:p multiset1 == (multiset_intersect multiset1 multiset1)
> True

'#### Empty multiset tests

emptymultiset = to_multiset ([]::(Fin 0)=>String)

:p emptymultiset == emptymultiset
> True

:p emptymultiset == (emptymultiset + emptymultiset)
> True

:p emptymultiset == (multiset_intersect emptymultiset emptymultiset)
> True

:p multiset1 == (multiset1 + emptymultiset)
> True

:p emptymultiset == (multiset_intersect multiset1 emptymultiset)
> True

0 comments on commit 4302fdd

Please sign in to comment.