diff --git a/emmet-builders/emmet/builders/vasp/materials.py b/emmet-builders/emmet/builders/vasp/materials.py index 2933ab46cf..4661d93375 100644 --- a/emmet-builders/emmet/builders/vasp/materials.py +++ b/emmet-builders/emmet/builders/vasp/materials.py @@ -102,17 +102,11 @@ def prechunk(self, number_splits: int) -> Iterable[Dict]: # pragma: no cover temp_query["tags"] = {"$in": self.settings.BUILD_TAGS} self.logger.info("Finding tasks to process") - all_tasks = list( - self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"]) - ) + all_tasks = list(self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"])) processed_tasks = set(self.materials.distinct("task_ids")) to_process_tasks = {d[self.tasks.key] for d in all_tasks} - processed_tasks - to_process_forms = { - d["formula_pretty"] - for d in all_tasks - if d[self.tasks.key] in to_process_tasks - } + to_process_forms = {d["formula_pretty"] for d in all_tasks if d[self.tasks.key] in to_process_tasks} N = ceil(len(to_process_forms) / number_splits) @@ -152,17 +146,11 @@ def get_items(self) -> Iterator[List[Dict]]: temp_query["tags"] = {"$in": self.settings.BUILD_TAGS} self.logger.info("Finding tasks to process") - all_tasks = list( - self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"]) - ) + all_tasks = list(self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"])) processed_tasks = set(self.materials.distinct("task_ids")) to_process_tasks = {d[self.tasks.key] for d in all_tasks} - processed_tasks - to_process_forms = { - d["formula_pretty"] - for d in all_tasks - if d[self.tasks.key] in to_process_tasks - } + to_process_forms = {d["formula_pretty"] for d in all_tasks if d[self.tasks.key] in to_process_tasks} self.logger.info(f"Found {len(to_process_tasks)} unprocessed tasks") self.logger.info(f"Found {len(to_process_forms)} unprocessed formulas") @@ -172,10 +160,7 @@ def get_items(self) -> Iterator[List[Dict]]: if self.task_validation: invalid_ids = { - doc[self.tasks.key] - for doc in self.task_validation.query( - {"valid": False}, [self.task_validation.key] - ) + doc[self.tasks.key] for doc in self.task_validation.query({"valid": False}, [self.task_validation.key]) } else: invalid_ids = set() @@ -199,7 +184,7 @@ def get_items(self) -> Iterator[List[Dict]]: "input.hubbards", "input.potcar_spec", # needed for transform deformation structure back for grouping - "transmuter", + "transformations", # misc info for materials doc "tags", ] @@ -207,9 +192,7 @@ def get_items(self) -> Iterator[List[Dict]]: for formula in to_process_forms: tasks_query = dict(temp_query) tasks_query["formula_pretty"] = formula - tasks = list( - self.tasks.query(criteria=tasks_query, properties=projected_fields) - ) + tasks = list(self.tasks.query(criteria=tasks_query, properties=projected_fields)) for t in tasks: t["is_valid"] = t[self.tasks.key] not in invalid_ids @@ -231,12 +214,12 @@ def process_item(self, items: List[Dict]) -> List[Dict]: formula = tasks[0].formula_pretty task_ids = [task.task_id for task in tasks] - # not all tasks contains transmuter - transmuters = [task.get("transmuter", None) for task in items] + # not all tasks contains transformation information + task_transformations = [task.get("transformations", None) for task in items] self.logger.debug(f"Processing {formula}: {task_ids}") - grouped_tasks = self.filter_and_group_tasks(tasks, transmuters) + grouped_tasks = self.filter_and_group_tasks(tasks, task_transformations) materials = [] for group in grouped_tasks: try: @@ -253,8 +236,7 @@ def process_item(self, items: List[Dict]) -> List[Dict]: doc.warnings.append(str(e)) materials.append(doc) self.logger.warn( - f"Failed making material for {failed_ids}." - f" Inserted as deprecated Material: {doc.material_id}" + f"Failed making material for {failed_ids}." f" Inserted as deprecated Material: {doc.material_id}" ) self.logger.debug(f"Produced {len(materials)} materials for {formula}") @@ -285,38 +267,29 @@ def update_targets(self, items: List[List[Dict]]): self.logger.info("No items to update") def filter_and_group_tasks( - self, tasks: List[TaskDocument], transmuters: List[Union[Dict, None]] + self, tasks: List[TaskDocument], task_transformations: List[Union[Dict, None]] ) -> Iterator[List[TaskDocument]]: """ Groups tasks by structure matching """ filtered_tasks = [] - filtered_transmuters = [] - for task, transmuter in zip(tasks, transmuters): - if any( - allowed_type == task.task_type - for allowed_type in self.settings.VASP_ALLOWED_VASP_TYPES - ): + filtered_transformations = [] + for task, transformations in zip(tasks, task_transformations): + if any(allowed_type == task.task_type for allowed_type in self.settings.VASP_ALLOWED_VASP_TYPES): filtered_tasks.append(task) - filtered_transmuters.append(transmuter) + filtered_transformations.append(transformations) structures = [] - for idx, (task, transmuter) in enumerate( - zip(filtered_tasks, filtered_transmuters) - ): + for idx, (task, transformations) in enumerate(zip(filtered_tasks, filtered_transformations)): if task.task_type == TaskType.Deformation: - if ( - transmuter is None - ): # Do not include deformed tasks without transmuter information + if transformations is None: # Do not include deformed tasks without transformation information self.logger.debug( - "Cannot find transmuter for deformation task {}. Excluding task.".format( - task.task_id - ) + "Cannot find transformation for deformation task {}. Excluding task.".format(task.task_id) ) continue else: - s = undeform_structure(task.input.structure, transmuter) + s = undeform_structure(task.input.structure, transformations) else: s = task.output.structure s.index: int = idx # type: ignore @@ -334,32 +307,26 @@ def filter_and_group_tasks( yield grouped_tasks -def undeform_structure(structure: Structure, transmuter: Dict) -> Structure: +def undeform_structure(structure: Structure, transformations: Dict) -> Structure: """ Get the undeformed structure by applying the transformations in a reverse order. Args: structure: deformed structure - transmuter: transformation that deforms the structure + transformation: transformation that deforms the structure Returns: undeformed structure """ - for trans, params in reversed( - list(zip(transmuter["transformations"], transmuter["transformation_params"])) - ): - # The transmuter only stores the transformation class and parameter, without - # module info and such. Therefore, there is no general way to reconstruct it, - # and has to do if else check. - if trans == "DeformStructureTransformation": - deform = Deformation(params["deformation"]) + for transformation in reversed(transformations.get("history", [])): + if transformation["@class"] == "DeformStructureTransformation": + deform = Deformation(transformation["deformation"]) dst = DeformStructureTransformation(deform.inv) structure = dst.apply_transformation(structure) else: raise RuntimeError( - "Expect transformation to be `DeformStructureTransformation`; " - f"got {trans}" + "Expect transformation to be `DeformStructureTransformation`; " f"got {transformation['@class']}" ) return structure