Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instructions for use of the benchmark datasets and metrics on custom generative models #10

Open
sgbaird opened this issue Jun 4, 2022 · 8 comments

Comments

@sgbaird
Copy link

sgbaird commented Jun 4, 2022

Hi @txie-93, I'm enjoying digging into the manuscript, and congratulations on its acceptance to ICLR! It is really nice to see the comparison with FTCP and other methods, and CDVAE certainly has some impressive results.

Would you mind providing some instructions in the repository for using the benchmark datasets and the metrics on a custom generative model? For example, how would this look for FTCP or the slew of other generative models in this space (i.e. the general inverse design ones)?

@kyonofx
Copy link

kyonofx commented Jun 4, 2022

Hi Sterling, thank you for your interest.

Use our datasets on other models

Our datasets are csv files where each row contains a crystal with its cif string. Both FTCP and Cond-DFC-VAE have utilities and guideline to configure the model for cif data source. Maybe that is also the case for other crystal generative models, given the common use of cif for crystals?

To adopt G-SchNet, we processed the crystals to ASE atoms. I am enclosing my code here:

import numpy as np
import pandas as pd

from pathlib import Path
from tqdm import tqdm
from ase import Atoms

from pymatgen.core.structure import Structure
from pymatgen.core.lattice import Lattice

def abs_cap(val, max_abs_val=1):
    return max(min(val, max_abs_val), -max_abs_val)

def lattice_params_to_matrix(a, b, c, alpha, beta, gamma):
    angles_r = np.radians([alpha, beta, gamma])
    cos_alpha, cos_beta, cos_gamma = np.cos(angles_r)
    sin_alpha, sin_beta, sin_gamma = np.sin(angles_r)

    val = (cos_alpha * cos_beta - cos_gamma) / (sin_alpha * sin_beta)
    # Sometimes rounding errors result in values slightly > 1.
    val = abs_cap(val)
    gamma_star = np.arccos(val)

    vector_a = [a * sin_beta, 0.0, a * cos_beta]
    vector_b = [
        -b * sin_alpha * np.cos(gamma_star),
        b * sin_alpha * np.sin(gamma_star),
        b * cos_alpha,
    ]
    vector_c = [0.0, 0.0, float(c)]
    return np.array([vector_a, vector_b, vector_c])

def build_crystal(crystal_str, niggli=True, primitive=False, supercell=False):
    """Build crystal from cif string."""
    crystal = Structure.from_str(crystal_str, fmt='cif')
    if primitive:
        crystal = crystal.get_primitive_structure()
    if niggli:
        crystal = crystal.get_reduced_structure()
    canonical_crystal = Structure(
        lattice=Lattice.from_parameters(*crystal.lattice.parameters),
        species=crystal.species,
        coords=crystal.frac_coords,
        coords_are_cartesian=False,
    )
    return canonical_crystal

def get_ase_atoms(cif):
    crystal = build_crystal(cif)
    lattice = lattice_params_to_matrix(*crystal.lattice.abc, *crystal.lattice.angles)
    at = Atoms(scaled_positions=crystal.frac_coords, 
               numbers=np.array(crystal.atomic_numbers), 
               cell=lattice, pbc=True)
    return at

and then one can follow the instruction to build dataset objects for G-SchNet.

Use our benchmark metrics

A dictionary containing the following is all you need for evaluation:

frac_coords: fractional coordinates of each atom, shape (num_evals, N, 3)
atom_types: atomic number of each atom, shape (num_evals, N)
lengths: the lengths of the lattice, shape (num_evals, M, 3)
angles: the angles of the lattice, shape (num_evals, M, 3)
num_atoms: the number of atoms in each material, shape (num_evals, M)

Any crystal generative models would generate these quantities to be complete.

Our evaluation scripts for computing metrics are independent of CDVAE. One just need to save these quantities as a torch pickle file, and then run compute_metrics.py with that file as input. See https://github.com/txie-93/cdvae/blob/main/scripts/compute_metrics.py#L267 on how the saved crystals are loaded.

Hope this helps.

@sgbaird
Copy link
Author

sgbaird commented Jun 11, 2022

@kyonofx thank you! As I was browsing further, also noticed the README in the data directory. I appreciate the extra clarification here.

@sgbaird sgbaird closed this as completed Jun 11, 2022
@sgbaird
Copy link
Author

sgbaird commented Jun 16, 2022

Our evaluation scripts for computing metrics are independent of CDVAE.

@kyonofx while the scripts are in separate files/folders from cdvae, there are import dependencies that trace back to CDVAE:

from eval_utils import (
smact_validity, structure_validity, CompScaler, get_fp_pdist,
load_config, load_data, get_crystals_list, prop_model_eval, compute_cov)

from cdvae.common.constants import CompScalerMeans, CompScalerStds
from cdvae.common.data_utils import StandardScaler, chemical_symbols
from cdvae.pl_data.dataset import TensorCrystDataset
from cdvae.pl_data.datamodule import worker_init_fn

@sgbaird sgbaird reopened this Jun 16, 2022
@kyonofx
Copy link

kyonofx commented Jun 16, 2022

Hi,

Yes, you still need to install the cdvae package, and evaluation can be run without training a cdvae model.

@sgbaird
Copy link
Author

sgbaird commented Jun 16, 2022

@kyonofx I'm planning to expose these metrics in their own package plus additional metric(s). Would you recommend that I try to splice out the functionality or package CDVAE as a whole onto PyPI and Anaconda? #14

@kyonofx
Copy link

kyonofx commented Jun 16, 2022

Hi,

I think it might be the easiest to splice out the evaluation code as they only compose a small fraction of the cdvae codebase.

@sgbaird
Copy link
Author

sgbaird commented Jun 23, 2022

@kyonofx separating it out is turning out to be ☠️ I'm reconstructing most of the repository piece by piece. Not very straightforward, as it accesses many files in the repository.

@sgbaird
Copy link
Author

sgbaird commented Aug 5, 2022

Note that this is mainly due to some metric(s) requiring predictions from a CDVAE submodel (i.e. property regressor).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants