Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement array patterns #1912

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions core/src/parser/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ PatternF<F>: Pattern = {
#[inline]
PatternDataF<F>: PatternData = {
RecordPattern => PatternData::Record(<>),
ArrayPattern => PatternData::Array(<>),
ConstantPattern => PatternData::Constant(<>),
EnumPatternF<F> => PatternData::Enum(<>),
Ident => PatternData::Any(<>),
Expand Down Expand Up @@ -597,15 +598,15 @@ RecordPattern: RecordPattern = {
let tail = match last {
Some(LastPattern::Normal(m)) => {
field_pats.push(*m);
RecordPatternTail::Empty
TailPattern::Empty
},
Some(LastPattern::Ellipsis(Some(captured))) => {
RecordPatternTail::Capture(captured)
TailPattern::Capture(captured)
}
Some(LastPattern::Ellipsis(None)) => {
RecordPatternTail::Open
TailPattern::Open
}
None => RecordPatternTail::Empty,
None => TailPattern::Empty,
};

let pattern = RecordPattern {
Expand All @@ -619,6 +620,32 @@ RecordPattern: RecordPattern = {
},
};

ArrayPattern: ArrayPattern = {
<start: @L> "[" <mut patterns: (<Pattern> ",")*> <last: LastElemPat?> "]" <end: @R> => {
let tail = match last {
Some(LastPattern::Normal(m)) => {
patterns.push(*m);
TailPattern::Empty
},
Some(LastPattern::Ellipsis(Some(captured))) => {
TailPattern::Capture(captured)
}
Some(LastPattern::Ellipsis(None)) => {
TailPattern::Open
}
None => TailPattern::Empty,
};

let pattern = ArrayPattern{
patterns,
tail,
pos: mk_pos(src_id, start, end)
};

pattern
},
};

EnumPatternF<F>: EnumPattern = {
<start: @L> <tag: EnumTag> <end: @R> => EnumPattern {
tag,
Expand Down Expand Up @@ -663,12 +690,18 @@ FieldPattern: FieldPattern = {
},
};

// Last field of a pattern
// Last field pattern of a record pattern
LastFieldPat: LastPattern<FieldPattern> = {
FieldPattern => LastPattern::Normal(Box::new(<>)),
".." <Ident?> => LastPattern::Ellipsis(<>),
};

// Last pattern of an array pattern
LastElemPat: LastPattern<Pattern> = {
Pattern => LastPattern::Normal(Box::new(<>)),
".." <Ident?> => LastPattern::Ellipsis(<>),
}

// A default annotation in a pattern.
DefaultAnnot: RichTerm = "?" <t: Term> => t;

Expand Down
37 changes: 34 additions & 3 deletions core/src/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ where
PatternData::Wildcard => allocator.text("_"),
PatternData::Any(id) => allocator.as_string(id),
PatternData::Record(rp) => rp.pretty(allocator),
PatternData::Array(ap) => ap.pretty(allocator),
PatternData::Enum(evp) => evp.pretty(allocator),
PatternData::Constant(cp) => cp.pretty(allocator),
}
Expand Down Expand Up @@ -688,9 +689,9 @@ where
allocator.line()
),
match tail {
RecordPatternTail::Empty => allocator.nil(),
RecordPatternTail::Open => docs![allocator, allocator.line(), ".."],
RecordPatternTail::Capture(id) =>
TailPattern::Empty => allocator.nil(),
TailPattern::Open => docs![allocator, allocator.line(), ".."],
TailPattern::Capture(id) =>
docs![allocator, allocator.line(), "..", id.ident().to_string()],
},
]
Expand All @@ -701,6 +702,36 @@ where
}
}

impl<'a, D, A> Pretty<'a, D, A> for &ArrayPattern
where
D: NickelAllocatorExt<'a, A>,
D::Doc: Clone,
A: Clone + 'a,
{
fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
docs![
allocator,
allocator.intersperse(
self.patterns.iter(),
docs![allocator, ",", allocator.line()],
),
if !self.patterns.is_empty() && self.is_open() {
docs![allocator, ",", allocator.line()]
} else {
allocator.nil()
},
match self.tail {
TailPattern::Empty => allocator.nil(),
TailPattern::Open => allocator.text(".."),
TailPattern::Capture(id) => docs![allocator, "..", id.ident().to_string()],
},
]
.nest(2)
.brackets()
.group()
}
}

