diff --git a/mace/calculators/foundations_models.py b/mace/calculators/foundations_models.py index e18e7f33..1ee7659a 100644 --- a/mace/calculators/foundations_models.py +++ b/mace/calculators/foundations_models.py @@ -108,6 +108,7 @@ def mace_mp( mace_calc = MACECalculator( model_paths=model, device=device, default_dtype=default_dtype, **kwargs ) + d3_calc = None if dispersion: gh_url = "https://github.com/pfnet-research/torch-dftd" try: diff --git a/mace/cli/preprocess_data.py b/mace/cli/preprocess_data.py index 9f695d96..64d8dee6 100644 --- a/mace/cli/preprocess_data.py +++ b/mace/cli/preprocess_data.py @@ -97,16 +97,15 @@ def main() -> None: This script loads an xyz dataset and prepares new hdf5 file that is ready for training with on-the-fly dataloading """ - args = tools.build_default_arg_parser().parse_args() + args = tools.build_preprocess_arg_parser().parse_args() run(args) -def run(args: argparse.Namespace) -> None: +def run(args: argparse.Namespace): """ This script loads an xyz dataset and prepares new hdf5 file that is ready for training with on-the-fly dataloading """ - args = tools.build_preprocess_arg_parser().parse_args() # Setup tools.set_seeds(args.seed) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 1aabb366..1e17ad97 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -390,6 +390,7 @@ def run(args: argparse.Namespace) -> None: ) model_config_foundation["atomic_energies"] = atomic_energies args.model = "FoundationMACE" + model_config = model_config_foundation # pylint else: logging.info("Building model") if args.num_channels is not None and args.max_L is not None: @@ -584,8 +585,8 @@ def run(args: argparse.Namespace) -> None: args.start_swa = max(1, args.max_num_epochs // 4 * 3) logging.info(f"Setting start swa to {args.start_swa}") if args.loss == "forces_only": - logging.info("Can not select swa with forces only loss.") - elif args.loss == "virials": + raise ValueError("Can not select swa with forces only loss.") + if args.loss == "virials": loss_fn_energy = modules.WeightedEnergyForcesVirialsLoss( energy_weight=args.swa_energy_weight, forces_weight=args.swa_forces_weight,