Skip to content

Commit

Permalink
Add initial support for arrays of records.
Browse files Browse the repository at this point in the history
Some of the error checking is a bit crude.
  • Loading branch information
athas committed Jul 25, 2024
1 parent 9c27f75 commit 6fde6ab
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 55 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ All user-visible changes are noted here.

* The array signatures now support an `index` function.

* Arrays of records are now supported.

## [1.4.0]

* Arrays modules now have `new_raw` and `values_raw` functions.
Expand Down
31 changes: 31 additions & 0 deletions smlfut.1
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,37 @@ with
.Li values ,
you must still also free the original Futhark object at some point.
.
.Sh Arrays of records
.
Each array of records is represented by a structure implementing the
signature below.
.Bd -literal -offset indent
signature FUTHARK_RECORD_ARRAY =
sig
include FUTHARK_OPAQUE
type shape
type index
type elem
type fields
val shape: t -> shape
val index: t -> index -> elem
val zip : fields -> t
end
.Ed
.Pp
The
.Li shape ,
.Li index ,
and
.Li elem
types are similar to the corresponding types in the
.Li FUTHARK_ARRAY
signature. The
.Li fields
type will be refined to a tuple of Futhark array types, which can be zipped together (using
.Li zip )
to produce an array of records.
.
.Sh Entry points
.
Each Futhark entry point becomes a function with two parameters: the
Expand Down
190 changes: 137 additions & 53 deletions src/smlfut.sml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ val sig_FUTHARK_SUM =
, "end"
]

val sig_FUTHARK_RECORD_ARRAY =
[ "signature FUTHARK_RECORD_ARRAY ="
, "sig"
, " include FUTHARK_OPAQUE"
, " type shape"
, " type index"
, " type elem"
, " type fields"
, " val shape: t -> shape"
, " val index: t -> index -> elem"
, " val zip : fields -> t"
, "end"
]
(* Actual logic. *)

fun gpuBackend "opencl" = true
Expand Down Expand Up @@ -182,7 +195,10 @@ fun blankRef manifest t =
| _ => raise Fail ("blankRef: " ^ t)


fun checkUseAfterFree free = "if !" ^ free ^ " then raise Free else ()"
fun checkUseAfterFree free =
"if !" ^ parens free ^ " then raise Free else ()"

fun intToInt64 x = apply "Int64.fromInt" [x]

fun boundsCheck 1 index shape =
"if " ^ index ^ " >= 0 andalso " ^ index ^ " < " ^ shape
Expand All @@ -196,7 +212,31 @@ fun boundsCheck 1 index shape =
"if " ^ check ^ " then () else raise Subscript"
end

fun intToInt64 x = apply "Int64.fromInt" [x]
(* Assuming 'i' is a valid index into an array with 'rank' dimensions,
produce a list of the indexes. Also works for shape arguments. *)
fun indexArgs 1 i =
[(intToInt64 i, "Int64.int")]
| indexArgs rank i =
List.tabulate (rank, fn j =>
(intToInt64 (project (j + 1) i), "Int64.int"))

fun mkShape fficall pointer shape_cfun rank v =
letbind
[ ( "shape_c"
, fficall shape_cfun [("ctx", "futhark_context"), (v, pointer)] pointer
)
, ("r", Int.toString rank)
, ("shape_ml", apply "Int64Array.array" [Int.toString rank, "0"])
, ( "_"
, fficall "smlfut_memcpy"
[ ("shape_ml", "Int64Array.array")
, ("shape_c", pointer)
, (intToInt64 "r*8", "Int64.int")
] pointer
)
]
[tuple_e (List.tabulate (rank, fn i =>
apply "Int64.toInt" [apply "Int64Array.sub" ["shape_ml", Int.toString i]]))]

signature TARGET =
sig
Expand Down Expand Up @@ -387,6 +427,9 @@ struct
fs
end

