Skip to content

Commit

Permalink
Merge pull request #110 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 1.2 release
  • Loading branch information
ablaom authored Aug 17, 2021
2 parents 9a7b6ba + fc7191c commit 6288bac
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 3 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
name = "MLJModelInterface"
uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
authors = ["Thibaut Lienart and Anthony Blaom"]
version = "1.1.3"
version = "1.2.0"

[deps]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
StatisticalTraits = "64bff920-2084-43da-a3e6-9bb72801c0c9"

[compat]
ScientificTypesBase = "1, 2"
StatisticalTraits = "2"
ScientificTypesBase = "2.1"
StatisticalTraits = "2.1"
julia = "1"

[extras]
Expand Down
5 changes: 5 additions & 0 deletions src/MLJModelInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ const MODEL_TRAITS = [
:input_scitype,
:output_scitype,
:target_scitype,
:fit_data_scitype,
:predict_scitype,
:transform_scitype,
:inverse_transform_scitype,
:is_pure_julia,
:package_name,
:package_license,
Expand All @@ -18,6 +22,7 @@ const MODEL_TRAITS = [
:name,
:is_supervised,
:prediction_type,
:abstract_type,
:implemented_methods,
:hyperparameters,
:hyperparameter_types,
Expand Down
71 changes: 71 additions & 0 deletions src/model_traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,74 @@ StatisticalTraits.prediction_type(::Type{<:Interval}) = :interval
implemented_methods(M::Type) = implemented_methods(get_interface_mode(), M)
implemented_methods(model) = implemented_methods(typeof(model))
implemented_methods(::LightInterface, M) = errlight("implemented_methods")

for M in ABSTRACT_MODEL_SUBTYPES
@eval(StatisticalTraits.abstract_type(::Type{<:$M}) = $M)
end

StatisticalTraits.fit_data_scitype(M::Type{<:Unsupervised}) =
Tuple{input_scitype(M)}
StatisticalTraits.fit_data_scitype(::Type{<:Static}) = Tuple{}
function StatisticalTraits.fit_data_scitype(M::Type{<:Supervised})
I = input_scitype(M)
T = target_scitype(M)
ret = Tuple{I,T}
if supports_weights(M)
W = AbstractVector{Union{Continuous,Count}} # weight scitype
return Union{ret,Tuple{I,T,W}}
elseif supports_class_weights(M)
W = AbstractDict{Finite,Union{Continuous,Count}}
return Union{ret,Tuple{I,T,W}}
end
return ret
end

StatisticalTraits.transform_scitype(M::Type{<:Unsupervised}) =
output_scitype(M)

StatisticalTraits.inverse_transform_scitype(M::Type{<:Unsupervised}) =
input_scitype(M)

StatisticalTraits.predict_scitype(M::Type{<:Deterministic}) = target_scitype(M)


## FALLBACKS FOR `predict_scitype` FOR `Probabilistic` MODELS

# This seems less than ideal but should reduce the number of `Unknown`
# in `prediction_type` for models which, historically, have not
# implemented the trait.

StatisticalTraits.predict_scitype(M::Type{<:Probabilistic}) =
_density(target_scitype(M))

_density(::Any) = Unknown
for T in [:Continuous, :Count, :Textual]
eval(quote
_density(::Type{AbstractArray{$T,D}}) where D =
AbstractArray{Density{$T},D}
end)
end

for T in [:Finite,
:Multiclass,
:OrderedFactor,
:Infinite,
:Continuous,
:Count,
:Textual]
eval(quote
_density(::Type{AbstractArray{<:$T,D}}) where D =
AbstractArray{Density{<:$T},D}
_density(::Type{Table($T)}) = Table(Density{$T})
end)
end

for T in [:Finite, :Multiclass, :OrderedFactor]
eval(quote
_density(::Type{AbstractArray{<:$T{N},D}}) where {N,D} =
AbstractArray{Density{<:$T{N}},D}
_density(::Type{AbstractArray{$T{N},D}}) where {N,D} =
AbstractArray{Density{$T{N}},D}
_density(::Type{Table($T{N})}) where N = Table(Density{$T{N}})
end)
end
64 changes: 64 additions & 0 deletions test/model_traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,67 @@ import .Fruit
@test docstring(Float64) == "Float64"
@test docstring(Fruit.Banana) == "Banana"
end

@testset "`_density` - helper for predict_scitype fallback" begin
for T in [Continuous, Count, Textual]
@test M._density(AbstractArray{T,3}) ==
AbstractArray{Density{T},3}
end

for T in [Finite,
Multiclass,
OrderedFactor,
Infinite,
Continuous,
Count,
Textual]
@test M._density(AbstractVector{<:T}) ==
AbstractVector{Density{<:T}}
@test M._density(Table(T)) == Table(Density{T})
end

for T in [Finite, Multiclass, OrderedFactor]
@test M._density(AbstractArray{<:T{2},3}) ==
AbstractArray{Density{<:T{2}},3}
@test M._density(AbstractArray{T{2},3}) ==
AbstractArray{Density{T{2}},3}
@test M._density(Table(T{2})) == Table(Density{T{2}})
end
end

@mlj_model mutable struct P2 <: Probabilistic end
M.target_scitype(::Type{<:P2}) = AbstractVector{<:Multiclass}
M.input_scitype(::Type{<:P2}) = Table(Continuous)

@mlj_model mutable struct U2 <: Unsupervised end
M.output_scitype(::Type{<:U2}) = AbstractVector{<:Multiclass}
M.input_scitype(::Type{<:U2}) = Table(Continuous)

@mlj_model mutable struct S2 <: Static end
M.output_scitype(::Type{<:S2}) = AbstractVector{<:Multiclass}
M.input_scitype(::Type{<:S2}) = Table(Continuous)

@testset "operation scitypes" begin
@test predict_scitype(P2()) == AbstractVector{Density{<:Multiclass}}
@test transform_scitype(P2()) == Unknown
@test transform_scitype(U2()) == AbstractVector{<:Multiclass}
@test inverse_transform_scitype(U2()) == Table(Continuous)
@test predict_scitype(U2()) == Unknown
@test transform_scitype(S2()) == AbstractVector{<:Multiclass}
@test inverse_transform_scitype(S2()) == Table(Continuous)
end

@testset "abstract_type, fit_data_scitype" begin
@test abstract_type(P2()) == Probabilistic
@test abstract_type(S1()) == Supervised
@test abstract_type(U1()) == Unsupervised
@test abstract_type(D1()) == Deterministic
@test abstract_type(P1()) == Probabilistic

@test fit_data_scitype(P2()) ==
Tuple{Table(Continuous),AbstractVector{<:Multiclass}}
@test fit_data_scitype(U2()) == Tuple{Table(Continuous)}
@test fit_data_scitype(S2()) == Tuple{}
end

true

0 comments on commit 6288bac

Please sign in to comment.