Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multisets #1335

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion examples/ctc.dx
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ instance Ix(FenceAndPosts n) given (n|Ix)
False -> Posts $ unsafe_from_ordinal (intdiv2 o)

instance NonEmpty(FenceAndPosts n) given (n|Ix)
first_ix = unsafe_from_ordinal 0
pass

instance Eq(FenceAndPosts a) given (a|Ix|Eq)
def (==)(x, y) = case x of
Expand Down Expand Up @@ -220,3 +220,11 @@ or the paper.
sum for i:(Fin 3=>Vocab).
ls_to_f $ ctc blank logits i
> 0.5653746


'One major advantage of Dex is its parallelism-preserving autodiff.
The original CTC paper, and most CUDA implementations, used hand-written
reverse-mode derivatives. Dex should be able to
prodice an efficient one automatically. Let's check:

-- grad (\logits. ls_to_f $ ctc blank logits labels) logits
78 changes: 65 additions & 13 deletions lib/parser.dx
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ def run_parser_partial(s:String, parser:Parser a) -> Maybe a given (a) =

'## Primitive combinators

def p_char(c:Char) -> Parser () = MkParser \h.
def p_char(c:Char) -> Parser Char = MkParser \h.
i = get h.offset
c' = index_list h.input i
assert (c == c')
h.offset := i + 1
c'

def p_eof() ->> Parser () = MkParser \h.
assert $ get h.offset >= list_length h.input
Expand Down Expand Up @@ -99,15 +100,16 @@ def parse_digit() ->> Parser Int = try $ MkParser \h.
def optional(p:Parser a) -> Parser (Maybe a) given (a) =
(MkParser \h. Just (parse h p)) <|> returnP Nothing

def parse_many(parser:Parser a) -> Parser (List a) given (a|Data) = MkParser \h.
yield_state (AsList _ []) \results.
iter \_.
maybeVal = parse h $ optional parser
case maybeVal of
Nothing -> Done ()
Just x ->
push results x
Continue
def parse_many(parser:Parser a) -> Parser (List a) given (a|Data) =
MkParser \h.
yield_state (AsList _ []) \results.
iter \_.
maybeVal = parse h $ optional parser
case maybeVal of
Nothing -> Done ()
Just x ->
push results x
Continue

def parse_some(parser:Parser a) -> Parser (List a) given (a|Data) =
MkParser \h.
Expand All @@ -125,9 +127,9 @@ def parse_int() ->> Parser Int = MkParser \h.
x = parse h $ parse_unsigned_int
case negSign of
Nothing -> x
Just () -> (-1) * x
Just _ -> (-1) * x

def bracketed(l:Parser (), r:Parser (), body:Parser a) -> Parser a given (a) =
def bracketed(l:Parser Char, r:Parser Char, body:Parser a) -> Parser a given (a) =
MkParser \h.
_ = parse h l
ans = parse h body
Expand All @@ -137,8 +139,11 @@ def bracketed(l:Parser (), r:Parser (), body:Parser a) -> Parser a given (a) =
def parens(parser:Parser a) -> Parser a given (a) =
bracketed (p_char '(') (p_char ')') parser


'## String Utilities

def split(space:Char, s:String) -> List String =
def trailing_spaces(space:Parser (), body:Parser a) -> Parser a given (a) =
def trailing_spaces(space:Parser Char, body:Parser a) -> Parser a given (a) =
MkParser \h.
ans = parse h body
_ = parse h $ parse_many space
Expand All @@ -149,3 +154,50 @@ def split(space:Char, s:String) -> List String =
case run_parser s split_parser of
Just l -> l
Nothing -> AsList _ []

def join(space:a, strings:List(List a)) -> List a given (a|Data) =
AsList(n_string, string_table) = strings
yield_accum (ListMonoid a) \r.
for i:FenceAndInnerPosts(Fin n_string).
case i of
Fence j -> r += string_table[j]
InnerPost _ -> r += AsList(1, [space])

def find_first(word:m=>a, text:n=>a) -> Maybe n given (n|Ix, m|Ix, a|Eq) =
-- This implementation has O(nm) complexity, could be O(m + n).
-- Could maybe be nicer using the AllButLast index set.
case (size m > size n) || (size m == 0) of
True -> Nothing
False ->
bounded_iter (unsafe_nat_diff (size n) (size m)) Nothing \i.
next_substring = for j:m. text[unsafe_from_ordinal (i + ordinal j)]
case word == next_substring of
True -> Done $ Just (unsafe_from_ordinal i)
False -> Continue

-- put in prelude? Nothing particular to strings.
def split_at(xs:n=>a, at:Post n) -> (List a, List a) given (n|Ix, a) =
size_left = ordinal at
size_right = unsafe_nat_diff (size n) size_left
left = AsList(size_left, for i. xs[unsafe_from_ordinal (ordinal i)])
right = AsList(size_right, for i. xs[unsafe_from_ordinal (ordinal i + size_left)])
(left, right)

def unsafe_replace_at(text:n=>a, new:List a, old_len:Nat, at_ix:n) -> List a given (n|Ix, a|Data)=
AsList(m, new_tab) = new
(beginning_and_middle, end) = split_at text (unsafe_from_ordinal (ordinal at_ix + old_len))
AsList(_, beginning_and_middle_tab) = beginning_and_middle
(beginning, _) = split_at beginning_and_middle_tab (unsafe_from_ordinal (ordinal (left_post at_ix)))
concat [beginning, new, end]

def replace_all(old:List a, to_replace:List a, new:List a) -> List a given (a|Eq|Data) =
AsList(_, new_table) = new
AsList(to_replace_length, to_replace_table) = to_replace
yield_state old \cur_str.
while \.
AsList(_, cur_str_table) = get cur_str
case find_first to_replace_table cur_str_table of
Nothing -> False
Just i ->
cur_str := unsafe_replace_at cur_str_table new to_replace_length i
True
110 changes: 100 additions & 10 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -874,27 +874,27 @@ instance Ix(Maybe a) given (a|Ix)
True -> Nothing

interface NonEmpty(n|Ix)
first_ix : n
pass

instance NonEmpty(())
first_ix = unsafe_from_ordinal(0)
pass

instance NonEmpty(Bool)
first_ix = unsafe_from_ordinal 0
pass

instance NonEmpty((a,b)) given (a|NonEmpty, b|NonEmpty)
first_ix = unsafe_from_ordinal 0
pass

instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix)
first_ix = unsafe_from_ordinal 0
pass

