Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add DummyFeaturization #87

Closed
wants to merge 13 commits into from
Closed
5 changes: 3 additions & 2 deletions src/ChemistryFeaturization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module ChemistryFeaturization

using SimpleWeightedGraphs

# some functions that will get extended in various places
encodable_elements(a::Any) = throw(MethodError(encodable_elements, a))
decode(a::Any, encoded_features) = throw(MethodError(decode, a))

Expand All @@ -26,8 +27,8 @@ export AtomGraph

include("featurizations/featurizations.jl")
export Featurization
using .Featurization: GraphNodeFeaturization, featurize!
export GraphNodeFeaturization, featurize!
using .Featurization: GraphNodeFeaturization, DummyFeaturization, featurize!#, validate_features
export GraphNodeFeaturization, DummyFeaturization, featurize!#, validate_features

export encodable_elements, decode

Expand Down
8 changes: 8 additions & 0 deletions src/atoms/atomgraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ function Base.show(io::IO, ::MIME"text/plain", ag::AtomGraph)

end

function decode(ag::AtomGraph)
@assert !(any(isnothing.([ag.featurization, ag.encoded_features])))
decoded = decode(ag.featurization, ag.encoded_features)
for (k, v) in decoded
v["Symbol"] = ag.elements[k]
end
return decoded
end

"""
normalized_laplacian(graph)
Expand Down
7 changes: 7 additions & 0 deletions src/atoms/atoms.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
module Atoms

import ..ChemistryFeaturization.decode
import ..ChemistryFeaturization.AbstractType.AbstractAtoms

decode(a::AbstractAtoms) = decode(a.featurization, a.encoded_features)

export decode

include("atomgraph.jl")
export AtomGraph, visualize

Expand Down
2 changes: 0 additions & 2 deletions src/features/elementfeature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ function (ed::OneHotOneCold)(
end
end

output_shape(efd::ElementFeatureDescriptor) = output_shape(efd, efd.encoder_decoder)

function output_shape(efd::ElementFeatureDescriptor, ed::OneHotOneCold)
return efd.categorical ? length(unique(efd.lookup_table[:, Symbol(efd.name)])) :
ed.nbins
Expand Down
6 changes: 5 additions & 1 deletion src/features/features.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module FeatureDescriptor

using ..ChemistryFeaturization.AbstractType: AbstractFeatureDescriptor
using ..ChemistryFeaturization.AbstractType: AbstractFeatureDescriptor, AbstractCodec

import ..ChemistryFeaturization.encodable_elements
encodable_elements(fd::AbstractFeatureDescriptor) =
Expand All @@ -11,6 +11,10 @@ import ..ChemistryFeaturization.decode
decode(fd::AbstractFeatureDescriptor, encoded_feature) = throw(MethodError(decode, fd))
export decode

output_shape(fd::AbstractFeatureDescriptor, ed::AbstractCodec) = throw(MethodError(output_shape, (fd, ed)))
rkurchin marked this conversation as resolved.
Show resolved Hide resolved
output_shape(fd::AbstractFeatureDescriptor) = output_shape(fd, fd.encoder_decoder)
export output_shape

include("abstractfeatures.jl")

include("bondfeatures.jl")
Expand Down
3 changes: 2 additions & 1 deletion src/features/speciesfeature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@ function (f::SpeciesFeatureDescriptor)(a::AbstractAtoms)
f.encode_f(a)
end

# TODO: some Weave stuff needed here?
# TODO: some Weave stuff needed here
# also output_shape dispatch
25 changes: 25 additions & 0 deletions src/featurizations/dummyfeaturization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
export DummyFeaturization, featurize!
export encodable_elements, decode

using ..ChemistryFeaturization.AbstractType: AbstractFeaturization

"""
DummyFeaturization

