From 9fbc87b948d5cd86aa15dc1edd6cbc39a62d689c Mon Sep 17 00:00:00 2001 From: vondrakmar Date: Wed, 15 May 2024 12:24:26 +0000 Subject: [PATCH 1/3] removed build_default_arg_parser from preprocess_data file --- mace/cli/preprocess_data.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 9f695d96..544c90af 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -97,11 +97,10 @@ def main() -> None: This script loads an xyz dataset and prepares new hdf5 file that is ready for training with on-the-fly dataloading """ - args = tools.build_default_arg_parser().parse_args() - run(args) + run() -def run(args: argparse.Namespace) -> None: +def run() -> None: """ This script loads an xyz dataset and prepares new hdf5 file that is ready for training with on-the-fly dataloading From b4ea20ccc3f9c361873b99a3f6e7a5c266845ee5 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 15 May 2024 15:04:22 +0100 Subject: [PATCH 2/3] fix pylint warnings --- mace/calculators/foundations_models.py | 1 + mace/cli/preprocess_data.py | 6 +++--- mace/cli/run_train.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index e18e7f33..1ee7659a 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -108,6 +108,7 @@ def mace_mp( mace_calc = MACECalculator( model_paths=model, device=device, default_dtype=default_dtype, **kwargs ) + d3_calc = None if dispersion: gh_url = "https://github.com/pfnet-research/torch-dftd" try: diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 544c90af..64d8dee6 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -97,15 +97,15 @@ def main() -> None: This script loads an xyz dataset and prepares new hdf5 file that is ready for training with on-the-fly dataloading """ - run() + args = tools.build_preprocess_arg_parser().parse_args() + run(args) -def run() -> None: +def run(args: argparse.Namespace): """ This script loads an xyz dataset and prepares new hdf5 file that is ready for training with on-the-fly dataloading """ - args = tools.build_preprocess_arg_parser().parse_args() # Setup tools.set_seeds(args.seed) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1aabb366..3a0b14c9 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -390,6 +390,7 @@ def run(args: argparse.Namespace) -> None: ) model_config_foundation["atomic_energies"] = atomic_energies args.model = "FoundationMACE" + model_config = model_config_foundation # pylint else: logging.info("Building model") if args.num_channels is not None and args.max_L is not None: @@ -584,7 +585,7 @@ def run(args: argparse.Namespace) -> None: args.start_swa = max(1, args.max_num_epochs // 4 * 3) logging.info(f"Setting start swa to {args.start_swa}") if args.loss == "forces_only": - logging.info("Can not select swa with forces only loss.") + raise ValueError("Can not select swa with forces only loss.") elif args.loss == "virials": loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( energy_weight=args.swa_energy_weight, From 3cc7c3d56060313b8e61b332c1dc16beea1ec9d7 Mon Sep 17 00:00:00 2001 From: Ilyes Batatia <48651863+ilyes319@users.noreply.github.com> Date: Wed, 15 May 2024 15:10:24 +0100 Subject: [PATCH 3/3] fix formatting --- mace/cli/run_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 3a0b14c9..1e17ad97 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -586,7 +586,7 @@ def run(args: argparse.Namespace) -> None: logging.info(f"Setting start swa to {args.start_swa}") if args.loss == "forces_only": raise ValueError("Can not select swa with forces only loss.") - elif args.loss == "virials": + if args.loss == "virials": loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( energy_weight=args.swa_energy_weight, forces_weight=args.swa_forces_weight,