Skip to content

Commit

Permalink
using NamedTuple instead of a Dict in flat_params
Browse files Browse the repository at this point in the history
  • Loading branch information
pebeto committed Aug 14, 2023
1 parent c3f0f99 commit 121e9f3
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/MLJModelInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ end
export UnivariateFinite

# parameter_inspection:
export params, flat_params
export params

# model constructor + metadata
export @mlj_model, metadata_pkg, metadata_model
Expand Down
2 changes: 1 addition & 1 deletion src/parameter_inspection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ not a hard requirement.
"""
flat_params(m; prefix="") = flat_params(m, Val(isamodel(m)); prefix=prefix)
flat_params(m, ::Val{false}; prefix="") = Dict(prefix=>m)
flat_params(m, ::Val{false}; prefix="") = NamedTuple{(Symbol(prefix),), Tuple{Any}}((m,))
function flat_params(m, ::Val{true}; prefix="")
fields = propertynames(m)
prefix = prefix == "" ? "" : prefix * "__"
Expand Down
45 changes: 31 additions & 14 deletions test/parameter_inspection.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
using Test
using MLJModelInterface

struct Opaque <: Model
struct Opaque
a::Int
end

struct Transparent <: Model
struct Transparent
A::Int
B::Opaque
end

MLJModelInterface.istransparent(::Transparent) = true

struct Dummy <: Model
struct Dummy <: MLJType
t::Transparent
o::Opaque
n::Integer
Expand All @@ -27,20 +27,37 @@ end
@test params(m) == (
t = (
A = 6,
B = (
a = 5,
)
),
o = (
a = 7,
B = Opaque(5)
),
o = Opaque(7),
n = 42
)
@test flat_params(m) == Dict(
"o__a" => 7,
"t__A" => 6,
"t__B__a" => 5,
"n" => 42
end

struct ChildModel <: Model
x::Int
y::String
end

struct ParentModel <: Model
x::Int
y::String
first_child::ChildModel
second_child::ChildModel
end

@testset "flat_params method" begin

m = ParentModel(1, "parent", ChildModel(2, "child1"),
ChildModel(3, "child2"))

@test MLJModelInterface.flat_params(m) == (
x = 1,
y = "parent",
first_child__x = 2,
first_child__y = "child1",
second_child__x = 3,
second_child__y = "child2"
)
end
true

0 comments on commit 121e9f3

Please sign in to comment.