Skip to content

Commit

Permalink
Implement array patterns
Browse files Browse the repository at this point in the history
This commit adds a new calss of patterns: array patterns. This is a
natural extension, both syntactically and semantically, of existing data
structure patterns in Nickel (in particular of records). Similarly,
arrays pattern can also capture the rest of the pattern (the tail of the
array that hasn't been matched yet) and bind it to a variable.
  • Loading branch information
yannham committed May 13, 2024
1 parent d2fbf28 commit a7eacc3
Show file tree
Hide file tree
Showing 12 changed files with 498 additions and 34 deletions.
43 changes: 38 additions & 5 deletions core/src/parser/grammar.lalrpop
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ PatternF<F>: Pattern = {
#[inline]
PatternDataF<F>: PatternData = {
RecordPattern => PatternData::Record(<>),
ArrayPattern => PatternData::Array(<>),
ConstantPattern => PatternData::Constant(<>),
EnumPatternF<F> => PatternData::Enum(<>),
Ident => PatternData::Any(<>),
Expand Down Expand Up @@ -597,15 +598,15 @@ RecordPattern: RecordPattern = {
let tail = match last {
Some(LastPattern::Normal(m)) => {
field_pats.push(*m);
RecordPatternTail::Empty
TailPattern::Empty
},
Some(LastPattern::Ellipsis(Some(captured))) => {
RecordPatternTail::Capture(captured)
TailPattern::Capture(captured)
}
Some(LastPattern::Ellipsis(None)) => {
RecordPatternTail::Open
TailPattern::Open
}
None => RecordPatternTail::Empty,
None => TailPattern::Empty,
};

let pattern = RecordPattern {
Expand All @@ -619,6 +620,32 @@ RecordPattern: RecordPattern = {
},
};

ArrayPattern: ArrayPattern = {
<start: @L> "[" <mut patterns: (<Pattern> ",")*> <last: LastElemPat?> "]" <end: @R> => {
let tail = match last {
Some(LastPattern::Normal(m)) => {
patterns.push(*m);
TailPattern::Empty
},
Some(LastPattern::Ellipsis(Some(captured))) => {
TailPattern::Capture(captured)
}
Some(LastPattern::Ellipsis(None)) => {
TailPattern::Open
}
None => TailPattern::Empty,
};

let pattern = ArrayPattern{
patterns,
tail,
pos: mk_pos(src_id, start, end)
};

pattern
},
};

EnumPatternF<F>: EnumPattern = {
<start: @L> <tag: EnumTag> <end: @R> => EnumPattern {
tag,
Expand Down Expand Up @@ -663,12 +690,18 @@ FieldPattern: FieldPattern = {
},
};

// Last field of a pattern
// Last field pattern of a record pattern
LastFieldPat: LastPattern<FieldPattern> = {
FieldPattern => LastPattern::Normal(Box::new(<>)),
".." <Ident?> => LastPattern::Ellipsis(<>),
};

// Last pattern of an array pattern
LastElemPat: LastPattern<Pattern> = {
Pattern => LastPattern::Normal(Box::new(<>)),
".." <Ident?> => LastPattern::Ellipsis(<>),
}

// A default annotation in a pattern.
DefaultAnnot: RichTerm = "?" <t: Term> => t;

Expand Down
37 changes: 34 additions & 3 deletions core/src/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ where
PatternData::Wildcard => allocator.text("_"),
PatternData::Any(id) => allocator.as_string(id),
PatternData::Record(rp) => rp.pretty(allocator),
PatternData::Array(ap) => ap.pretty(allocator),
PatternData::Enum(evp) => evp.pretty(allocator),
PatternData::Constant(cp) => cp.pretty(allocator),
}
Expand Down Expand Up @@ -688,9 +689,9 @@ where
allocator.line()
),
match tail {
RecordPatternTail::Empty => allocator.nil(),
RecordPatternTail::Open => docs![allocator, allocator.line(), ".."],
RecordPatternTail::Capture(id) =>
TailPattern::Empty => allocator.nil(),
TailPattern::Open => docs![allocator, allocator.line(), ".."],
TailPattern::Capture(id) =>
docs![allocator, allocator.line(), "..", id.ident().to_string()],
},
]
Expand All @@ -701,6 +702,36 @@ where
}
}