impl<'a, D, A> Pretty<'a, D, A> for &RichTerm
where
D: NickelAllocatorExt<'a, A>,
Expand Down
177 changes: 171 additions & 6 deletions core/src/term/pattern/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ use super::*;
use crate::{
mk_app,
term::{
make, record::FieldMetadata, BinaryOp, MatchBranch, MatchData, RecordExtKind, RecordOpKind,
RichTerm, Term, UnaryOp,
make, record::FieldMetadata, BinaryOp, MatchBranch, MatchData, NAryOp, RecordExtKind,
RecordOpKind, RichTerm, Term, UnaryOp,
},
};

Expand Down Expand Up @@ -258,6 +258,7 @@ impl CompilePart for PatternData {
insert_binding(*id, value_id, bindings_id)
}
PatternData::Record(pat) => pat.compile_part(value_id, bindings_id),
PatternData::Array(pat) => pat.compile_part(value_id, bindings_id),
PatternData::Enum(pat) => pat.compile_part(value_id, bindings_id),
PatternData::Constant(pat) => pat.compile_part(value_id, bindings_id),
}
Expand Down Expand Up @@ -312,7 +313,7 @@ impl CompilePart for RecordPattern {
//
// We don't have tuples, and to avoid adding an indirection (by storing the current state
// as `{rest, bindings}` where bindings itself is a record), we store this rest alongside
// the bindings in a special field which is a freshly generated indentifier. This is an
// the bindings in a special field which is a freshly generated identifier. This is an
// implementation detail which isn't very hard to change, should we have to.
//
// if %typeof% value_id == 'Record
Expand Down Expand Up @@ -519,7 +520,7 @@ impl CompilePart for RecordPattern {
// null
// else
// %record_remove% "<REST>" final_bindings_id
RecordPatternTail::Empty => make::if_then_else(
TailPattern::Empty => make::if_then_else(
make::op1(
UnaryOp::BoolNot(),
make::op2(
Expand All @@ -539,7 +540,7 @@ impl CompilePart for RecordPattern {
// final_bindings_id
// (%static_access% <REST_FIELD> final_bindings_id)
// )
RecordPatternTail::Capture(rest) => make::op2(
TailPattern::Capture(rest) => make::op2(
BinaryOp::DynRemove(RecordOpKind::ConsiderAllFields),
Term::Str(rest_field.into()),
mk_app!(
Expand All @@ -555,7 +556,7 @@ impl CompilePart for RecordPattern {
),
),
// %record_remove% "<REST>" final_bindings_id
RecordPatternTail::Open => bindings_without_rest,
TailPattern::Open => bindings_without_rest,
};

// the last `final_bindings_id != null` guard:
Expand All @@ -579,6 +580,170 @@ impl CompilePart for RecordPattern {
}
}

impl CompilePart for ArrayPattern {
// Compilation of an array pattern.
//
// let value_len = %array_length% value_id in
//
// <if self.is_open()>
// if %typeof% value_id == 'Array && value_len >= <self.patterns.len()>
// <else>
// if %typeof% value_id == 'Array && value_len == <self.patterns.len()>
// <end if>
//
// let final_bindings_id =
// <fold (idx, elem_pat) in 0..self.patterns.len()
// - cont is the accumulator
// - initial accumulator is `bindings_id`
// >
//
// let local_bindings_id = cont in
// if local_bindings_id == null then
// null
// else
// let local_value_id = %array_access% <idx> value_id in
// <elem_pat.compile_part(local_value_id, local_bindings_id)>
//
// <end fold>
// in
//
// if final_bindings_id == null then
// null
// else
// <if self.tail is capture(rest)>
// %record_insert%
// <rest>
// final_bindings_id
// (%array_slice% <self.patterns.len()> value_len value_id)
// <else>
// final_bindings_id
// <end if>
// else
// null
fn compile_part(&self, value_id: LocIdent, bindings_id: LocIdent) -> RichTerm {
let value_len_id = LocIdent::fresh();
let pats_len = Term::Num(self.patterns.len().into());

// <fold (idx) in 0..self.patterns.len()
// - cont is the accumulator
// - initial accumulator is `bindings_id`
// >
//
// let local_bindings_id = cont in
// if local_bindings_id == null then
// null
// else
// let local_value_id = %array_access% <idx> value_id in
// <self.patterns[idx].compile_part(local_value_id, local_bindings_id)>
//
// <end fold>
let fold_block: RichTerm = self.patterns.iter().enumerate().fold(
Term::Var(bindings_id).into(),
|cont, (idx, elem_pat)| {
let local_bindings_id = LocIdent::fresh();
let local_value_id = LocIdent::fresh();

// <self.patterns[idx].compile_part(local_value_id, local_bindings_id)>
let updated_bindings_let = elem_pat.compile_part(local_value_id, local_bindings_id);

// %array_access% idx value_id
let extracted_value = make::op2(
BinaryOp::ArrayElemAt(),
Term::Var(value_id),
Term::Num(idx.into()),
);

// let local_value_id = <extracted_value> in <updated_bindings_let>
let inner_else_block =
make::let_in(local_value_id, extracted_value, updated_bindings_let);

// The innermost if:
//
// if local_bindings_id == null then
// null
// else
// <inner_else_block>
let inner_if = make::if_then_else(
make::op2(BinaryOp::Eq(), Term::Var(local_bindings_id), Term::Null),
Term::Null,
inner_else_block,
);

// let local_bindings_id = cont in <inner_if>
make::let_in(local_bindings_id, cont, inner_if)
},
);

// %typeof% value_id == 'Array
let is_array: RichTerm = make::op2(
BinaryOp::Eq(),
make::op1(UnaryOp::Typeof(), Term::Var(value_id)),
Term::Enum("Array".into()),
);

let comp_op = if self.is_open() {
BinaryOp::GreaterOrEq()
} else {
BinaryOp::Eq()
};

// <is_array> && value_len <comp_op> <self.patterns.len()>
let outer_check = mk_app!(
make::op1(UnaryOp::BoolAnd(), is_array),
make::op2(comp_op, Term::Var(value_len_id), pats_len.clone(),)
);

let final_bindings_id = LocIdent::fresh();

// the else block which depends on the tail of the record pattern
let tail_block = match self.tail {
// final_bindings_id
TailPattern::Empty | TailPattern::Open => make::var(final_bindings_id),
// %record_insert%
// <rest>
// final_bindings_id
// (%array_slice% <self.patterns.len()> value_len value_id)
TailPattern::Capture(rest) => mk_app!(
make::op2(
record_insert(),
Term::Str(rest.label().into()),
Term::Var(final_bindings_id),
),
make::opn(
NAryOp::ArraySlice(),
vec![pats_len, Term::Var(value_len_id), Term::Var(value_id)]
)
),
};

// the last `final_bindings_id != null` guard:
//
// if final_bindings_id == null then
// null
// else
// <tail_block>
let guard_tail_block = make::if_then_else(
make::op2(BinaryOp::Eq(), Term::Var(final_bindings_id), Term::Null),
Term::Null,
tail_block,
);

// The let enclosing the fold block and the let binding `final_bindings_id`:
// let final_bindings_id = <fold_block> in <tail_block>
let outer_let = make::let_in(final_bindings_id, fold_block, guard_tail_block);

// if <outer_check> then <outer_let> else null
let outer_if = make::if_then_else(outer_check, outer_let, Term::Null);

// finally, we need to bind `value_len_id` to the length of the array
make::let_in(
value_len_id,
make::op1(UnaryOp::ArrayLength(), Term::Var(value_id)),
outer_if,
)
}
}

impl CompilePart for EnumPattern {
fn compile_part(&self, value_id: LocIdent, bindings_id: LocIdent) -> RichTerm {
// %enum_get_tag% value_id == '<self.tag>
Expand Down
Loading
Loading