Skip to content

Commit

Permalink
Fix remaining elasticity transformation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
munrojm committed Oct 18, 2023
1 parent cbcda4e commit b2cedcf
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 25 deletions.
44 changes: 33 additions & 11 deletions emmet-builders/emmet/builders/materials/elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def get_items(

self.ensure_index()

cursor = self.materials.query(criteria=self.query, properties=["material_id", "calc_types", "task_ids"])
cursor = self.materials.query(
criteria=self.query, properties=["material_id", "calc_types", "task_ids"]
)

# query for tasks
# query = self.query.copy()
Expand Down Expand Up @@ -126,7 +128,9 @@ def get_items(

yield material_id, calc_types, tasks

def process_item(self, item: Tuple[MPID, Dict[str, str], List[Dict]]) -> Union[Dict, None]:
def process_item(
self, item: Tuple[MPID, Dict[str, str], List[Dict]]
) -> Union[Dict, None]:
"""
Process all tasks belong to the same material into an elasticity doc.
Expand Down Expand Up @@ -165,17 +169,23 @@ def process_item(self, item: Tuple[MPID, Dict[str, str], List[Dict]]) -> Union[D

# select one task for each set of optimization tasks with the same lattice
opt_grouped_tmp = group_by_parent_lattice(opt_tasks, mode="opt")
opt_grouped = [(lattice, filter_opt_tasks_by_time(tasks, self.logger)) for lattice, tasks in opt_grouped_tmp]
opt_grouped = [
(lattice, filter_opt_tasks_by_time(tasks, self.logger))
for lattice, tasks in opt_grouped_tmp
]

# for deformed tasks with the same lattice, select one if there are multiple
# tasks with the same deformation
deform_grouped = group_by_parent_lattice(deform_tasks, mode="deform")
deform_grouped = [
(lattice, filter_deform_tasks_by_time(tasks, logger=self.logger)) for lattice, tasks in deform_grouped
(lattice, filter_deform_tasks_by_time(tasks, logger=self.logger))
for lattice, tasks in deform_grouped
]

# select opt and deform tasks for fitting
final_opt, final_deform = select_final_opt_deform_tasks(opt_grouped, deform_grouped, self.logger)
final_opt, final_deform = select_final_opt_deform_tasks(
opt_grouped, deform_grouped, self.logger
)
if final_opt is None or final_deform is None:
return None

Expand All @@ -185,7 +195,9 @@ def process_item(self, item: Tuple[MPID, Dict[str, str], List[Dict]]) -> Union[D
deform_task_ids = []
deform_dir_names = []
for doc in final_deform:
deforms.append(Deformation(doc["transformations"]["history"][0]["deformation"]))
deforms.append(
Deformation(doc["transformations"]["history"][0]["deformation"])
)
# 0.1 to convert to GPa from kBar, and the minus sign to flip the stress
# direction from compressive as positive (in vasp) to tensile as positive
stresses.append(-0.1 * Stress(doc["output"]["stress"]))
Expand Down Expand Up @@ -237,7 +249,7 @@ def filter_opt_tasks(
def filter_deform_tasks(
tasks: List[Dict],
calc_types: Dict[str, str],
target_calc_type: str = CalcType.GGA_Structure_Optimization,
target_calc_type: str = CalcType.GGA_Deformation,
) -> List[Dict]:
"""
Filter deformation tasks, by
Expand All @@ -249,13 +261,18 @@ def filter_deform_tasks(
for t in tasks:
if calc_types[str(t["task_id"])] == target_calc_type:
transforms = t.get("transformations", {}).get("history", [])
if len(transforms) == 1 and transforms[0]["@class"] == "DeformStructureTransformation":
if (
len(transforms) == 1
and transforms[0]["@class"] == "DeformStructureTransformation"
):
deform_tasks.append(t)

return deform_tasks


def filter_by_incar_settings(tasks: List[Dict], incar_settings: Optional[Dict[str, Any]] = None) -> List[Dict]:
def filter_by_incar_settings(
tasks: List[Dict], incar_settings: Optional[Dict[str, Any]] = None
) -> List[Dict]:
"""
Filter tasks by incar parameters.
"""
Expand Down Expand Up @@ -312,7 +329,9 @@ def filter_opt_tasks_by_time(tasks: List[Dict], logger) -> Dict:
return _filter_tasks_by_time(tasks, "optimization", logger)


def filter_deform_tasks_by_time(tasks: List[Dict], deform_comp_tol: float = 1e-5, logger=None) -> List[Dict]:
def filter_deform_tasks_by_time(
tasks: List[Dict], deform_comp_tol: float = 1e-5, logger=None
) -> List[Dict]:
"""
For deformation tasks with the same deformation, select the latest completed one.
Expand Down Expand Up @@ -412,7 +431,10 @@ def select_final_opt_deform_tasks(
tasks.extend(pair[1])

ids = [t["task_id"] for t in tasks]
logger.warning(f"Cannot find optimization and deformation tasks that match by lattice " f"for tasks {ids}")
logger.warning(
f"Cannot find optimization and deformation tasks that match by lattice "
f"for tasks {ids}"
)

final_opt_task = None
final_deform_tasks = None
Expand Down
53 changes: 41 additions & 12 deletions emmet-builders/emmet/builders/vasp/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,17 @@ def prechunk(self, number_splits: int) -> Iterable[Dict]: # pragma: no cover
temp_query["tags"] = {"$in": self.settings.BUILD_TAGS}

self.logger.info("Finding tasks to process")
all_tasks = list(self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"]))
all_tasks = list(
self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"])
)

processed_tasks = set(self.materials.distinct("task_ids"))
to_process_tasks = {d[self.tasks.key] for d in all_tasks} - processed_tasks
to_process_forms = {d["formula_pretty"] for d in all_tasks if d[self.tasks.key] in to_process_tasks}
to_process_forms = {
d["formula_pretty"]
for d in all_tasks
if d[self.tasks.key] in to_process_tasks
}

N = ceil(len(to_process_forms) / number_splits)

Expand Down Expand Up @@ -146,11 +152,17 @@ def get_items(self) -> Iterator[List[Dict]]:
temp_query["tags"] = {"$in": self.settings.BUILD_TAGS}

self.logger.info("Finding tasks to process")
all_tasks = list(self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"]))
all_tasks = list(
self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"])
)

processed_tasks = set(self.materials.distinct("task_ids"))
to_process_tasks = {d[self.tasks.key] for d in all_tasks} - processed_tasks
to_process_forms = {d["formula_pretty"] for d in all_tasks if d[self.tasks.key] in to_process_tasks}
to_process_forms = {
d["formula_pretty"]
for d in all_tasks
if d[self.tasks.key] in to_process_tasks
}

self.logger.info(f"Found {len(to_process_tasks)} unprocessed tasks")
self.logger.info(f"Found {len(to_process_forms)} unprocessed formulas")
Expand All @@ -160,7 +172,10 @@ def get_items(self) -> Iterator[List[Dict]]:

if self.task_validation:
invalid_ids = {
doc[self.tasks.key] for doc in self.task_validation.query({"valid": False}, [self.task_validation.key])
doc[self.tasks.key]
for doc in self.task_validation.query(
{"valid": False}, [self.task_validation.key]
)
}
else:
invalid_ids = set()
Expand Down Expand Up @@ -192,7 +207,9 @@ def get_items(self) -> Iterator[List[Dict]]:
for formula in to_process_forms:
tasks_query = dict(temp_query)
tasks_query["formula_pretty"] = formula
tasks = list(self.tasks.query(criteria=tasks_query, properties=projected_fields))
tasks = list(
self.tasks.query(criteria=tasks_query, properties=projected_fields)
)
for t in tasks:
t["is_valid"] = t[self.tasks.key] not in invalid_ids

Expand Down Expand Up @@ -236,7 +253,8 @@ def process_item(self, items: List[Dict]) -> List[Dict]:
doc.warnings.append(str(e))
materials.append(doc)
self.logger.warn(
f"Failed making material for {failed_ids}." f" Inserted as deprecated Material: {doc.material_id}"
f"Failed making material for {failed_ids}."
f" Inserted as deprecated Material: {doc.material_id}"
)

self.logger.debug(f"Produced {len(materials)} materials for {formula}")
Expand Down Expand Up @@ -276,20 +294,30 @@ def filter_and_group_tasks(
filtered_tasks = []
filtered_transformations = []
for task, transformations in zip(tasks, task_transformations):
if any(allowed_type == task.task_type for allowed_type in self.settings.VASP_ALLOWED_VASP_TYPES):
if any(
allowed_type == task.task_type
for allowed_type in self.settings.VASP_ALLOWED_VASP_TYPES
):
filtered_tasks.append(task)
filtered_transformations.append(transformations)

structures = []
for idx, (task, transformations) in enumerate(zip(filtered_tasks, filtered_transformations)):
for idx, (task, transformations) in enumerate(
zip(filtered_tasks, filtered_transformations)
):
if task.task_type == TaskType.Deformation:
if transformations is None: # Do not include deformed tasks without transformation information
if (
transformations is None
): # Do not include deformed tasks without transformation information
self.logger.debug(
"Cannot find transformation for deformation task {}. Excluding task.".format(task.task_id)
"Cannot find transformation for deformation task {}. Excluding task.".format(
task.task_id
)
)
continue
else:
s = undeform_structure(task.input.structure, transformations)

else:
s = task.output.structure
s.index: int = idx # type: ignore
Expand Down Expand Up @@ -326,7 +354,8 @@ def undeform_structure(structure: Structure, transformations: Dict) -> Structure
structure = dst.apply_transformation(structure)
else:
raise RuntimeError(
"Expect transformation to be `DeformStructureTransformation`; " f"got {transformation['@class']}"
"Expect transformation to be `DeformStructureTransformation`; "
f"got {transformation['@class']}"
)

return structure
4 changes: 2 additions & 2 deletions emmet-builders/tests/test_elasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def test_elasticity_builder(tasks_store, materials_store, elasticity_store):
)
builder.run()

assert elasticity_store.count() == 3
assert elasticity_store.count({"deprecated": False}) == 3
assert elasticity_store.count() == 6
assert elasticity_store.count({"deprecated": False}) == 6


def test_serialization(tmpdir):
Expand Down
Binary file modified test_files/elasticity/SiC_tasks.json.gz
Binary file not shown.

0 comments on commit b2cedcf

Please sign in to comment.