diff --git a/cli/tests/snapshot/inputs/errors/enum_contract_shape_mismatch.ncl b/cli/tests/snapshot/inputs/errors/enum_contract_shape_mismatch.ncl new file mode 100644 index 000000000..1a1883894 --- /dev/null +++ b/cli/tests/snapshot/inputs/errors/enum_contract_shape_mismatch.ncl @@ -0,0 +1,3 @@ +# capture = 'stderr' +# command = ['eval'] +(std.function.id | forall r. [| 'Foo Number, 'Bar String, 'Qux; r |] -> [| 'Foo Number, 'Bar String, 'Qux; r |]) 'Foo diff --git a/cli/tests/snapshot/inputs/errors/enum_contract_shape_mismatch_rev.ncl b/cli/tests/snapshot/inputs/errors/enum_contract_shape_mismatch_rev.ncl new file mode 100644 index 000000000..10d78b4b1 --- /dev/null +++ b/cli/tests/snapshot/inputs/errors/enum_contract_shape_mismatch_rev.ncl @@ -0,0 +1,3 @@ +# capture = 'stderr' +# command = ['eval'] +(std.function.id | forall r. [| 'Foo, 'Bar String, 'Qux Dyn; r |] -> [| 'Foo, 'Bar String, 'Qux Dyn; r |]) ('Foo..(5)) diff --git a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_enum_contract_shape_mismatch.ncl.snap b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_enum_contract_shape_mismatch.ncl.snap new file mode 100644 index 000000000..2ccdaaed2 --- /dev/null +++ b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_enum_contract_shape_mismatch.ncl.snap @@ -0,0 +1,29 @@ +--- +source: cli/tests/snapshot/main.rs +expression: err +--- +error: contract broken by the caller + shape mismatch for 'Foo + ┌─ [INPUTS_PATH]/errors/enum_contract_shape_mismatch.ncl:3:30 + │ +3 │ (std.function.id | forall r. [| 'Foo Number, 'Bar String, 'Qux; r |] -> [| 'Foo Number, 'Bar String, 'Qux; r |]) 'Foo + │ --------------------------------------- ---- evaluated to this expression + │ │ + │ expected type of the argument provided by the caller + │ + = Found an enum with tag `'Foo` which is indeed part of the expected enum type + = However, their shape differs: one is an enum variant that carries an argument while the other is a bare enum tag + +note: + ┌─ [INPUTS_PATH]/errors/enum_contract_shape_mismatch.ncl:3:1 + │ +3 │ (std.function.id | forall r. [| 'Foo Number, 'Bar String, 'Qux; r |] -> [| 'Foo Number, 'Bar String, 'Qux; r |]) 'Foo + │ --------------------------------------------------------------------------------------------------------------------- (1) calling + +note: + ┌─ [INPUTS_PATH]/errors/enum_contract_shape_mismatch.ncl:3:2 + │ +3 │ (std.function.id | forall r. [| 'Foo Number, 'Bar String, 'Qux; r |] -> [| 'Foo Number, 'Bar String, 'Qux; r |]) 'Foo + │ --------------- (2) calling + + diff --git a/cli/tests/snapshot/snapshots/snapshot__eval_stderr_enum_contract_shape_mismatch_rev.ncl.snap b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_enum_contract_shape_mismatch_rev.ncl.snap new file mode 100644 index 000000000..2c47f32d8 --- /dev/null +++ b/cli/tests/snapshot/snapshots/snapshot__eval_stderr_enum_contract_shape_mismatch_rev.ncl.snap @@ -0,0 +1,29 @@ +--- +source: cli/tests/snapshot/main.rs +expression: err +--- +error: contract broken by the caller + shape mismatch for 'Foo + ┌─ [INPUTS_PATH]/errors/enum_contract_shape_mismatch_rev.ncl:3:30 + │ +3 │ (std.function.id | forall r. [| 'Foo, 'Bar String, 'Qux Dyn; r |] -> [| 'Foo, 'Bar String, 'Qux Dyn; r |]) ('Foo..(5)) + │ ------------------------------------ ----------- evaluated to this expression + │ │ + │ expected type of the argument provided by the caller + │ + = Found an enum with tag `'Foo` which is indeed part of the expected enum type + = However, their shape differs: one is an enum variant that carries an argument while the other is a bare enum tag + +note: + ┌─ [INPUTS_PATH]/errors/enum_contract_shape_mismatch_rev.ncl:3:1 + │ +3 │ (std.function.id | forall r. [| 'Foo, 'Bar String, 'Qux Dyn; r |] -> [| 'Foo, 'Bar String, 'Qux Dyn; r |]) ('Foo..(5)) + │ ---------------------------------------------------------------------------------------------------------------------- (1) calling + +note: + ┌─ [INPUTS_PATH]/errors/enum_contract_shape_mismatch_rev.ncl:3:2 + │ +3 │ (std.function.id | forall r. [| 'Foo, 'Bar String, 'Qux Dyn; r |] -> [| 'Foo, 'Bar String, 'Qux Dyn; r |]) ('Foo..(5)) + │ --------------- (2) calling + + diff --git a/core/src/eval/operation.rs b/core/src/eval/operation.rs index cfb07ec17..edc7cb61c 100644 --- a/core/src/eval/operation.rs +++ b/core/src/eval/operation.rs @@ -2300,7 +2300,7 @@ impl VirtualMachine { let Term::Array(array, _) = t1 else { return Err(mk_type_error!( "label_with_notes", - "Array String", + "Array", 1, t1.into(), pos1 diff --git a/core/src/stdlib.rs b/core/src/stdlib.rs index 67af652ed..bcec75345 100644 --- a/core/src/stdlib.rs +++ b/core/src/stdlib.rs @@ -64,29 +64,41 @@ pub mod internals { mk_term::var("$dyn") } + // `enum` is a reserved keyword in rust + pub fn enumeration() -> RichTerm { + mk_term::var("$enum") + } + generate_accessor!(num); generate_accessor!(bool); generate_accessor!(string); + generate_accessor!(fail); + generate_accessor!(array); generate_accessor!(array_dyn); + generate_accessor!(func); generate_accessor!(func_dom); generate_accessor!(func_codom); generate_accessor!(func_dyn); + generate_accessor!(forall_var); generate_accessor!(forall); - generate_accessor!(fail); - generate_accessor!(enums); + generate_accessor!(enum_fail); + generate_accessor!(enum_variant); + generate_accessor!(forall_enum_tail); + generate_accessor!(record); - generate_accessor!(dict_type); - generate_accessor!(dict_contract); - generate_accessor!(dict_dyn); generate_accessor!(record_extend); - generate_accessor!(forall_tail); + generate_accessor!(forall_record_tail); generate_accessor!(dyn_tail); generate_accessor!(empty_tail); - generate_accessor!(enum_variant); + + generate_accessor!(dict_type); + generate_accessor!(dict_contract); + generate_accessor!(dict_dyn); + generate_accessor!(stdlib_contract_equal); generate_accessor!(rec_default); diff --git a/core/src/typ.rs b/core/src/typ.rs index 80e3de31b..4c5a1433e 100644 --- a/core/src/typ.rs +++ b/core/src/typ.rs @@ -41,22 +41,21 @@ //! Conversely, any Nickel term seen as a contract corresponds to a type, which is opaque and can //! only be equated with itself. use crate::{ + environment::Environment, error::{EvalError, ParseError, ParseErrors, TypecheckError}, identifier::{Ident, LocIdent}, impl_display_from_pretty, label::Polarity, mk_app, mk_fun, position::TermPos, + stdlib::internals, term::{ array::Array, make as mk_term, record::RecordData, string::NickelString, IndexMap, MatchData, RichTerm, Term, Traverse, TraverseControl, TraverseOrder, }, }; -use std::{ - collections::{HashMap, HashSet}, - convert::Infallible, -}; +use std::{collections::HashSet, convert::Infallible}; /// A record row, mapping an identifier to a type. A record type is a dictionary mapping /// identifiers to Nickel type. Record types are represented as sequences of `RecordRowF`, ending @@ -773,10 +772,31 @@ impl<'a> Iterator for EnumRowsIterator<'a, Type, EnumRows> { } } +trait Subcontract { + /// Return the contract corresponding to a type component of a larger type. + /// + /// # Arguments + /// + /// - `vars` is an environment mapping type variables to contracts. Type variables are + /// introduced locally when opening a `forall`. Note that we don't need to keep separate + /// environments for different kind of type variables, as by shadowing, one name can only + /// refer to one type of variable at any given time. + /// - `pol` is the current polarity, which is toggled when generating a contract for the + /// argument of an arrow type (see [`crate::label::Label`]). + /// - `sy` is a counter used to generate fresh symbols for `forall` contracts (see + /// [`crate::term::Term::Sealed`]). + fn subcontract( + &self, + vars: Environment, + pol: Polarity, + sy: &mut i32, + ) -> Result; +} + /// Retrieve the contract corresponding to a type variable occurrence in a type as a `RichTerm`. /// Helper used by the `subcontract` functions. fn get_var_contract( - vars: &HashMap, + vars: &Environment, sym: Ident, pos: TermPos, ) -> Result { @@ -786,73 +806,116 @@ fn get_var_contract( .clone()) } -impl EnumRows { - fn subcontract(&self) -> Result { - use crate::stdlib::internals; - use crate::term::pattern::{EnumPattern, Pattern, PatternData}; +impl Subcontract for Type { + fn subcontract( + &self, + mut vars: Environment, + pol: Polarity, + sy: &mut i32, + ) -> Result { + let ctr = match self.typ { + TypeF::Dyn => internals::dynamic(), + TypeF::Number => internals::num(), + TypeF::Bool => internals::bool(), + TypeF::String => internals::string(), + // Array Dyn is specialized to array_dyn, which is constant time + TypeF::Array(ref ty) if matches!(ty.typ, TypeF::Dyn) => internals::array_dyn(), + TypeF::Array(ref ty) => mk_app!(internals::array(), ty.subcontract(vars, pol, sy)?), + TypeF::Symbol => panic!("unexpected Symbol type during contract elaboration"), + // Similarly, any variant of `A -> B` where either `A` or `B` is `Dyn` get specialized + // to the corresponding builtin contract. + TypeF::Arrow(ref s, ref t) if matches!((&s.typ, &t.typ), (TypeF::Dyn, TypeF::Dyn)) => { + internals::func_dyn() + } + TypeF::Arrow(ref s, ref t) if matches!(s.typ, TypeF::Dyn) => { + mk_app!(internals::func_codom(), t.subcontract(vars, pol, sy)?) + } + TypeF::Arrow(ref s, ref t) if matches!(t.typ, TypeF::Dyn) => { + mk_app!( + internals::func_dom(), + s.subcontract(vars.clone(), pol.flip(), sy)? + ) + } + TypeF::Arrow(ref s, ref t) => mk_app!( + internals::func(), + s.subcontract(vars.clone(), pol.flip(), sy)?, + t.subcontract(vars, pol, sy)? + ), + TypeF::Flat(ref t) => t.clone(), + TypeF::Var(id) => get_var_contract(&vars, id, self.pos)?, + TypeF::Forall { + ref var, + ref body, + ref var_kind, + } => { + let sealing_key = Term::SealingKey(*sy); + let contract = match var_kind { + VarKind::Type => mk_app!(internals::forall_var(), sealing_key.clone()), + kind @ VarKind::RecordRows { excluded } + | kind @ VarKind::EnumRows { excluded } => { + let excluded_ncl: RichTerm = Term::Array( + Array::from_iter( + excluded + .iter() + .map(|id| Term::Str(NickelString::from(*id)).into()), + ), + Default::default(), + ) + .into(); - let mut branches = Vec::new(); - let mut has_tail = false; - let value_arg = LocIdent::from("x"); - let label_arg = LocIdent::from("l"); + let forall_contract = match kind { + VarKind::RecordRows { .. } => internals::forall_record_tail(), + VarKind::EnumRows { .. } => internals::forall_enum_tail(), + _ => unreachable!(), + }; - // TODO[adt]: actually implement the right contract for enum variants - for row in self.iter() { - match row { - EnumRowsIteratorItem::Row(row) => { - let pattern = Pattern { - data: PatternData::Enum(EnumPattern { - tag: row.id, - pattern: None, - pos: row.id.pos, - }), - alias: None, - pos: row.id.pos, - }; + mk_app!(forall_contract, sealing_key.clone(), excluded_ncl) + } + }; + vars.insert(var.ident(), contract); - branches.push((pattern, mk_term::var(value_arg))); - } - EnumRowsIteratorItem::TailVar(_) => { - has_tail = true; - break; - } + *sy += 1; + mk_app!( + internals::forall(), + sealing_key, + Term::from(pol), + body.subcontract(vars, pol, sy)? + ) } - } - - // If the enum type has a tail, the tail must be a universally quantified variable, - // and this means that the tag can be anything. - let case_body = if has_tail { - mk_term::var(value_arg) - } - // Otherwise, we build a match with all the tags as cases, which just returns the - // original argument, and a default case that blames. - // - // For example, for an enum type [| 'foo, 'bar, 'baz |], the `case` function looks - // like: - // - // ``` - // fun l x => - // x |> match { - // 'foo => x, - // 'bar => x, - // 'baz => x, - // _ => $enum_fail l - // } - // ``` - else { - mk_app!( - Term::Match(MatchData { - branches, - default: Some(mk_app!(internals::enum_fail(), mk_term::var(label_arg))), - }), - mk_term::var(value_arg) - ) + TypeF::Enum(ref erows) => erows.subcontract(vars, pol, sy)?, + TypeF::Record(ref rrows) => rrows.subcontract(vars, pol, sy)?, + // `{_: Dyn}` and `{_ | Dyn}` are equivalent, and both specialied to the constant-time + // `dict_dyn`. + TypeF::Dict { + ref type_fields, + flavour: _, + } if matches!(type_fields.typ, TypeF::Dyn) => internals::dict_dyn(), + TypeF::Dict { + ref type_fields, + flavour: DictTypeFlavour::Contract, + } => { + mk_app!( + internals::dict_contract(), + type_fields.subcontract(vars, pol, sy)? + ) + } + TypeF::Dict { + ref type_fields, + flavour: DictTypeFlavour::Type, + } => { + mk_app!( + internals::dict_type(), + type_fields.subcontract(vars, pol, sy)? + ) + } + TypeF::Wildcard(_) => internals::dynamic(), }; - let case = mk_fun!(label_arg, value_arg, case_body); - Ok(mk_app!(internals::enums(), case)) + Ok(ctr) } +} +impl EnumRows { /// Find the row with the given identifier in the enum type. Return `None` if there is no such /// row. pub fn find_row(&self, id: Ident) -> Option { @@ -873,45 +936,120 @@ impl EnumRows { } } -impl RecordRows { - /// Construct the subcontract corresponding to a record type +impl Subcontract for EnumRows { fn subcontract( &self, - vars: HashMap, + vars: Environment, pol: Polarity, sy: &mut i32, ) -> Result { - use crate::stdlib::internals; + use crate::term::{ + pattern::{EnumPattern, Pattern, PatternData}, + BinaryOp, + }; - // We begin by building a record whose arguments are contracts - // derived from the types of the statically known fields. - let mut rrows = self; - let mut fcs = IndexMap::new(); + let mut branches = Vec::new(); + let mut tail_var = None; - while let RecordRowsF::Extend { - row: RecordRowF { id, typ: ty }, - tail, - } = &rrows.0 - { - fcs.insert(*id, ty.subcontract(vars.clone(), pol, sy)?); - rrows = tail + let value_arg = LocIdent::fresh(); + let label_arg = LocIdent::fresh(); + // We don't need to generate a different fresh variable for each match branch, as they have + // their own scope, so we use the same name instead. + let variant_arg = LocIdent::fresh(); + + // We build a match where each row corresponds to a branch, such that: + // + // - if the row is a simple enum tag, we just return the original contract argument + // - if the row is an enum variant, we extract the argument and apply the corresponding + // contract to it + // + // For the default branch, depending on the tail: + // + // - if the tail is an enum type variable, we perform the required sealing/unsealing + // - otherwise, if the enum type is closed, we add a default case which blames + // + // For example, for an enum type [| 'foo, 'bar, 'Baz T |], the function looks like: + // + // ``` + // fun l x => + // x |> match { + // 'foo => x, + // 'bar => x, + // 'Baz variant_arg => %apply_contract% T label_arg variant_arg, + // _ => $enum_fail l + // } + // ``` + for row in self.iter() { + match row { + EnumRowsIteratorItem::Row(row) => { + let arg_pattern = row.typ.as_ref().map(|_| { + Box::new(Pattern { + data: PatternData::Any(variant_arg), + alias: None, + pos: TermPos::None, + }) + }); + + let body = if let Some(ty) = row.typ.as_ref() { + // %apply_contract% T label_arg variant_arg + mk_app!( + mk_term::op2( + BinaryOp::ApplyContract(), + ty.subcontract(vars.clone(), pol, sy)?, + mk_term::var(label_arg) + ), + mk_term::var(variant_arg) + ) + } else { + mk_term::var(value_arg) + }; + + let pattern = Pattern { + data: PatternData::Enum(EnumPattern { + tag: row.id, + pattern: arg_pattern, + pos: row.id.pos, + }), + alias: None, + pos: row.id.pos, + }; + + branches.push((pattern, body)); + } + EnumRowsIteratorItem::TailVar(var) => { + tail_var = Some(var); + } + } } - // Now that we've dealt with the row extends, we just need to - // work out the tail. - let tail = match &rrows.0 { - RecordRowsF::Empty => internals::empty_tail(), - RecordRowsF::TailDyn => internals::dyn_tail(), - RecordRowsF::TailVar(id) => get_var_contract(&vars, id.ident(), id.pos)?, - // Safety: the while above excludes that `tail` can have the form `Extend`. - RecordRowsF::Extend { .. } => unreachable!(), + let default = if let Some(var) = tail_var { + mk_app!( + mk_term::op2( + BinaryOp::ApplyContract(), + get_var_contract(&vars, var.ident(), var.pos)?, + mk_term::var(label_arg) + ), + mk_term::var(value_arg) + ) + } else { + mk_app!(internals::enum_fail(), mk_term::var(label_arg)) }; - let rec = RichTerm::from(Term::Record(RecordData::with_field_values(fcs))); + let match_expr = mk_app!( + Term::Match(MatchData { + branches, + default: Some(default) + }), + mk_term::var(value_arg) + ); - Ok(mk_app!(internals::record(), rec, tail)) + let case = mk_fun!(label_arg, value_arg, match_expr); + // println!("Generated case: {case}"); + Ok(mk_app!(internals::enumeration(), case)) } +} +impl RecordRows { /// Find a nested binding in a record row type. The nested field is given as a list of /// successive fields, that is, as a path. Return `None` if there is no such binding. /// @@ -969,6 +1107,43 @@ impl RecordRows { } } +impl Subcontract for RecordRows { + fn subcontract( + &self, + vars: Environment, + pol: Polarity, + sy: &mut i32, + ) -> Result { + // We begin by building a record whose arguments are contracts + // derived from the types of the statically known fields. + let mut rrows = self; + let mut fcs = IndexMap::new(); + + while let RecordRowsF::Extend { + row: RecordRowF { id, typ: ty }, + tail, + } = &rrows.0 + { + fcs.insert(*id, ty.subcontract(vars.clone(), pol, sy)?); + rrows = tail + } + + // Now that we've dealt with the row extends, we just need to + // work out the tail. + let tail = match &rrows.0 { + RecordRowsF::Empty => internals::empty_tail(), + RecordRowsF::TailDyn => internals::dyn_tail(), + RecordRowsF::TailVar(id) => get_var_contract(&vars, id.ident(), id.pos)?, + // Safety: the while above excludes that `tail` can have the form `Extend`. + RecordRowsF::Extend { .. } => unreachable!(), + }; + + let rec = RichTerm::from(Term::Record(RecordData::with_field_values(fcs))); + + Ok(mk_app!(internals::record(), rec, tail)) + } +} + impl From, RecordRows, EnumRows>> for Type { fn from(typ: TypeF, RecordRows, EnumRows>) -> Self { Type { @@ -1025,74 +1200,87 @@ impl Type { /// - All positive occurrences of first order contracts (that is, anything but a function type) /// are turned to `Dyn` contracts. fn optimize_static(self) -> Self { - use crate::environment::Environment; // We use this environment as a shareable HashSet type VarsHashSet = Environment; - fn optimize_rrows( - rrows: RecordRows, - vars_elide: VarsHashSet, - polarity: Polarity, - ) -> RecordRows { - RecordRows(rrows.0.map( - |typ| Box::new(optimize(*typ, vars_elide.clone(), polarity)), - |rrows| Box::new(optimize_rrows(*rrows, vars_elide.clone(), polarity)), - )) + trait Optimize { + fn optimize(self, vars_elide: VarsHashSet, polarity: Polarity) -> Self; } - fn optimize(typ: Type, mut vars_elide: VarsHashSet, polarity: Polarity) -> Type { - let mut pos = typ.pos; - - let optimized = match typ.typ { - TypeF::Arrow(dom, codom) => TypeF::Arrow( - Box::new(optimize(*dom, vars_elide.clone(), polarity.flip())), - Box::new(optimize(*codom, vars_elide, polarity)), - ), - // TODO: don't optimize only VarKind::Type - TypeF::Forall { - var, - var_kind: VarKind::Type, - body, - } if polarity == Polarity::Positive => { - vars_elide.insert(var.ident(), ()); - let result = optimize(*body, vars_elide, polarity); - // we keep the position of the body, not the one of the forall - pos = result.pos; - result.typ + impl Optimize for Type { + fn optimize(self, mut vars_elide: VarsHashSet, polarity: Polarity) -> Type { + let mut pos = self.pos; + + let optimized = match self.typ { + TypeF::Arrow(dom, codom) => TypeF::Arrow( + Box::new(dom.optimize(vars_elide.clone(), polarity.flip())), + Box::new(codom.optimize(vars_elide, polarity)), + ), + // TODO: don't optimize only VarKind::Type + TypeF::Forall { + var, + var_kind: VarKind::Type, + body, + } if polarity == Polarity::Positive => { + vars_elide.insert(var.ident(), ()); + let result = body.optimize(vars_elide, polarity); + // we keep the position of the body, not the one of the forall + pos = result.pos; + result.typ + } + TypeF::Forall { + var, + var_kind, + body, + } => TypeF::Forall { + var, + var_kind, + body: Box::new(body.optimize(vars_elide, polarity)), + }, + TypeF::Var(id) if vars_elide.get(&id).is_some() => TypeF::Dyn, + v @ TypeF::Var(_) => v, + // Any first-order type on positive position can be elided + _ if matches!(polarity, Polarity::Positive) => TypeF::Dyn, + // Otherwise, we still recurse into non-primitive types + TypeF::Record(rrows) => TypeF::Record(rrows.optimize(vars_elide, polarity)), + TypeF::Enum(erows) => TypeF::Enum(erows.optimize(vars_elide, polarity)), + TypeF::Dict { + type_fields, + flavour, + } => TypeF::Dict { + type_fields: Box::new(type_fields.optimize(vars_elide, polarity)), + flavour, + }, + TypeF::Array(t) => TypeF::Array(Box::new(t.optimize(vars_elide, polarity))), + // All other types don't contain subtypes, it's a base case + t => t, + }; + Type { + typ: optimized, + pos, } - TypeF::Forall { - var, - var_kind, - body, - } => TypeF::Forall { - var, - var_kind, - body: Box::new(optimize(*body, vars_elide, polarity)), - }, - TypeF::Var(id) if vars_elide.get(&id).is_some() => TypeF::Dyn, - v @ TypeF::Var(_) => v, - // Any first-order type on positive position can be elided - _ if matches!(polarity, Polarity::Positive) => TypeF::Dyn, - // Otherwise, we still recurse into non-primitive types - TypeF::Record(rrows) => TypeF::Record(optimize_rrows(rrows, vars_elide, polarity)), - TypeF::Dict { - type_fields, - flavour, - } => TypeF::Dict { - type_fields: Box::new(optimize(*type_fields, vars_elide, polarity)), - flavour, - }, - TypeF::Array(t) => TypeF::Array(Box::new(optimize(*t, vars_elide, polarity))), - // All other types don't contain subtypes, it's a base case - t => t, - }; - Type { - typ: optimized, - pos, } } - optimize(self, VarsHashSet::new(), Polarity::Positive) + impl Optimize for RecordRows { + fn optimize(self, vars_elide: VarsHashSet, polarity: Polarity) -> RecordRows { + RecordRows(self.0.map( + |typ| Box::new(typ.optimize(vars_elide.clone(), polarity)), + |rrows| Box::new(rrows.optimize(vars_elide.clone(), polarity)), + )) + } + } + + impl Optimize for EnumRows { + fn optimize(self, vars_elide: VarsHashSet, polarity: Polarity) -> EnumRows { + EnumRows(self.0.map( + |typ| Box::new(typ.optimize(vars_elide.clone(), polarity)), + |erows| Box::new(erows.optimize(vars_elide.clone(), polarity)), + )) + } + } + + self.optimize(VarsHashSet::new(), Polarity::Positive) } /// Return the contract corresponding to a type which appears in a static type annotation. Said @@ -1103,14 +1291,14 @@ impl Type { pub fn contract_static(self) -> Result { let mut sy = 0; self.optimize_static() - .subcontract(HashMap::new(), Polarity::Positive, &mut sy) + .subcontract(Environment::new(), Polarity::Positive, &mut sy) } /// Return the contract corresponding to a type, either as a function or a record. Said /// contract must then be applied using the `ApplyContract` primitive operation. pub fn contract(&self) -> Result { let mut sy = 0; - self.subcontract(HashMap::new(), Polarity::Positive, &mut sy) + self.subcontract(Environment::new(), Polarity::Positive, &mut sy) } /// Returns true if this type is a function type, false otherwise. @@ -1122,117 +1310,6 @@ impl Type { } } - /// Return the contract corresponding to a subtype. - /// - /// # Arguments - /// - /// - `h` is an environment mapping type variables to contracts. Type variables are introduced - /// locally when opening a `forall`. - /// - `pol` is the current polarity, which is toggled when generating a contract for the - /// argument of an arrow type (see [`crate::label::Label`]). - /// - `sy` is a counter used to generate fresh symbols for `forall` contracts (see - /// [`crate::term::Term::Sealed`]). - fn subcontract( - &self, - mut vars: HashMap, - pol: Polarity, - sy: &mut i32, - ) -> Result { - use crate::stdlib::internals; - - let ctr = match self.typ { - TypeF::Dyn => internals::dynamic(), - TypeF::Number => internals::num(), - TypeF::Bool => internals::bool(), - TypeF::String => internals::string(), - // Array Dyn is specialized to array_dyn, which is constant time - TypeF::Array(ref ty) if matches!(ty.typ, TypeF::Dyn) => internals::array_dyn(), - TypeF::Array(ref ty) => mk_app!(internals::array(), ty.subcontract(vars, pol, sy)?), - TypeF::Symbol => panic!("Are you trying to check a Sym at runtime?"), - // Similarly, any variant of `A -> B` where either `A` or `B` is `Dyn` get specialized - // to the corresponding builtin contract. - TypeF::Arrow(ref s, ref t) if matches!((&s.typ, &t.typ), (TypeF::Dyn, TypeF::Dyn)) => { - internals::func_dyn() - } - TypeF::Arrow(ref s, ref t) if matches!(s.typ, TypeF::Dyn) => { - mk_app!(internals::func_codom(), t.subcontract(vars, pol, sy)?) - } - TypeF::Arrow(ref s, ref t) if matches!(t.typ, TypeF::Dyn) => { - mk_app!( - internals::func_dom(), - s.subcontract(vars.clone(), pol.flip(), sy)? - ) - } - TypeF::Arrow(ref s, ref t) => mk_app!( - internals::func(), - s.subcontract(vars.clone(), pol.flip(), sy)?, - t.subcontract(vars, pol, sy)? - ), - TypeF::Flat(ref t) => t.clone(), - TypeF::Var(id) => get_var_contract(&vars, id, self.pos)?, - TypeF::Forall { - ref var, - ref body, - ref var_kind, - } => { - let sealing_key = Term::SealingKey(*sy); - let contract = match var_kind { - VarKind::Type => mk_app!(internals::forall_var(), sealing_key.clone()), - VarKind::RecordRows { excluded } | VarKind::EnumRows { excluded } => { - let excluded_ncl: RichTerm = Term::Array( - Array::from_iter( - excluded - .iter() - .map(|id| Term::Str(NickelString::from(*id)).into()), - ), - Default::default(), - ) - .into(); - mk_app!(internals::forall_tail(), sealing_key.clone(), excluded_ncl) - } - }; - vars.insert(var.ident(), contract); - - *sy += 1; - mk_app!( - internals::forall(), - sealing_key, - Term::from(pol), - body.subcontract(vars, pol, sy)? - ) - } - TypeF::Enum(ref erows) => erows.subcontract()?, - TypeF::Record(ref rrows) => rrows.subcontract(vars, pol, sy)?, - // `{_: Dyn}` and `{_ | Dyn}` are equivalent, and both specialied to the constant-time - // `dict_dyn`. - TypeF::Dict { - ref type_fields, - flavour: _, - } if matches!(type_fields.typ, TypeF::Dyn) => internals::dict_dyn(), - TypeF::Dict { - ref type_fields, - flavour: DictTypeFlavour::Contract, - } => { - mk_app!( - internals::dict_contract(), - type_fields.subcontract(vars, pol, sy)? - ) - } - TypeF::Dict { - ref type_fields, - flavour: DictTypeFlavour::Type, - } => { - mk_app!( - internals::dict_type(), - type_fields.subcontract(vars, pol, sy)? - ) - } - TypeF::Wildcard(_) => internals::dynamic(), - }; - - Ok(ctr) - } - /// Determine if a type is an atom, that is a either a primitive type (`Dyn`, `Number`, etc.) or /// a type delimited by specific markers (such as a row type). Used in formatting to decide if /// parentheses need to be inserted during pretty pretting. diff --git a/core/stdlib/internals.ncl b/core/stdlib/internals.ncl index 9271a60c6..3ff75f612 100644 --- a/core/stdlib/internals.ncl +++ b/core/stdlib/internals.ncl @@ -80,23 +80,98 @@ if polarity == current_polarity then %unseal% sealing_key value (%blame% label) else - # Here, we know that this term should be sealed, but to give the right - # blame for the contract, we have to change the polarity to match the - # polarity of the `Forall`, because this is what's important for - # blaming polymorphic contracts. + # [^forall_chng_pol]: Blame assignment for polymorphic contracts + # should take into account the polarity at the point the forall was + # introduced, not the current polarity of the variable occurrence. Indeed, + # forall can never blame in a negative position (relative to the + # forall): the contract is entirely on the callee. + # + # Thus, for correct blame assignment, we want to set the polarity to the + # forall polarity (here `polarity`). Because we only have the `chng_pol` + # primop, and we know that in this branch they are unequal, flipping the + # current polarity will indeed give the original forall's polarity. %seal% sealing_key (%chng_pol% label) value, "$forall" = fun sealing_key polarity contract label value => contract (%insert_type_variable% sealing_key polarity label) value, - "$enums" = fun case label value => + "$enum" = fun case label value => if %typeof% value == 'Enum then %apply_contract% case label value else - %blame% (%label_with_message% "not an enum tag" label), + %blame% (%label_with_message% "expected an enum" label), "$enum_fail" = fun label => - %blame% (%label_with_message% "tag not included in the enum type" label), + %blame% (%label_with_message% "tag not in the enum type" label), + + # Contract for an enum variant with tag `'tag`, that is any value of the form + # `'tag exp`. + "$enum_variant" = fun tag label value => + if %enum_is_variant% value then + let value_tag = %enum_get_tag% value in + + if value_tag == tag then + value + else + let msg = "expected `'%{%to_str% tag}`, got `'%{%to_str% value_tag}`" in + %blame% (%label_with_message% "tag mismatch: %{msg}" label) + else + %blame% (%label_with_message% "expected an enum variant" label), + + "$forall_enum_tail" = fun sealing_key constr label value => + # We check for conflicts only when the polarity of the foralls matches the + # current polarity ("negative" polarity relatively to the forall). In + # positive (relative) polarity, a value lying in the tail either is sealed + # and has thus already been checked for conflicts, or it is not sealed and + # it will fail the `unseal` check anyway. + let current_polarity = %polarity% label in + let polarity = (%lookup_type_variable% sealing_key label).polarity in + if polarity == current_polarity then + # [^enum-no-sealing]: Theoretically, we should seal/unseal values that + # are part of enum tail. However, we can't just do that, because then a + # match expression that is entirely legit, for example + # + # ``` + # match { 'Foo => 1, _ => 2 } : forall r. [| 'Foo; r|] -> Number` + # ``` + # + # would fail on `'Bar`, if we just seal the latter naively. It looks like + # we should allow matches to see through sealed enum, but accept only if + # the catch-all case matches what's inside the saled enum. + # + # This doesn't look trivial for now, and would actually break the stdlib, + # as parametricity hasn't been correctly enforced for enum types. One + # example is `std.string.from_enum`, which has contract + # `forall a. [|; a |] -> String` but actually violates parametricity (it + # actually looks inside its argument). + # + # While this might be an issue to investigate in the longer term, or for + # the next major version, we continue to just not enforce parametricity + # for enum types for now. + value + else + let value_tag = %to_str% (%enum_get_tag% value) in + if std.array.elem value_tag constr then + %blame% + ( + %label_with_message% + "shape mismatch for '%{value_tag}" + ( + %label_with_notes% + ( + %force% + [ + "Found an enum with tag `'%{value_tag}` which is indeed part of the expected enum type", + "However, their shape differs: one is an enum variant that carries an argument while the other is a bare enum tag" + ] + ) + label + ) + ) + else + # Same as [^enum-no-sealing] above. We should theoretically seal here, + # but we don't for now. + value, "$record" = fun field_contracts tail_contract label value => if %typeof% value == 'Record then @@ -173,7 +248,7 @@ else %blame% (%label_with_message% "not a record" label), - "$forall_tail" = fun sealing_key constr acc label value => + "$forall_record_tail" = fun sealing_key constr acc label value => let current_polarity = %polarity% label in let polarity = (%lookup_type_variable% sealing_key label).polarity in let plural = fun list => if %length% list == 1 then "" else "s" in @@ -204,11 +279,7 @@ label ) else - # Note: in order to correctly attribute blame, the polarity of `l` - # must match the polarity of the `forall` which introduced the - # polymorphic contract (i.e. `pol`). Since we know in this branch - # that `pol` and `%polarity% l` differ, we swap `l`'s polarity before - # we continue. + # See [^forall_chng_pol] %record_seal_tail% sealing_key (%chng_pol% label) acc value, "$dyn_tail" = fun acc label value => acc & value, @@ -226,20 +297,6 @@ label ), - # Contract for an enum variant with tag `'tag`, that is any value of the form - # `'tag exp`. - "$enum_variant" = fun tag label value => - if %enum_is_variant% value then - let value_tag = %enum_get_tag% value in - - if value_tag == tag then - value - else - let msg = "expected `'%{%to_str% tag}`, got `'%{%to_str% value_tag}`" in - %blame% (%label_with_message% "tag mismatch: %{msg}" label) - else - %blame% (%label_with_message% "expected an enum variant" label), - # Recursive priorities operators "$rec_force" = fun value => %rec_force% (%force% value), diff --git a/core/tests/integration/inputs/contracts/contracts.ncl b/core/tests/integration/inputs/contracts/contracts.ncl index 5df478667..8c4959a4e 100644 --- a/core/tests/integration/inputs/contracts/contracts.ncl +++ b/core/tests/integration/inputs/contracts/contracts.ncl @@ -37,19 +37,19 @@ let {check, Assert, ..} = import "../lib/assert.ncl" in ('"bar:baz" | forall r. [| '"foo:baz", '"bar:baz" ; r |]) == '"bar:baz", # enums_complex - let f : forall r. [| 'foo, 'bar ; r |] -> Number = + let f | forall r. [| 'foo, 'bar ; r |] -> Number = match { 'foo => 1, 'bar => 2, _ => 3, } in f 'bar == 2, - let f : forall r. [| 'foo, 'bar ; r |] -> Number = + let f | forall r. [| 'foo, 'bar ; r |] -> Number = match { 'foo => 1, 'bar => 2, _ => 3, } in f 'boo == 3, - let f : forall r. [| 'foo, '"bar:baz" ; r |] -> Number = + let f | forall r. [| 'foo, '"bar:baz" ; r |] -> Number = fun x => match { 'foo => 1, '"bar:baz" => 2, _ => 3, } x in f '"bar:baz" == 2, - let f : forall r. [| 'foo, '"bar:baz" ; r |] -> Number = + let f | forall r. [| 'foo, '"bar:baz" ; r |] -> Number = fun x => match { 'foo => 1, '"bar:baz" => 2, _ => 3, } x in f '"boo,grr" == 3, @@ -128,6 +128,7 @@ let {check, Assert, ..} = import "../lib/assert.ncl" in let f | {foo | Number} -> {bar | Number} = fun r => {bar = r.foo} in (f {foo = 1}).bar == 1, + # user-written contract application let Extend = fun base label value => let derived = if std.is_record base then diff --git a/core/tests/integration/inputs/contracts/enum_variant_fail.ncl b/core/tests/integration/inputs/contracts/enum_variant_fail.ncl new file mode 100644 index 000000000..da1a85215 --- /dev/null +++ b/core/tests/integration/inputs/contracts/enum_variant_fail.ncl @@ -0,0 +1,5 @@ +# test.type = 'error' +# +# [test.metadata] +# error = 'EvalError::BlameError' +'Foo..(5) | [| 'Foo String, 'Bar Number, 'Barg |]