-- The below instance is valid, but causes "multiple candidate dictionaries"
-- errors if both Left and Right are NonEmpty.
-- instance NonEmpty (a|b) given {a b} [Ix a, NonEmpty b]
-- first_ix = unsafe_from_ordinal _ 0
-- pass

instance NonEmpty(Maybe a) given (a|Ix)
first_ix = unsafe_from_ordinal 0
pass

'## Fencepost index sets

Expand Down Expand Up @@ -924,11 +924,14 @@ def right_fence(p:Post n) -> Maybe n given (n|Ix) =
then Nothing
else Just $ unsafe_from_ordinal ix

def first_ix() ->> n given (n|NonEmpty) =
unsafe_from_ordinal(0)

def last_ix() ->> n given (n|NonEmpty) =
unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1))

instance NonEmpty(Post n) given (n|Ix)
first_ix = unsafe_from_ordinal(n=Post n, 0)
pass

def scan(
init:a,
Expand Down Expand Up @@ -1704,7 +1707,7 @@ def from_ordinal(i:Nat) -> n given (n|Ix) =
False -> error $ from_ordinal_error(i, size n)

-- TODO: should this be called `from_ordinal`?
def to_ix(i:Nat) -> Maybe n given (n|Ix) =
def to_ix(i:Nat) -> Maybe n given (n|Ix) =
case i < size n of
True -> Just $ unsafe_from_ordinal i
False -> Nothing
Expand Down Expand Up @@ -2266,6 +2269,93 @@ instance Subset(b, Either(a,b)) given (a|Data, b|Data)
Left( x) -> error "Can't project Left branch to Right branch"
Right(x) -> x

instance Subset(n=>a, n=>b) given (n|Ix, a|Data, b|Data) (Subset a b)
def inject'(xs) = for i. inject xs[i]
def project'(xs') =
xs = for i. project xs'[i]
case any_sat(is_nothing, xs) of
True -> Nothing
False -> Just $ each xs from_just
def unsafe_project'(xs') =
xs = for i. project xs'[i]
case any_sat(is_nothing, xs) of
True -> error "Couldn't project table."
False -> each xs from_just

