Skip to content

Commit

Permalink
Moved batch size checks compared to data size right after data set is…
Browse files Browse the repository at this point in the history
… loaded
  • Loading branch information
Eszter Varga-Umbrich committed Aug 16, 2024
1 parent 8ecfa97 commit 7bcf55e
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,18 +167,18 @@ def run(args: argparse.Namespace) -> None:
charges_key=args.charges_key,
keep_isolated_atoms=args.keep_isolated_atoms,
)
if len(collections.train)<args.batch_size:
logging.warning(f"Batch size ({args.batch_size}) is larger than the number of training data ({len(collections.train)})")
args.batch_size = int(len(collections.train)*0.1)
logging.warning(f"Batch size changed to {args.batch_size}")
if len(collections.train)<len(collections.valid):
logging.warning(f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({len(collections.valid)})")
args.valid_batch_size = int(len(collections.valid)*0.1)
logging.warning(f"Validation batch size changed to {args.valid_batch_size}")

else:
atomic_energies_dict = None

if len(collections.train)<args.batch_size:
logging.warning(f"Batch size ({args.batch_size}) is larger than the number of training data ({len(collections.train)})")
args.batch_size = int(len(collections.train)*0.1)
logging.warning(f"Batch size changed to {args.batch_size}")
if len(collections.train)<len(collections.valid):
logging.warning(f"Validation batch size ({args.valid_batch_size}) is larger than the number of validation data ({len(collections.valid)})")
args.valid_batch_size = int(len(collections.valid)*0.1)
logging.warning(f"Validation batch size changed to {args.valid_batch_size}")

# Atomic number table
# yapf: disable
Expand Down

0 comments on commit 7bcf55e

Please sign in to comment.