diff --git a/Project.toml b/Project.toml index b7a07370..bcadec5d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ChemistryFeaturization" uuid = "6c925690-434a-421d-aea7-51398c5b007a" authors = ["Rachel Kurchin ", "Sean Sun"] -version = "0.3.1" +version = "0.3.2" [deps] CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" diff --git a/src/ChemistryFeaturization.jl b/src/ChemistryFeaturization.jl index 9b231f3f..622c941b 100644 --- a/src/ChemistryFeaturization.jl +++ b/src/ChemistryFeaturization.jl @@ -2,6 +2,7 @@ module ChemistryFeaturization using SimpleWeightedGraphs +# some functions that will get extended in various places encodable_elements(a::Any) = throw(MethodError(encodable_elements, a)) decode(a::Any, encoded_features) = throw(MethodError(decode, a)) @@ -21,13 +22,13 @@ export ElementFeatureDescriptor include("atoms/atoms.jl") export Atoms -using .Atoms: AtomGraph -export AtomGraph +using .Atoms: AtomGraph, visualize +export AtomGraph, visualize include("featurizations/featurizations.jl") export Featurization -using .Featurization: GraphNodeFeaturization, featurize! -export GraphNodeFeaturization, featurize! +using .Featurization: GraphNodeFeaturization, DummyFeaturization, featurize!#, validate_features +export GraphNodeFeaturization, DummyFeaturization, featurize!#, validate_features export encodable_elements, decode diff --git a/src/atoms/atomgraph.jl b/src/atoms/atomgraph.jl index 560146b1..e23cf7f3 100644 --- a/src/atoms/atomgraph.jl +++ b/src/atoms/atomgraph.jl @@ -103,6 +103,28 @@ AtomGraph( AtomGraph(adj::Array{R}, elements::Vector{String}, id = "") where {R<:Real} = AtomGraph(SimpleWeightedGraph(adj), elements, id) +""" + AtomGraph(input_file_path; id="", output_file_path=nothing, featurization=nothing, overwrite_file=false, use_voronoi=false, cutoff_radius=8.0, max_num_nbr=12, dist_decay_func=inverse_square, normalize_weights=true) + +Construct an AtomGraph object from a structure file. + +# Required Arguments +- `input_file_path::String`: path to file containing structure (must be readable by ASE.io.read) + +# Optional Arguments +- `id::String=""`: ID associated with structure (e.g. identifier from online database) +- `output_file_path=nothing`: If provided, structure will be serialized to file at this location +- `featurization`: If provided, features will be encoded using it +- `overwrite_file::Bool=false`: whether to overwrite an existing file at `output_file_path` +- `use_voronoi::Bool=false`: Whether to build neighbor lists using Voronoi decompositions +- `cutoff_radius::Real=8.0`: If not using Voronoi neighbor lists, longest allowable distance to a neighbor, in Angstroms +- `max_num_nbr::Integer=12`: If not using Voronoi neighbor lists, largest allowable number of neighbors +- `dist_decay_func=inverse_square`: Function by which to assign edge weights according to distance between neighbors +- `normalize_weights::Bool=true`: Whether to normalize weights such that the largest is 1.0 + +# Note +`max_num_nbr` is a "soft" limit – if multiple neighbors are at the same distance, the full neighbor list may be longer. +""" function AtomGraph( input_file_path::String, id::String = "", @@ -190,6 +212,14 @@ function Base.show(io::IO, ::MIME"text/plain", ag::AtomGraph) end +function decode(ag::AtomGraph) + @assert !(any(isnothing.([ag.featurization, ag.encoded_features]))) + decoded = decode(ag.featurization, ag.encoded_features) + for (k, v) in decoded + v["Symbol"] = ag.elements[k] + end + return decoded +end """ normalized_laplacian(graph) @@ -258,7 +288,7 @@ end "Visualize a given graph." function visualize(ag::AtomGraph) # gplot doesn't work on weighted graphs - sg = SimpleGraph(adjacency_matrix(ag)) + sg = SimpleGraph(adjacency_matrix(ag.graph)) plt = gplot( sg, nodefillc = graph_colors(ag.elements), diff --git a/src/atoms/atoms.jl b/src/atoms/atoms.jl index 0214a4b9..15224bd3 100644 --- a/src/atoms/atoms.jl +++ b/src/atoms/atoms.jl @@ -1,5 +1,12 @@ module Atoms +import ..ChemistryFeaturization.decode +import ..ChemistryFeaturization.AbstractType.AbstractAtoms + +decode(a::AbstractAtoms) = decode(a.featurization, a.encoded_features) + +export decode + include("atomgraph.jl") export AtomGraph, visualize diff --git a/src/features/elementfeature.jl b/src/features/elementfeature.jl index 17e4bc91..5818e4ac 100644 --- a/src/features/elementfeature.jl +++ b/src/features/elementfeature.jl @@ -91,8 +91,6 @@ function (ed::OneHotOneCold)( end end -output_shape(efd::ElementFeatureDescriptor) = output_shape(efd, efd.encoder_decoder) - function output_shape(efd::ElementFeatureDescriptor, ed::OneHotOneCold) return efd.categorical ? length(unique(efd.lookup_table[:, Symbol(efd.name)])) : ed.nbins diff --git a/src/features/features.jl b/src/features/features.jl index a9eb42ce..665b0cfa 100644 --- a/src/features/features.jl +++ b/src/features/features.jl @@ -1,6 +1,6 @@ module FeatureDescriptor -using ..ChemistryFeaturization.AbstractType: AbstractFeatureDescriptor +using ..ChemistryFeaturization.AbstractType: AbstractFeatureDescriptor, AbstractCodec import ..ChemistryFeaturization.encodable_elements encodable_elements(fd::AbstractFeatureDescriptor) = @@ -11,6 +11,11 @@ import ..ChemistryFeaturization.decode decode(fd::AbstractFeatureDescriptor, encoded_feature) = throw(MethodError(decode, fd)) export decode +output_shape(fd::AbstractFeatureDescriptor, ed::AbstractCodec) = + throw(MethodError(output_shape, (fd, ed))) +output_shape(fd::AbstractFeatureDescriptor) = output_shape(fd, fd.encoder_decoder) +export output_shape + include("abstractfeatures.jl") include("bondfeatures.jl") diff --git a/src/features/speciesfeature.jl b/src/features/speciesfeature.jl index 184b049f..dcdbe466 100644 --- a/src/features/speciesfeature.jl +++ b/src/features/speciesfeature.jl @@ -41,4 +41,5 @@ function (f::SpeciesFeatureDescriptor)(a::AbstractAtoms) f.encode_f(a) end -# TODO: some Weave stuff needed here? +# TODO: some Weave stuff needed here +# also output_shape dispatch diff --git a/src/featurizations/dummyfeaturization.jl b/src/featurizations/dummyfeaturization.jl new file mode 100644 index 00000000..90be3bb4 --- /dev/null +++ b/src/featurizations/dummyfeaturization.jl @@ -0,0 +1,25 @@ +export DummyFeaturization, featurize! +export encodable_elements, decode + +using ..ChemistryFeaturization.AbstractType: AbstractFeaturization + +""" + DummyFeaturization + +A dummy featurization that cannot actually encode or decode anything. For use primarily for populating output of model layers (e.g. AtomicGraphnets' AGNConv) that return an Atoms object, but with the encoded features transformed such that the original featurization is no longer applicable/valid. +""" +struct DummyFeaturization <: AbstractFeaturization end + +encodable_elements(df::DummyFeaturization) = [] + +decode(fzn::DummyFeaturization, encoded_feature) = throw( + ArgumentError( + "This featurization is just a dummy, likely created by a model layer such as `AGNConv`, and cannot actually decode encoded features.", + ), +) + +featurize!(a::AbstractAtoms, df::DummyFeaturization) = throw( + ArgumentError( + "This featurization is just a dummy, likely created by a model layer such as `AGNConv`, and cannot actually encode features.", + ), +) diff --git a/src/featurizations/featurizations.jl b/src/featurizations/featurizations.jl index e7478343..e014848b 100644 --- a/src/featurizations/featurizations.jl +++ b/src/featurizations/featurizations.jl @@ -1,17 +1,29 @@ module Featurization -using ..ChemistryFeaturization.AbstractType: AbstractFeaturization +using ..ChemistryFeaturization.AbstractType: + AbstractFeaturization, AbstractAtoms, AbstractCodec import ..ChemistryFeaturization.encodable_elements encodable_elements(fzn::AbstractFeaturization) = throw(MethodError(encodable_elements, fzn)) export encodable_elements +featurize!(a::AbstractAtoms, fzn::AbstractFeaturization) = + throw(MethodError(featurize!, (a, fzn))) + import ..ChemistryFeaturization.decode -decode(fzn::AbstractFeaturization, encoded_feature) = throw(MethodError(decode, fzn)) +decode(fzn::AbstractFeaturization, encoded_feature) = + throw(MethodError(decode, (fzn, encoded_feature))) include("graphnodefeaturization.jl") export GraphNodeFeaturization, featurize!, decode +validate_features(fzn::AbstractFeaturization, ed::AbstractCodec, encoded) = + throw(MethodError(validate_features, (fzn, ed, encoded))) +export validate_features + include("weavefeaturization.jl") export WeaveFeaturization +include("dummyfeaturization.jl") +export DummyFeaturization + end diff --git a/src/featurizations/graphnodefeaturization.jl b/src/featurizations/graphnodefeaturization.jl index d94d3b28..0c673a26 100644 --- a/src/featurizations/graphnodefeaturization.jl +++ b/src/featurizations/graphnodefeaturization.jl @@ -133,12 +133,3 @@ function decode(fzn::GraphNodeFeaturization, encoded::Matrix{<:Real}) end return decoded end - -function decode(ag::AtomGraph) - @assert !(any(isnothing.([ag.featurization, ag.encoded_features]))) - decoded = decode(ag.featurization, ag.encoded_features) - for (k, v) in decoded - v["Symbol"] = ag.elements[k] - end - return decoded -end diff --git a/test/featurizations/DummyFeaturization_tests.jl b/test/featurizations/DummyFeaturization_tests.jl new file mode 100644 index 00000000..6dcb491a --- /dev/null +++ b/test/featurizations/DummyFeaturization_tests.jl @@ -0,0 +1,12 @@ +using Test +using ChemistryFeaturization.Featurization +using ChemistryFeaturization.Atoms: AtomGraph + +@testset "DummyFeaturization" begin + df = DummyFeaturization() + @test isempty(encodable_elements(df)) + dummy_encoded = [0 1 0] + @test_throws ArgumentError decode(df, dummy_encoded) + F2 = AtomGraph(Float32.([0 1; 1 0]), ["F", "F"]) + @test_throws ArgumentError featurize!(F2, df) +end diff --git a/test/module_tests.jl b/test/module_tests.jl new file mode 100644 index 00000000..f7bcd453 --- /dev/null +++ b/test/module_tests.jl @@ -0,0 +1,31 @@ +using Test +using ChemistryFeaturization.AbstractType: AbstractFeaturization, AbstractFeatureDescriptor + +@testset "modules and abstract methods" begin + F2 = AtomGraph(Float32.([0 1; 1 0]), ["F", "F"]) + + @testset "top-level module" begin + # testing on `nothing` as example of ::Any + @test_throws MethodError encodable_elements(nothing) + @test_throws MethodError decode(nothing, nothing) + end + + @testset "featurizations module" begin + struct FakeFeaturization <: AbstractFeaturization end + ff = FakeFeaturization() + @test_throws MethodError encodable_elements(ff) + @test_throws MethodError featurize!(F2, ff) + @test_throws MethodError decode(ff, nothing) + end + + # @testset "atoms module" begin + # TBD cleanest way to test generic decode(::AbstractAtoms) - either another "fake" class, or maybe the `invoke` function + # end + + @testset "features module" begin + struct FakeFD <: AbstractFeatureDescriptor end + fd = FakeFD() + @test_throws MethodError encodable_elements(fd) + @test_throws MethodError decode(fd, nothing) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 942f1bca..2e60fc2b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,10 +7,12 @@ tests = [ "utils/ElementFeatureUtils_tests", "utils/GraphBuilding_tests", "atoms/AtomGraph_tests", + # TODO: add SpeciesFeature tests "features/ElementFeature_tests", "featurizations/GraphNodeFeaturization_tests", + "featurizations/DummyFeaturization_tests", + "module_tests", # TODO: add Weave stuff - # TODO: add SpeciesFeature tests ] @testset "ChemistryFeaturization" begin