impl<'a, D, A> Pretty<'a, D, A> for &ArrayPattern
where
D: NickelAllocatorExt<'a, A>,
D::Doc: Clone,
A: Clone + 'a,
{
fn pretty(self, allocator: &'a D) -> DocBuilder<'a, D, A> {
docs![
allocator,
allocator.intersperse(
self.patterns.iter(),
docs![allocator, ",", allocator.line()],
),
if !self.patterns.is_empty() && self.tail.is_open() {
docs![allocator, ",", allocator.line()]
} else {
allocator.nil()
},
match self.tail {
TailPattern::Empty => allocator.nil(),
TailPattern::Open => allocator.text(".."),
TailPattern::Capture(id) => docs![allocator, "..", id.ident().to_string()],
},
]
.nest(2)
.brackets()
.group()
}
}

impl<'a, D, A> Pretty<'a, D, A> for &RichTerm
where
D: NickelAllocatorExt<'a, A>,
Expand Down
177 changes: 171 additions & 6 deletions core/src/term/pattern/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ use super::*;
use crate::{
mk_app,
term::{
make, record::FieldMetadata, BinaryOp, MatchBranch, MatchData, RecordExtKind, RecordOpKind,
RichTerm, Term, UnaryOp,
make, record::FieldMetadata, BinaryOp, MatchBranch, MatchData, NAryOp, RecordExtKind,
RecordOpKind, RichTerm, Term, UnaryOp,
},
};

Expand Down Expand Up @@ -258,6 +258,7 @@ impl CompilePart for PatternData {
insert_binding(*id, value_id, bindings_id)
}
PatternData::Record(pat) => pat.compile_part(value_id, bindings_id),
PatternData::Array(pat) => pat.compile_part(value_id, bindings_id),
PatternData::Enum(pat) => pat.compile_part(value_id, bindings_id),
PatternData::Constant(pat) => pat.compile_part(value_id, bindings_id),
}
Expand Down Expand Up @@ -312,7 +313,7 @@ impl CompilePart for RecordPattern {
//
// We don't have tuples, and to avoid adding an indirection (by storing the current state
// as `{rest, bindings}` where bindings itself is a record), we store this rest alongside
// the bindings in a special field which is a freshly generated indentifier. This is an
// the bindings in a special field which is a freshly generated identifier. This is an
// implementation detail which isn't very hard to change, should we have to.
//
// if %typeof% value_id == 'Record
Expand Down Expand Up @@ -519,7 +520,7 @@ impl CompilePart for RecordPattern {
// null
// else
// %record_remove% "<REST>" final_bindings_id
RecordPatternTail::Empty => make::if_then_else(
TailPattern::Empty => make::if_then_else(
make::op1(
UnaryOp::BoolNot(),
make::op2(
Expand All @@ -539,7 +540,7 @@ impl CompilePart for RecordPattern {
// final_bindings_id
// (%static_access% <REST_FIELD> final_bindings_id)
// )
RecordPatternTail::Capture(rest) => make::op2(
TailPattern::Capture(rest) => make::op2(
BinaryOp::DynRemove(RecordOpKind::ConsiderAllFields),
Term::Str(rest_field.into()),
mk_app!(
Expand All @@ -555,7 +556,7 @@ impl CompilePart for RecordPattern {
),
),
// %record_remove% "<REST>" final_bindings_id
RecordPatternTail::Open => bindings_without_rest,
TailPattern::Open => bindings_without_rest,
};

// the last `final_bindings_id != null` guard:
Expand All @@ -579,6 +580,170 @@ impl CompilePart for RecordPattern {
}
}

impl CompilePart for ArrayPattern {
// Compilation of an array pattern.
//
// let value_len = %array_length% value_id in
//
// <if self.is_open()>
// if %typeof% value_id == 'Array && value_len >= <self.patterns.len()>
// <else>
// if %typeof% value_id == 'Array && value_len == <self.patterns.len()>
// <end if>
//
// let final_bindings_id =
// <fold (idx, elem_pat) in 0..self.patterns.len()
// - cont is the accumulator
// - initial accumulator is `bindings_id`
// >
//
// let local_bindings_id = cont in
// if local_bindings_id == null then
// null
// else
// let local_value_id = %array_access% <idx> value_id in
// <elem_pat.compile_part(local_value_id, local_bindings_id)>
//
// <end fold>
// in
//
// if final_bindings_id == null then
// null
// else
// <if self.tail is capture(rest)>
// %record_insert%
// <rest>
// final_bindings_id
// (%array_slice% <self.patterns.len()> value_len value_id)
// <else>
// final_bindings_id
// <end if>
// else
// null
fn compile_part(&self, value_id: LocIdent, bindings_id: LocIdent) -> RichTerm {
let value_len_id = LocIdent::fresh();
let pats_len = Term::Num(self.patterns.len().into());

// <fold (idx) in 0..self.patterns.len()
// - cont is the accumulator
// - initial accumulator is `bindings_id`
// >
//
// let local_bindings_id = cont in
// if local_bindings_id == null then
// null
// else
// let local_value_id = %array_access% <idx> value_id in
// <self.patterns[idx].compile_part(local_value_id, local_bindings_id)>
//
// <end fold>
let fold_block: RichTerm = self.patterns.iter().enumerate().fold(
Term::Var(bindings_id).into(),
|cont, (idx, elem_pat)| {
let local_bindings_id = LocIdent::fresh();
let local_value_id = LocIdent::fresh();

// <self.patterns[idx].compile_part(local_value_id, local_bindings_id)>
let updated_bindings_let = elem_pat.compile_part(local_value_id, local_bindings_id);

// %array_access% idx value_id
let extracted_value = make::op2(
BinaryOp::ArrayElemAt(),
Term::Var(value_id),
Term::Num(idx.into()),
);

// let local_value_id = <extracted_value> in <updated_bindings_let>
let inner_else_block =
make::let_in(local_value_id, extracted_value, updated_bindings_let);

// The innermost if:
//
// if local_bindings_id == null then
// null
// else
// <inner_else_block>
let inner_if = make::if_then_else(
make::op2(BinaryOp::Eq(), Term::Var(local_bindings_id), Term::Null),
Term::Null,
inner_else_block,
);

// let local_bindings_id = cont in <inner_if>
make::let_in(local_bindings_id, cont, inner_if)
},
);

// %typeof% value_id == 'Array
let is_array: RichTerm = make::op2(
BinaryOp::Eq(),
make::op1(UnaryOp::Typeof(), Term::Var(value_id)),
Term::Enum("Array".into()),
);

let comp_op = if self.is_open() {
BinaryOp::GreaterOrEq()
} else {
BinaryOp::Eq()
};

// <is_array> && value_len <comp_op> <self.patterns.len()>
let outer_check = mk_app!(
make::op1(UnaryOp::BoolAnd(), is_array),
make::op2(comp_op, Term::Var(value_len_id), pats_len.clone(),)
);

let final_bindings_id = LocIdent::fresh();

// the else block which depends on the tail of the record pattern
let tail_block = match self.tail {
// final_bindings_id
TailPattern::Empty | TailPattern::Open => make::var(final_bindings_id),
// %record_insert%
// <rest>
// final_bindings_id
// (%array_slice% <self.patterns.len()> value_len value_id)
TailPattern::Capture(rest) => mk_app!(
make::op2(
record_insert(),
Term::Str(rest.label().into()),
Term::Var(final_bindings_id),
),
make::opn(
NAryOp::ArraySlice(),
vec![pats_len, Term::Var(value_len_id), Term::Var(value_id)]
)
),
};

// the last `final_bindings_id != null` guard:
//
// if final_bindings_id == null then
// null
// else
// <tail_block>
let guard_tail_block = make::if_then_else(
make::op2(BinaryOp::Eq(), Term::Var(final_bindings_id), Term::Null),
Term::Null,
tail_block,
);

// The let enclosing the fold block and the let binding `final_bindings_id`:
// let final_bindings_id = <fold_block> in <tail_block>
let outer_let = make::let_in(final_bindings_id, fold_block, guard_tail_block);

// if <outer_check> then <outer_let> else null
let outer_if = make::if_then_else(outer_check, outer_let, Term::Null);

// finally, we need to bind `value_len_id` to the length of the array
make::let_in(
value_len_id,
make::op1(UnaryOp::ArrayLength(), Term::Var(value_id)),
outer_if,
)
}
}

impl CompilePart for EnumPattern {
fn compile_part(&self, value_id: LocIdent, bindings_id: LocIdent) -> RichTerm {
// %enum_get_tag% value_id == '<self.tag>
Expand Down
Loading

0 comments on commit a7eacc3

Please sign in to comment.