Skip to content

Commit

Permalink
chore: Add missing cases to arithmetic generics (#5841)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

## Summary\*

In the initial arithmetic generics PR we only added the one specific
case for simplifying `(N + C1) - C2`.

Later in the associated types PR we added another case to simplify the
non-constant `(N + M) - M`

This PR fills in the missing cases for each other operator. It also has
somewhat better overflow handling by returning an `Option` in the
operator function and removing the wrapping operations.

## Additional Context



## Documentation\*

Check one:
- [ ] No documentation needed.
- [ ] Documentation included in this PR.
- [x] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [ ] I have tested the changes locally.
- [ ] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Michael J Klein <[email protected]>
  • Loading branch information
jfecher and michaeljklein authored Aug 28, 2024
1 parent 58f855e commit c23463e
Show file tree
Hide file tree
Showing 8 changed files with 414 additions and 146 deletions.
9 changes: 7 additions & 2 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,14 +445,19 @@ impl<'context> Elaborator<'context> {
})
}
UnresolvedTypeExpression::Constant(int, _) => Type::Constant(int),
UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, _) => {
UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, span) => {
let (lhs_span, rhs_span) = (lhs.span(), rhs.span());
let lhs = self.convert_expression_type(*lhs);
let rhs = self.convert_expression_type(*rhs);

match (lhs, rhs) {
(Type::Constant(lhs), Type::Constant(rhs)) => {
Type::Constant(op.function(lhs, rhs))
if let Some(result) = op.function(lhs, rhs) {
Type::Constant(result)
} else {
self.push_err(ResolverError::OverflowInType { lhs, op, rhs, span });
Type::Error
}
}
(lhs, rhs) => {
if !self.enable_arithmetic_generics {
Expand Down
9 changes: 9 additions & 0 deletions compiler/noirc_frontend/src/hir/resolution/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ pub enum ResolverError {
NamedTypeArgs { span: Span, item_kind: &'static str },
#[error("Associated constants may only be a field or integer type")]
AssociatedConstantsMustBeNumeric { span: Span },
#[error("Overflow in `{lhs} {op} {rhs}`")]
OverflowInType { lhs: u32, op: crate::BinaryTypeOperator, rhs: u32, span: Span },
}

impl ResolverError {
Expand Down Expand Up @@ -491,6 +493,13 @@ impl<'a> From<&'a ResolverError> for Diagnostic {
*span,
)
}
ResolverError::OverflowInType { lhs, op, rhs, span } => {
Diagnostic::simple_error(
format!("Overflow in `{lhs} {op} {rhs}`"),
"Overflow here".to_string(),
*span,
)
}
}
}
}
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/hir/type_check/generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ fn fmt_trait_generics(
write!(f, "{} = {}", named.name, named.typ)?;
}
}
write!(f, ">")?;
}
Ok(())
}
291 changes: 148 additions & 143 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ use super::{
traits::NamedType,
};

#[derive(PartialEq, Eq, Clone, Hash, Ord, PartialOrd)]
mod arithmetic;

#[derive(Eq, Clone, Ord, PartialOrd)]
pub enum Type {
/// A primitive Field type
FieldElement,
Expand Down Expand Up @@ -1657,132 +1659,6 @@ impl Type {
}
}

/// Try to canonicalize the representation of this type.
/// Currently the only type with a canonical representation is
/// `Type::Infix` where for each consecutive commutative operator
/// we sort the non-constant operands by `Type: Ord` and place all constant
/// operands at the end, constant folded.
///
/// For example:
/// - `canonicalize[((1 + N) + M) + 2] = (M + N) + 3`
/// - `canonicalize[A + 2 * B + 3 - 2] = A + (B * 2) + 3 - 2`
pub fn canonicalize(&self) -> Type {
match self.follow_bindings() {
Type::InfixExpr(lhs, op, rhs) => {
// evaluate_to_u32 also calls canonicalize so if we just called
// `self.evaluate_to_u32()` we'd get infinite recursion.
if let (Some(lhs), Some(rhs)) = (lhs.evaluate_to_u32(), rhs.evaluate_to_u32()) {
return Type::Constant(op.function(lhs, rhs));
}

let lhs = lhs.canonicalize();
let rhs = rhs.canonicalize();
if let Some(result) = Self::try_simplify_addition(&lhs, op, &rhs) {
return result;
}

if let Some(result) = Self::try_simplify_subtraction(&lhs, op, &rhs) {
return result;
}

if op.is_commutative() {
return Self::sort_commutative(&lhs, op, &rhs);
}

Type::InfixExpr(Box::new(lhs), op, Box::new(rhs))
}
other => other,
}
}

