Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize std.contract.Equal using %record/split_pair% #1988

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/src/eval/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2576,7 +2576,7 @@ impl<R: ImportResolver, C: Cache> VirtualMachine<R, C> {

let Term::Record(record1) = t1 else {
return Err(mk_type_error!(
"record/full_difference",
"record/split_pair",
"Record",
1,
t1.into(),
Expand All @@ -2586,7 +2586,7 @@ impl<R: ImportResolver, C: Cache> VirtualMachine<R, C> {

let Term::Record(record2) = t2 else {
return Err(mk_type_error!(
"record/full_difference",
"record/split_pair",
"Record",
2,
t2.into(),
Expand Down
2 changes: 1 addition & 1 deletion core/src/term/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1848,7 +1848,7 @@ impl fmt::Display for BinaryOp {
RecordFieldIsDefined(RecordOpKind::ConsiderAllFields) => {
write!(f, "record/field_is_defined_with_opts")
}
Self::RecordSplitPair => write!(f, "record/full_difference"),
Self::RecordSplitPair => write!(f, "record/split_pair"),
Self::RecordDisjointMerge => write!(f, "record/disjoint_merge"),
ArrayConcat => write!(f, "array/concat"),
ArrayAt => write!(f, "array/at"),
Expand Down
49 changes: 12 additions & 37 deletions core/stdlib/std.ncl
Original file line number Diff line number Diff line change
Expand Up @@ -841,45 +841,14 @@
```
"%
=
let fields_diff
| doc m%"
Compute the difference between the fields of two records.
`fields_diff` isn't concerned with the actual values themselves, but
just with field names.

Return a record of type
`{extra : Array String, missing: Array String}`, relative to the
first argument `constant`.
"%
= fun constant value =>
let diff =
value
|> std.record.fields
|> std.array.fold_left
(
fun acc field =>
if std.record.has_field field acc.rest then
{
extra = acc.extra,
rest = std.record.remove field acc.rest,
}
else
{
extra = std.array.append field acc.extra,
rest = acc.rest,
}
)
{ extra = [], rest = constant }
in
{ extra = diff.extra, missing = std.record.fields diff.rest }
in
let blame_fields_differ = fun qualifier fields ctr_label =>
let plural = if %array/length% fields == 1 then "" else "s" in
ctr_label
|> label.with_message "%{qualifier} field%{plural} `%{std.string.join ", " fields}`"
|> label.append_note "`std.contract.Equal some_record` requires that the checked value is equal to the record `some_record`, but the sets of their fields differ."
|> blame
in

fun constant =>
let constant_type = %typeof% constant in
let check_typeof_eq = fun ctr_label value =>
Expand Down Expand Up @@ -916,11 +885,17 @@
let contract_map = %record/map% constant (fun _key => Equal) in
fun ctr_label value =>
let value = check_typeof_eq ctr_label value in
let diff = fields_diff constant value in
if %array/length% diff.extra != 0 then
blame_fields_differ "extra" diff.extra ctr_label
else if %array/length% diff.missing != 0 then
blame_fields_differ "missing" diff.missing ctr_label
let split_result = %record/split_pair% constant value in
if split_result.right_only != {} then
blame_fields_differ
"extra"
(%record/fields% split_result.right_only)
ctr_label
else if split_result.left_only != {} then
blame_fields_differ
"missing"
(%record/fields% split_result.left_only)
ctr_label
else
%contract/record_lazy_apply%
ctr_label
Expand Down