From 5261339a3b4dda736b0c74523b568d2d9753f0b3 Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 20 Sep 2024 14:12:12 +0800 Subject: [PATCH 1/7] fix: unnecessary GenericSizedArray type --- examples/array.no | 12 +++++++++++- src/type_checker/checker.rs | 17 ++++++++++++----- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/examples/array.no b/examples/array.no index adb3f913f..84cca9b37 100644 --- a/examples/array.no +++ b/examples/array.no @@ -4,8 +4,18 @@ fn init() -> [Field; size] { // array type can depends on constant var return [4; size]; // array init with constant var } +fn init_concrete() -> [Field; 3] { + // as this function won't be monomorphized, + // this is to test this array is constructed as Array instead of GenericSizedArray. + let mut arr = [0; 3]; + for idx in 0..3 { + arr[idx] = idx + 1; + } + return arr; +} + fn main(pub public_input: [Field; 2]) { - let xx = [1, 2, 3]; + let xx = init_concrete(); assert_eq(public_input[0], xx[0]); assert_eq(public_input[1], xx[1]); diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index be2393f86..76accc207 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -538,11 +538,18 @@ impl TypeChecker { .expect("expected a typed size"); if is_numeric(&size_node.typ) { - // use generic array as the size node might include generic parameters or constant vars - let res = ExprTyInfo::new_anon(TyKind::GenericSizedArray( - Box::new(item_node.typ), - Symbolic::parse(size)?, - )); + let sym = Symbolic::parse(size)?; + let res = if let Symbolic::Concrete(size) = sym { + // if sym is a concrete variant, then just return concrete array type + ExprTyInfo::new_anon(TyKind::Array(Box::new(item_node.typ), size)) + } else { + // use generic array as the size node might include generic parameters or constant vars + ExprTyInfo::new_anon(TyKind::GenericSizedArray( + Box::new(item_node.typ), + sym, + )) + }; + Some(res) } else { return Err(self.error(ErrorKind::InvalidArraySize, expr.span)); From 1f7f80ce00ea887dc362b530f82bed50c3aa63d1 Mon Sep 17 00:00:00 2001 From: kata Date: Fri, 20 Sep 2024 13:47:42 +0800 Subject: [PATCH 2/7] fix: record generic function scopes to correctly decide if it needs to create a mast node --- src/mast/ast.rs | 17 +++++++++-------- src/mast/mod.rs | 8 ++++---- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/mast/ast.rs b/src/mast/ast.rs index a76c4dbc8..3dbde0da2 100644 --- a/src/mast/ast.rs +++ b/src/mast/ast.rs @@ -8,14 +8,15 @@ use super::MastCtx; impl Expr { /// Convert an expression to another expression, with the same span and a regenerated node id. pub fn to_mast(&self, ctx: &mut MastCtx, kind: &ExprKind) -> Expr { - if !ctx.in_generic_func { - return self.clone(); - } - - Expr { - node_id: ctx.next_node_id(), - kind: kind.clone(), - ..self.clone() + match ctx.generic_func_scope { + // not in any generic function scope + Some(0) => self.clone(), + // in a generic function scope + _ => Expr { + node_id: ctx.next_node_id(), + kind: kind.clone(), + ..self.clone() + }, } } } diff --git a/src/mast/mod.rs b/src/mast/mod.rs index c3accdfbc..2df9519fb 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -150,7 +150,7 @@ where B: Backend, { tast: TypeChecker, - in_generic_func: bool, + generic_func_scope: Option, // fully qualified function name functions_to_delete: Vec, // fully qualified struct name, method name @@ -161,7 +161,7 @@ impl MastCtx { pub fn new(tast: TypeChecker) -> Self { Self { tast, - in_generic_func: false, + generic_func_scope: Some(0), functions_to_delete: vec![], methods_to_delete: vec![], } @@ -174,11 +174,11 @@ impl MastCtx { } pub fn start_monomorphize_func(&mut self) { - self.in_generic_func = true; + self.generic_func_scope = Some(self.generic_func_scope.unwrap() + 1); } pub fn finish_monomorphize_func(&mut self) { - self.in_generic_func = false; + self.generic_func_scope = Some(self.generic_func_scope.unwrap() - 1); } pub fn add_monomorphized_fn( From 1c82c3536c378b7f952bd871e43a99e15f6abdcf Mon Sep 17 00:00:00 2001 From: kata Date: Sat, 21 Sep 2024 17:10:28 +0800 Subject: [PATCH 3/7] instantiate all functions --- src/mast/mod.rs | 67 ++++++++++++++++++++----------------------------- 1 file changed, 27 insertions(+), 40 deletions(-) diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 2df9519fb..d323148a4 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -411,30 +411,23 @@ fn monomorphize_expr( .to_owned(); // monomorphize the function call - let (mexpr, typ) = if fn_info.sig().require_monomorphization() { - let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; + let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; - let args_mono = observed.clone().into_iter().map(|e| e.expr).collect(); + let args_mono = observed.clone().into_iter().map(|e| e.expr).collect(); - let fn_name_mono = &fn_info_mono.sig().name; - let mexpr = Expr { - kind: ExprKind::FnCall { - module: module.clone(), - fn_name: fn_name_mono.clone(), - args: args_mono, - }, - ..expr.clone() - }; - - let qualified = FullyQualified::new(module, &fn_name_mono.value); - ctx.add_monomorphized_fn(old_qualified, qualified, fn_info_mono); - - (mexpr, typ) - } else { - // otherwise, reuse the expression node and the computed type - (expr.clone(), ctx.tast.expr_type(expr).cloned()) + let fn_name_mono = &fn_info_mono.sig().name; + let mexpr = Expr { + kind: ExprKind::FnCall { + module: module.clone(), + fn_name: fn_name_mono.clone(), + args: args_mono, + }, + ..expr.clone() }; + let qualified = FullyQualified::new(module, &fn_name_mono.value); + ctx.add_monomorphized_fn(old_qualified, qualified, fn_info_mono); + // assume the function call won't return constant value ExprMonoInfo::new(mexpr, typ, None) } @@ -491,29 +484,23 @@ fn monomorphize_expr( } // monomorphize the function call - let (mexpr, typ) = if fn_info.sig().require_monomorphization() { - let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; - - let fn_name_mono = &fn_info_mono.sig().name; - let mexpr = Expr { - kind: ExprKind::MethodCall { - lhs: Box::new(lhs_mono.expr), - method_name: fn_name_mono.clone(), - args: args_mono, - }, - ..expr.clone() - }; + let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; - let fn_def = fn_info_mono.native(); - ctx.tast - .add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def); - - (mexpr, typ) - } else { - // otherwise, reuse the expression node and the computed type - (expr.clone(), ctx.tast.expr_type(expr).cloned()) + let fn_name_mono = &fn_info_mono.sig().name; + let mexpr = Expr { + kind: ExprKind::MethodCall { + lhs: Box::new(lhs_mono.expr), + method_name: fn_name_mono.clone(), + args: args_mono, + }, + ..expr.clone() }; + let fn_def = fn_info_mono.native(); + ctx.tast + .add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def); + + // assume the function call won't return constant value ExprMonoInfo::new(mexpr, typ, None) } From 39157c324b1b2a04e327117c4ae17e95d94b302f Mon Sep 17 00:00:00 2001 From: kata Date: Sat, 21 Sep 2024 17:12:50 +0800 Subject: [PATCH 4/7] fix: const folding based on expr node --- src/mast/mod.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/mast/mod.rs b/src/mast/mod.rs index d323148a4..c97dbe857 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -1,4 +1,5 @@ use num_bigint::BigUint; +use num_traits::ToPrimitive; use std::collections::HashMap; use crate::{ @@ -553,8 +554,12 @@ fn monomorphize_expr( | Op2::BoolOr => lhs_mono.typ, }; - let cst = match (lhs_mono.constant, rhs_mono.constant) { - (Some(lhs), Some(rhs)) => match op { + let ExprMonoInfo { expr: lhs_expr, .. } = lhs_mono; + let ExprMonoInfo { expr: rhs_expr, .. } = rhs_mono; + + // fold constants + let cst = match (&lhs_expr.kind, &rhs_expr.kind) { + (ExprKind::BigUInt(lhs), ExprKind::BigUInt(rhs)) => match op { Op2::Addition => Some(lhs + rhs), Op2::Subtraction => Some(lhs - rhs), Op2::Multiplication => Some(lhs * rhs), @@ -566,9 +571,9 @@ fn monomorphize_expr( match cst { Some(v) => { - let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(BigUint::from(v))); + let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(v.clone())); - ExprMonoInfo::new(mexpr, typ, Some(v)) + ExprMonoInfo::new(mexpr, typ, v.to_u32()) } None => { let mexpr = expr.to_mast( @@ -576,8 +581,8 @@ fn monomorphize_expr( &ExprKind::BinaryOp { op: op.clone(), protected: *protected, - lhs: Box::new(lhs_mono.expr), - rhs: Box::new(rhs_mono.expr), + lhs: Box::new(lhs_expr), + rhs: Box::new(rhs_expr), }, ); From bff8d27023395f0f58ea6d636e8a531c9b56c4ba Mon Sep 17 00:00:00 2001 From: kata Date: Sat, 21 Sep 2024 17:13:37 +0800 Subject: [PATCH 5/7] fix: correctly remove unused functions --- src/mast/mod.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/mast/mod.rs b/src/mast/mod.rs index c97dbe857..bde7cee70 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -188,8 +188,11 @@ impl MastCtx { new_qualified: FullyQualified, fn_info: FnInfo, ) { - self.tast.add_monomorphized_fn(new_qualified, fn_info); - self.functions_to_delete.push(old_qualified); + self.tast + .add_monomorphized_fn(new_qualified.clone(), fn_info); + if new_qualified != old_qualified { + self.functions_to_delete.push(old_qualified); + } } pub fn add_monomorphized_method( @@ -201,8 +204,11 @@ impl MastCtx { ) { self.tast .add_monomorphized_method(struct_qualified.clone(), method_name, fn_info); - self.methods_to_delete - .push((struct_qualified, old_method_name.to_string())); + + if method_name != old_method_name { + self.methods_to_delete + .push((struct_qualified, old_method_name.to_string())); + } } pub fn clear_generic_fns(&mut self) { From 1b027e8547d323b84e9c0eb5394830746ea9ea64 Mon Sep 17 00:00:00 2001 From: kata Date: Sat, 21 Sep 2024 17:14:10 +0800 Subject: [PATCH 6/7] add tests --- .../asm/kimchi/generic_nested_func.asm | 13 +++++++ .../asm/kimchi/generic_nested_method.asm | 13 +++++++ .../fixture/asm/r1cs/generic_nested_func.asm | 5 +++ .../asm/r1cs/generic_nested_method.asm | 5 +++ examples/generic_nested_func.no | 17 +++++++++ examples/generic_nested_method.no | 21 +++++++++++ src/tests/examples.rs | 36 +++++++++++++++++++ 7 files changed, 110 insertions(+) create mode 100644 examples/fixture/asm/kimchi/generic_nested_func.asm create mode 100644 examples/fixture/asm/kimchi/generic_nested_method.asm create mode 100644 examples/fixture/asm/r1cs/generic_nested_func.asm create mode 100644 examples/fixture/asm/r1cs/generic_nested_method.asm create mode 100644 examples/generic_nested_func.no create mode 100644 examples/generic_nested_method.no diff --git a/examples/fixture/asm/kimchi/generic_nested_func.asm b/examples/fixture/asm/kimchi/generic_nested_func.asm new file mode 100644 index 000000000..ea83d5c55 --- /dev/null +++ b/examples/fixture/asm/kimchi/generic_nested_func.asm @@ -0,0 +1,13 @@ +@ noname.0.7.0 + +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1,-1> +DoubleGeneric<1,-1> +DoubleGeneric<1,-1> +(0,0) -> (4,0) +(1,0) -> (5,0) +(2,0) -> (6,0) +(3,0) -> (4,1) -> (5,1) -> (6,1) diff --git a/examples/fixture/asm/kimchi/generic_nested_method.asm b/examples/fixture/asm/kimchi/generic_nested_method.asm new file mode 100644 index 000000000..ea83d5c55 --- /dev/null +++ b/examples/fixture/asm/kimchi/generic_nested_method.asm @@ -0,0 +1,13 @@ +@ noname.0.7.0 + +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1> +DoubleGeneric<1,-1> +DoubleGeneric<1,-1> +DoubleGeneric<1,-1> +(0,0) -> (4,0) +(1,0) -> (5,0) +(2,0) -> (6,0) +(3,0) -> (4,1) -> (5,1) -> (6,1) diff --git a/examples/fixture/asm/r1cs/generic_nested_func.asm b/examples/fixture/asm/r1cs/generic_nested_func.asm new file mode 100644 index 000000000..4d0e4621a --- /dev/null +++ b/examples/fixture/asm/r1cs/generic_nested_func.asm @@ -0,0 +1,5 @@ +@ noname.0.7.0 + +v_4 == (v_1) * (1) +v_4 == (v_2) * (1) +v_4 == (v_3) * (1) diff --git a/examples/fixture/asm/r1cs/generic_nested_method.asm b/examples/fixture/asm/r1cs/generic_nested_method.asm new file mode 100644 index 000000000..4d0e4621a --- /dev/null +++ b/examples/fixture/asm/r1cs/generic_nested_method.asm @@ -0,0 +1,5 @@ +@ noname.0.7.0 + +v_4 == (v_1) * (1) +v_4 == (v_2) * (1) +v_4 == (v_3) * (1) diff --git a/examples/generic_nested_func.no b/examples/generic_nested_func.no new file mode 100644 index 000000000..e12a84430 --- /dev/null +++ b/examples/generic_nested_func.no @@ -0,0 +1,17 @@ +fn nested_func(const LEN: Field) -> [Field; LEN] { + return [0; LEN]; +} + +fn mod_arr(val: Field) -> [Field; 3] { + // this generic function should be instantiated + let mut result = nested_func(3); + for idx in 0..3 { + result[idx] = val; + } + return result; +} + +fn main(pub val: Field) -> [Field; 3] { + let result = mod_arr(val); + return result; +} \ No newline at end of file diff --git a/examples/generic_nested_method.no b/examples/generic_nested_method.no new file mode 100644 index 000000000..b5d087b63 --- /dev/null +++ b/examples/generic_nested_method.no @@ -0,0 +1,21 @@ +struct Thing { + xx: Field, +} + +fn Thing.nested_func(const LEN: Field) -> [Field; LEN] { + return [0; LEN]; +} + +fn Thing.mod_arr(self) -> [Field; 3] { + // this generic function should be instantiated + let mut result = self.nested_func(3); + for idx in 0..3 { + result[idx] = self.xx; + } + return result; +} + +fn main(pub val: Field) -> [Field; 3] { + let thing = Thing {xx: val}; + return thing.mod_arr(); +} \ No newline at end of file diff --git a/src/tests/examples.rs b/src/tests/examples.rs index 34065e460..3aade9de0 100644 --- a/src/tests/examples.rs +++ b/src/tests/examples.rs @@ -664,3 +664,39 @@ fn test_generic_iterator(#[case] backend: BackendKind) -> miette::Result<()> { Ok(()) } + +#[rstest] +#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))] +#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))] +fn test_generic_nested_func(#[case] backend: BackendKind) -> miette::Result<()> { + let public_inputs = r#"{"val":"1"}"#; + let private_inputs = r#"{}"#; + + test_file( + "generic_nested_func", + public_inputs, + private_inputs, + vec!["1", "1", "1"], + backend, + )?; + + Ok(()) +} + +#[rstest] +#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))] +#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))] +fn test_generic_nested_method(#[case] backend: BackendKind) -> miette::Result<()> { + let public_inputs = r#"{"val":"1"}"#; + let private_inputs = r#"{}"#; + + test_file( + "generic_nested_method", + public_inputs, + private_inputs, + vec!["1", "1", "1"], + backend, + )?; + + Ok(()) +} From 8fa2dba05b5a3bbb913e463df5735317322df998 Mon Sep 17 00:00:00 2001 From: kata Date: Sat, 21 Sep 2024 17:18:04 +0800 Subject: [PATCH 7/7] fmt --- src/mast/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mast/mod.rs b/src/mast/mod.rs index bde7cee70..709086106 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -507,7 +507,6 @@ fn monomorphize_expr( ctx.tast .add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def); - // assume the function call won't return constant value ExprMonoInfo::new(mexpr, typ, None) }