diff --git a/bin/main.ml b/bin/main.ml index a69c7ee..08fa5db 100644 --- a/bin/main.ml +++ b/bin/main.ml @@ -131,6 +131,7 @@ Supported options:|} Eurydice.Logging.log "Phase2.5" "%a" pfiles files; let files = Krml.Inlining.cross_call_analysis files in let files = Krml.Simplify.remove_unused files in + let files = Eurydice.Cleanup2.remove_array_from_fn#visit_files () files in (* Macros stemming from globals *) let files, macros = Eurydice.Cleanup2.build_macros files in diff --git a/lib/AstOfLlbc.ml b/lib/AstOfLlbc.ml index f20c3de..f096db2 100644 --- a/lib/AstOfLlbc.ml +++ b/lib/AstOfLlbc.ml @@ -294,11 +294,10 @@ let rec typ_of_ty (env: env) (ty: Charon.Types.ty): K.typ = | TAdt (TTuple, { types = args; const_generics; _ }) -> assert (const_generics = []); - if args = [] then - TUnit - else begin - assert (List.length args > 1); - TTuple (List.map (typ_of_ty env) args) + begin match args with + | [] -> TUnit + | [ t ] -> typ_of_ty env t (* happens with closures *) + | _ -> TTuple (List.map (typ_of_ty env) args) end | TAdt (TAssumed TArray, { types = [ t ]; const_generics = [ cg ]; _ }) -> @@ -587,6 +586,114 @@ let maybe_addrof (env: env) (ty: C.ty) (e: K.expr) = | _ -> K.(with_type (TBuf (e.typ, false)) (EAddrOf e)) +type lookup_result = { + name: K.lident; + n_type_args: int; (* just for a sanity check *) + cg_types: K.typ list; + arg_types: K.typ list; + ret_type: K.typ; + is_assumed: bool; +} + +let lookup_fun (env: env) (f: C.fn_ptr): lookup_result = + let open RustNames in + let matches p = Charon.NameMatcher.match_fn_ptr env.name_ctx RustNames.config p f in + let builtin b = + let { Builtin.name; typ; n_type_args; cg_args; _ } = b in + let ret_type, arg_types = Krml.Helpers.flatten_arrow typ in + { name; n_type_args; arg_types; ret_type; cg_types = cg_args; is_assumed = true } + in + match List.find_opt (fun (p, _) -> matches p) known_builtins with + | Some (_, b) -> + builtin b + | None -> + let regular f = + let { C.name; signature = { generics = { types = type_params; const_generics; _ }; inputs; output; _ }; _ } = env.get_nth_function f in + L.log "Calls" "--> name: %s" (string_of_name env name); + L.log "Calls" "--> args: %s, ret: %s" + (String.concat " ++ " (List.map (Charon.PrintTypes.ty_to_string env.format_env) inputs)) + (Charon.PrintTypes.ty_to_string env.format_env output); + let env = push_cg_binders env const_generics in + let env = push_type_binders env type_params in + { + name = lid_of_name env name; + n_type_args = List.length type_params; + cg_types = List.map (fun (v: C.const_generic_var) -> typ_of_literal_ty env v.ty) const_generics; + arg_types = List.map (typ_of_ty env) inputs; + ret_type = typ_of_ty env output; + is_assumed = false + } + in + match f.func with + | FunId (FRegular f) -> + regular f + + | FunId (FAssumed f) -> + Krml.Warn.fatal_error "unknown assumed function: %s" (C.show_assumed_fun_id f) + + | TraitMethod (trait_ref, method_name, _trait_opaque_signature) -> + match trait_ref.trait_id with + | TraitImpl id -> + let trait = env.get_nth_trait_impl id in + let f = List.assoc method_name trait.required_methods in + regular f + | _ -> + Krml.Warn.fatal_error "Error looking trait ref: %s %s" + (Charon.PrintTypes.trait_ref_to_string env.format_env trait_ref) method_name + +let expression_of_fn_ptr env (fn_ptr: C.fn_ptr) = + let { + C.generics = { types = type_args; const_generics = const_generic_args; trait_refs; _ }; + trait_and_method_generic_args; + _ + } = fn_ptr in + + (* General case for function calls and trait method calls. *) + L.log "Calls" "Visiting call: %s" (Charon.PrintExpressions.fn_ptr_to_string env.format_env fn_ptr); + L.log "Calls" "is_array_map: %b" (RustNames.is_array_map env fn_ptr); + L.log "Calls" "--> %d type_args, %d const_generics, %d trait_refs" + (List.length type_args) (List.length const_generic_args) (List.length trait_refs); + + let type_args, const_generic_args, trait_refs = + match trait_and_method_generic_args with + | None -> + type_args, const_generic_args, trait_refs + | Some { types; const_generics; trait_refs; _ } -> + types @ type_args, const_generics @ const_generic_args, trait_refs @ trait_refs + in + L.log "Calls" "--> %d type_args, %d const_generics, %d trait_refs" + (List.length type_args) (List.length const_generic_args) (List.length trait_refs); + L.log "Calls" "--> trait_refs: %s\n" + (String.concat " ++ " (List.map (Charon.PrintTypes.trait_ref_to_string env.format_env) trait_refs)); + L.log "Calls" "--> pattern: %s" (string_of_fn_ptr env fn_ptr); + + let type_args = List.map (typ_of_ty env) type_args in + let const_generic_args = List.map (expression_of_const_generic env) const_generic_args in + let { name; n_type_args = n_type_params; arg_types = inputs; ret_type = output; cg_types = cg_inputs; is_assumed } = + lookup_fun env fn_ptr + in + + let inputs = if inputs = [] && not is_assumed then [ K.TUnit ] else inputs in + if not (n_type_params = List.length type_args) then + Krml.Warn.fatal_error "%a: n_type_params %d != type_args %d" + Krml.PrintAst.Ops.plid name + n_type_params (List.length type_args); + let poly_t_sans_cgs = Krml.Helpers.fold_arrow inputs output in + let poly_t = Krml.Helpers.fold_arrow cg_inputs poly_t_sans_cgs in + let output, t = + Krml.DeBruijn.(subst_ctn (List.length env.binders) const_generic_args (subst_tn type_args output)), + Krml.DeBruijn.(subst_ctn (List.length env.binders) const_generic_args (subst_tn type_args poly_t_sans_cgs)) + in + let hd = + let hd = K.with_type poly_t (K.EQualified name) in + if type_args <> [] || const_generic_args <> [] then + K.with_type t (K.ETApp (hd, const_generic_args, type_args)) + else + hd + in + hd, is_assumed, output + + let expression_of_rvalue (env: env) (p: C.rvalue): K.expr = match p with | Use op -> @@ -632,8 +739,21 @@ let expression_of_rvalue (env: env) (p: C.rvalue): K.expr = end | Aggregate (AggregatedAdt (TAssumed _, _, _), _) -> failwith "unsupported: AggregatedAdt / TAssume" - | Aggregate (AggregatedClosure _, _) -> - failwith "unsupported: AggregatedClosure" + + | Aggregate (AggregatedClosure (func, generics), ops) -> + if ops <> [] then + failwith (Printf.sprintf "unsupported: AggregatedClosure (TODO: closure conversion): %d" (List.length ops)) + else + let fun_ptr = { C.func = C.FunId (FRegular func); generics; trait_and_method_generic_args = None } in + let e, _, _ = expression_of_fn_ptr env fun_ptr in + begin match e.typ with + | TArrow (TBuf (TUnit, _) as t_state, t) -> + (* Empty closure block, passed by address...? TBD *) + K.(with_type t (EApp (e, [ with_type t_state (EAddrOf Krml.Helpers.eunit) ]))) + | _ -> + assert false + end + | Aggregate (AggregatedArray (t, cg), ops) -> K.with_type (TArray (typ_of_ty env t, constant_of_scalar_value (assert_cg_scalar cg))) (K.EBufCreateL (Stack, List.map (expression_of_operand env) ops)) | Global id -> @@ -652,61 +772,6 @@ let expression_of_assertion (env: env) ({ cond; expected }: C.assertion): K.expr with_type TAny (EAbort (None, Some "assert failure")), Krml.Helpers.eunit))) -type lookup_result = { - name: K.lident; - n_type_args: int; (* just for a sanity check *) - cg_types: K.typ list; - arg_types: K.typ list; - ret_type: K.typ; - is_assumed: bool; -} - -let lookup_fun (env: env) (f: C.fn_ptr): lookup_result = - let open RustNames in - let matches p = Charon.NameMatcher.match_fn_ptr env.name_ctx RustNames.config p f in - let builtin b = - let { Builtin.name; typ; n_type_args; cg_args; _ } = b in - let ret_type, arg_types = Krml.Helpers.flatten_arrow typ in - { name; n_type_args; arg_types; ret_type; cg_types = cg_args; is_assumed = true } - in - match List.find_opt (fun (p, _) -> matches p) known_builtins with - | Some (_, b) -> - builtin b - | None -> - let regular f = - let { C.name; signature = { generics = { types = type_params; const_generics; _ }; inputs; output; _ }; _ } = env.get_nth_function f in - L.log "Calls" "--> name: %s" (string_of_name env name); - L.log "Calls" "--> args: %s, ret: %s" - (String.concat " ++ " (List.map (Charon.PrintTypes.ty_to_string env.format_env) inputs)) - (Charon.PrintTypes.ty_to_string env.format_env output); - let env = push_cg_binders env const_generics in - let env = push_type_binders env type_params in - { - name = lid_of_name env name; - n_type_args = List.length type_params; - cg_types = List.map (fun (v: C.const_generic_var) -> typ_of_literal_ty env v.ty) const_generics; - arg_types = List.map (typ_of_ty env) inputs; - ret_type = typ_of_ty env output; - is_assumed = false - } - in - match f.func with - | FunId (FRegular f) -> - regular f - - | FunId (FAssumed f) -> - Krml.Warn.fatal_error "unknown assumed function: %s" (C.show_assumed_fun_id f) - - | TraitMethod (trait_ref, method_name, _trait_opaque_signature) -> - match trait_ref.trait_id with - | TraitImpl id -> - let trait = env.get_nth_trait_impl id in - let f = List.assoc method_name trait.required_methods in - regular f - | _ -> - Krml.Warn.fatal_error "Error looking trait ref: %s %s" - (Charon.PrintTypes.trait_ref_to_string env.format_env trait_ref) method_name - let lesser t1 t2 = if t1 = K.TAny then t2 @@ -819,63 +884,19 @@ let rec expression_of_raw_statement (env: env) (ret_var: C.var_id) (s: C.raw_sta H.with_unit (K.EBufWrite (Krml.DeBruijn.lift 1 dest, i, K.with_type t (K.EBufRead (Krml.DeBruijn.lift 1 src, i)))))) - | Call { func = FnOpRegular ({ - func; - generics = { types = type_args; const_generics = const_generic_args; trait_refs; _ }; - trait_and_method_generic_args - } as fn_ptr); args; dest; _ } - -> - (* General case for function calls and trait method calls. *) - L.log "Calls" "Visiting call: %s" (Charon.PrintExpressions.fn_ptr_to_string env.format_env fn_ptr); - L.log "Calls" "is_array_map: %b" (RustNames.is_array_map env fn_ptr); - L.log "Calls" "--> %d type_args, %d const_generics, %d trait_refs" - (List.length type_args) (List.length const_generic_args) (List.length trait_refs); + | Call { func = FnOpRegular fn_ptr; args; dest; _ } -> (* For now, we take trait type arguments to be part of the code-gen *) - let type_args, const_generic_args, trait_refs = - match trait_and_method_generic_args with - | None -> - type_args, const_generic_args, trait_refs - | Some { types; const_generics; trait_refs; _ } -> - types @ type_args, const_generics @ const_generic_args, trait_refs @ trait_refs - in - L.log "Calls" "--> %d type_args, %d const_generics, %d trait_refs" - (List.length type_args) (List.length const_generic_args) (List.length trait_refs); - L.log "Calls" "--> trait_refs: %s\n" - (String.concat " ++ " (List.map (Charon.PrintTypes.trait_ref_to_string env.format_env) trait_refs)); - L.log "Calls" "--> pattern: %s" (string_of_fn_ptr env fn_ptr); - + let hd, is_assumed, output_t = expression_of_fn_ptr env fn_ptr in let dest, _ = expression_of_place env dest in let args = List.map (expression_of_operand env) args in - let original_type_args = type_args in - let type_args = List.map (typ_of_ty env) type_args in - let const_generic_args = List.map (expression_of_const_generic env) const_generic_args in - let { name; n_type_args = n_type_params; arg_types = inputs; ret_type = output; cg_types = cg_inputs; is_assumed } = - lookup_fun env fn_ptr - in let args = if args = [] && not is_assumed then [ Krml.Helpers.eunit ] else args in - let inputs = if inputs = [] && not is_assumed then [ K.TUnit ] else inputs in - if not (n_type_params = List.length type_args) then - Krml.Warn.fatal_error "%a: n_type_params %d != type_args %d" - Krml.PrintAst.Ops.plid name - n_type_params (List.length type_args); - let poly_t_sans_cgs = Krml.Helpers.fold_arrow inputs output in - let poly_t = Krml.Helpers.fold_arrow cg_inputs poly_t_sans_cgs in - let output, t = - Krml.DeBruijn.(subst_ctn (List.length env.binders) const_generic_args (subst_tn type_args output)), - Krml.DeBruijn.(subst_ctn (List.length env.binders) const_generic_args (subst_tn type_args poly_t_sans_cgs)) - in - let hd = - let hd = K.with_type poly_t (K.EQualified name) in - if type_args <> [] || const_generic_args <> [] then - K.with_type t (K.ETApp (hd, const_generic_args, type_args)) - else - hd - in - let rhs = K.with_type output (K.EApp (hd, args)) in + let rhs = K.with_type output_t (K.EApp (hd, args)) in (* This does something similar to maybe_addrof *) let rhs = - match func, original_type_args with + (* TODO: determine whether extra_types is necessary *) + let extra_types = match fn_ptr.trait_and_method_generic_args with Some { types; _ } -> types | None -> [] in + match fn_ptr.func, fn_ptr.generics.types @ extra_types with | FunId (FAssumed (SliceIndexShared | SliceIndexMut)), [ TAdt (TAssumed (TArray | TSlice), _) ] -> (* Will decay. See comment above maybe_addrof *) rhs @@ -966,7 +987,6 @@ let rec expression_of_raw_statement (env: env) (ret_var: C.var_id) (s: C.raw_sta K.(with_type TUnit (EWhile (Krml.Helpers.etrue, expression_of_raw_statement env ret_var s.content))) - (** Top-level declarations: orchestration *) let of_declaration_group (dg: 'id C.g_declaration_group) (f: 'id -> 'a): 'a list = @@ -1047,6 +1067,7 @@ let decls_of_declarations (env: env) (d: C.declaration_group): K.decl list = (Printexc.get_backtrace ()); None end + | Some { arg_count; locals; body; _ } -> if is_global_decl_body then None @@ -1061,15 +1082,19 @@ let decls_of_declarations (env: env) (d: C.declaration_group): K.decl list = let args = List.tl args in let return_type = typ_of_ty env return_var.var_ty in + + (* Note: Rust allows zero-argument functions but the krml internal + representation wants a unit there. *) + let t_unit = C.(TAdt (TTuple, { types = []; const_generics = []; regions = []; trait_refs = [] })) in + let v_unit = { C.index = Charon.Expressions.VarId.of_int max_int; name = None; var_ty = t_unit } in + let args = if args = [] then [ v_unit ] else args in + let arg_binders = List.map (fun (arg: C.const_generic_var) -> Krml.Helpers.fresh_binder ~mut:true arg.name (typ_of_literal_ty env arg.ty) ) signature.C.generics.const_generics @ List.map (fun (arg: C.var) -> let name = Option.value ~default:"_" arg.name in Krml.Helpers.fresh_binder ~mut:true name (typ_of_ty env arg.var_ty) ) args in - (* Note: Rust allows zero-argument functions but the krml internal - representation wants a unit there. *) - let arg_binders = if arg_binders = [] then [ Krml.Helpers.fresh_binder "dummy" K.TUnit ] else arg_binders in let env = push_binders env args in let body = with_locals env return_type (return_var :: locals) (fun env -> @@ -1077,6 +1102,7 @@ let decls_of_declarations (env: env) (d: C.declaration_group): K.decl list = in Some (K.DFunction (None, [], List.length signature.C.generics.const_generics, List.length signature.C.generics.types, return_type, name, arg_binders, body)) + ) | GlobalGroup id -> let global = env.get_nth_global id in diff --git a/lib/Cleanup2.ml b/lib/Cleanup2.ml index 85c0879..324711c 100644 --- a/lib/Cleanup2.ml +++ b/lib/Cleanup2.ml @@ -137,6 +137,42 @@ let remove_array_repeats = object(self) super#visit_ELet env b e1 e2 end +let remove_array_from_fn = object + inherit [_] map as super + + val mutable defs = Hashtbl.create 41 + + method! visit_DFunction _ cc flags n_cgs n t name bs e = + assert (n_cgs = 0 && n = 0); + match bs with + | [{ typ = TInt SizeT; _ }] -> + Hashtbl.add defs name e + | _ -> + () + ; ; + super#visit_DFunction () cc flags n_cgs n t name bs e + + method! visit_EApp env e es = + match e.node with + | ETApp ({ node = EQualified (["core"; "array"], "from_fn"); _ }, + [ len ], + [ t_elements; TArrow (t_index, t_elements') ]) -> + assert (t_elements' = t_elements); + assert (t_index = TInt SizeT); + assert (List.length es = 2); + let closure = Krml.Helpers.assert_elid (List.nth es 0).node in + assert (Hashtbl.mem defs closure); + let dst = List.nth es 1 in + EFor (Krml.Helpers.fresh_binder ~mut:true "i" H.usize, H.zero_usize (* i = 0 *), + H.mk_lt_usize (Krml.DeBruijn.lift 1 len) (* i < len *), + H.mk_incr_usize (* i++ *), + let i = with_type H.usize (EBound 0) in + Krml.Helpers.with_unit (EBufWrite (Krml.DeBruijn.lift 1 dst, i, Hashtbl.find defs closure))) + | _ -> + super#visit_EApp env e es +end + + let rewrite_slice_to_array = object(_self) inherit [_] map as super diff --git a/test/array/src/main.rs b/test/array/src/main.rs index a75d966..074c7b9 100644 --- a/test/array/src/main.rs +++ b/test/array/src/main.rs @@ -23,10 +23,25 @@ fn mk_foo2() -> Foo { mk_foo() } +// fn mk_incr2() -> [ u32; K ] { +// let j = 1; +// core::array::from_fn(|i| i as u32 + j) +// } + +fn mk_incr() -> [ u32; K ] { + core::array::from_fn(|i| i as u32) +} + fn main() { let Foo { x, y } = mk_foo2(); let expected = 0u32; mut_array(x); mut_foo(Foo { x, y }); assert_eq!(x[0], expected); + let a: [ u32; 10 ] = mk_incr(); + let expected = 9; + assert_eq!(a[9], expected); + // let a: [ u32; 10 ] = mk_incr2(); + // let expected = 10; + // assert_eq!(a[9], expected); }