Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix swa bigger than epoch #616

Merged
merged 7 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ class MACECalculator(Calculator):

def __init__(
self,
model_paths: Union[list, str] | None = None,
device: str | None = None,
models: Union[list[torch.nn.Module], torch.nn.Module] | None = None,
model_paths: Union[list, str, None] = None,
models: Union[list[torch.nn.Module], torch.nn.Module, None] = None,
device: str = "cpu",
energy_units_to_eV: float = 1.0,
length_units_to_A: float = 1.0,
default_dtype="",
Expand Down
13 changes: 12 additions & 1 deletion mace/cli/create_lammps_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@ def parse_args():
help="Head of the model to be converted to LAMMPS",
default=None,
)
parser.add_argument(
"--dtype",
type=str,
nargs="?",
help="Data type of the model to be converted to LAMMPS",
default="float64",
)
return parser.parse_args()


Expand Down Expand Up @@ -58,7 +65,11 @@ def main():
model_path,
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
model = model.double().to("cpu")
if args.dtype == "float64":
model = model.double().to("cpu")
elif args.dtype == "float32":
print("Converting model to float32, this may cause loss of precision.")
model = model.float().to("cpu")

if args.head is None:
head = select_head(model)
Expand Down
17 changes: 16 additions & 1 deletion mace/cli/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,22 @@ def pool_compute_stats(inputs: List):
pool.join()

results = [r.get() for r in tqdm.tqdm(re)]
return np.average(results, axis=0)

if not results:
raise ValueError(
"No results were computed. Check if the input files exist and are readable."
)

# Separate avg_num_neighbors, mean, and std
avg_num_neighbors = np.mean([r[0] for r in results])
means = np.array([r[1] for r in results])
stds = np.array([r[2] for r in results])

# Compute averages
mean = np.mean(means, axis=0).item()
std = np.mean(stds, axis=0).item()

return avg_num_neighbors, mean, std


def split_array(a: np.ndarray, max_size: int):
Expand Down
9 changes: 5 additions & 4 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def run(args: argparse.Namespace) -> None:
args.loss = "universal"
if (
args.foundation_model in ["small", "medium", "large"]
or "mp" in args.foundation_model
or args.pt_train_file is None
):
logging.info(
Expand Down Expand Up @@ -344,6 +343,7 @@ def run(args: argparse.Namespace) -> None:
atomic_energies_dict = {}
for head_config in head_configs:
if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0:
assert head_config.E0s is not None, "Atomic energies must be provided"
if check_path_ase_read(head_config.train_file) and head_config.E0s.lower() != "foundation":
atomic_energies_dict[head_config.head_name] = get_atomic_energies(
head_config.E0s, head_config.collections.train, head_config.z_table
Expand Down Expand Up @@ -403,7 +403,10 @@ def run(args: argparse.Namespace) -> None:
# )
atomic_energies = dict_to_array(atomic_energies_dict, heads)
for head_config in head_configs:
logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}")
try:
logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}")
except KeyError as e:
raise KeyError(f"Atomic number {e} not found in atomic_energies_dict for head {head_config.head_name}, add E0s for this atomic number") from e


valid_sets = {head: [] for head in heads}
Expand Down Expand Up @@ -627,9 +630,7 @@ def run(args: argparse.Namespace) -> None:
stop_first_test = True
for head_config in head_configs:
if check_path_ase_read(head_config.train_file):
print(head_config.test_file)
for name, subset in head_config.collections.tests:
print(name)
test_sets[name] = [
data.AtomicData.from_config(
config, z_table=z_table, cutoff=args.r_max, heads=heads
Expand Down
10 changes: 6 additions & 4 deletions mace/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def compute_statistics(
forces_list = []
num_neighbors = []
head_list = []
head_batch = []

for batch in data_loader:
head = batch.head
Expand All @@ -391,21 +392,22 @@ def compute_statistics(
) # {[n_graphs], }
forces_list.append(batch.forces) # {[n_graphs*n_atoms,3], }
head_list.append(head) # {[n_graphs], }

head_batch.append(head[batch.batch])
_, receivers = batch.edge_index
_, counts = torch.unique(receivers, return_counts=True)
num_neighbors.append(counts)

atom_energies = torch.cat(atom_energy_list, dim=0) # [total_n_graphs]
forces = torch.cat(forces_list, dim=0) # {[total_n_graphs*n_atoms,3], }
head = torch.cat(head_list, dim=0) # [total_n_graphs]
head_batch = torch.cat(head_batch, dim=0) # [total_n_graphs]

# mean = to_numpy(torch.mean(atom_energies)).item()
mean = to_numpy(scatter_mean(src=atom_energies, index=head, dim=0).squeeze(-1))
# do the mean for each head
# rms = to_numpy(torch.sqrt(torch.mean(torch.square(forces)))).item()
rms = to_numpy(
torch.sqrt(scatter_mean(src=torch.square(forces), index=head, dim=0))
torch.sqrt(
scatter_mean(src=torch.square(forces), index=head_batch, dim=0).mean(-1)
)
)

avg_num_neighbors = torch.mean(
Expand Down
3 changes: 2 additions & 1 deletion mace/tools/scripts_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,11 @@ def get_swa(
if args.start_swa is None:
args.start_swa = max(1, args.max_num_epochs // 4 * 3)
else:
if args.start_swa > args.max_num_epochs:
if args.start_swa >= args.max_num_epochs:
logging.warning(
f"Start Stage Two must be less than max_num_epochs, got {args.start_swa} > {args.max_num_epochs}"
)
swas[-1] = False
if args.loss == "forces_only":
raise ValueError("Can not select Stage Two with forces only loss.")
if args.loss == "virials":
Expand Down
Loading
Loading