-- add instance for subset n=>a m=>a given subset n m

instance Subset(List a, List b) given (a|Data, b|Data) (Subset a b)
def inject'(xs') =
AsList(n, xs) = xs'
AsList(n, inject xs)
def project'(l) =
AsList(n, tab) = l
case project tab of
Nothing -> Nothing
Just xs -> Just AsList(n, xs)
def unsafe_project'(l) =
AsList(n, tab) = l
case project tab of
Nothing -> error "Couldn't project list."
Just xs -> AsList(n, xs)

'### All but Last Index set
All the indices of the original index set except the last one.

struct AllButLast(n:Nat, a|Ix) =
val : a

instance Ix(AllButLast n a) given (n:Nat, a|Ix|Data)
def size'() = (size a) -| n
def ordinal(i) = ordinal i.val
def unsafe_from_ordinal(o) = AllButLast $ unsafe_from_ordinal o

instance Subset(AllButLast n a, a) given (n:Nat, a|Ix)
def inject'(x) = x.val
def project'(x) = case (ordinal x) < ((size a) -| n) of
True -> Just (AllButLast x)
False -> Nothing
def unsafe_project'(x) = (AllButLast x)

instance Eq(AllButLast n a) given (n:Nat, a|Eq|Ix)
def (==)(x, y) = x.val == y.val

def unsafe_increment(i:n) -> n given (n|Ix) = from_ordinal (ordinal i + 1)
def next(i: AllButLast 1 n) -> n given (n|Ix) = unsafe_increment i.val
def get_next_m(tab:n=>a, i:AllButLast m n) -> List a given (n|Ix, m:Nat, a) =
-- The list returned always has size (n - m), but can't spell that yet.
to_list $ for j:(Fin m). tab[unsafe_from_ordinal (ordinal i + ordinal j)]

'### Fence and Inner Posts
A custom datatype and index set
that interleaves the elements of a table with another set
of values representing all the spaces in between those elements,
not including the 2 ends.

data FenceAndInnerPosts(n|Ix) =
Fence(n)
InnerPost(AllButLast 1 n)

instance Ix(FenceAndInnerPosts n) given (n|Ix)
def size'() = 2 * size n -| 1
def ordinal(i) = case i of
Fence j -> 2 * ordinal j
InnerPost j -> 2 * ordinal j + 1
def unsafe_from_ordinal(o) =
case is_odd o of
False -> Fence $ unsafe_from_ordinal (intdiv2 o)
True -> InnerPost $ unsafe_from_ordinal (intdiv2 o)

instance Eq(FenceAndInnerPosts a) given (a|Ix|Eq)
def (==)(x, y) = case x of
Fence x -> case y of
Fence y -> x == y
InnerPost y -> False
InnerPost x -> case y of
Fence y -> False
InnerPost y -> x == y


'### Index set for tables

def int_to_reversed_digits(k:Nat) -> a=>b given (a|Ix, b|Ix) =
Expand All @@ -2291,7 +2381,7 @@ instance Ix(a=>b) given (a|Ix, b|Ix)
def unsafe_from_ordinal(i) = int_to_reversed_digits i

instance NonEmpty(a=>b) given (a|Ix, b|NonEmpty)
first_ix = unsafe_from_ordinal 0
pass

'### Stack

Expand Down
Loading
Loading