fn sort_commutative(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Type {
let mut queue = vec![lhs.clone(), rhs.clone()];

let mut sorted = BTreeSet::new();

let zero_value = if op == BinaryTypeOperator::Addition { 0 } else { 1 };
let mut constant = zero_value;

// Push each non-constant term to `sorted` to sort them. Recur on InfixExprs with the same operator.
while let Some(item) = queue.pop() {
match item.canonicalize() {
Type::InfixExpr(lhs, new_op, rhs) if new_op == op => {
queue.push(*lhs);
queue.push(*rhs);
}
Type::Constant(new_constant) => {
constant = op.function(constant, new_constant);
}
other => {
sorted.insert(other);
}
}
}

if let Some(first) = sorted.pop_first() {
let mut typ = first.clone();

for rhs in sorted {
typ = Type::InfixExpr(Box::new(typ), op, Box::new(rhs.clone()));
}

if constant != zero_value {
typ = Type::InfixExpr(Box::new(typ), op, Box::new(Type::Constant(constant)));
}

typ
} else {
// Every type must have been a constant
Type::Constant(constant)
}
}

/// Try to simplify an addition expression of `lhs + rhs`.
///
/// - Simplifies `(a - b) + b` to `a`.
fn try_simplify_addition(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Option<Type> {
use BinaryTypeOperator::*;
match lhs {
Type::InfixExpr(l_lhs, l_op, l_rhs) => {
if op == Addition && *l_op == Subtraction {
// TODO: Propagate type bindings. Can do in another PR, this one is large enough.
let unifies = l_rhs.try_unify(rhs, &mut TypeBindings::new());
if unifies.is_ok() {
return Some(l_lhs.as_ref().clone());
}
}
None
}
_ => None,
}
}

/// Try to simplify a subtraction expression of `lhs - rhs`.
///
/// - Simplifies `(a + C1) - C2` to `a + (C1 - C2)` if C1 and C2 are constants.
fn try_simplify_subtraction(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Option<Type> {
use BinaryTypeOperator::*;
match lhs {
Type::InfixExpr(l_lhs, l_op, l_rhs) => {
// Simplify `(N + 2) - 1`
if op == Subtraction && *l_op == Addition {
if let (Some(lhs_const), Some(rhs_const)) =
(l_rhs.evaluate_to_u32(), rhs.evaluate_to_u32())
{
if lhs_const > rhs_const {
let constant = Box::new(Type::Constant(lhs_const - rhs_const));
return Some(
Type::InfixExpr(l_lhs.clone(), *l_op, constant).canonicalize(),
);
}
}
}
None
}
_ => None,
}
}

/// Try to unify a type variable to `self`.
/// This is a helper function factored out from try_unify.
fn try_unify_to_type_variable(
Expand Down Expand Up @@ -1926,7 +1802,7 @@ impl Type {
Type::InfixExpr(lhs, op, rhs) => {
let lhs = lhs.evaluate_to_u32()?;
let rhs = rhs.evaluate_to_u32()?;
Some(op.function(lhs, rhs))
op.function(lhs, rhs)
}
_ => None,
}
Expand Down Expand Up @@ -2030,17 +1906,13 @@ impl Type {
Type::Forall(typevars, typ) => {
assert_eq!(types.len() + implicit_generic_count, typevars.len(), "Turbofish operator used with incorrect generic count which was not caught by name resolution");

let bindings =
(0..implicit_generic_count).map(|_| interner.next_type_variable()).chain(types);

let replacements = typevars
.iter()
.enumerate()
.map(|(i, var)| {
let binding = if i < implicit_generic_count {
interner.next_type_variable()
} else {
types[i - implicit_generic_count].clone()
};
(var.id(), (var.clone(), binding))
})
.zip(bindings)
.map(|(var, binding)| (var.id(), (var.clone(), binding)))
.collect();

let instantiated = typ.substitute(&replacements);
Expand Down Expand Up @@ -2457,13 +2329,13 @@ fn convert_array_expression_to_slice(

impl BinaryTypeOperator {
/// Perform the actual rust numeric operation associated with this operator
pub fn function(self, a: u32, b: u32) -> u32 {
pub fn function(self, a: u32, b: u32) -> Option<u32> {
match self {
BinaryTypeOperator::Addition => a.wrapping_add(b),
BinaryTypeOperator::Subtraction => a.wrapping_sub(b),
BinaryTypeOperator::Multiplication => a.wrapping_mul(b),
BinaryTypeOperator::Division => a.wrapping_div(b),
BinaryTypeOperator::Modulo => a.wrapping_rem(b),
BinaryTypeOperator::Addition => a.checked_add(b),
BinaryTypeOperator::Subtraction => a.checked_sub(b),
BinaryTypeOperator::Multiplication => a.checked_mul(b),
BinaryTypeOperator::Division => a.checked_div(b),
BinaryTypeOperator::Modulo => a.checked_rem(b),
}
}

Expand Down Expand Up @@ -2681,3 +2553,136 @@ impl std::fmt::Debug for StructType {
write!(f, "{}", self.name)
}
}

impl std::hash::Hash for Type {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
if let Some(variable) = self.get_inner_type_variable() {
if let TypeBinding::Bound(typ) = &*variable.borrow() {
typ.hash(state);
return;
}
}

if !matches!(self, Type::TypeVariable(..) | Type::NamedGeneric(..)) {
std::mem::discriminant(self).hash(state);
}

match self {
Type::FieldElement | Type::Bool | Type::Unit | Type::Error => (),
Type::Array(len, elem) => {
len.hash(state);
elem.hash(state);
}
Type::Slice(elem) => elem.hash(state),
Type::Integer(sign, bits) => {
sign.hash(state);
bits.hash(state);
}
Type::String(len) => len.hash(state),
Type::FmtString(len, env) => {
len.hash(state);
env.hash(state);
}
Type::Tuple(elems) => elems.hash(state),
Type::Struct(def, args) => {
def.hash(state);
args.hash(state);
}
Type::Alias(alias, args) => {
alias.hash(state);
args.hash(state);
}
Type::TypeVariable(var, _) | Type::NamedGeneric(var, ..) => var.hash(state),
Type::TraitAsType(trait_id, _, args) => {
trait_id.hash(state);
args.hash(state);
}
Type::Function(args, ret, env, is_unconstrained) => {
args.hash(state);
ret.hash(state);
env.hash(state);
is_unconstrained.hash(state);
}
Type::MutableReference(elem) => elem.hash(state),
Type::Forall(vars, typ) => {
vars.hash(state);
typ.hash(state);
}
Type::Constant(value) => value.hash(state),
Type::Quoted(typ) => typ.hash(state),
Type::InfixExpr(lhs, op, rhs) => {
lhs.hash(state);
op.hash(state);
rhs.hash(state);
}
}
}
}

impl PartialEq for Type {
fn eq(&self, other: &Self) -> bool {
if let Some(variable) = self.get_inner_type_variable() {
if let TypeBinding::Bound(typ) = &*variable.borrow() {
return typ == other;
}
}

if let Some(variable) = other.get_inner_type_variable() {
if let TypeBinding::Bound(typ) = &*variable.borrow() {
return self == typ;
}
}

use Type::*;
match (self, other) {
(FieldElement, FieldElement) | (Bool, Bool) | (Unit, Unit) | (Error, Error) => true,
(Array(lhs_len, lhs_elem), Array(rhs_len, rhs_elem)) => {
lhs_len == rhs_len && lhs_elem == rhs_elem
}
(Slice(lhs_elem), Slice(rhs_elem)) => lhs_elem == rhs_elem,
(Integer(lhs_sign, lhs_bits), Integer(rhs_sign, rhs_bits)) => {
lhs_sign == rhs_sign && lhs_bits == rhs_bits
}
(String(lhs_len), String(rhs_len)) => lhs_len == rhs_len,
(FmtString(lhs_len, lhs_env), FmtString(rhs_len, rhs_env)) => {
lhs_len == rhs_len && lhs_env == rhs_env
}
(Tuple(lhs_types), Tuple(rhs_types)) => lhs_types == rhs_types,
(Struct(lhs_struct, lhs_generics), Struct(rhs_struct, rhs_generics)) => {
lhs_struct == rhs_struct && lhs_generics == rhs_generics
}
(Alias(lhs_alias, lhs_generics), Alias(rhs_alias, rhs_generics)) => {
lhs_alias == rhs_alias && lhs_generics == rhs_generics
}
(TraitAsType(lhs_trait, _, lhs_generics), TraitAsType(rhs_trait, _, rhs_generics)) => {
lhs_trait == rhs_trait && lhs_generics == rhs_generics
}
(
Function(lhs_args, lhs_ret, lhs_env, lhs_unconstrained),
Function(rhs_args, rhs_ret, rhs_env, rhs_unconstrained),
) => {
let args_and_ret_eq = lhs_args == rhs_args && lhs_ret == rhs_ret;
args_and_ret_eq && lhs_env == rhs_env && lhs_unconstrained == rhs_unconstrained
}
(MutableReference(lhs_elem), MutableReference(rhs_elem)) => lhs_elem == rhs_elem,
(Forall(lhs_vars, lhs_type), Forall(rhs_vars, rhs_type)) => {
lhs_vars == rhs_vars && lhs_type == rhs_type
}
(Constant(lhs), Constant(rhs)) => lhs == rhs,
(Quoted(lhs), Quoted(rhs)) => lhs == rhs,
(InfixExpr(l_lhs, l_op, l_rhs), InfixExpr(r_lhs, r_op, r_rhs)) => {
l_lhs == r_lhs && l_op == r_op && l_rhs == r_rhs
}
// Special case: we consider unbound named generics and type variables to be equal to each
// other if their type variable ids match. This is important for some corner cases in
// monomorphization where we call `replace_named_generics_with_type_variables` but
// still want them to be equal for canonicalization checks in arithmetic generics.
// Without this we'd fail the `serialize` test.
(
NamedGeneric(lhs_var, _, _) | TypeVariable(lhs_var, _),
NamedGeneric(rhs_var, _, _) | TypeVariable(rhs_var, _),
) => lhs_var.id() == rhs_var.id(),
_ => false,
}
}
}
Loading

0 comments on commit c23463e

Please sign in to comment.