fun recordArrayZipType manifest (arr: record_array) =
tuple_t (map (typeToSMLInside manifest o #type_ o #2) (#fields arr))

fun generateTypeSpec manifest (name, FUTHARK_ARRAY info) =
origNameComment name @ futharkArrayStructSpec info
| generateTypeSpec manifest (name, FUTHARK_OPAQUE info) =
Expand Down Expand Up @@ -421,8 +464,12 @@ struct
, "where type ctx = ctx"
]
| SOME (OPAQUE_RECORD_ARRAY arr) =>
[ structspec (escapeName name) "FUTHARK_OPAQUE"
[ structspec (escapeName name) "FUTHARK_RECORD_ARRAY"
, "where type ctx = ctx"
, " and type shape = " ^ shapeTypeOfRank (#rank arr)
, " and type index = " ^ shapeTypeOfRank (#rank arr)
, " and type elem = " ^ typeToSMLInside manifest (#elemtype arr)
, " and type fields = " ^ recordArrayZipType manifest arr
])


Expand All @@ -431,7 +478,7 @@ struct
[ "ctx"
, fficall "smlfut_to_pointer"
[(apply "Word64Array.sub" [out, "0"], "Word64.word")] pointer
, "free"
, "ctx_free"
, "ref false"
]

Expand All @@ -440,16 +487,72 @@ struct
futharkArrayStructDef manifest info
| generateTypeDef manifest (name, FUTHARK_OPAQUE info) =
let
fun wrapCheck ls =
letbind
[ ("()", checkUseAfterFree "ctx_free")
, ("()", checkUseAfterFree "obj_free")
] ls
val freechecks =
[ ("()", checkUseAfterFree "ctx_free")
, ("()", checkUseAfterFree "obj_free")
]
fun wrapCheck ls = letbind freechecks ls
val more =
case #extra info of
NONE => []
| SOME (OPAQUE_ARRAY arr) => []
| SOME (OPAQUE_RECORD_ARRAY arr) => []
| SOME (OPAQUE_RECORD_ARRAY arr) =>
[ typedef "elem" [] (typeToSMLInside manifest (#elemtype arr))
, typedef "index" [] (shapeTypeOfRank (#rank arr))
, typedef "shape" [] (shapeTypeOfRank (#rank arr))
, typedef "fields" [] (recordArrayZipType manifest arr)
]
@
fundef "shape" ["(ctx,data,ctx_free,obj_free)"] (wrapCheck
(mkShape fficall pointer (#shape arr) (#rank arr) "data"))
@
fundef "index"
["(ctx,data,ctx_free,obj_free)", parens ("i: index")]
(letbind
(freechecks
@
[ ( "shape"
, apply "shape" ["(ctx,data,ctx_free,obj_free)"]
)
, ("()", boundsCheck (#rank arr) "i" "shape")
, ("out", mkOut manifest (#elemtype arr))
, ( "()"
, apply "error_check"
[ fficall (#index arr)
([ ("ctx", "futhark_context")
, ("out", outType (#elemtype arr))
, ("data", pointer)
] @ indexArgs (#rank arr) "i") "int"
, "ctx"
]
)
]) [valFromPtrArr "out"])
@
fundef "zip" [parens ("fs: fields")]
(letbind
([("()", checkUseAfterFree (project 3 (project 1 "fs")))]
@
List.tabulate (length (#fields arr), fn i =>
( "()"
, checkUseAfterFree (project 4 (project (i + 1) "fs"))
))
@
[ ("(ctx, _, ctx_free, _)", project 1 "fs")
, ("out", apply "Word64Array.array" ["1", "0w0"])
, ( "()"
, apply "error_check"
[ fficall (#zip arr)
([ ("ctx", "futhark_context")
, ("out", "Word64Array.array")
]
@
List.tabulate (length (#fields arr), fn i =>
(project 2 (project (i + 1) "fs"), pointer)))
"int"
, "ctx"
]
)
]) [valFromPtrArr "out"])
| SOME (OPAQUE_RECORD record) =>
let
val fields = recordFieldMap (#fields record)
Expand Down Expand Up @@ -499,9 +602,11 @@ struct
[record_e (map getField fields)])
@
fundef "new"
["{cfg,ctx,free}", record_e (map fieldParam fields)]
[ "{cfg,ctx,free=ctx_free}"
, record_e (map fieldParam fields)
]
(letbind
([("()", checkUseAfterFree "free")]
([("()", checkUseAfterFree "ctx_free")]
@ map fieldCheckFree fields
@ [("out", apply "Word64Array.array" ["1", "0w0"])])
[ apply "error_check"
Expand Down Expand Up @@ -585,9 +690,9 @@ struct
in
sumDef manifest name sum @ [typedef "sum" [] name]
@
fundef "new" ["{cfg,ctx,free}", "sum"]
(letbind [("()", checkUseAfterFree "free")] (case_e "sum"
(map mkCase (#variants sum))))
fundef "new" ["{cfg,ctx,free=ctx_free}", "sum"]
(letbind [("()", checkUseAfterFree "ctx_free")]
(case_e "sum" (map mkCase (#variants sum))))
@
fundef "values" ["(ctx,data,ctx_free,obj_free)"]
(wrapCheck
Expand Down Expand Up @@ -746,6 +851,18 @@ struct
, FUTHARK_OPAQUE {ctype, ops, extra = SOME (OPAQUE_SUM sum)}
) =
List.all (List.all isKnown o #payload o #2) (#variants sum)
| usesKnown
( name
, FUTHARK_OPAQUE
{ctype, ops, extra = SOME (OPAQUE_ARRAY arr)}
) =
isKnown (#elemtype arr)
| usesKnown
( name
, FUTHARK_OPAQUE
{ctype, ops, extra = SOME (OPAQUE_RECORD_ARRAY arr)}
) =
isKnown (#elemtype arr)
| usesKnown _ = true
val (ok, next) = List.partition usesKnown rs
in
Expand Down Expand Up @@ -1030,7 +1147,7 @@ struct
( unlines
(header @ sig_FUTHARK_ARRAY @ [""] @ sig_FUTHARK_OPAQUE @ [""]
@ sig_FUTHARK_RECORD @ [""] @ sig_FUTHARK_SUM @ [""]
@ sigdef sig_name specs)
@ sig_FUTHARK_RECORD_ARRAY @ [""] @ sigdef sig_name specs)
, unlines (header @ structdef struct_name (SOME sig_name) defs)
, unlines
([ "#include <stdint.h>"
Expand Down Expand Up @@ -1122,28 +1239,6 @@ local
[(mkProd (List.tabulate (#rank info, fn i =>
apply "Int64.toInt"
[apply "Int64Array.sub" ["shape_ml", Int.toString i]])))]


fun mkShape fficall pointer (info: array_info) v =
letbind
[ ( "shape_c"
, fficall (#shape (#ops info))
[("ctx", "futhark_context"), (v, pointer)] pointer
)
, ("r", Int.toString (#rank info))
, ("shape_ml", apply "Int64Array.array" [Int.toString (#rank info), "0"])
, ( "_"
, fficall "smlfut_memcpy"
[ ("shape_ml", "Int64Array.array")
, ("shape_c", pointer)
, (intToInt64 "r*8", "Int64.int")
] pointer
)
]
[tuple_e (List.tabulate (#rank info, fn i =>
apply "Int64.toInt"
[apply "Int64Array.sub" ["shape_ml", Int.toString i]]))]

in
fun futharkArrayStructDef fficall pointer null defs (data_t: string)
(manifest as MANIFEST {backend, ...}) (info as {ctype, rank, elemtype, ops}) =
Expand All @@ -1160,10 +1255,6 @@ in
[ ("()", checkUseAfterFree "ctx_free")
, ("()", checkUseAfterFree "arr_free")
] ls


val shape_args =
map (fn x => (apply "Int64.fromInt" [x], "Int64.int")) shape
in
structdef (futharkArrayStruct info) NONE
([ typedef "array" []
Expand All @@ -1188,7 +1279,7 @@ in
([ ("ctx", "futhark_context")
, ("arr", data_t)
, ("Int64.fromInt i", "Int64.int")
] @ shape_args) pointer
] @ indexArgs rank "shape") pointer
)
]
[ "if isnull arr"
Expand All @@ -1207,7 +1298,7 @@ in
])
@
fundef "shape" ["(ctx,data,ctx_free,arr_free)"] (wrapCheck
(mkShape fficall pointer info "data"))
(mkShape fficall pointer (#shape (#ops info)) (#rank info) "data"))
@
fundef "values_into" ["(ctx,data,ctx_free,arr_free)", "slice"]
(wrapCheck
Expand Down Expand Up @@ -1254,14 +1345,7 @@ in
([ ("ctx", "futhark_context")
, ("out", data_t)
, ("data", pointer)
]
@
(if rank = 1 then
[(intToInt64 "i", "Int64.int")]
else
List.tabulate (rank, fn i =>
(intToInt64 (project (i + 1) "i"), "Int64.int"))))
"int"
] @ indexArgs rank "i") "int"
, "ctx"
]
)
Expand Down Expand Up @@ -1296,7 +1380,7 @@ in
, ( "arr"
, fficall (#new_raw (#ops info))
([("ctx", "futhark_context"), ("data", pointer)]
@ shape_args) pointer
@ indexArgs rank "shape") pointer
)
]
[ "if isnull arr"
Expand Down
Loading

0 comments on commit 6fde6ab

Please sign in to comment.