From 9a43f85a055f23e5746e6836fe11990f4c87bbdc Mon Sep 17 00:00:00 2001 From: jfecher Date: Thu, 30 Mar 2023 14:38:49 +0100 Subject: [PATCH] feat: Implement `std::unsafe::zeroed` (#1048) * Add zeroed builtin * Implement std::unsafe::zeroed * Fix merge conflict --- .../src/monomorphization/mod.rs | 128 +++++++++++++----- noir_stdlib/src/lib.nr | 1 + noir_stdlib/src/unsafe.nr | 5 + 3 files changed, 100 insertions(+), 34 deletions(-) create mode 100644 noir_stdlib/src/unsafe.nr diff --git a/crates/noirc_frontend/src/monomorphization/mod.rs b/crates/noirc_frontend/src/monomorphization/mod.rs index cb06be7c1c6..26009c0227e 100644 --- a/crates/noirc_frontend/src/monomorphization/mod.rs +++ b/crates/noirc_frontend/src/monomorphization/mod.rs @@ -700,12 +700,8 @@ impl<'interner> Monomorphizer<'interner> { let return_type = Self::convert_type(&return_type); let location = call.location; - self.try_evaluate_call(&func, &call.arguments).unwrap_or(ast::Expression::Call(ast::Call { - func, - arguments, - return_type, - location, - })) + self.try_evaluate_call(&func, &call.arguments, &return_type) + .unwrap_or(ast::Expression::Call(ast::Call { func, arguments, return_type, location })) } /// Try to evaluate certain builtin functions (currently only 'array_len' and field modulus methods) @@ -715,50 +711,47 @@ impl<'interner> Monomorphizer<'interner> { /// To fix this we need to evaluate on the identifier instead, which /// requires us to evaluate to a Lambda value which isn't in noir yet. fn try_evaluate_call( - &self, + &mut self, func: &ast::Expression, arguments: &[node_interner::ExprId], + result_type: &ast::Type, ) -> Option { - match func { - ast::Expression::Ident(ident) => match &ident.definition { - Definition::Builtin(opcode) if opcode == "array_len" => { + if let ast::Expression::Ident(ident) = func { + if let Definition::Builtin(opcode) = &ident.definition { + if opcode == "array_len" { let typ = self.interner.id_type(arguments[0]); let len = typ.evaluate_to_u64().unwrap(); - Some(ast::Expression::Literal(ast::Literal::Integer( + return Some(ast::Expression::Literal(ast::Literal::Integer( (len as u128).into(), ast::Type::Field, - ))) - } - Definition::Builtin(opcode) if opcode == "modulus_num_bits" => { - Some(ast::Expression::Literal(ast::Literal::Integer( + ))); + } else if opcode == "modulus_num_bits" { + return Some(ast::Expression::Literal(ast::Literal::Integer( (FieldElement::max_num_bits() as u128).into(), ast::Type::Field, - ))) + ))); + } else if opcode == "zeroed" { + return Some(self.zeroed_value_of_type(result_type)); } - Definition::Builtin(opcode) if opcode == "modulus_le_bits" => { - let modulus = FieldElement::modulus(); + + let modulus = FieldElement::modulus(); + + if opcode == "modulus_le_bits" { let bits = modulus.to_radix_le(2); - Some(self.modulus_array_literal(bits, 1)) - } - Definition::Builtin(opcode) if opcode == "modulus_be_bits" => { - let modulus = FieldElement::modulus(); + return Some(self.modulus_array_literal(bits, 1)); + } else if opcode == "modulus_be_bits" { let bits = modulus.to_radix_be(2); - Some(self.modulus_array_literal(bits, 1)) - } - Definition::Builtin(opcode) if opcode == "modulus_be_bytes" => { - let modulus = FieldElement::modulus(); + return Some(self.modulus_array_literal(bits, 1)); + } else if opcode == "modulus_be_bytes" { let bytes = modulus.to_bytes_be(); - Some(self.modulus_array_literal(bytes, 8)) - } - Definition::Builtin(opcode) if opcode == "modulus_le_bytes" => { - let modulus = FieldElement::modulus(); + return Some(self.modulus_array_literal(bytes, 8)); + } else if opcode == "modulus_le_bytes" { let bytes = modulus.to_bytes_le(); - Some(self.modulus_array_literal(bytes, 8)) + return Some(self.modulus_array_literal(bytes, 8)); } - _ => None, - }, - _ => None, + } } + None } fn modulus_array_literal(&self, bytes: Vec, arr_elem_bits: u32) -> ast::Expression { @@ -919,6 +912,73 @@ impl<'interner> Monomorphizer<'interner> { typ, }) } + + /// Implements std::unsafe::zeroed by returning an appropriate zeroed + /// ast literal or collection node for the given type. Note that for functions + /// there is no obvious zeroed value so this should be considered unsafe to use. + fn zeroed_value_of_type(&mut self, typ: &ast::Type) -> ast::Expression { + match typ { + ast::Type::Field | ast::Type::Integer(..) => { + ast::Expression::Literal(ast::Literal::Integer(0_u128.into(), typ.clone())) + } + ast::Type::Bool => ast::Expression::Literal(ast::Literal::Bool(false)), + // There is no unit literal currently. Replace it with 'false' since it should be ignored + // anyway. + ast::Type::Unit => ast::Expression::Literal(ast::Literal::Bool(false)), + ast::Type::Array(length, element_type) => { + let element = self.zeroed_value_of_type(element_type.as_ref()); + ast::Expression::Literal(ast::Literal::Array(ast::ArrayLiteral { + contents: vec![element; *length as usize], + element_type: element_type.as_ref().clone(), + })) + } + ast::Type::String(length) => { + ast::Expression::Literal(ast::Literal::Str("\0".repeat(*length as usize))) + } + ast::Type::Tuple(fields) => { + ast::Expression::Tuple(vecmap(fields, |field| self.zeroed_value_of_type(field))) + } + ast::Type::Function(parameter_types, ret_type) => { + self.create_zeroed_function(parameter_types, ret_type) + } + } + } + + // Creating a zeroed function value is almost always an error if it is used later, + // Hence why std::unsafe::zeroed is unsafe. + // + // To avoid confusing later passes, we arbitrarily choose to construct a function + // that satisfies the input type by discarding all its parameters and returning a + // zeroed value of the result type. + fn create_zeroed_function( + &mut self, + parameter_types: &[ast::Type], + ret_type: &ast::Type, + ) -> ast::Expression { + let lambda_name = "zeroed_lambda"; + + let parameters = vecmap(parameter_types, |parameter_type| { + (self.next_local_id(), false, "_".into(), parameter_type.clone()) + }); + + let body = self.zeroed_value_of_type(ret_type); + + let id = self.next_function_id(); + let return_type = ret_type.clone(); + let name = lambda_name.to_owned(); + + let unconstrained = false; + let function = ast::Function { id, name, parameters, body, return_type, unconstrained }; + self.push_function(id, function); + + ast::Expression::Ident(ast::Ident { + definition: Definition::Function(id), + mutable: false, + location: None, + name: lambda_name.to_owned(), + typ: ast::Type::Function(parameter_types.to_owned(), Box::new(ret_type.clone())), + }) + } } fn unwrap_tuple_type(typ: &HirType) -> Vec { diff --git a/noir_stdlib/src/lib.nr b/noir_stdlib/src/lib.nr index abdd56c4975..16383c2c704 100644 --- a/noir_stdlib/src/lib.nr +++ b/noir_stdlib/src/lib.nr @@ -8,6 +8,7 @@ mod sha256; mod sha512; mod field; mod ec; +mod unsafe; #[builtin(println)] fn println(_input : T) {} diff --git a/noir_stdlib/src/unsafe.nr b/noir_stdlib/src/unsafe.nr new file mode 100644 index 00000000000..a28549d5011 --- /dev/null +++ b/noir_stdlib/src/unsafe.nr @@ -0,0 +1,5 @@ +/// For any type, return an instance of that type by initializing +/// all of its fields to 0. This is considered to be unsafe since there +/// is no guarantee that all zeroes is a valid bit pattern for every type. +#[builtin(zeroed)] +fn zeroed() -> T {}