Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tidying up examples #50

Merged
merged 1 commit into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CSV = "0.7, 0.8"
ChemistryFeaturization = "0.2"
ChemistryFeaturization = "0.2.2"
DataFrames = "0.21, 0.22"
Flux = "0.11"
LightGraphs = "1.3"
Expand Down
19 changes: 3 additions & 16 deletions examples/1_formation_energy/formation_energy.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#=
Train a simple network to predict formation energy per atom (downloaded from Materials Project).
=#
#using Pkg
#Pkg.activate("../../")
using CSV, DataFrames
using Random, Statistics
using Flux
Expand All @@ -14,7 +12,7 @@ using AtomicGraphNets
println("Setting things up...")

# data-related options
num_pts = 100 # how many points to use? Up to 32530 in the formation energy case as of 2020/04/01
num_pts = 100 # how many points to use?
train_frac = 0.8 # what fraction for training?
num_epochs = 5 # how many epochs to train?
num_train = Int32(round(train_frac * num_pts))
Expand Down Expand Up @@ -48,13 +46,8 @@ output = y[indices]

# next, make graphs and build input features (matrices of dimension (# features, # nodes))
println("Building graphs and feature vectors from structures...")
#graphs = SimpleWeightedGraph{Int32, Float32}[]
#element_lists = Array{String}[]
#inputs = Tuple{Array{Float32,2},SparseArrays.SparseMatrixCSC{Float32,Int64}}[]
inputs = AtomGraph[]

#TODO: this with bulk processing fcn

for r in eachrow(info)
cifpath = string(datadir, prop, "_cifs/", r[Symbol(id)], ".cif")
gr = build_graph(cifpath)
Expand All @@ -71,21 +64,15 @@ train_input = inputs[1:num_train]
test_input = inputs[num_train+1:end]
train_data = zip(train_input, train_output)

# build the network (basically just copied from CGCNN.py for now): the convolutional layers, a mean pooling function, some dense layers, then fully connected output to one value for prediction

# build the model
println("Building the network...")
#model = Chain([AGNConv(num_features=>num_features) for i in 1:num_conv]..., AGNMeanPool(crys_fea_len, 0.1), [Dense(crys_fea_len, crys_fea_len, softplus) for i in 1:num_hidden_layers]..., Dense(crys_fea_len, 1))
model = Xie_model(num_features, num_conv=num_conv, atom_conv_feature_length=crys_fea_len, num_hidden_layers=1)

# MaxPool might make more sense?

# define loss function
# define loss function and a callback to monitor progress
loss(x,y) = Flux.mse(model(x), y)
# and a callback to see training progress
evalcb() = @show(mean(loss.(test_input, test_output)))
evalcb()

# train
println("Training!")
#Flux.train!(loss, params(model), train_data, opt)
@epochs num_epochs Flux.train!(loss, params(model), train_data, opt, cb = Flux.throttle(evalcb, 5))
4 changes: 3 additions & 1 deletion examples/2_qm9/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Example 2: QM9

In this example we will use the same network architecture but use data from the QM9 dataset. Note that the .xyz files provided within the QM9 dataset are not parsable directly by ASE, you need to remove the last couple lines, which is easy enough to script yourself, but I've included a small set of them here for demonstration purposes.
In this example we will use the same network architecture but use data from the QM9 dataset. Note that the .xyz files provided within the QM9 dataset are not parsable directly by ASE, you need to remove the last couple lines, which is easy enough to script yourself, but I've included a small set of them here for demonstration purposes.

NB: the actual model performance on QM9 is not that great because we're currently not encoding a variety of important features for organic molecules. This is provided mainly to show the processing of a different dataset and demonstrate batch processing capabilities.
5 changes: 0 additions & 5 deletions examples/3_deq/README.md

This file was deleted.

86 changes: 0 additions & 86 deletions examples/3_deq/deq.jl

This file was deleted.