Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix function instantiation #187

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion examples/array.no
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,18 @@ fn init() -> [Field; size] { // array type can depends on constant var
return [4; size]; // array init with constant var
}

fn init_concrete() -> [Field; 3] {
// as this function won't be monomorphized,
// this is to test this array is constructed as Array instead of GenericSizedArray.
let mut arr = [0; 3];
for idx in 0..3 {
arr[idx] = idx + 1;
}
return arr;
}

fn main(pub public_input: [Field; 2]) {
let xx = [1, 2, 3];
let xx = init_concrete();

assert_eq(public_input[0], xx[0]);
assert_eq(public_input[1], xx[1]);
Expand Down
13 changes: 13 additions & 0 deletions examples/fixture/asm/kimchi/generic_nested_func.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
@ noname.0.7.0

DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
(0,0) -> (4,0)
(1,0) -> (5,0)
(2,0) -> (6,0)
(3,0) -> (4,1) -> (5,1) -> (6,1)
13 changes: 13 additions & 0 deletions examples/fixture/asm/kimchi/generic_nested_method.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
@ noname.0.7.0

DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
(0,0) -> (4,0)
(1,0) -> (5,0)
(2,0) -> (6,0)
(3,0) -> (4,1) -> (5,1) -> (6,1)
5 changes: 5 additions & 0 deletions examples/fixture/asm/r1cs/generic_nested_func.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@ noname.0.7.0

v_4 == (v_1) * (1)
v_4 == (v_2) * (1)
v_4 == (v_3) * (1)
5 changes: 5 additions & 0 deletions examples/fixture/asm/r1cs/generic_nested_method.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@ noname.0.7.0

v_4 == (v_1) * (1)
v_4 == (v_2) * (1)
v_4 == (v_3) * (1)
17 changes: 17 additions & 0 deletions examples/generic_nested_func.no
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
fn nested_func(const LEN: Field) -> [Field; LEN] {
return [0; LEN];
}

fn mod_arr(val: Field) -> [Field; 3] {
// this generic function should be instantiated
let mut result = nested_func(3);
for idx in 0..3 {
result[idx] = val;
}
return result;
}

fn main(pub val: Field) -> [Field; 3] {
let result = mod_arr(val);
return result;
}
21 changes: 21 additions & 0 deletions examples/generic_nested_method.no
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
struct Thing {
xx: Field,
}

fn Thing.nested_func(const LEN: Field) -> [Field; LEN] {
return [0; LEN];
}

fn Thing.mod_arr(self) -> [Field; 3] {
// this generic function should be instantiated
let mut result = self.nested_func(3);
for idx in 0..3 {
result[idx] = self.xx;
}
return result;
}

fn main(pub val: Field) -> [Field; 3] {
let thing = Thing {xx: val};
return thing.mod_arr();
}
17 changes: 9 additions & 8 deletions src/mast/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ use super::MastCtx;
impl Expr {
/// Convert an expression to another expression, with the same span and a regenerated node id.
pub fn to_mast<B: Backend>(&self, ctx: &mut MastCtx<B>, kind: &ExprKind) -> Expr {
if !ctx.in_generic_func {
return self.clone();
}

Expr {
node_id: ctx.next_node_id(),
kind: kind.clone(),
..self.clone()
match ctx.generic_func_scope {
// not in any generic function scope
Some(0) => self.clone(),
// in a generic function scope
_ => Expr {
node_id: ctx.next_node_id(),
kind: kind.clone(),
..self.clone()
},
}
}
}
105 changes: 51 additions & 54 deletions src/mast/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use num_bigint::BigUint;
use num_traits::ToPrimitive;
use std::collections::HashMap;

use crate::{
Expand Down Expand Up @@ -150,7 +151,7 @@ where
B: Backend,
{
tast: TypeChecker<B>,
in_generic_func: bool,
generic_func_scope: Option<usize>,
// fully qualified function name
functions_to_delete: Vec<FullyQualified>,
// fully qualified struct name, method name
Expand All @@ -161,7 +162,7 @@ impl<B: Backend> MastCtx<B> {
pub fn new(tast: TypeChecker<B>) -> Self {
Self {
tast,
in_generic_func: false,
generic_func_scope: Some(0),
functions_to_delete: vec![],
methods_to_delete: vec![],
}
Expand All @@ -174,11 +175,11 @@ impl<B: Backend> MastCtx<B> {
}

pub fn start_monomorphize_func(&mut self) {
self.in_generic_func = true;
self.generic_func_scope = Some(self.generic_func_scope.unwrap() + 1);
}

pub fn finish_monomorphize_func(&mut self) {
self.in_generic_func = false;
self.generic_func_scope = Some(self.generic_func_scope.unwrap() - 1);
}

pub fn add_monomorphized_fn(
Expand All @@ -187,8 +188,11 @@ impl<B: Backend> MastCtx<B> {
new_qualified: FullyQualified,
fn_info: FnInfo<B>,
) {
self.tast.add_monomorphized_fn(new_qualified, fn_info);
self.functions_to_delete.push(old_qualified);
self.tast
.add_monomorphized_fn(new_qualified.clone(), fn_info);
if new_qualified != old_qualified {
self.functions_to_delete.push(old_qualified);
}
}

pub fn add_monomorphized_method(
Expand All @@ -200,8 +204,11 @@ impl<B: Backend> MastCtx<B> {
) {
self.tast
.add_monomorphized_method(struct_qualified.clone(), method_name, fn_info);
self.methods_to_delete
.push((struct_qualified, old_method_name.to_string()));

if method_name != old_method_name {
self.methods_to_delete
.push((struct_qualified, old_method_name.to_string()));
}
}

pub fn clear_generic_fns(&mut self) {
Expand Down Expand Up @@ -411,30 +418,23 @@ fn monomorphize_expr<B: Backend>(
.to_owned();

// monomorphize the function call
let (mexpr, typ) = if fn_info.sig().require_monomorphization() {
let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?;
let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?;

let args_mono = observed.clone().into_iter().map(|e| e.expr).collect();
let args_mono = observed.clone().into_iter().map(|e| e.expr).collect();

let fn_name_mono = &fn_info_mono.sig().name;
let mexpr = Expr {
kind: ExprKind::FnCall {
module: module.clone(),
fn_name: fn_name_mono.clone(),
args: args_mono,
},
..expr.clone()
};

let qualified = FullyQualified::new(module, &fn_name_mono.value);
ctx.add_monomorphized_fn(old_qualified, qualified, fn_info_mono);

(mexpr, typ)
} else {
// otherwise, reuse the expression node and the computed type
(expr.clone(), ctx.tast.expr_type(expr).cloned())
let fn_name_mono = &fn_info_mono.sig().name;
let mexpr = Expr {
kind: ExprKind::FnCall {
module: module.clone(),
fn_name: fn_name_mono.clone(),
args: args_mono,
},
..expr.clone()
};

let qualified = FullyQualified::new(module, &fn_name_mono.value);
ctx.add_monomorphized_fn(old_qualified, qualified, fn_info_mono);

// assume the function call won't return constant value
ExprMonoInfo::new(mexpr, typ, None)
}
Expand Down Expand Up @@ -491,29 +491,22 @@ fn monomorphize_expr<B: Backend>(
}

// monomorphize the function call
let (mexpr, typ) = if fn_info.sig().require_monomorphization() {
let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?;

let fn_name_mono = &fn_info_mono.sig().name;
let mexpr = Expr {
kind: ExprKind::MethodCall {
lhs: Box::new(lhs_mono.expr),
method_name: fn_name_mono.clone(),
args: args_mono,
},
..expr.clone()
};

let fn_def = fn_info_mono.native();
ctx.tast
.add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def);
let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?;

(mexpr, typ)
} else {
// otherwise, reuse the expression node and the computed type
(expr.clone(), ctx.tast.expr_type(expr).cloned())
let fn_name_mono = &fn_info_mono.sig().name;
let mexpr = Expr {
kind: ExprKind::MethodCall {
lhs: Box::new(lhs_mono.expr),
method_name: fn_name_mono.clone(),
args: args_mono,
},
..expr.clone()
};

let fn_def = fn_info_mono.native();
ctx.tast
.add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def);

// assume the function call won't return constant value
ExprMonoInfo::new(mexpr, typ, None)
}
Expand Down Expand Up @@ -566,8 +559,12 @@ fn monomorphize_expr<B: Backend>(
| Op2::BoolOr => lhs_mono.typ,
};

let cst = match (lhs_mono.constant, rhs_mono.constant) {
(Some(lhs), Some(rhs)) => match op {
let ExprMonoInfo { expr: lhs_expr, .. } = lhs_mono;
let ExprMonoInfo { expr: rhs_expr, .. } = rhs_mono;

// fold constants
let cst = match (&lhs_expr.kind, &rhs_expr.kind) {
(ExprKind::BigUInt(lhs), ExprKind::BigUInt(rhs)) => match op {
Op2::Addition => Some(lhs + rhs),
Op2::Subtraction => Some(lhs - rhs),
Op2::Multiplication => Some(lhs * rhs),
Expand All @@ -579,18 +576,18 @@ fn monomorphize_expr<B: Backend>(

match cst {
Some(v) => {
let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(BigUint::from(v)));
let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(v.clone()));

ExprMonoInfo::new(mexpr, typ, Some(v))
ExprMonoInfo::new(mexpr, typ, v.to_u32())
}
None => {
let mexpr = expr.to_mast(
ctx,
&ExprKind::BinaryOp {
op: op.clone(),
protected: *protected,
lhs: Box::new(lhs_mono.expr),
rhs: Box::new(rhs_mono.expr),
lhs: Box::new(lhs_expr),
rhs: Box::new(rhs_expr),
},
);

Expand Down
36 changes: 36 additions & 0 deletions src/tests/examples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,39 @@ fn test_generic_iterator(#[case] backend: BackendKind) -> miette::Result<()> {

Ok(())
}

#[rstest]
#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))]
#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))]
fn test_generic_nested_func(#[case] backend: BackendKind) -> miette::Result<()> {
let public_inputs = r#"{"val":"1"}"#;
let private_inputs = r#"{}"#;

test_file(
"generic_nested_func",
public_inputs,
private_inputs,
vec!["1", "1", "1"],
backend,
)?;

Ok(())
}

#[rstest]
#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))]
#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))]
fn test_generic_nested_method(#[case] backend: BackendKind) -> miette::Result<()> {
let public_inputs = r#"{"val":"1"}"#;
let private_inputs = r#"{}"#;

test_file(
"generic_nested_method",
public_inputs,
private_inputs,
vec!["1", "1", "1"],
backend,
)?;

Ok(())
}
17 changes: 12 additions & 5 deletions src/type_checker/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -538,11 +538,18 @@ impl<B: Backend> TypeChecker<B> {
.expect("expected a typed size");

if is_numeric(&size_node.typ) {
// use generic array as the size node might include generic parameters or constant vars
let res = ExprTyInfo::new_anon(TyKind::GenericSizedArray(
Box::new(item_node.typ),
Symbolic::parse(size)?,
));
let sym = Symbolic::parse(size)?;
let res = if let Symbolic::Concrete(size) = sym {
// if sym is a concrete variant, then just return concrete array type
ExprTyInfo::new_anon(TyKind::Array(Box::new(item_node.typ), size))
} else {
// use generic array as the size node might include generic parameters or constant vars
ExprTyInfo::new_anon(TyKind::GenericSizedArray(
Box::new(item_node.typ),
sym,
))
};

Some(res)
} else {
return Err(self.error(ErrorKind::InvalidArraySize, expr.span));
Expand Down
Loading