A dummy featurization that cannot actually encode or decode anything. For use primarily for populating output of model layers (e.g. AtomicGraphnets' AGNConv) that return an Atoms object, but with the encoded features transformed such that the original featurization is no longer applicable/valid.
"""
struct DummyFeaturization <: AbstractFeaturization end

encodable_elements(df::DummyFeaturization) = []

decode(fzn::DummyFeaturization, encoded_feature) = throw(
ArgumentError(
"This featurization is just a dummy, likely created by a model layer such as `AGNConv`, and cannot actually decode encoded features.",
),
)

featurize!(a::AbstractAtoms, df::DummyFeaturization) = throw(
ArgumentError(
"This featurization is just a dummy, likely created by a model layer such as `AGNConv`, and cannot actually encode features.",
),
)
14 changes: 12 additions & 2 deletions src/featurizations/featurizations.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
module Featurization

using ..ChemistryFeaturization.AbstractType: AbstractFeaturization
using ..ChemistryFeaturization.AbstractType: AbstractFeaturization, AbstractAtoms, AbstractCodec
rkurchin marked this conversation as resolved.
Show resolved Hide resolved

import ..ChemistryFeaturization.encodable_elements
encodable_elements(fzn::AbstractFeaturization) = throw(MethodError(encodable_elements, fzn))
export encodable_elements

featurize!(a::AbstractAtoms, fzn::AbstractFeaturization) =
throw(MethodError(featurize!, (a, fzn)))

import ..ChemistryFeaturization.decode
decode(fzn::AbstractFeaturization, encoded_feature) = throw(MethodError(decode, fzn))
decode(fzn::AbstractFeaturization, encoded_feature) =
throw(MethodError(decode, (fzn, encoded_feature)))
include("graphnodefeaturization.jl")
export GraphNodeFeaturization, featurize!, decode

validate_features(fzn::AbstractFeaturization, ed::AbstractCodec, encoded) = throw(MethodError(validate_features, (fzn, ed, encoded)))
rkurchin marked this conversation as resolved.
Show resolved Hide resolved
export validate_features

include("weavefeaturization.jl")
export WeaveFeaturization

include("dummyfeaturization.jl")
export DummyFeaturization

end
9 changes: 0 additions & 9 deletions src/featurizations/graphnodefeaturization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,3 @@ function decode(fzn::GraphNodeFeaturization, encoded::Matrix{<:Real})
end
return decoded
end

function decode(ag::AtomGraph)
@assert !(any(isnothing.([ag.featurization, ag.encoded_features])))
decoded = decode(ag.featurization, ag.encoded_features)
for (k, v) in decoded
v["Symbol"] = ag.elements[k]
end
return decoded
end
12 changes: 12 additions & 0 deletions test/featurizations/DummyFeaturization_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using Test
using ChemistryFeaturization.Featurization
using ChemistryFeaturization.Atoms: AtomGraph

@testset "DummyFeaturization" begin
df = DummyFeaturization()
@test isempty(encodable_elements(df))
dummy_encoded = [0 1 0]
@test_throws ArgumentError decode(df, dummy_encoded)
F2 = AtomGraph(Float32.([0 1; 1 0]), ["F", "F"])
@test_throws ArgumentError featurize!(F2, df)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

31 changes: 31 additions & 0 deletions test/module_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using Test
using ChemistryFeaturization.AbstractType: AbstractFeaturization, AbstractFeatureDescriptor

@testset "modules and abstract methods" begin
F2 = AtomGraph(Float32.([0 1; 1 0]), ["F", "F"])

@testset "top-level module" begin
# testing on `nothing` as example of ::Any
@test_throws MethodError encodable_elements(nothing)
@test_throws MethodError decode(nothing, nothing)
end

@testset "featurizations module" begin
struct FakeFeaturization <: AbstractFeaturization end
ff = FakeFeaturization()
@test_throws MethodError encodable_elements(ff)
@test_throws MethodError featurize!(F2, ff)
@test_throws MethodError decode(ff, nothing)
end

# @testset "atoms module" begin
# TBD cleanest way to test generic decode(::AbstractAtoms) - either another "fake" class, or maybe the `invoke` function
# end

@testset "features module" begin
struct FakeFD <: AbstractFeatureDescriptor end
fd = FakeFD()
@test_throws MethodError encodable_elements(fd)
@test_throws MethodError decode(fd, nothing)
end
end
rkurchin marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ tests = [
"utils/ElementFeatureUtils_tests",
"utils/GraphBuilding_tests",
"atoms/AtomGraph_tests",
# TODO: add SpeciesFeature tests
"features/ElementFeature_tests",
"featurizations/GraphNodeFeaturization_tests",
"featurizations/DummyFeaturization_tests",
"module_tests"
rkurchin marked this conversation as resolved.
Show resolved Hide resolved
# TODO: add Weave stuff
# TODO: add SpeciesFeature tests
]

@testset "ChemistryFeaturization" begin
Expand Down