diff --git a/qe.ipynb b/qe.ipynb
index 0ec99c175..6f3e977e0 100644
--- a/qe.ipynb
+++ b/qe.ipynb
@@ -73,8 +73,6 @@
"metadata": {},
"outputs": [],
"source": [
- "import urllib.parse as urlparse\n",
- "\n",
"from aiidalab_qe.app.main import App\n",
"from aiidalab_widgets_base.bug_report import (\n",
" install_create_github_issue_exception_handler,\n",
@@ -86,17 +84,23 @@
" labels=(\"bug\", \"automated-report\"),\n",
")\n",
"\n",
+ "app = App(qe_auto_setup=True)\n",
+ "\n",
+ "view.main.children = [app]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import urllib.parse as urlparse\n",
+ "\n",
"url = urlparse.urlsplit(jupyter_notebook_url) # noqa F821\n",
"query = urlparse.parse_qs(url.query)\n",
- "\n",
- "app = App(qe_auto_setup=True)\n",
- "# if a pk is provided in the query string, set it as the process of the app\n",
"if \"pk\" in query:\n",
- " pk = query[\"pk\"][0]\n",
- " app.process = pk\n",
- "\n",
- "view.main.children = [app]\n",
- "view.app = app"
+ " app.process = query[\"pk\"][0]"
]
},
{
@@ -111,7 +115,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3 (ipykernel)",
+ "display_name": "base",
"language": "python",
"name": "python3"
},
@@ -126,11 +130,6 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
- },
- "vscode": {
- "interpreter": {
- "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
- }
}
},
"nbformat": 4,
diff --git a/src/aiidalab_qe/app/configuration/__init__.py b/src/aiidalab_qe/app/configuration/__init__.py
index 39f5f07ae..1b8e25418 100644
--- a/src/aiidalab_qe/app/configuration/__init__.py
+++ b/src/aiidalab_qe/app/configuration/__init__.py
@@ -8,91 +8,82 @@
import ipywidgets as ipw
import traitlets as tl
-from aiida import orm
-from aiidalab_qe.app.utils import get_entry_items
+from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
+from aiidalab_qe.common.panel import SettingsPanel
from aiidalab_widgets_base import WizardAppWidgetStep
from .advanced import AdvancedSettings
+from .model import ConfigurationModel
from .workflow import WorkChainSettings
+DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore
+
class ConfigureQeAppWorkChainStep(ipw.VBox, WizardAppWidgetStep):
- confirmed = tl.Bool()
previous_step_state = tl.UseEnum(WizardAppWidgetStep.State)
- input_structure = tl.Instance(orm.StructureData, allow_none=True)
-
- # output dictionary
- configuration_parameters = tl.Dict()
- def __init__(self, **kwargs):
- self.workchain_settings = WorkChainSettings()
- self.advanced_settings = AdvancedSettings()
+ def __init__(self, model: ConfigurationModel, **kwargs):
+ from aiidalab_qe.common.widgets import LoadingWidget
- ipw.dlink(
- (self.workchain_settings.workchain_protocol, "value"),
- (self.advanced_settings, "protocol"),
- )
- ipw.dlink(
- (self.workchain_settings.spin_type, "value"),
- (self.advanced_settings, "spin_type"),
+ super().__init__(
+ children=[LoadingWidget("Loading workflow configuration panel")],
+ **kwargs,
)
- ipw.dlink(
- (self.workchain_settings.electronic_type, "value"),
- (self.advanced_settings, "electronic_type"),
+
+ self._model = model
+ self._model.observe(
+ self._on_confirmation_change,
+ "confirmed",
)
- ipw.dlink(
- (self, "input_structure"),
- (self.advanced_settings, "input_structure"),
+ self._model.observe(
+ self._on_input_structure_change,
+ "input_structure",
)
- #
- ipw.dlink(
- (self, "input_structure"),
- (self.workchain_settings, "input_structure"),
+ self._model.workchain.observe(
+ self._on_protocol_change,
+ "protocol",
)
- #
+
+ self.missing_structure_message = """
+
+ Please set the input structure first.
+
+ """
+ self.structure_set_message = ipw.HTML(self.missing_structure_message)
+
+ self.workchain_settings = WorkChainSettings(config_model=model)
+ self.advanced_settings = AdvancedSettings(config_model=model)
+
self.built_in_settings = [
self.workchain_settings,
self.advanced_settings,
]
- self.tab = ipw.Tab(
- children=self.built_in_settings,
- layout=ipw.Layout(min_height="250px"),
- )
- self.tab.set_title(0, "Basic settings")
- self.tab.set_title(1, "Advanced settings")
-
- # store the property identifier and setting panel for all plugins
- # only show the setting panel when the corresponding property is selected
- # first add the built-in settings
- self.settings = {
+ self.settings: dict[str, SettingsPanel] = {
"workchain": self.workchain_settings,
"advanced": self.advanced_settings,
}
- # list of trailets to link
- # if new trailets are added to the settings, they need to be added here
- trailets_list = ["input_structure", "protocol", "electronic_type", "spin_type"]
-
- # then add plugin specific settings
- entries = get_entry_items("aiidalab_qe.properties", "setting")
- for identifier, entry_point in entries.items():
- self.settings[identifier] = entry_point(parent=self)
- self.settings[identifier].identifier = identifier
- # link basic protocol to all plugin specific protocols
- if identifier in self.workchain_settings.properties:
- self.workchain_settings.properties[identifier].run.observe(
- self._update_panel, "value"
- )
- # link the trailets if they exist in the plugin specific settings
- for trailet in trailets_list:
- if hasattr(self.settings[identifier], trailet):
- ipw.dlink(
- (self.advanced_settings, trailet),
- (self.settings[identifier], trailet),
- )
-
- self._submission_blocker_messages = ipw.HTML()
+ self.workchain_settings.fetch_setting_entries(
+ register_setting_callback=self._register_setting,
+ update_tabs_callback=self._update_tabs,
+ )
+
+ self.rendered = False
+
+ def render(self):
+ if self.rendered:
+ return
+
+ self.tab = ipw.Tab(
+ layout=ipw.Layout(min_height="250px"),
+ selected_index=None,
+ )
+ self.tab.observe(
+ self._on_tab_change,
+ "selected_index",
+ )
+ self._update_tabs()
self.confirm_button = ipw.Button(
description="Confirm",
@@ -102,84 +93,91 @@ def __init__(self, **kwargs):
disabled=True,
layout=ipw.Layout(width="auto"),
)
-
+ ipw.dlink(
+ (self, "state"),
+ (self.confirm_button, "disabled"),
+ lambda state: state != self.State.CONFIGURED,
+ )
self.confirm_button.on_click(self.confirm)
- super().__init__(
- children=[
- self.tab,
- self._submission_blocker_messages,
- self.confirm_button,
- ],
- **kwargs,
- )
+ self.children = [
+ self.structure_set_message,
+ self.tab,
+ self.confirm_button,
+ ]
+
+ self.rendered = True
+
+ def is_saved(self):
+ return self._model.confirmed
+
+ def confirm(self, _=None):
+ self._model.confirmed = True
+
+ def reset(self):
+ self._model.reset()
+ if self.rendered:
+ self.tab.selected_index = 0
+ for _, settings in self.settings.items():
+ settings.reset()
@tl.observe("previous_step_state")
- def _observe_previous_step_state(self, _change):
+ def _on_previous_step_state_change(self, _):
self._update_state()
- def get_configuration_parameters(self):
- """Get the parameters of the configuration step."""
+ def _on_tab_change(self, change):
+ if (tab_index := change["new"]) is None:
+ return
+ tab: SettingsPanel = self.tab.children[tab_index] # type: ignore
+ tab.render()
+ tab.update()
- return {s.identifier: s.get_panel_value() for s in self.tab.children}
+ def _on_input_structure_change(self, _):
+ # TODO model updates should be done more generally (plugin-approach)
+ self._model.workchain.update()
+ self._update_missing_structure_warning()
+ self.reset()
- def set_configuration_parameters(self, parameters):
- """Set the inputs in the GUI based on a set of parameters."""
+ def _on_protocol_change(self, _):
+ self._model.advanced.update()
- with self.hold_trait_notifications():
- for identifier, settings in self.settings.items():
- if parameters.get(identifier):
- settings.set_panel_value(parameters[identifier])
+ def _on_confirmation_change(self, _):
+ self._update_state()
+
+ def _register_setting(self, identifier, setting):
+ self.settings[identifier] = setting(
+ parent=self,
+ identifier=identifier,
+ config_model=self._model,
+ )
+
+ def _update_missing_structure_warning(self):
+ self.structure_set_message.value = (
+ self.missing_structure_message
+ if self._model.input_structure is None
+ else ""
+ )
+
+ def _update_tabs(self):
+ children = []
+ titles = []
+ for identifier, model in self._model.get_models():
+ if model.include:
+ setting = self.settings[identifier]
+ titles.append(setting.title)
+ children.append(setting)
+ if hasattr(self, "tab"):
+ self.tab.children = children
+ for i, title in enumerate(titles):
+ self.tab.set_title(i, title)
+ self.tab.selected_index = 0
def _update_state(self, _=None):
- if self.previous_step_state == self.State.SUCCESS:
- self.confirm_button.disabled = False
- self._submission_blocker_messages.value = ""
+ if self._model.confirmed:
+ self.state = self.State.SUCCESS
+ elif self.previous_step_state is self.State.SUCCESS:
self.state = self.State.CONFIGURED
- # update plugin specific settings
- for _, settings in self.settings.items():
- settings._update_state()
- elif self.previous_step_state == self.State.FAIL:
+ elif self.previous_step_state is self.State.FAIL:
self.state = self.State.FAIL
else:
- self.confirm_button.disabled = True
self.state = self.State.INIT
- self.reset()
-
- def confirm(self, _=None):
- self.configuration_parameters = self.get_configuration_parameters()
- self.confirm_button.disabled = False
- self.state = self.State.SUCCESS
-
- def is_saved(self):
- """Check if the current step is saved.
- That all changes are confirmed.
- """
- new_parameters = self.get_configuration_parameters()
- return new_parameters == self.configuration_parameters
-
- @tl.default("state")
- def _default_state(self):
- return self.State.INIT
-
- def reset(self):
- """Reset the widgets in all settings to their initial states."""
- with self.hold_trait_notifications():
- self.input_structure = None
- for _, settings in self.settings.items():
- settings.reset()
-
- def _update_panel(self, _=None):
- """Dynamic add/remove the panel based on the selected properties."""
- # only keep basic and advanced settings
- self.tab.children = self.built_in_settings
- # add plugin specific settings
- for identifier in self.workchain_settings.properties:
- if (
- identifier in self.settings
- and self.workchain_settings.properties[identifier].run.value
- ):
- self.tab.children += (self.settings[identifier],)
- self.tab.set_title(
- len(self.tab.children) - 1, self.settings[identifier].title
- )
diff --git a/src/aiidalab_qe/app/configuration/advanced.py b/src/aiidalab_qe/app/configuration/advanced.py
index 5df0e76cd..cba889fed 100644
--- a/src/aiidalab_qe/app/configuration/advanced.py
+++ b/src/aiidalab_qe/app/configuration/advanced.py
@@ -3,135 +3,122 @@
Authors: AiiDAlab team
"""
-import os
-
import ipywidgets as ipw
-import numpy as np
-import traitlets as tl
-from IPython.display import clear_output, display
-from aiida import orm
-from aiida_quantumespresso.calculations.functions.create_kpoints_from_distance import (
- create_kpoints_from_distance,
-)
-from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData
-from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
-from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
-from aiidalab_qe.common.panel import Panel
-from aiidalab_qe.common.widgets import HubbardWidget
-from aiidalab_qe.setup.pseudos import PseudoFamily
+from aiidalab_qe.common.panel import SettingsPanel
-from .pseudos import PseudoFamilySelector, PseudoSetter
+from .hubbard import HubbardSettings
+from .magnetization import MagnetizationSettings
+from .model import ConfigurationModel
+from .pseudos import PseudoSettings
+from .smearing import SmearingSettings
-class AdvancedSettings(Panel):
+class AdvancedSettings(SettingsPanel):
+ title = "Advanced Settings"
identifier = "advanced"
- title = ipw.HTML(
- """
-
Advanced Settings
"""
- )
- pw_adv_description = ipw.HTML(
- """Select the advanced settings for the pw.x code."""
- )
- kpoints_description = ipw.HTML(
- """
- The k-points mesh density of the SCF calculation is set by the protocol.
- The value below represents the maximum distance between the k-points in each direction of reciprocal space.
- Tick the box to override the default, smaller is more accurate and costly.
"""
- )
+ def __init__(self, config_model: ConfigurationModel, **kwargs):
+ super().__init__(
+ config_model=config_model,
+ layout={"justify_content": "space-between", **kwargs.get("layout", {})},
+ **kwargs,
+ )
- dftd3_version = {
- "dft-d3": 3,
- "dft-d3bj": 4,
- "dft-d3m": 5,
- "dft-d3mbj": 6,
- }
- # protocol interface
- protocol = tl.Unicode(allow_none=True)
- input_structure = tl.Instance(orm.StructureData, allow_none=True)
- spin_type = tl.Unicode()
- electronic_type = tl.Unicode()
+ self._config_model.advanced.observe(
+ self._on_override_change,
+ "override",
+ )
+ self._config_model.advanced.observe(
+ self._on_kpoints_distance_change,
+ "kpoints_distance",
+ )
- # output dictionary
- value = tl.Dict()
+ self.smearing = SmearingSettings(model=config_model)
+ self.magnetization = MagnetizationSettings(model=config_model)
+ self.hubbard = HubbardSettings(model=config_model)
+ self.pseudos = PseudoSettings(model=config_model)
- def __init__(self, default_protocol=None, **kwargs):
- self._default_protocol = (
- default_protocol or DEFAULT_PARAMETERS["workchain"]["protocol"]
- )
+ def render(self):
+ if self.rendered:
+ return
# clean-up workchain settings
self.clean_workdir = ipw.Checkbox(
description="",
indent=False,
- value=False,
layout=ipw.Layout(max_width="20px"),
)
- self.clean_workdir_description = ipw.HTML(
- """
- Tick to clean-up the work directory after the calculation is finished.
"""
+ ipw.link(
+ (self._model, "clean_workdir"),
+ (self.clean_workdir, "value"),
)
-
# Override setting widget
- self.override_prompt = ipw.HTML(" Override ")
self.override = ipw.Checkbox(
description="",
indent=False,
- value=False,
layout=ipw.Layout(max_width="10%"),
)
- self.override.observe(self._override_changed, "value")
-
- self.override_widget = ipw.HBox(
- [self.override_prompt, self.override],
- layout=ipw.Layout(max_width="20%"),
- )
- # Smearing setting widget
- self.smearing = SmearingSettings()
- ipw.dlink(
+ ipw.link(
+ (self._model, "override"),
(self.override, "value"),
- (self.smearing, "disabled"),
- lambda override: not override,
)
- self.smearing.observe(
- self._callback_value_set, ["degauss_value", "smearing_value"]
+ ipw.dlink(
+ (self._config_model, "input_structure"),
+ (self.override, "disabled"),
+ lambda structure: structure is None,
)
+ # Smearing setting widget
+ self.smearing.render()
+
# Kpoints setting widget
self.kpoints_distance = ipw.BoundedFloatText(
min=0.0,
step=0.05,
description="K-points distance (1/Å):",
- disabled=False,
style={"description_width": "initial"},
)
- self.mesh_grid = ipw.HTML()
- self.create_kpoints_distance_link()
- self.kpoints_distance.observe(self._callback_value_set, "value")
-
- # Hubbard setting widget
- self.hubbard_widget = HubbardWidget()
+ ipw.link(
+ (self._model, "kpoints_distance"),
+ (self.kpoints_distance, "value"),
+ )
ipw.dlink(
(self.override, "value"),
- (self.hubbard_widget.activate_hubbard, "disabled"),
- lambda override: not override,
+ (self.kpoints_distance, "disabled"),
+ lambda override: not (override and self._config_model.has_pbc),
+ )
+ ipw.dlink(
+ (self._config_model, "has_pbc"),
+ (self.kpoints_distance, "disabled"),
+ lambda periodic: not (self._model.override and periodic),
+ )
+ self.mesh_grid = ipw.HTML()
+ ipw.dlink(
+ (self._model, "mesh_grid"),
+ (self.mesh_grid, "value"),
)
+
+ # Hubbard setting widget
+ self.hubbard.render()
+
# Total change setting widget
self.total_charge = ipw.BoundedFloatText(
min=-3,
max=3,
step=0.01,
- disabled=False,
description="Total charge:",
style={"description_width": "initial"},
)
+ ipw.link(
+ (self._model, "total_charge"),
+ (self.total_charge, "value"),
+ )
ipw.dlink(
- (self.override, "value"),
+ (self._model, "override"),
(self.total_charge, "disabled"),
lambda override: not override,
)
- self.total_charge.observe(self._callback_value_set, "value")
# Van der Waals setting widget
self.van_der_waals = ipw.Dropdown(
@@ -144,70 +131,95 @@ def __init__(self, default_protocol=None, **kwargs):
("Tkatchenko-Scheffler", "ts-vdw"),
],
description="Van der Waals correction:",
- value="none",
- disabled=False,
style={"description_width": "initial"},
)
-
+ ipw.link(
+ (self._model, "van_der_waals"),
+ (self.van_der_waals, "value"),
+ )
ipw.dlink(
- (self.override, "value"),
+ (self._model, "override"),
(self.van_der_waals, "disabled"),
lambda override: not override,
)
- self.magnetization = MagnetizationSettings()
- ipw.dlink(
- (self.override, "value"),
- (self.magnetization, "disabled"),
- lambda override: not override,
- )
+ # Magnetization settings
+ self.magnetization.render()
# Convergence Threshold settings
self.scf_conv_thr = ipw.BoundedFloatText(
min=1e-15,
max=1.0,
- step=1e-10,
description="SCF conv.:",
- disabled=False,
style={"description_width": "initial"},
)
- self.scf_conv_thr.observe(self._callback_value_set, "value")
+ ipw.link(
+ (self._model, "scf_conv_thr"),
+ (self.scf_conv_thr, "value"),
+ )
ipw.dlink(
- (self.override, "value"),
+ (self._model, "scf_conv_thr_step"),
+ (self.scf_conv_thr, "step"),
+ )
+ ipw.dlink(
+ (self._model, "override"),
(self.scf_conv_thr, "disabled"),
lambda override: not override,
)
self.forc_conv_thr = ipw.BoundedFloatText(
min=1e-15,
max=1.0,
- step=0.0001,
description="Force conv.:",
- disabled=False,
style={"description_width": "initial"},
)
- self.forc_conv_thr.observe(self._callback_value_set, "value")
+ ipw.link(
+ (self._model, "forc_conv_thr"),
+ (self.forc_conv_thr, "value"),
+ )
ipw.dlink(
- (self.override, "value"),
+ (self._model, "forc_conv_thr_step"),
+ (self.forc_conv_thr, "step"),
+ )
+ ipw.dlink(
+ (self._model, "override"),
(self.forc_conv_thr, "disabled"),
lambda override: not override,
)
self.etot_conv_thr = ipw.BoundedFloatText(
min=1e-15,
max=1.0,
- step=0.00001,
description="Energy conv.:",
- disabled=False,
style={"description_width": "initial"},
)
- self.etot_conv_thr.observe(self._callback_value_set, "value")
+ ipw.link(
+ (self._model, "etot_conv_thr"),
+ (self.etot_conv_thr, "value"),
+ )
ipw.dlink(
- (self.override, "value"),
+ (self._model, "etot_conv_thr_step"),
+ (self.etot_conv_thr, "step"),
+ )
+ ipw.dlink(
+ (self._model, "override"),
(self.etot_conv_thr, "disabled"),
lambda override: not override,
)
-
- # Max electron SCF steps widget
- self._create_electron_maxstep_widgets()
+ self.electron_maxstep = ipw.BoundedIntText(
+ min=20,
+ max=1000,
+ step=1,
+ description="Max. electron steps:",
+ style={"description_width": "initial"},
+ )
+ ipw.link(
+ (self._model, "electron_maxstep"),
+ (self.electron_maxstep, "value"),
+ )
+ ipw.dlink(
+ (self._model, "override"),
+ (self.electron_maxstep, "disabled"),
+ lambda override: not override,
+ )
# Spin-Orbit calculation
self.spin_orbit = ipw.ToggleButtons(
@@ -216,727 +228,103 @@ def __init__(self, default_protocol=None, **kwargs):
("On", "soc"),
],
description="Spin-Orbit:",
- value="wo_soc",
style={"description_width": "initial"},
)
+ ipw.link(
+ (self._model, "spin_orbit"),
+ (self.spin_orbit, "value"),
+ )
ipw.dlink(
- (self.override, "value"),
+ (self._model, "override"),
(self.spin_orbit, "disabled"),
lambda override: not override,
)
- self.pseudo_family_selector = PseudoFamilySelector()
- self.pseudo_setter = PseudoSetter()
- ipw.dlink(
- (self.pseudo_family_selector, "value"),
- (self.pseudo_setter, "pseudo_family"),
- )
- self.kpoints_distance.observe(self._display_mesh, "value")
+ self.pseudos.render()
- # Link with PseudoWidget
- ipw.dlink(
- (self.spin_orbit, "value"),
- (self.pseudo_family_selector, "spin_orbit"),
- )
self.children = [
- self.title,
+ ipw.HTML("""
+
+
Advanced Settings
+
+ """),
ipw.HBox(
- [self.clean_workdir, self.clean_workdir_description],
+ children=[
+ self.clean_workdir,
+ ipw.HTML("""
+
+ Tick to clean-up the work directory after the calculation is finished.
+
+ """),
+ ],
layout=ipw.Layout(height="50px", justify_content="flex-start"),
),
ipw.HBox(
- [self.pw_adv_description, self.override_widget],
+ children=[
+ ipw.HTML("""
+ Select the advanced settings for the pw.x code.
+ """),
+ ipw.HBox(
+ children=[
+ ipw.HTML(
+ value="Override",
+ layout=ipw.Layout(margin="0 5px 0 0"),
+ ),
+ self.override,
+ ],
+ layout=ipw.Layout(max_width="20%"),
+ ),
+ ],
layout=ipw.Layout(height="50px", justify_content="space-between"),
),
- # total charge setting widget
self.total_charge,
- # van der waals setting widget
self.van_der_waals,
- # magnetization setting widget
self.magnetization,
- # convergence threshold setting widget
ipw.HTML("Convergence Thresholds:"),
ipw.HBox(
- [self.forc_conv_thr, self.etot_conv_thr, self.scf_conv_thr],
+ children=[
+ self.forc_conv_thr,
+ self.etot_conv_thr,
+ self.scf_conv_thr,
+ ],
layout=ipw.Layout(height="50px", justify_content="flex-start"),
),
- # Max electron SCF steps widget
self.electron_maxstep,
- # smearing setting widget
self.smearing,
- # Kpoints setting widget
- self.kpoints_description,
- ipw.HBox([self.kpoints_distance, self.mesh_grid]),
- self.hubbard_widget,
- # Spin-Orbit calculation
+ ipw.HTML("""
+
+ The k-points mesh density of the SCF calculation is set by the
+ protocol. The value below represents the maximum distance
+ between the k-points in each direction of reciprocal space. Tick
+ the box to override the default, smaller is more accurate and
+ costly.
+
+ """),
+ ipw.HBox(
+ children=[
+ self.kpoints_distance,
+ self.mesh_grid,
+ ]
+ ),
+ self.hubbard,
self.spin_orbit,
- self.pseudo_family_selector,
- self.pseudo_setter,
+ self.pseudos,
]
- super().__init__(
- layout=ipw.Layout(justify_content="space-between"),
- **kwargs,
- )
-
- # Default settings to trigger the callback
- self.reset()
-
- def create_kpoints_distance_link(self):
- """Create the dlink for override and kpoints_distance."""
- self.kpoints_distance_link = ipw.dlink(
- (self.override, "value"),
- (self.kpoints_distance, "disabled"),
- lambda override: not override,
- )
-
- def remove_kpoints_distance_link(self):
- """Remove the kpoints_distance_link."""
- if hasattr(self, "kpoints_distance_link"):
- self.kpoints_distance_link.unlink()
- del self.kpoints_distance_link
-
- def _create_electron_maxstep_widgets(self):
- self.electron_maxstep = ipw.BoundedIntText(
- min=20,
- max=1000,
- step=1,
- value=80,
- description="Max. electron steps:",
- style={"description_width": "initial"},
- )
- ipw.dlink(
- (self.override, "value"),
- (self.electron_maxstep, "disabled"),
- lambda override: not override,
- )
- self.electron_maxstep.observe(self._callback_value_set, "value")
-
- def set_value_and_step(self, attribute, value):
- """
- Sets the value and adjusts the step based on the order of magnitude of the value.
- This is used for the thresolds values (etot_conv_thr, scf_conv_thr, forc_conv_thr).
- Parameters:
- attribute: The attribute whose values are to be set (e.g., self.etot_conv_thr).
- value: The numerical value to set.
- """
- attribute.value = value
- if value != 0:
- order_of_magnitude = np.floor(np.log10(abs(value)))
- attribute.step = 10 ** (order_of_magnitude - 1)
- else:
- attribute.step = 0.1 # Default step if value is zero
-
- def _override_changed(self, change):
- """Callback function to set the override value"""
- if change["new"] is False:
- # When override is set to False, reset the widget
- self.reset()
-
- @tl.observe("input_structure")
- def _update_input_structure(self, change):
- if self.input_structure is not None:
- self.magnetization._update_widget(change)
- self.pseudo_setter.structure = change["new"]
- self._update_settings_from_protocol(self.protocol)
- self._display_mesh()
- self.hubbard_widget.update_widgets(change["new"])
- if isinstance(self.input_structure, HubbardStructureData):
- self.override.value = True
- if self.input_structure.pbc == (False, False, False):
- self.kpoints_distance.value = 100.0
- self.kpoints_distance.disabled = True
- if hasattr(self, "kpoints_distance_link"):
- self.remove_kpoints_distance_link()
- else:
- # self.kpoints_distance.disabled = False
- if not hasattr(self, "kpoints_distance_link"):
- self.create_kpoints_distance_link()
- else:
- self.magnetization.input_structure = None
- self.pseudo_setter.structure = None
- self.hubbard_widget.update_widgets(None)
- self.kpoints_distance.disabled = False
- if not hasattr(self, "kpoints_distance_link"):
- self.create_kpoints_distance_link()
-
- @tl.observe("electronic_type")
- def _electronic_type_changed(self, change):
- """Input electronic_type changed, update the widget values."""
- self.magnetization.electronic_type = change["new"]
-
- @tl.observe("protocol")
- def _protocol_changed(self, _):
- """Input protocol changed, update the widget values."""
- self._update_settings_from_protocol(self.protocol)
-
- def _update_settings_from_protocol(self, protocol):
- """Update the values of sub-widgets from the given protocol, this will
- trigger the callback of the sub-widget if it is exist.
- """
- self.smearing.protocol = protocol
- self.pseudo_family_selector.protocol = protocol
-
- parameters = PwBaseWorkChain.get_protocol_inputs(protocol)
-
- if self.input_structure:
- if self.input_structure.pbc == (False, False, False):
- self.kpoints_distance.value = 100.0
- self.kpoints_distance.disabled = True
- else:
- self.kpoints_distance.value = parameters["kpoints_distance"]
- else:
- self.kpoints_distance.value = parameters["kpoints_distance"]
- num_atoms = len(self.input_structure.sites) if self.input_structure else 1
-
- etot_value = num_atoms * parameters["meta_parameters"]["etot_conv_thr_per_atom"]
- self.set_value_and_step(self.etot_conv_thr, etot_value)
-
- # Set SCF conversion threshold
- scf_value = num_atoms * parameters["meta_parameters"]["conv_thr_per_atom"]
- self.set_value_and_step(self.scf_conv_thr, scf_value)
-
- # Set force conversion threshold
- forc_value = parameters["pw"]["parameters"]["CONTROL"]["forc_conv_thr"]
- self.set_value_and_step(self.forc_conv_thr, forc_value)
-
- # The pseudo_family read from the protocol (aiida-quantumespresso plugin settings)
- # we override it with the value from the pseudo_family_selector widget
- parameters["pseudo_family"] = self.pseudo_family_selector.value
-
- def _callback_value_set(self, _=None):
- """Callback function to set the parameters"""
- settings = {
- "kpoints_distance": self.kpoints_distance.value,
- "total_charge": self.total_charge.value,
- "degauss": self.smearing.degauss_value,
- "smearing": self.smearing.smearing_value,
- }
-
- self.update_settings(**settings)
-
- def update_settings(self, **kwargs):
- """Set the output dict from the given keyword arguments.
- This function will only update the traitlets but not the widget value.
-
- This function can also be used to set values directly for testing purpose.
- """
- self.value = kwargs
-
- def get_panel_value(self):
- # create the the initial_magnetic_moments as None (Default)
- # XXX: start from parameters = {} and then bundle the settings by purposes (e.g. pw, bands, etc.)
- parameters = {
- "initial_magnetic_moments": None,
- "pw": {
- "parameters": {
- "SYSTEM": {},
- "CONTROL": {},
- "ELECTRONS": {},
- }
- },
- "clean_workdir": self.clean_workdir.value,
- "pseudo_family": self.pseudo_family_selector.value,
- "kpoints_distance": self.value.get("kpoints_distance"),
- }
-
- # Set total charge
- parameters["pw"]["parameters"]["SYSTEM"]["tot_charge"] = self.total_charge.value
-
- if self.hubbard_widget.activate_hubbard.value:
- parameters["hubbard_parameters"] = self.hubbard_widget.hubbard_dict
- if self.hubbard_widget.eigenvalues_label.value:
- parameters["pw"]["parameters"]["SYSTEM"].update(
- self.hubbard_widget.eigenvalues_dict
- )
-
- # add clean_workdir to the parameters
- parameters["clean_workdir"] = self.clean_workdir.value
-
- # add the pseudo_family to the parameters
- parameters["pseudo_family"] = self.pseudo_family_selector.value
- if self.pseudo_setter.pseudos:
- parameters["pw"]["pseudos"] = self.pseudo_setter.pseudos
- parameters["pw"]["parameters"]["SYSTEM"]["ecutwfc"] = (
- self.pseudo_setter.ecutwfc
- )
- parameters["pw"]["parameters"]["SYSTEM"]["ecutrho"] = (
- self.pseudo_setter.ecutrho
- )
-
- if self.van_der_waals.value in ["none", "ts-vdw"]:
- parameters["pw"]["parameters"]["SYSTEM"]["vdw_corr"] = (
- self.van_der_waals.value
- )
- else:
- parameters["pw"]["parameters"]["SYSTEM"]["vdw_corr"] = "dft-d3"
- parameters["pw"]["parameters"]["SYSTEM"]["dftd3_version"] = (
- self.dftd3_version[self.van_der_waals.value]
- )
-
- # there are two choose, use link or parent
- if self.spin_type == "collinear":
- parameters["initial_magnetic_moments"] = (
- self.magnetization.get_magnetization()
- )
- parameters["kpoints_distance"] = self.value.get("kpoints_distance")
- if self.electronic_type == "metal":
- # smearing type setting
- parameters["pw"]["parameters"]["SYSTEM"]["smearing"] = (
- self.smearing.smearing_value
- )
- # smearing degauss setting
- parameters["pw"]["parameters"]["SYSTEM"]["degauss"] = (
- self.smearing.degauss_value
- )
-
- # Set tot_magnetization for collinear simulations.
- if self.spin_type == "collinear":
- # Conditions for metallic systems. Select the magnetization type and set the value if override is True
- if self.electronic_type == "metal" and self.override.value is True:
- self.set_metallic_magnetization(parameters)
- # Conditions for insulator systems. Default value is 0.0
- elif self.electronic_type == "insulator":
- self.set_insulator_magnetization(parameters)
-
- # convergence threshold setting
- parameters["pw"]["parameters"]["CONTROL"]["forc_conv_thr"] = (
- self.forc_conv_thr.value
- )
- parameters["pw"]["parameters"]["ELECTRONS"]["conv_thr"] = (
- self.scf_conv_thr.value
- )
- parameters["pw"]["parameters"]["CONTROL"]["etot_conv_thr"] = (
- self.etot_conv_thr.value
- )
-
- # Max electron SCF steps
- parameters["pw"]["parameters"]["ELECTRONS"]["electron_maxstep"] = (
- self.electron_maxstep.value
- )
-
- # Spin-Orbit calculation
- if self.spin_orbit.value == "soc":
- parameters["pw"]["parameters"]["SYSTEM"]["lspinorb"] = True
- parameters["pw"]["parameters"]["SYSTEM"]["noncolin"] = True
- parameters["pw"]["parameters"]["SYSTEM"]["nspin"] = 4
-
- return parameters
-
- def set_insulator_magnetization(self, parameters):
- """Set the parameters for collinear insulator calculation. Total magnetization."""
- parameters["pw"]["parameters"]["SYSTEM"]["tot_magnetization"] = (
- self.magnetization.tot_magnetization.value
- )
-
- def set_metallic_magnetization(self, parameters):
- """Set the parameters for magnetization calculation in metals"""
- magnetization_type = self.magnetization.magnetization_type.value
- if magnetization_type == "tot_magnetization":
- parameters["pw"]["parameters"]["SYSTEM"]["tot_magnetization"] = (
- self.magnetization.tot_magnetization.value
- )
- else:
- parameters["initial_magnetic_moments"] = (
- self.magnetization.get_magnetization()
- )
-
- def set_panel_value(self, parameters):
- """Set the panel value from the given parameters."""
-
- if "pseudo_family" in parameters:
- pseudo_family_string = parameters["pseudo_family"]
- self.pseudo_family_selector.load_from_pseudo_family(
- PseudoFamily.from_string(pseudo_family_string)
- )
- if "pseudos" in parameters["pw"]:
- self.pseudo_setter.set_pseudos(parameters["pw"]["pseudos"], {})
- self.pseudo_setter.ecutwfc_setter.value = parameters["pw"]["parameters"][
- "SYSTEM"
- ]["ecutwfc"]
- self.pseudo_setter.ecutrho_setter.value = parameters["pw"]["parameters"][
- "SYSTEM"
- ]["ecutrho"]
- #
- self.kpoints_distance.value = parameters.get("kpoints_distance", 0.15)
- if parameters.get("pw") is not None:
- system = parameters["pw"]["parameters"]["SYSTEM"]
- if "degauss" in system:
- self.smearing.degauss.value = system["degauss"]
- if "smearing" in system:
- self.smearing.smearing.value = system["smearing"]
- self.total_charge.value = parameters["pw"]["parameters"]["SYSTEM"].get(
- "tot_charge", 0
- )
- if "lspinorb" in system:
- self.spin_orbit.value = "soc"
- else:
- self.spin_orbit.value = "wo_soc"
- # van der waals correction
- self.van_der_waals.value = self.dftd3_version.get(
- system.get("dftd3_version"),
- parameters["pw"]["parameters"]["SYSTEM"].get("vdw_corr", "none"),
- )
-
- # convergence threshold setting
- self.forc_conv_thr.value = (
- parameters.get("pw", {})
- .get("parameters", {})
- .get("CONTROL", {})
- .get("forc_conv_thr", 0.0)
- )
- self.etot_conv_thr.value = (
- parameters.get("pw", {})
- .get("parameters", {})
- .get("CONTROL", {})
- .get("etot_conv_thr", 0.0)
- )
- self.scf_conv_thr.value = (
- parameters.get("pw", {})
- .get("parameters", {})
- .get("ELECTRONS", {})
- .get("conv_thr", 0.0)
- )
-
- # Max electron SCF steps
- self.electron_maxstep.value = (
- parameters.get("pw", {})
- .get("parameters", {})
- .get("ELECTRONS", {})
- .get("electron_maxstep", 80)
- )
-
- # Logic to set the magnetization
- if parameters.get("initial_magnetic_moments"):
- self.magnetization._set_magnetization_values(
- parameters.get("initial_magnetic_moments")
- )
-
- if "tot_magnetization" in parameters["pw"]["parameters"]["SYSTEM"]:
- self.magnetization.magnetization_type.value = "tot_magnetization"
- self.magnetization._set_tot_magnetization(
- parameters["pw"]["parameters"]["SYSTEM"]["tot_magnetization"]
- )
-
- if parameters.get("hubbard_parameters"):
- self.hubbard_widget.activate_hubbard.value = True
- self.hubbard_widget.set_hubbard_widget(
- parameters["hubbard_parameters"]["hubbard_u"]
- )
- starting_ns_eigenvalue = (
- parameters.get("pw", {})
- .get("parameters", {})
- .get("SYSTEM", {})
- .get("starting_ns_eigenvalue")
- )
-
- if starting_ns_eigenvalue is not None:
- self.hubbard_widget.eigenvalues_label.value = True
- self.hubbard_widget.set_eigenvalues_widget(starting_ns_eigenvalue)
+ self.rendered = True
def reset(self):
- """Reset the widget and the traitlets"""
-
with self.hold_trait_notifications():
- # Reset protocol dependent settings
- self._update_settings_from_protocol(self.protocol)
-
- # reset the pseudo family
- self.pseudo_family_selector.reset()
-
- # reset total charge
- self.total_charge.value = DEFAULT_PARAMETERS["advanced"]["tot_charge"]
-
- # reset the van der waals correction
- self.van_der_waals.value = DEFAULT_PARAMETERS["advanced"]["vdw_corr"]
-
- # reset the override checkbox
- self.override.value = False
+ self._model.reset()
self.smearing.reset()
- # reset the pseudo setter
- if self.input_structure is None:
- self.pseudo_setter.structure = None
- self.pseudo_setter._reset()
- else:
- self.pseudo_setter._reset()
- if self.input_structure.pbc == (False, False, False):
- self.kpoints_distance.value = 100.0
- self.kpoints_distance.disabled = True
-
- # reset the magnetization
+ self.hubbard.reset()
self.magnetization.reset()
- # reset the hubbard widget
- self.hubbard_widget.reset()
- # reset mesh grid
- if self.input_structure is None:
- self.mesh_grid.value = " "
-
- def _display_mesh(self, _=None):
- if self.input_structure is None:
- return
- if self.kpoints_distance.value > 0:
- # To avoid creating an aiida node every time we change the kpoints_distance,
- # we use the function itself instead of the decorated calcfunction.
- mesh = create_kpoints_from_distance.process_class._func(
- self.input_structure,
- orm.Float(self.kpoints_distance.value),
- orm.Bool(False),
- )
- self.mesh_grid.value = "Mesh " + str(mesh.get_kpoints_mesh()[0])
- else:
- self.mesh_grid.value = "Please select a number higher than 0.0"
-
-
-class MagnetizationSettings(ipw.VBox):
- """Widget to set the type of magnetization used in the calculation:
- 1) Tot_magnetization: Total majority spin charge - minority spin charge.
- 2) Starting magnetization: Starting spin polarization on atomic type 'i' in a spin polarized (LSDA or noncollinear/spin-orbit) calculation.
-
- For Starting magnetization you can set each kind names defined in the StructureData (StructureDtaa.get_kind_names())
- Usually these are the names of the elements in the StructureData
- (For example 'C' , 'N' , 'Fe' . However the StructureData can have defined kinds like 'Fe1' and 'Fe2')
- The widget generate a dictionary that can be used to set initial_magnetic_moments in the builder of PwBaseWorkChain
-
- Attributes:
- input_structure(StructureData): trait that containes the input_strucgure (confirmed structure from previous step)
- """
+ self.pseudos.reset()
+ self._model.update()
- input_structure = tl.Instance(orm.StructureData, allow_none=True)
- electronic_type = tl.Unicode()
- disabled = tl.Bool()
- _DEFAULT_TOT_MAGNETIZATION = 0.0
- _DEFAULT_DESCRIPTION = "Magnetization: Input structure not confirmed"
+ def _on_kpoints_distance_change(self, _=None):
+ self._model.update_kpoints_mesh()
- def __init__(self, **kwargs):
- self.input_structure = orm.StructureData()
- self.input_structure_labels = []
- self.tot_magnetization = ipw.BoundedIntText(
- min=0,
- max=100,
- step=1,
- value=self._DEFAULT_TOT_MAGNETIZATION,
- disabled=True,
- description="Total magnetization:",
- style={"description_width": "initial"},
- )
- self.magnetization_type = ipw.ToggleButtons(
- options=[
- ("Starting Magnetization", "starting_magnetization"),
- ("Tot. Magnetization", "tot_magnetization"),
- ],
- value="starting_magnetization",
- style={"description_width": "initial"},
- )
- self.description = ipw.HTML(self._DEFAULT_DESCRIPTION)
- self.kinds = self.create_kinds_widget()
- self.kinds_widget_out = ipw.Output()
- self.magnetization_out = ipw.Output()
- self.magnetization_type.observe(self._render, "value")
- super().__init__(
- children=[
- self.description,
- self.magnetization_out,
- self.kinds_widget_out,
- ],
- layout=ipw.Layout(justify_content="space-between"),
- **kwargs,
- )
-
- @tl.observe("disabled")
- def _disabled_changed(self, _):
- """Disable the widget"""
- if hasattr(self.kinds, "children") and self.kinds.children:
- for i in range(len(self.kinds.children)):
- self.kinds.children[i].disabled = self.disabled
- self.tot_magnetization.disabled = self.disabled
- self.magnetization_type.disabled = self.disabled
-
- def reset(self):
- self.disabled = True
- self.tot_magnetization.value = self._DEFAULT_TOT_MAGNETIZATION
- #
- if self.input_structure is None:
- self.description.value = self._DEFAULT_DESCRIPTION
- self.kinds = None
- else:
- self.description.value = "Magnetization"
- self.kinds = self.create_kinds_widget()
-
- def create_kinds_widget(self):
- if self.input_structure_labels:
- widgets_list = []
- for kind_label in self.input_structure_labels:
- kind_widget = ipw.BoundedFloatText(
- description=kind_label,
- min=-4,
- max=4,
- step=0.1,
- value=0.0,
- disabled=True,
- )
- widgets_list.append(kind_widget)
- kinds_widget = ipw.VBox(widgets_list)
- else:
- kinds_widget = None
-
- return kinds_widget
-
- @tl.observe("electronic_type")
- def _electronic_type_changed(self, change):
- with self.magnetization_out:
- clear_output()
- if change["new"] == "metal":
- display(self.magnetization_type)
- self._render({"new": self.magnetization_type.value})
- else:
- display(self.tot_magnetization)
- with self.kinds_widget_out:
- clear_output()
-
- def update_kinds_widget(self):
- self.input_structure_labels = self.input_structure.get_kind_names()
- self.kinds = self.create_kinds_widget()
- self.description.value = "Magnetization"
-
- def _render(self, value):
- if value["new"] == "tot_magnetization":
- with self.kinds_widget_out:
- clear_output()
- display(self.tot_magnetization)
- else:
- self.display_kinds()
-
- def display_kinds(self):
- if "PYTEST_CURRENT_TEST" not in os.environ and self.kinds:
- with self.kinds_widget_out:
- clear_output()
- display(self.kinds)
-
- def _update_widget(self, change):
- self.input_structure = change["new"]
- self.update_kinds_widget()
- self.display_kinds()
-
- def get_magnetization(self):
- """Method to generate the dictionary with the initial magnetic moments"""
- magnetization = {}
- for i in range(len(self.kinds.children)):
- magnetization[self.input_structure_labels[i]] = self.kinds.children[i].value
- return magnetization
-
- def _set_magnetization_values(self, magnetic_moments):
- """Set magnetization"""
- # self.override.value = True
- with self.hold_trait_notifications():
- for i in range(len(self.kinds.children)):
- if isinstance(magnetic_moments, dict):
- self.kinds.children[i].value = magnetic_moments.get(
- self.kinds.children[i].description, 0.0
- )
- else:
- self.kinds.children[i].value = magnetic_moments
-
- def _set_tot_magnetization(self, tot_magnetization):
- """Set the total magnetization"""
- self.tot_magnetization.value = tot_magnetization
-
-
-class SmearingSettings(ipw.VBox):
- # accept protocol as input and set the values
- protocol = tl.Unicode(allow_none=True)
-
- # The output of the widget is a dictionary with the values of smearing and degauss
- degauss_value = tl.Float()
- smearing_value = tl.Unicode()
-
- smearing_description = ipw.HTML(
- """
- The smearing type and width is set by the chosen protocol.
- Tick the box to override the default, not advised unless you've mastered smearing effects (click here for a discussion).
-
"""
- )
- disabled = tl.Bool()
-
- def __init__(self, default_protocol=None, **kwargs):
- self._default_protocol = (
- default_protocol or DEFAULT_PARAMETERS["workchain"]["protocol"]
- )
-
- self.smearing = ipw.Dropdown(
- options=["cold", "gaussian", "fermi-dirac", "methfessel-paxton"],
- description="Smearing type:",
- disabled=False,
- style={"description_width": "initial"},
- )
- self.degauss = ipw.FloatText(
- step=0.005,
- description="Smearing width (Ry):",
- disabled=False,
- style={"description_width": "initial"},
- )
- ipw.dlink(
- (self, "disabled"),
- (self.degauss, "disabled"),
- )
- ipw.dlink(
- (self, "disabled"),
- (self.smearing, "disabled"),
- )
- self.degauss.observe(self._callback_value_set, "value")
- self.smearing.observe(self._callback_value_set, "value")
-
- super().__init__(
- children=[
- self.smearing_description,
- ipw.HBox([self.smearing, self.degauss]),
- ],
- layout=ipw.Layout(justify_content="space-between"),
- **kwargs,
- )
-
- # Default settings to trigger the callback
- self.protocol = self._default_protocol
-
- @tl.default("disabled")
- def _default_disabled(self):
- return False
-
- @tl.observe("protocol")
- def _protocol_changed(self, _):
- """Input protocol changed, update the widget values."""
- self._update_settings_from_protocol(self.protocol)
-
- def _update_settings_from_protocol(self, protocol):
- """Update the widget values from the given protocol, and trigger the callback."""
- parameters = PwBaseWorkChain.get_protocol_inputs(protocol)["pw"]["parameters"][
- "SYSTEM"
- ]
-
- with self.hold_trait_notifications():
- # This changes will trigger callbacks
- self.degauss.value = parameters["degauss"]
- self.smearing.value = parameters["smearing"]
-
- def _callback_value_set(self, _=None):
- """callback function to set the smearing and degauss values"""
- settings = {
- "degauss": self.degauss.value,
- "smearing": self.smearing.value,
- }
- self.update_settings(**settings)
-
- def update_settings(self, **kwargs):
- """Set the output dict from the given keyword arguments.
- This function will only update the traitlets but not the widget value.
- """
- self.degauss_value = kwargs.get("degauss")
- self.smearing_value = kwargs.get("smearing")
-
- def reset(self):
- """Reset the widget and the traitlets"""
- self.protocol = self._default_protocol
-
- with self.hold_trait_notifications():
- self._update_settings_from_protocol(self.protocol)
- self.disabled = True
+ def _on_override_change(self, change):
+ if not change["new"]:
+ self.reset()
diff --git a/src/aiidalab_qe/app/configuration/hubbard.py b/src/aiidalab_qe/app/configuration/hubbard.py
new file mode 100644
index 000000000..4e7afe3f3
--- /dev/null
+++ b/src/aiidalab_qe/app/configuration/hubbard.py
@@ -0,0 +1,265 @@
+import ipywidgets as ipw
+
+from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData
+from aiidalab_qe.common.widgets import LoadingWidget
+
+from .model import ConfigurationModel
+
+
+class HubbardSettings(ipw.VBox):
+ """Widget for setting up Hubbard parameters."""
+
+ def __init__(self, model: ConfigurationModel, **kwargs):
+ self.loading_message = LoadingWidget("Loading Hubbard settings")
+
+ super().__init__(
+ children=[self.loading_message],
+ **kwargs,
+ )
+
+ self._model = model
+ self._model.observe(
+ self._on_input_structure_change,
+ "input_structure",
+ )
+ self._model.advanced.hubbard.observe(
+ self._on_hubbard_activation,
+ "is_active",
+ )
+ self._model.advanced.hubbard.observe(
+ self._on_eigenvalues_definition,
+ "has_eigenvalues",
+ )
+
+ self.links = []
+ self.eigenvalues_widget_links = []
+
+ self.rendered = False
+ self.updated = False
+
+ def render(self):
+ if self.rendered:
+ return
+
+ self.activate_hubbard_checkbox = ipw.Checkbox(
+ description="",
+ indent=False,
+ layout=ipw.Layout(max_width="10%"),
+ )
+ ipw.link(
+ (self._model.advanced.hubbard, "is_active"),
+ (self.activate_hubbard_checkbox, "value"),
+ )
+ ipw.dlink(
+ (self._model.advanced, "override"),
+ (self.activate_hubbard_checkbox, "disabled"),
+ lambda override: not override,
+ )
+
+ self.eigenvalues_help = ipw.HTML(
+ value="For transition metals and lanthanoids, the starting eigenvalues can be defined (Magnetic calculation).",
+ layout=ipw.Layout(width="auto"),
+ )
+ self.define_eigenvalues_checkbox = ipw.Checkbox(
+ description="Define eigenvalues",
+ indent=False,
+ layout=ipw.Layout(max_width="30%"),
+ )
+ ipw.link(
+ (self._model.advanced.hubbard, "has_eigenvalues"),
+ (self.define_eigenvalues_checkbox, "value"),
+ )
+
+ self.hubbard_widget = ipw.VBox()
+ self.eigenvalues_widget = ipw.VBox()
+
+ self.container = ipw.VBox()
+
+ self.children = [
+ ipw.HBox(
+ children=[
+ ipw.HTML("Hubbard (DFT+U)"),
+ self.activate_hubbard_checkbox,
+ ]
+ ),
+ self.container,
+ ]
+
+ self.rendered = True
+
+ self.update()
+
+ def update(self):
+ if not self.updated:
+ self._update()
+ self.updated = True
+
+ def reset(self):
+ if not self._model.input_structure:
+ self._unsubscribe()
+ self._model.advanced.hubbard.reset()
+ self.updated = False
+
+ def _on_input_structure_change(self, _):
+ self._unsubscribe()
+ self._model.advanced.hubbard.update()
+ self._build_hubbard_widget(rebuild=True)
+ if isinstance(self._model.input_structure, HubbardStructureData):
+ self._model.advanced.hubbard.set_parameters_from_hubbard_structure()
+
+ def _on_hubbard_activation(self, _):
+ self._model.advanced.hubbard.update()
+ self._build_hubbard_widget()
+ self._toggle_hubbard_widget()
+
+ def _on_eigenvalues_definition(self, _):
+ self._toggle_eigenvalues_widget()
+
+ def _update(self, rebuild=False):
+ self._unsubscribe()
+ self._show_loading()
+ self._model.advanced.hubbard.update()
+ self._build_hubbard_widget(rebuild=rebuild)
+ self._toggle_hubbard_widget()
+ self._toggle_eigenvalues_widget()
+
+ def _show_loading(self):
+ if self.rendered:
+ self.hubbard_widget.children = [self.loading_message]
+
+ def _build_hubbard_widget(self, rebuild=False):
+ if not self.rendered or len(self.hubbard_widget.children) > 1 and not rebuild:
+ return
+
+ children = []
+
+ if self._model.input_structure and self._model.advanced.hubbard.is_active:
+ children.append(ipw.HTML("Define U value [eV] "))
+
+ for label in self._model.advanced.hubbard.orbital_labels:
+ float_widget = ipw.BoundedFloatText(
+ description=label,
+ min=0,
+ max=20,
+ step=0.1,
+ layout={"width": "160px"},
+ )
+ link = ipw.link(
+ (self._model.advanced.hubbard, "parameters"),
+ (float_widget, "value"),
+ [
+ lambda p, label=label: p.get(label, 0.0),
+ lambda v, label=label: {
+ **self._model.advanced.hubbard.parameters,
+ label: v,
+ },
+ ],
+ )
+ self.links.append(link)
+ children.append(float_widget)
+
+ if self._model.advanced.hubbard.needs_eigenvalues_widget:
+ children.extend(
+ [
+ self.eigenvalues_help,
+ self.define_eigenvalues_checkbox,
+ ]
+ )
+
+ self.hubbard_widget.children = children
+
+ if self._model.advanced.hubbard.needs_eigenvalues_widget:
+ self._build_eigenvalues_widget()
+ else:
+ self.eigenvalues_widget.children = []
+
+ def _build_eigenvalues_widget(self):
+ def update(index, spin, state, symbol, value):
+ """Update the eigenvalues list."""
+ eigenvalues = [*self._model.advanced.hubbard.eigenvalues]
+ eigenvalues[index][spin][state] = [state + 1, spin, symbol, value]
+ return eigenvalues
+
+ children = []
+
+ for ei, element in enumerate(self._model.advanced.hubbard.applicable_elements):
+ es = element.symbol
+ num_states = 5 if element.is_transition_metal else 7 # d or f states
+
+ label_layout = ipw.Layout(justify_content="flex-start", width="50px")
+ spin_up_row = ipw.HBox([ipw.Label("Up:", layout=label_layout)])
+ spin_down_row = ipw.HBox([ipw.Label("Down:", layout=label_layout)])
+
+ for si in range(num_states):
+ eigenvalues_up = ipw.Dropdown(
+ description=f"{si+1}",
+ options=["-1", "0", "1"],
+ layout=ipw.Layout(width="65px"),
+ style={"description_width": "initial"},
+ )
+ link = ipw.link(
+ (self._model.advanced.hubbard, "eigenvalues"),
+ (eigenvalues_up, "value"),
+ [
+ lambda evs, ei=ei, si=si: str(evs[ei][0][si][-1]),
+ lambda v, ei=ei, si=si, es=es: update(ei, 0, si, es, float(v)),
+ ],
+ )
+ self.links.append(link)
+ spin_up_row.children += (eigenvalues_up,)
+
+ eigenvalues_down = ipw.Dropdown(
+ description=f"{si+1}",
+ options=["-1", "0", "1"],
+ layout=ipw.Layout(width="65px"),
+ style={"description_width": "initial"},
+ )
+ link = ipw.link(
+ (self._model.advanced.hubbard, "eigenvalues"),
+ (eigenvalues_down, "value"),
+ [
+ lambda evs, ei=ei, si=si: str(evs[ei][1][si][-1]),
+ lambda v, ei=ei, si=si, es=es: update(ei, 1, si, es, float(v)),
+ ],
+ )
+ self.links.append(link)
+ spin_down_row.children += (eigenvalues_down,)
+
+ children.append(
+ ipw.HBox(
+ [
+ ipw.Label(element.symbol, layout=label_layout),
+ ipw.VBox(
+ children=[
+ spin_up_row,
+ spin_down_row,
+ ]
+ ),
+ ]
+ )
+ )
+
+ self.eigenvalues_widget.children = children
+
+ def _toggle_hubbard_widget(self):
+ if not self.rendered:
+ return
+ widget = [self.hubbard_widget] if self._model.advanced.hubbard.is_active else []
+ self.container.children = widget
+
+ def _toggle_eigenvalues_widget(self):
+ if not self.rendered:
+ return
+ self.hubbard_widget.children = (
+ [
+ *self.hubbard_widget.children,
+ self.eigenvalues_widget,
+ ]
+ if self._model.advanced.hubbard.has_eigenvalues
+ else [*self.hubbard_widget.children][:-1]
+ )
+
+ def _unsubscribe(self):
+ for link in self.links:
+ link.unlink()
+ self.links.clear()
diff --git a/src/aiidalab_qe/app/configuration/magnetization.py b/src/aiidalab_qe/app/configuration/magnetization.py
new file mode 100644
index 000000000..d7f55a491
--- /dev/null
+++ b/src/aiidalab_qe/app/configuration/magnetization.py
@@ -0,0 +1,212 @@
+import ipywidgets as ipw
+
+from aiidalab_qe.common.widgets import LoadingWidget
+
+from .model import ConfigurationModel
+
+
+class MagnetizationSettings(ipw.VBox):
+ """Widget to set the type of magnetization used in the calculation:
+ 1) Tot_magnetization: Total majority spin charge - minority spin charge.
+ 2) Starting magnetization: Starting spin polarization on atomic type 'i' in a spin polarized (LSDA or noncollinear/spin-orbit) calculation.
+
+ For Starting magnetization you can set each kind names defined in the StructureData (StructureData.get_kind_names())
+ Usually these are the names of the elements in the StructureData
+ (For example 'C' , 'N' , 'Fe' . However the StructureData can have defined kinds like 'Fe1' and 'Fe2')
+ The widget generate a dictionary that can be used to set initial_magnetic_moments in the builder of PwBaseWorkChain
+
+ Attributes:
+ input_structure(StructureData): trait that contains the input_structure (confirmed structure from previous step)
+ """
+
+ def __init__(self, model: ConfigurationModel, **kwargs):
+ self.loading_message = LoadingWidget("Loading magnetization settings")
+
+ super().__init__(
+ layout={"justify_content": "space-between", **kwargs.get("layout", {})},
+ children=[self.loading_message],
+ **kwargs,
+ )
+
+ self._model = model
+ self._model.observe(
+ self._on_input_structure_change,
+ "input_structure",
+ )
+ self._model.workchain.observe(
+ self._on_electronic_type_change,
+ "electronic_type",
+ )
+ self._model.workchain.observe(
+ self._on_spin_type_change,
+ "spin_type",
+ )
+ self._model.advanced.magnetization.observe(
+ self._on_magnetization_type_change,
+ "type",
+ )
+
+ self.links = []
+
+ self.rendered = False
+ self.updated = False
+
+ def render(self):
+ if self.rendered:
+ return
+
+ self.description = ipw.HTML("Magnetization:")
+
+ self.magnetization_type_toggle = ipw.ToggleButtons(
+ options=[
+ ("Starting Magnetization", "starting_magnetization"),
+ ("Tot. Magnetization", "tot_magnetization"),
+ ],
+ style={"description_width": "initial"},
+ )
+ ipw.link(
+ (self._model.advanced.magnetization, "type"),
+ (self.magnetization_type_toggle, "value"),
+ )
+ ipw.dlink(
+ (self._model.advanced, "override"),
+ (self.magnetization_type_toggle, "disabled"),
+ lambda override: not override,
+ )
+
+ self.tot_magnetization = ipw.BoundedIntText(
+ min=0,
+ max=100,
+ step=1,
+ disabled=True,
+ description="Total magnetization:",
+ style={"description_width": "initial"},
+ )
+ ipw.link(
+ (self._model.advanced.magnetization, "total"),
+ (self.tot_magnetization, "value"),
+ )
+ ipw.dlink(
+ (self._model.advanced, "override"),
+ (self.tot_magnetization, "disabled"),
+ lambda override: not override,
+ )
+
+ self.kinds_widget = ipw.VBox()
+
+ self.container = ipw.VBox(
+ children=[
+ self.tot_magnetization,
+ ]
+ )
+
+ self.children = [self.description]
+
+ self.rendered = True
+
+ self.update()
+
+ def update(self):
+ if not self.updated:
+ self._update()
+ self.updated = True
+
+ def reset(self):
+ if not self._model.input_structure:
+ self._unsubscribe()
+ self._model.advanced.magnetization.reset()
+ self.updated = False
+
+ def _on_input_structure_change(self, _):
+ self._model.advanced.magnetization.update()
+ self._update(rebuild=True)
+
+ def _on_electronic_type_change(self, _):
+ self._switch_widgets()
+
+ def _on_spin_type_change(self, _):
+ self._update()
+
+ def _on_magnetization_type_change(self, _):
+ self._toggle_widgets()
+
+ def _update(self, rebuild=False):
+ self._unsubscribe()
+ self._show_loading()
+ self._model.advanced.magnetization.update()
+ self._build_kinds_widget(rebuild=rebuild)
+ self._switch_widgets()
+ self._toggle_widgets()
+
+ def _show_loading(self):
+ if self.rendered:
+ self.kinds_widget.children = [self.loading_message]
+
+ def _build_kinds_widget(self, rebuild=False):
+ if not self.rendered or len(self.kinds_widget.children) > 0 and not rebuild:
+ return
+
+ children = []
+
+ if (
+ self._model.workchain.spin_type == "none"
+ or self._model.input_structure is None
+ ):
+ labels = []
+ else:
+ labels = self._model.input_structure.get_kind_names()
+
+ for label in labels:
+ kind_widget = ipw.BoundedFloatText(
+ description=label,
+ min=-4,
+ max=4,
+ step=0.1,
+ disabled=True,
+ )
+ link = ipw.link(
+ (self._model.advanced.magnetization, "moments"),
+ (kind_widget, "value"),
+ [
+ lambda d, label=label: d.get(label, 0.0),
+ lambda v, label=label: {
+ **self._model.advanced.magnetization.moments,
+ label: v,
+ },
+ ],
+ )
+ self.links.append(link)
+ ipw.dlink(
+ (self._model.advanced, "override"),
+ (kind_widget, "disabled"),
+ lambda override: not override,
+ )
+ children.append(kind_widget)
+
+ self.kinds_widget.children = children
+
+ def _switch_widgets(self):
+ if not self.rendered:
+ return
+ if self._model.workchain.spin_type == "none":
+ children = []
+ else:
+ children = [self.description]
+ if self._model.workchain.electronic_type == "metal":
+ children.extend([self.magnetization_type_toggle, self.container])
+ else:
+ children.append(self.tot_magnetization)
+ self.children = children
+
+ def _toggle_widgets(self):
+ if self._model.workchain.spin_type == "none" or not self.rendered:
+ return
+ if self._model.advanced.magnetization.type == "tot_magnetization":
+ self.container.children = [self.tot_magnetization]
+ else:
+ self.container.children = [self.kinds_widget]
+
+ def _unsubscribe(self):
+ for link in self.links:
+ link.unlink()
+ self.links.clear()
diff --git a/src/aiidalab_qe/app/configuration/model.py b/src/aiidalab_qe/app/configuration/model.py
new file mode 100644
index 000000000..984a4c168
--- /dev/null
+++ b/src/aiidalab_qe/app/configuration/model.py
@@ -0,0 +1,1075 @@
+from copy import deepcopy
+
+import ipywidgets as ipw
+import numpy as np
+import traitlets as tl
+from aiida_pseudo.common.units import U
+from pymatgen.core.periodic_table import Element
+
+from aiida import orm
+from aiida.common import exceptions
+from aiida.plugins import GroupFactory
+from aiida_quantumespresso.calculations.functions.create_kpoints_from_distance import (
+ create_kpoints_from_distance,
+)
+from aiida_quantumespresso.common.types import RelaxType
+from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData
+from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
+from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
+from aiidalab_qe.common.panel import SettingsModel
+from aiidalab_qe.setup.pseudos import PSEUDODOJO_VERSION, SSSP_VERSION, PseudoFamily
+
+SsspFamily = GroupFactory("pseudo.family.sssp")
+PseudoDojoFamily = GroupFactory("pseudo.family.pseudo_dojo")
+CutoffsPseudoPotentialFamily = GroupFactory("pseudo.family.cutoffs")
+
+DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore
+
+
+class WorkChainModel(SettingsModel):
+ input_structure = tl.Union(
+ [
+ tl.Instance(orm.StructureData),
+ tl.Instance(HubbardStructureData),
+ ],
+ allow_none=True,
+ )
+
+ protocol = tl.Unicode(DEFAULT["workchain"]["protocol"])
+ relax_type_help = tl.Unicode()
+ relax_type_options = tl.List([""])
+ relax_type = tl.Unicode("")
+ spin_type = tl.Unicode(DEFAULT["workchain"]["spin_type"])
+ electronic_type = tl.Unicode(DEFAULT["workchain"]["electronic_type"])
+
+ _defaults = {}
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.include = True
+
+ self.relax_type_help_template = """
+
+ You have {option_count} options:
+
+ (1) Structure as is: perform a self consistent calculation using
+ the structure provided as input.
+
+ (2) Atomic positions: perform a full relaxation of the internal
+ atomic coordinates.
+ {full_relaxation_option}
+
+ """
+
+ def update(self):
+ self._update_defaults()
+ self.relax_type_help = self._get_default_relax_type_help()
+ self.relax_type_options = self._get_default_relax_type_options()
+ self.relax_type = self._get_default_relax_type()
+
+ def get_model_state(self):
+ return {
+ "protocol": self.protocol,
+ "relax_type": self.relax_type,
+ "spin_type": self.spin_type,
+ "electronic_type": self.electronic_type,
+ }
+
+ def set_model_state(self, parameters):
+ """Update the settings based on the given dict."""
+ for key in [
+ "relax_type",
+ "spin_type",
+ "electronic_type",
+ ]:
+ if key in parameters:
+ setattr(self, key, parameters[key])
+ if "protocol" in parameters:
+ self.protocol = parameters["protocol"]
+
+ def reset(self):
+ with self.hold_trait_notifications():
+ self.protocol = self.traits()["protocol"].default_value
+ self.relax_type_help = self._get_default_relax_type_help()
+ self.relax_type_options = self._get_default_relax_type_options()
+ self.relax_type = self._get_default_relax_type()
+ self.spin_type = self.traits()["spin_type"].default_value
+ self.electronic_type = self.traits()["electronic_type"].default_value
+
+ def _update_defaults(self):
+ structure = self.input_structure
+ if structure is None or any(structure.pbc):
+ relax_type_help = self.relax_type_help_template.format(
+ option_count="three",
+ full_relaxation_option=(
+ """
+
+ (3) Full geometry: perform a full relaxation of the internal atomic
+ coordinates and the cell parameters.
+ """
+ ),
+ )
+ relax_type_options = [
+ ("Structure as is", "none"),
+ ("Atomic positions", "positions"),
+ ("Full geometry", "positions_cell"),
+ ]
+ else:
+ relax_type_help = self.relax_type_help_template.format(
+ option_count="two",
+ full_relaxation_option="",
+ )
+ relax_type_options = [
+ ("Structure as is", "none"),
+ ("Atomic positions", "positions"),
+ ]
+ self._defaults = {
+ "relax_type_help": relax_type_help,
+ "relax_type_options": relax_type_options,
+ "relax_type": relax_type_options[-1][-1],
+ }
+
+ def _get_default_relax_type_help(self):
+ return self._defaults.get("relax_type_help", "")
+
+ def _get_default_relax_type_options(self):
+ return self._defaults.get("relax_type_options", [""])
+
+ def _get_default_relax_type(self):
+ return self._defaults.get("relax_type", "")
+
+
+class SmearingModel(tl.HasTraits):
+ protocol = tl.Unicode()
+ override = tl.Bool()
+
+ type = tl.Unicode()
+ degauss = tl.Float()
+
+ _defaults = {}
+
+ def update(self):
+ with self.hold_trait_notifications():
+ self._update_defaults()
+ self.type = self._get_default_type()
+ self.degauss = self._get_default_degauss()
+
+ def reset(self):
+ with self.hold_trait_notifications():
+ self.type = self._get_default_type()
+ self.degauss = self._get_default_degauss()
+
+ @tl.default("type")
+ def _get_default_type(self):
+ return self._defaults["type"]
+
+ @tl.default("degauss")
+ def _get_default_degauss(self):
+ return self._defaults["degauss"]
+
+ def _update_defaults(self):
+ parameters = (
+ PwBaseWorkChain.get_protocol_inputs(self.protocol)
+ .get("pw", {})
+ .get("parameters", {})
+ .get("SYSTEM", {})
+ )
+ self._defaults = {
+ "type": parameters["smearing"],
+ "degauss": parameters["degauss"],
+ }
+
+
+class MagnetizationModel(tl.HasTraits):
+ input_structure = tl.Union(
+ [
+ tl.Instance(orm.StructureData),
+ tl.Instance(HubbardStructureData),
+ ],
+ allow_none=True,
+ )
+ electronic_type = tl.Unicode()
+ spin_type = tl.Unicode()
+ override = tl.Bool()
+
+ type = tl.Unicode("starting_magnetization")
+ total = tl.Float(0.0)
+ moments = tl.Dict(
+ key_trait=tl.Unicode(), # element symbol
+ value_trait=tl.Float(), # magnetic moment
+ default_value={},
+ )
+
+ _default_moments = {}
+
+ def update(self):
+ with self.hold_trait_notifications():
+ self._update_defaults()
+ self.moments = self._get_default_moments()
+
+ def reset(self):
+ with self.hold_trait_notifications():
+ self.type = self.traits()["type"].default_value
+ self.total = self.traits()["total"].default_value
+ self.moments = self._get_default_moments()
+
+ def _update_defaults(self):
+ if self.spin_type == "none" or self.input_structure is None:
+ self._default_moments = {}
+ else:
+ self._default_moments = {
+ kind.symbol: 0.0 for kind in self.input_structure.kinds
+ }
+
+ def _get_default_moments(self):
+ return deepcopy(self._default_moments)
+
+
+class HubbardModel(tl.HasTraits):
+ input_structure = tl.Union(
+ [
+ tl.Instance(orm.StructureData),
+ tl.Instance(HubbardStructureData),
+ ],
+ allow_none=True,
+ )
+ override = tl.Bool()
+
+ is_active = tl.Bool(False)
+ has_eigenvalues = tl.Bool(False)
+ parameters = tl.Dict(
+ key_trait=tl.Unicode(), # element symbol
+ value_trait=tl.Float(), # U value
+ default_value={},
+ )
+ eigenvalues = tl.List(
+ trait=tl.List(), # [[[[state, spin, kind, eigenvalue] # state] # spin] # kind]
+ default_value=[],
+ )
+
+ applicable_elements = []
+ orbital_labels = []
+ _default_parameters = {}
+ _default_eigenvalues = []
+
+ def update(self):
+ with self.hold_trait_notifications():
+ self._update_defaults()
+ self.parameters = self._get_default_parameters()
+ self.eigenvalues = self._get_default_eigenvalues()
+ self.needs_eigenvalues_widget = len(self.applicable_elements) > 0
+
+ def get_active_eigenvalues(self):
+ return [
+ orbital_eigenvalue
+ for element_eigenvalues in self.eigenvalues
+ for spin_row in element_eigenvalues
+ for orbital_eigenvalue in spin_row
+ if orbital_eigenvalue[-1] != -1
+ ]
+
+ def set_parameters_from_hubbard_structure(self):
+ hubbard_parameters = self.input_structure.hubbard.dict()["parameters"]
+ sites = self.input_structure.sites
+ parameters = {
+ f"{sites[hp['atom_index']].kind_name} - {hp['atom_manifold']}": hp["value"]
+ for hp in hubbard_parameters
+ }
+ with self.hold_trait_notifications():
+ self.parameters = parameters
+ self.is_active = True
+
+ def reset(self):
+ with self.hold_trait_notifications():
+ self.is_active = False
+ self.has_eigenvalues = False
+ self.parameters = self._get_default_parameters()
+ self.eigenvalues = self._get_default_eigenvalues()
+
+ def _update_defaults(self):
+ if self.input_structure is None:
+ self.applicable_elements = []
+ self.orbital_labels = []
+ self._default_parameters = {}
+ self._default_eigenvalues = []
+ elif self.is_active:
+ self.orbital_labels = self._get_labels()
+ self._default_parameters = {label: 0.0 for label in self.orbital_labels}
+ self.applicable_elements = [
+ *filter(
+ lambda element: (
+ element.is_transition_metal
+ or element.is_lanthanoid
+ or element.is_actinoid
+ ),
+ [Element(kind.symbol) for kind in self.input_structure.kinds],
+ )
+ ]
+ self._default_eigenvalues = [
+ [
+ [
+ [state + 1, spin, element.symbol, -1] # default eigenvalue
+ for state in range(5 if element.is_transition_metal else 7)
+ ]
+ for spin in range(2) # spin up and down
+ ]
+ for element in self.applicable_elements # transition metals and lanthanoids
+ ]
+
+ def _get_default_parameters(self):
+ return deepcopy(self._default_parameters)
+
+ def _get_default_eigenvalues(self):
+ return deepcopy(self._default_eigenvalues)
+
+ def _get_labels(self):
+ kind_list = self.input_structure.get_kind_names()
+ hubbard_manifold_list = [
+ self._get_manifold(Element(kind.symbol))
+ for kind in self.input_structure.kinds
+ ]
+ return [
+ f"{kind} - {manifold}"
+ for kind, manifold in zip(kind_list, hubbard_manifold_list)
+ ]
+
+ def _get_manifold(self, element):
+ valence = [
+ orbital
+ for orbital in element.electronic_structure.split(".")
+ if "[" not in orbital
+ ]
+ orbital_shells = [shell[:2] for shell in valence]
+
+ def is_condition_met(shell):
+ return condition and condition in shell
+
+ # Conditions for determining the Hubbard manifold
+ # to be selected from the electronic structure
+ conditions = {
+ element.is_transition_metal: "d",
+ element.is_lanthanoid or element.is_actinoid: "f",
+ element.is_post_transition_metal
+ or element.is_metalloid
+ or element.is_halogen
+ or element.is_chalcogen
+ or element.symbol in ["C", "N", "P"]: "p",
+ element.is_alkaline or element.is_alkali or element.is_noble_gas: "s",
+ }
+
+ condition = next(
+ (shell for condition, shell in conditions.items() if condition), None
+ )
+
+ hubbard_manifold = next(
+ (shell for shell in orbital_shells if is_condition_met(shell)), None
+ )
+
+ return hubbard_manifold
+
+
+class PseudosModel(tl.HasTraits):
+ input_structure = tl.Union(
+ [
+ tl.Instance(orm.StructureData),
+ tl.Instance(HubbardStructureData),
+ ],
+ allow_none=True,
+ )
+ protocol = tl.Unicode()
+ spin_orbit = tl.Unicode()
+ override = tl.Bool()
+
+ dictionary = tl.Dict(
+ key_trait=tl.Unicode(), # element symbol
+ value_trait=tl.Unicode(), # pseudopotential node uuid
+ default_value={},
+ )
+ family = tl.Unicode(
+ "/".join(
+ [
+ DEFAULT["advanced"]["pseudo_family"]["library"],
+ str(DEFAULT["advanced"]["pseudo_family"]["version"]),
+ DEFAULT["advanced"]["pseudo_family"]["functional"],
+ DEFAULT["advanced"]["pseudo_family"]["accuracy"],
+ ]
+ )
+ )
+ functional = tl.Unicode(DEFAULT["advanced"]["pseudo_family"]["functional"])
+ functional_options = tl.List(
+ trait=tl.Unicode(),
+ default_value=[
+ "PBE",
+ "PBEsol",
+ ],
+ )
+ library = tl.Unicode(
+ " ".join(
+ [
+ DEFAULT["advanced"]["pseudo_family"]["library"],
+ DEFAULT["advanced"]["pseudo_family"]["accuracy"],
+ ]
+ )
+ )
+ library_options = tl.List(
+ trait=tl.Unicode(),
+ default_value=[
+ "SSSP efficiency",
+ "SSSP precision",
+ "PseudoDojo standard",
+ "PseudoDojo stringent",
+ ],
+ )
+ cutoffs = tl.List(
+ trait=tl.List(tl.Float()), # [[ecutwfc values], [ecutrho values]]
+ default_value=[[0.0], [0.0]],
+ )
+ ecutwfc = tl.Float()
+ ecutrho = tl.Float()
+ status_message = tl.Unicode()
+ family_help_message = tl.Unicode()
+
+ _default_dictionary = {}
+ _default_cutoffs = []
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ ipw.dlink(
+ (self, "cutoffs"),
+ (self, "ecutwfc"),
+ lambda cutoffs: max(cutoffs[0]),
+ )
+ ipw.dlink(
+ (self, "cutoffs"),
+ (self, "ecutrho"),
+ lambda cutoffs: max(cutoffs[1]),
+ )
+
+ self.PSEUDO_HELP_SOC = """
+
+ Spin-orbit coupling (SOC) calculations are supported exclusively with
+ PseudoDojo pseudopotentials. PseudoDojo offers these pseudopotentials
+ in two versions: standard and stringent. Here, we utilize the FR
+ (fully relativistic) type from PseudoDojo. Please ensure you choose
+ appropriate cutoff values for your calculations.
+
+ """
+
+ self.PSEUDO_HELP_WO_SOC = """
+
+ If you are unsure, select 'SSSP efficiency', which for most
+ calculations will produce sufficiently accurate results at
+ comparatively small computational costs. If your calculations require a
+ higher accuracy, select 'SSSP accuracy' or 'PseudoDojo stringent',
+ which will be computationally more expensive. SSSP is the standard
+ solid-state pseudopotentials. The PseudoDojo used here has the SR
+ relativistic type.
+
+ """
+
+ self.family_help_message = self.PSEUDO_HELP_WO_SOC
+
+ def update(self):
+ with self.hold_trait_notifications():
+ self._update_defaults()
+ self.update_family_parameters()
+
+ def update_default_pseudos(self):
+ try:
+ pseudo_family = self._get_pseudo_family_from_database()
+ pseudos = pseudo_family.get_pseudos(structure=self.input_structure)
+ except ValueError as exception:
+ self._status_message = f"""
+
+ ERROR: {exception!s}
+
+ """
+ return
+
+ self._default_dictionary = {
+ kind: pseudo.uuid for kind, pseudo in pseudos.items()
+ }
+ self.dictionary = self._get_default_dictionary()
+
+ def update_default_cutoffs(self):
+ """Update wavefunction and density cutoffs from pseudo family."""
+ try:
+ pseudo_family = self._get_pseudo_family_from_database()
+ current_unit = pseudo_family.get_cutoffs_unit()
+ cutoff_dict = pseudo_family.get_cutoffs()
+ except exceptions.NotExistent:
+ self._status_message = f"""
+
+ ERROR: required pseudo family `{self.family}` is
+ not installed. Please use `aiida-pseudo install` to install
+ it."
+
+ """
+ except ValueError as exception:
+ self._status_message = f"""
+
+ ERROR: failed to obtain recommended cutoffs for pseudos
+ `{pseudo_family}`: {exception}
+
+ """
+ else:
+ kind_names = (
+ self.input_structure.get_kind_names() if self.input_structure else []
+ )
+
+ ecutwfc_list = []
+ ecutrho_list = []
+ for kind in kind_names:
+ cutoff = cutoff_dict.get(kind, {})
+ ecutrho, ecutwfc = (
+ U.Quantity(v, current_unit).to("Ry").to_tuple()[0]
+ for v in cutoff.values()
+ )
+ ecutwfc_list.append(ecutwfc)
+ ecutrho_list.append(ecutrho)
+
+ self._default_cutoffs = [ecutwfc_list or [0.0], ecutrho_list or [0.0]]
+ self.cutoffs = self._get_default_cutoffs()
+
+ def update_library_options(self):
+ if self.spin_orbit == "soc":
+ self.library_options = [
+ "PseudoDojo standard",
+ "PseudoDojo stringent",
+ ]
+ self.family_help_message = self.PSEUDO_HELP_SOC
+ else:
+ self.library_options = [
+ "SSSP efficiency",
+ "SSSP precision",
+ "PseudoDojo standard",
+ "PseudoDojo stringent",
+ ]
+ self.family_help_message = self.PSEUDO_HELP_WO_SOC
+
+ self.update_family_parameters()
+
+ def update_family_parameters(self):
+ if self.spin_orbit == "soc":
+ if self.protocol in ["fast", "moderate"]:
+ pseudo_family_string = "PseudoDojo/0.4/PBE/FR/standard/upf"
+ else:
+ pseudo_family_string = "PseudoDojo/0.4/PBE/FR/stringent/upf"
+ else:
+ pseudo_family_string = PwBaseWorkChain.get_protocol_inputs(self.protocol)[
+ "pseudo_family"
+ ]
+
+ pseudo_family = PseudoFamily.from_string(pseudo_family_string)
+
+ with self.hold_trait_notifications():
+ self.library = f"{pseudo_family.library} {pseudo_family.accuracy}"
+ self.functional = pseudo_family.functional
+
+ def update_family(self):
+ library, accuracy = self.library.split()
+ functional = self.functional
+ # XXX (jusong.yu): a validator is needed to check the family string is
+ # consistent with the list of pseudo families defined in the setup_pseudos.py
+ if library == "PseudoDojo":
+ if self.spin_orbit == "soc":
+ pseudo_family_string = (
+ f"PseudoDojo/{PSEUDODOJO_VERSION}/{functional}/FR/{accuracy}/upf"
+ )
+ else:
+ pseudo_family_string = (
+ f"PseudoDojo/{PSEUDODOJO_VERSION}/{functional}/SR/{accuracy}/upf"
+ )
+ elif library == "SSSP":
+ pseudo_family_string = f"SSSP/{SSSP_VERSION}/{functional}/{accuracy}"
+ else:
+ raise ValueError(
+ f"Unknown pseudo family parameters: {library} | {accuracy}"
+ )
+
+ self.family = pseudo_family_string
+
+ def reset(self):
+ with self.hold_trait_notifications():
+ self.dictionary = self._get_default_dictionary()
+ self.cutoffs = self._get_default_cutoffs()
+ self.family = self.traits()["family"].default_value
+ self.library = self.traits()["library"].default_value
+ self.functional = self.traits()["functional"].default_value
+ self.family_help_message = self.PSEUDO_HELP_WO_SOC
+ self.status_message = ""
+
+ def _update_defaults(self):
+ if self.input_structure is None:
+ self._default_dictionary = {}
+ self._default_cutoffs = [[0.0], [0.0]]
+ else:
+ self.update_default_pseudos()
+ self.update_default_cutoffs()
+
+ def _get_pseudo_family_from_database(self):
+ """Get the pseudo family from the database."""
+ return (
+ orm.QueryBuilder()
+ .append(
+ (
+ PseudoDojoFamily,
+ SsspFamily,
+ CutoffsPseudoPotentialFamily,
+ ),
+ filters={"label": self.family},
+ )
+ .one()[0]
+ )
+
+ def _get_default_dictionary(self):
+ return deepcopy(self._default_dictionary)
+
+ def _get_default_cutoffs(self):
+ return deepcopy(self._default_cutoffs)
+
+
+class AdvancedModel(SettingsModel):
+ input_structure = tl.Union(
+ [
+ tl.Instance(orm.StructureData),
+ tl.Instance(HubbardStructureData),
+ ],
+ allow_none=True,
+ )
+ protocol = tl.Unicode()
+ spin_type = tl.Unicode()
+ electronic_type = tl.Unicode()
+
+ clean_workdir = tl.Bool(False)
+ override = tl.Bool(False)
+ total_charge = tl.Float(DEFAULT["advanced"]["tot_charge"])
+ van_der_waals = tl.Unicode(DEFAULT["advanced"]["vdw_corr"])
+ spin_orbit = tl.Unicode("wo_soc")
+ forc_conv_thr = tl.Float(0.0)
+ forc_conv_thr_step = tl.Float(1e-4)
+ etot_conv_thr = tl.Float(0.0)
+ etot_conv_thr_step = tl.Float(1e-5)
+ scf_conv_thr = tl.Float(0.0)
+ scf_conv_thr_step = tl.Float(1e-10)
+ electron_maxstep = tl.Int(80)
+ kpoints_distance = tl.Float(0.0)
+ mesh_grid = tl.Unicode("")
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.include = True
+
+ self.dftd3_version = {
+ "dft-d3": 3,
+ "dft-d3bj": 4,
+ "dft-d3m": 5,
+ "dft-d3mbj": 6,
+ }
+
+ self.smearing = SmearingModel()
+ ipw.dlink(
+ (self, "protocol"),
+ (self.smearing, "protocol"),
+ )
+ ipw.dlink(
+ (self, "override"),
+ (self.smearing, "override"),
+ )
+ self.smearing.observe(
+ self._unconfirm,
+ tl.All,
+ )
+
+ self.magnetization = MagnetizationModel()
+ ipw.dlink(
+ (self, "input_structure"),
+ (self.magnetization, "input_structure"),
+ )
+ ipw.dlink(
+ (self, "electronic_type"),
+ (self.magnetization, "electronic_type"),
+ )
+ ipw.dlink(
+ (self, "spin_type"),
+ (self.magnetization, "spin_type"),
+ )
+ ipw.dlink(
+ (self, "override"),
+ (self.magnetization, "override"),
+ )
+ self.magnetization.observe(
+ self._unconfirm,
+ tl.All,
+ )
+
+ self.hubbard = HubbardModel()
+ ipw.dlink(
+ (self, "input_structure"),
+ (self.hubbard, "input_structure"),
+ )
+ ipw.dlink(
+ (self, "override"),
+ (self.hubbard, "override"),
+ )
+ self.hubbard.observe(
+ self._unconfirm,
+ tl.All,
+ )
+
+ self.pseudos = PseudosModel()
+ ipw.dlink(
+ (self, "input_structure"),
+ (self.pseudos, "input_structure"),
+ )
+ ipw.dlink(
+ (self, "protocol"),
+ (self.pseudos, "protocol"),
+ )
+ ipw.dlink(
+ (self, "spin_orbit"),
+ (self.pseudos, "spin_orbit"),
+ )
+ ipw.dlink(
+ (self, "override"),
+ (self.pseudos, "override"),
+ )
+ self.pseudos.observe(
+ self._unconfirm,
+ tl.All,
+ )
+
+ self.observe(
+ self._unconfirm,
+ tl.All,
+ )
+
+ self.update()
+
+ def update(self):
+ parameters = PwBaseWorkChain.get_protocol_inputs(self.protocol)
+
+ self._update_kpoints_distance(parameters)
+ self.update_kpoints_mesh()
+
+ num_atoms = len(self.input_structure.sites) if self.input_structure else 1
+
+ etot_value = num_atoms * parameters["meta_parameters"]["etot_conv_thr_per_atom"]
+ self._set_value_and_step("etot_conv_thr", etot_value)
+
+ scf_value = num_atoms * parameters["meta_parameters"]["conv_thr_per_atom"]
+ self._set_value_and_step("scf_conv_thr", scf_value)
+
+ forc_value = parameters["pw"]["parameters"]["CONTROL"]["forc_conv_thr"]
+ self._set_value_and_step("forc_conv_thr", forc_value)
+
+ self.smearing.update()
+ self.magnetization.update()
+ self.hubbard.update()
+ self.pseudos.update()
+
+ def update_kpoints_mesh(self, _=None):
+ if self.input_structure is None:
+ self.mesh_grid = ""
+ elif self.kpoints_distance > 0:
+ # To avoid creating an aiida node every time we change the kpoints_distance,
+ # we use the function itself instead of the decorated calcfunction.
+ mesh = create_kpoints_from_distance.process_class._func(
+ self.input_structure,
+ orm.Float(self.kpoints_distance),
+ orm.Bool(False),
+ )
+ self.mesh_grid = f"Mesh {mesh.get_kpoints_mesh()[0]!s}"
+ else:
+ self.mesh_grid = "Please select a number higher than 0.0"
+
+ def get_model_state(self):
+ parameters = {
+ "initial_magnetic_moments": None,
+ "pw": {
+ "parameters": {
+ "SYSTEM": {
+ "tot_charge": self.total_charge,
+ },
+ "CONTROL": {
+ "forc_conv_thr": self.forc_conv_thr,
+ "etot_conv_thr": self.etot_conv_thr,
+ },
+ "ELECTRONS": {
+ "conv_thr": self.scf_conv_thr,
+ "electron_maxstep": self.electron_maxstep,
+ },
+ }
+ },
+ "clean_workdir": self.clean_workdir,
+ "pseudo_family": self.pseudos.family,
+ "kpoints_distance": self.kpoints_distance,
+ }
+
+ if self.hubbard.is_active:
+ parameters["hubbard_parameters"] = {"hubbard_u": self.hubbard.parameters}
+ if self.hubbard.has_eigenvalues:
+ parameters["pw"]["parameters"]["SYSTEM"].update(
+ {"starting_ns_eigenvalue": self.hubbard.get_active_eigenvalues()}
+ )
+
+ if self.pseudos.dictionary:
+ parameters["pw"]["pseudos"] = self.pseudos.dictionary
+ parameters["pw"]["parameters"]["SYSTEM"]["ecutwfc"] = self.pseudos.ecutwfc
+ parameters["pw"]["parameters"]["SYSTEM"]["ecutrho"] = self.pseudos.ecutrho
+
+ if self.van_der_waals in ["none", "ts-vdw"]:
+ parameters["pw"]["parameters"]["SYSTEM"]["vdw_corr"] = self.van_der_waals
+ else:
+ parameters["pw"]["parameters"]["SYSTEM"]["vdw_corr"] = "dft-d3"
+ parameters["pw"]["parameters"]["SYSTEM"]["dftd3_version"] = (
+ self.dftd3_version[self.van_der_waals]
+ )
+
+ # there are two choose, use link or parent
+ if self.spin_type == "collinear":
+ parameters["initial_magnetic_moments"] = self.magnetization.moments
+ if self.electronic_type == "metal":
+ # smearing type setting
+ parameters["pw"]["parameters"]["SYSTEM"]["smearing"] = self.smearing.type
+ # smearing degauss setting
+ parameters["pw"]["parameters"]["SYSTEM"]["degauss"] = self.smearing.degauss
+
+ # Set tot_magnetization for collinear simulations.
+ if self.spin_type == "collinear":
+ # Conditions for metallic systems.
+ # Select the magnetization type and set the value if override is True
+ if self.electronic_type == "metal" and self.override:
+ if self.magnetization.type == "tot_magnetization":
+ parameters["pw"]["parameters"]["SYSTEM"]["tot_magnetization"] = (
+ self.magnetization.total
+ )
+ else:
+ parameters["initial_magnetic_moments"] = self.magnetization.moments
+ # Conditions for insulator systems. Default value is 0.0
+ elif self.electronic_type == "insulator":
+ parameters["pw"]["parameters"]["SYSTEM"]["tot_magnetization"] = (
+ self.magnetization.total
+ )
+
+ # Spin-Orbit calculation
+ if self.spin_orbit == "soc":
+ parameters["pw"]["parameters"]["SYSTEM"]["lspinorb"] = True
+ parameters["pw"]["parameters"]["SYSTEM"]["noncolin"] = True
+ parameters["pw"]["parameters"]["SYSTEM"]["nspin"] = 4
+
+ return parameters
+
+ def set_model_state(self, parameters):
+ if "pseudo_family" in parameters:
+ pseudo_family = PseudoFamily.from_string(parameters["pseudo_family"])
+ library = pseudo_family.library
+ accuracy = pseudo_family.accuracy
+ self.pseudos.library = f"{library} {accuracy}"
+ self.pseudos.functional = pseudo_family.functional
+
+ if "pseudos" in parameters["pw"]:
+ self.pseudos.dict = parameters["pw"]["pseudos"]
+ self.pseudos.ecutwfc = parameters["pw"]["parameters"]["SYSTEM"]["ecutwfc"]
+ self.pseudos.ecutrho = parameters["pw"]["parameters"]["SYSTEM"]["ecutrho"]
+
+ self.kpoints_distance = parameters.get("kpoints_distance", 0.15)
+
+ if (pw_parameters := parameters.get("pw", {}).get("parameters")) is not None:
+ self._set_pw_parameters(pw_parameters)
+
+ if magnetic_moments := parameters.get("initial_magnetic_moments"):
+ if isinstance(magnetic_moments, (int, float)):
+ magnetic_moments = [magnetic_moments]
+ if isinstance(magnetic_moments, list):
+ magnetic_moments = dict(
+ zip(
+ self.input_structure.get_kind_names(),
+ magnetic_moments,
+ )
+ )
+ self.magnetization.moments = magnetic_moments
+
+ if parameters.get("hubbard_parameters"):
+ self.hubbard.is_active = True
+ self.hubbard.parameters = parameters["hubbard_parameters"]["hubbard_u"]
+ starting_ns_eigenvalue = (
+ parameters.get("pw", {})
+ .get("parameters", {})
+ .get("SYSTEM", {})
+ .get("starting_ns_eigenvalue")
+ )
+ if starting_ns_eigenvalue is not None:
+ self.hubbard.has_eigenvalues = True
+ self.hubbard.eigenvalues = starting_ns_eigenvalue
+
+ def reset(self):
+ with self.hold_trait_notifications():
+ self.total_charge = self.traits()["total_charge"].default_value
+ self.van_der_waals = self.traits()["van_der_waals"].default_value
+ self.forc_conv_thr = self.traits()["forc_conv_thr"].default_value
+ self.forc_conv_thr_step = self.traits()["forc_conv_thr_step"].default_value
+ self.etot_conv_thr = self.traits()["etot_conv_thr"].default_value
+ self.etot_conv_thr_step = self.traits()["etot_conv_thr_step"].default_value
+ self.scf_conv_thr = self.traits()["scf_conv_thr"].default_value
+ self.scf_conv_thr_step = self.traits()["scf_conv_thr_step"].default_value
+ self.electron_maxstep = self.traits()["electron_maxstep"].default_value
+ self.spin_orbit = self.traits()["spin_orbit"].default_value
+ self.kpoints_distance = self.traits()["kpoints_distance"].default_value
+ self.override = self.traits()["override"].default_value
+
+ def _update_kpoints_distance(self, parameters):
+ if self.input_structure is None or any(self.input_structure.pbc):
+ self.kpoints_distance = parameters["kpoints_distance"]
+ else:
+ self.kpoints_distance = 100.0
+
+ def _set_value_and_step(self, attribute, value):
+ setattr(self, attribute, value)
+ if value != 0:
+ order_of_magnitude = np.floor(np.log10(abs(value)))
+ setattr(self, f"{attribute}_step", 10 ** (order_of_magnitude - 1))
+ else:
+ setattr(self, f"{attribute}_step", 0.1)
+
+ def _set_pw_parameters(self, pw_parameters):
+ system_params = pw_parameters.get("SYSTEM", {})
+ control_params = pw_parameters.get("CONTROL", {})
+ electron_params = pw_parameters.get("ELECTRONS", {})
+
+ self.forc_conv_thr = control_params.get("forc_conv_thr", 0.0)
+ self.etot_conv_thr = control_params.get("etot_conv_thr", 0.0)
+ self.scf_conv_thr = electron_params.get("conv_thr", 0.0)
+ self.electron_maxstep = electron_params.get("electron_maxstep", 80)
+
+ self.total_charge = system_params.get("tot_charge", 0)
+ self.spin_orbit = "soc" if "lspinorb" in system_params else "wo_soc"
+
+ self.van_der_waals = self.dftd3_version.get(
+ system_params.get("dftd3_version"),
+ system_params.get("vdw_corr", "none"),
+ )
+
+ if "degauss" in system_params:
+ self.smearing.degauss = system_params["degauss"]
+
+ if "smearing" in system_params:
+ self.smearing.type = system_params["smearing"]
+
+ if "tot_magnetization" in system_params:
+ self.magnetization.type = "tot_magnetization"
+
+
+class ConfigurationModel(SettingsModel):
+ input_structure = tl.Union(
+ [
+ tl.Instance(orm.StructureData),
+ tl.Instance(HubbardStructureData),
+ ],
+ allow_none=True,
+ )
+ has_pbc = tl.Bool()
+ confirmed = tl.Bool(False)
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ ipw.dlink(
+ (self, "input_structure"),
+ (self, "has_pbc"),
+ lambda structure: structure is None or any(structure.pbc),
+ )
+
+ self.workchain = WorkChainModel()
+ ipw.dlink(
+ (self, "input_structure"),
+ (self.workchain, "input_structure"),
+ )
+ ipw.link(
+ (self, "confirmed"),
+ (self.workchain, "confirmed"),
+ )
+
+ self.advanced = AdvancedModel()
+ ipw.dlink(
+ (self, "input_structure"),
+ (self.advanced, "input_structure"),
+ )
+ ipw.dlink(
+ (self.workchain, "protocol"),
+ (self.advanced, "protocol"),
+ )
+ ipw.dlink(
+ (self.workchain, "spin_type"),
+ (self.advanced, "spin_type"),
+ )
+ ipw.dlink(
+ (self.workchain, "electronic_type"),
+ (self.advanced, "electronic_type"),
+ )
+ ipw.link(
+ (self, "confirmed"),
+ (self.advanced, "confirmed"),
+ )
+
+ self._models: dict[str, SettingsModel] = {
+ "workchain": self.workchain,
+ "advanced": self.advanced,
+ }
+
+ self._default_models = set(self._models.keys())
+
+ def add_model(self, identifier, model):
+ self._models[identifier] = model
+ ipw.link(
+ (self, "confirmed"),
+ (model, "confirmed"),
+ )
+
+ def get_models(self):
+ return self._models.items()
+
+ def get_model(self, identifier) -> SettingsModel:
+ if identifier in self._models:
+ return self._models[identifier]
+ raise ValueError(f"Model with identifier '{identifier}' not found.")
+
+ def get_model_state(self):
+ parameters = {
+ identifier: model.get_model_state()
+ for identifier, model in self._models.items()
+ if model.include
+ }
+ parameters["workchain"].update({"properties": self._get_properties()})
+ return parameters
+
+ def set_model_state(self, parameters):
+ with self.hold_trait_notifications():
+ properties = set(parameters.get("workchain", {}).get("properties", []))
+ for identifier, model in self._models.items():
+ if parameters.get(identifier):
+ model.set_model_state(parameters[identifier])
+ model.include = identifier in self._default_models | properties
+
+ def reset(self):
+ self.confirmed = False
+ for identifier, model in self._models.items():
+ if identifier not in self._default_models:
+ model.include = False
+
+ def _get_properties(self):
+ properties = []
+ run_bands = False
+ run_pdos = False
+ for identifier, model in self._models.items():
+ if identifier in self._default_models:
+ continue
+ if model.include:
+ properties.append(identifier)
+ if identifier in ("bands", "pdos"):
+ run_bands = True
+ if RelaxType(self.workchain.relax_type) is not RelaxType.NONE or not (
+ run_bands or run_pdos
+ ):
+ properties.append("relax")
+ return properties
diff --git a/src/aiidalab_qe/app/configuration/pseudos.py b/src/aiidalab_qe/app/configuration/pseudos.py
index 1d969779f..70349300c 100644
--- a/src/aiidalab_qe/app/configuration/pseudos.py
+++ b/src/aiidalab_qe/app/configuration/pseudos.py
@@ -1,541 +1,415 @@
from __future__ import annotations
import io
-import re
import ipywidgets as ipw
import traitlets as tl
from aiida import orm
-from aiida.common import exceptions
from aiida.plugins import DataFactory, GroupFactory
-from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
-from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
-from aiidalab_qe.setup.pseudos import (
- PSEUDODOJO_VERSION,
- SSSP_VERSION,
- PseudoFamily,
-)
+from aiidalab_qe.common.widgets import LoadingWidget
from aiidalab_widgets_base.utils import StatusHTML
+from .model import ConfigurationModel
+
UpfData = DataFactory("pseudo.upf")
SsspFamily = GroupFactory("pseudo.family.sssp")
PseudoDojoFamily = GroupFactory("pseudo.family.pseudo_dojo")
CutoffsPseudoPotentialFamily = GroupFactory("pseudo.family.cutoffs")
-class PseudoFamilySelector(ipw.VBox):
- title = ipw.HTML(
- """
-
Accuracy and precision
"""
- )
- PSEUDO_HELP_SOC = """
- Spin-orbit coupling (SOC) calculations are supported exclusively with PseudoDojo pseudopotentials.
- PseudoDojo offers these pseudopotentials in two versions: standard and stringent.
- Here, we utilize the FR (fully relativistic) type from PseudoDojo.
- Please ensure you choose appropriate cutoff values for your calculations.
-
"""
-
- PSEUDO_HELP_WO_SOC = """
- If you are unsure, select 'SSSP efficiency', which for
- most calculations will produce sufficiently accurate results at
- comparatively small computational costs. If your calculations require a
- higher accuracy, select 'SSSP accuracy' or 'PseudoDojo stringent', which will be computationally
- more expensive. SSSP is the standard solid-state pseudopotentials.
- The PseudoDojo used here has the SR relativistic type.
"""
-
- description = ipw.HTML(
- """
- The exchange-correlation functional and pseudopotential library is set by
- the protocol configured in the "Workflow" tab. Here you can
- override the defaults if desired.
""",
- layout=ipw.Layout(max_width="60%"),
- )
-
- # XXX: the link is not correct after add pseudo dojo
- pseudo_family_prompt = ipw.HTML(
- """"""
- )
- pseudo_family_help = ipw.HTML(PSEUDO_HELP_WO_SOC)
-
- dft_functional_prompt = ipw.HTML(
- """
-
- Exchange-correlation functional
"""
- )
- dft_functional_help = ipw.HTML(
- """
- The exchange-correlation energy is calculated using this functional. We currently provide support for two
- well-established generalised gradient approximation (GGA) functionals:
- PBE and PBEsol.
"""
- )
- protocol = tl.Unicode(allow_none=True)
- disabled = tl.Bool()
- spin_orbit = tl.Unicode()
-
- # output pseudo family widget which is the string of the pseudo family (of the AiiDA group).
- value = tl.Unicode(allow_none=True)
-
- def __init__(self, **kwargs):
- # Enable manual setting of the pseudopotential family
- self.set_pseudo_family_prompt = ipw.HTML(" Override ")
- self.override = ipw.Checkbox(
- description="",
- indent=False,
- layout=ipw.Layout(max_width="10%"),
+class PseudoSettings(ipw.VBox):
+ """Widget to set the pseudopotentials for the calculation."""
+
+ def __init__(self, model: ConfigurationModel, **kwargs):
+ self.loading_message = LoadingWidget("Loading pseudopotential settings")
+
+ super().__init__(
+ children=[self.loading_message],
+ **kwargs,
)
- self.set_pseudo_family_box = ipw.HBox(
- [self.set_pseudo_family_prompt, self.override],
- layout=ipw.Layout(max_width="20%"),
+
+ self._model = model
+ self._model.observe(
+ self._on_input_structure_change,
+ "input_structure",
)
- self.show_ui = ipw.Valid(value=True)
- self.override.observe(self.set_show_ui, "value")
- self.override.observe(self.set_text_color, "value")
- self.override.observe(self.set_value, "value")
-
- # the widget for DFT functional selection
- self.dft_functional = ipw.Dropdown(
- options=["PBE", "PBEsol"],
- style={"description_width": "initial"},
+ self._model.advanced.observe(
+ self._on_spin_orbit_change,
+ "spin_orbit",
)
- self.dft_functional.observe(self.set_value, "value")
- self.library_selection = ipw.ToggleButtons(
- options=[
- "SSSP efficiency",
- "SSSP precision",
- "PseudoDojo standard",
- "PseudoDojo stringent",
- ],
- layout=ipw.Layout(max_width="80%"),
+ self._model.advanced.observe(
+ self._on_override_change,
+ "override",
)
- self.library_selection.observe(self.set_value, "value")
-
- self.dft_functional_box = ipw.VBox(
- children=[
- self.dft_functional_prompt,
- self.dft_functional,
- self.dft_functional_help,
- ],
- layout=ipw.Layout(max_width="40%"),
+ self._model.advanced.pseudos.observe(
+ self._on_family_parameters_change,
+ ["library", "functional"],
)
- self.pseudo_setup_box = ipw.VBox(
- children=[
- self.pseudo_family_prompt,
- self.library_selection,
- self.pseudo_family_help,
- ],
- layout=ipw.Layout(max_width="60%"),
- **kwargs,
+ self._model.advanced.pseudos.observe(
+ self._on_family_change,
+ "family",
)
- ipw.dlink((self.show_ui, "value"), (self.library_selection, "disabled"))
- ipw.dlink((self.show_ui, "value"), (self.dft_functional, "disabled"))
- super().__init__(
- children=[
- self.title,
- ipw.HBox(
- [self.description, self.set_pseudo_family_box],
- layout=ipw.Layout(height="50px", justify_content="space-between"),
- ),
- ipw.HBox([self.dft_functional_box, self.pseudo_setup_box]),
- ]
- )
- # after the initialization, the protocol is set to the default
- # this will trigger the callback to set the value of widgets to the default
- self._default_protocol = DEFAULT_PARAMETERS["workchain"]["protocol"]
- self.protocol = self._default_protocol
- self.override.value = False
-
- def set_value(self, _=None):
- """The callback when the selection of pseudo family or dft functional is changed.
- Also triggered when the override checkbox is changed.
- This is the only method to set the value of the widget.
- """
- library, accuracy = self.library_selection.value.split()
- functional = self.dft_functional.value
- # XXX (jusong.yu): a validator is needed to check the family string is consistent with the list of pseudo families defined in the setup_pseudos.py
- if library == "PseudoDojo":
- if self.spin_orbit == "soc":
- pseudo_family_string = (
- f"PseudoDojo/{PSEUDODOJO_VERSION}/{functional}/FR/{accuracy}/upf"
- )
- else:
- pseudo_family_string = (
- f"PseudoDojo/{PSEUDODOJO_VERSION}/{functional}/SR/{accuracy}/upf"
- )
- elif library == "SSSP":
- pseudo_family_string = f"SSSP/{SSSP_VERSION}/{functional}/{accuracy}"
- else:
- raise ValueError(
- f"Unknown pseudo family {self.override_protocol_pseudo_family.value}"
- )
+ self.links = []
- self.value = pseudo_family_string
-
- def set_show_ui(self, change):
- self.show_ui.value = not change.new
-
- def set_text_color(self, change):
- opacity = 1.0 if change.new else 0.5
-
- for html in (
- self.pseudo_family_prompt,
- self.pseudo_family_help,
- self.dft_functional_help,
- self.dft_functional_prompt,
- ):
- old_opacity = re.match(
- r"[\s\S]+opacity:([\S]+);[\S\s]+", html.value
- ).groups()[0]
- html.value = html.value.replace(
- f"opacity:{old_opacity};", f"opacity:{opacity};"
- )
+ self.rendered = False
+ self.updated = False
- def reset(self):
- """Reset the widget to the initial state by reset protocol to default."""
- self.protocol = self._default_protocol
-
- # in case the protocol is not changed, the callback is not triggered
- # so we trigger it explicitly. This will happened when the protocol
- # stay the same while xc selection is changed.
- self._update_settings_from_protocol(self.protocol)
-
- @tl.observe("spin_orbit")
- def _update_library_selection(self, _):
- """Update the library selection according to the spin orbit value."""
- if self.spin_orbit == "soc":
- self.library_selection.options = [
- "PseudoDojo standard",
- "PseudoDojo stringent",
- ]
- self.pseudo_family_help.value = self.PSEUDO_HELP_SOC
- else:
- self.library_selection.options = [
- "SSSP efficiency",
- "SSSP precision",
- "PseudoDojo standard",
- "PseudoDojo stringent",
- ]
- self.pseudo_family_help.value = self.PSEUDO_HELP_WO_SOC
+ def render(self):
+ if self.rendered:
+ return
- @tl.observe("protocol")
- def _protocol_changed(self, _):
- """Input protocol changed, update the value of widgets."""
- self._update_settings_from_protocol(self.protocol)
+ self.family_prompt = ipw.HTML()
- def _update_settings_from_protocol(self, protocol):
- """Update the widget values from the given protocol, and trigger the callback."""
- # FIXME: this rely on the aiida-quantumespresso, which is not ideal
+ self.family_help = ipw.HTML()
+ ipw.dlink(
+ (self._model.advanced.pseudos, "family_help_message"),
+ (self.family_help, "value"),
+ )
- if self.spin_orbit == "soc":
- if protocol in ["fast", "moderate"]:
- pseudo_family_string = "PseudoDojo/0.4/PBE/FR/standard/upf"
- else:
- pseudo_family_string = "PseudoDojo/0.4/PBE/FR/stringent/upf"
- else:
- pseudo_family_string = PwBaseWorkChain.get_protocol_inputs(protocol)[
- "pseudo_family"
- ]
+ self.functional_prompt = ipw.HTML("""
+
+ Exchange-correlation functional
+
+ """)
+
+ self.functional_help = ipw.HTML("""
+
+ The exchange-correlation energy is calculated using this functional. We
+ currently provide support for two well-established generalized gradient
+ approximation (GGA) functionals: PBE and PBEsol.
+
+ """)
+
+ self.functional = ipw.Dropdown(style={"description_width": "initial"})
+ ipw.dlink(
+ (self._model.advanced.pseudos, "functional_options"),
+ (self.functional, "options"),
+ )
+ ipw.link(
+ (self._model.advanced.pseudos, "functional"),
+ (self.functional, "value"),
+ )
+ ipw.dlink(
+ (self._model.advanced, "override"),
+ (self.functional, "disabled"),
+ lambda override: not override,
+ )
- pseudo_family = PseudoFamily.from_string(pseudo_family_string)
+ self.library = ipw.ToggleButtons(layout=ipw.Layout(max_width="80%"))
+ ipw.dlink(
+ (self._model.advanced.pseudos, "library_options"),
+ (self.library, "options"),
+ )
+ ipw.link(
+ (self._model.advanced.pseudos, "library"),
+ (self.library, "value"),
+ )
+ ipw.dlink(
+ (self._model.advanced, "override"),
+ (self.library, "disabled"),
+ lambda override: not override,
+ )
- self.load_from_pseudo_family(pseudo_family)
+ self.setter_widget_helper = ipw.HTML("""
+
+ The pseudopotential for each kind of atom in the structure can be
+ custom set. The default pseudopotential and cutoffs are get from
+ the pseudo family. The cutoffs used for the calculation are the
+ maximum of the default from all pseudopotentials and can be custom
+ set.
+
+ """)
- def load_from_pseudo_family(self, pseudo_family: PseudoFamily):
- """Reload the widget from the given pseudo family string."""
- with self.hold_trait_notifications():
- # will trigger the callback to set the value of widgets
- self.library_selection.value = (
- f"{pseudo_family.library} {pseudo_family.accuracy}"
- )
- self.dft_functional.value = pseudo_family.functional
-
-
-class PseudoSetter(ipw.VBox):
- structure = tl.Instance(klass=orm.StructureData, allow_none=True)
- pseudo_family = tl.Unicode(allow_none=True)
-
- # output pseudos
- pseudos = tl.Dict()
-
- # output cutoffs
- ecutwfc = tl.Float()
- ecutrho = tl.Float()
-
- _default_pseudo_setter_helper_text = """
- Input structure is not set. Please set the structure first.
-
"""
- _update_pseudo_setter_helper_text = """
- The pseudopotential for each kind of atom in the structure can be set customly.
- The default pseudopotential and cutoffs are get from the pseudo family.
- The cutoffs used for the calculation are the maximum of the default from all pseudopotentials
- and can be set customly.
-
"""
-
- _cutoff_setter_helper_text = """
- Please set the cutoffs for the calculation. The default cutoffs are get from the pseudo family.
-
"""
-
- def __init__(
- self,
- structure: orm.StructureData | None = None,
- pseudo_family: str | None = None,
- **kwargs,
- ):
- self.pseudo_setting_widgets = ipw.VBox()
- self._status_message = StatusHTML(clear_after=20)
+ self.setter_widget = ipw.VBox()
- # the initial cutoffs are set to 0
- self.ecutwfc = 0
- self.ecutrho = 0
+ self._status_message = StatusHTML(clear_after=20)
+ ipw.dlink(
+ (self._model.advanced.pseudos, "status_message"),
+ (self._status_message, "message"),
+ )
- self.pseudo_setter_helper = ipw.HTML(self._default_pseudo_setter_helper_text)
- self.cutoff_setter_helper = ipw.HTML(self._cutoff_setter_helper_text)
- self.ecutwfc_setter = ipw.FloatText(
+ self.cutoff_helper = ipw.HTML("""
+
+ Please set the cutoffs for the calculation. The default cutoffs are get
+ from the pseudo family.
+
+ """)
+ self.ecutwfc = ipw.FloatText(
description="Wavefunction cutoff (Ry)",
style={"description_width": "initial"},
)
- self.ecutrho_setter = ipw.FloatText(
+ ipw.link(
+ (self._model.advanced.pseudos, "ecutwfc"),
+ (self.ecutwfc, "value"),
+ )
+ ipw.dlink(
+ (self._model.advanced, "override"),
+ (self.ecutwfc, "disabled"),
+ lambda override: not override,
+ )
+ self.ecutrho = ipw.FloatText(
description="Charge density cutoff (Ry)",
style={"description_width": "initial"},
)
- self.ecutwfc_setter.observe(self._on_cutoff_change, names="value")
- self.ecutrho_setter.observe(self._on_cutoff_change, names="value")
-
- super().__init__(
- children=[
- self.pseudo_setter_helper,
- self.pseudo_setting_widgets,
- self.cutoff_setter_helper,
- ipw.HBox(
- children=[
- self.ecutwfc_setter,
- self.ecutrho_setter,
- ],
- ),
- self._status_message,
- ],
- **kwargs,
+ ipw.link(
+ (self._model.advanced.pseudos, "ecutrho"),
+ (self.ecutrho, "value"),
+ )
+ ipw.dlink(
+ (self._model.advanced, "override"),
+ (self.ecutrho, "disabled"),
+ lambda override: not override,
)
- self._reset()
- with self.hold_trait_notifications():
- self.structure = structure
- self.pseudo_family = pseudo_family
-
- def _on_cutoff_change(self, _=None):
- """Update the cutoffs according to the cutoff widgets"""
- self.ecutwfc = self.ecutwfc_setter.value
- self.ecutrho = self.ecutrho_setter.value
-
- def _reset_cutoff_widgets(self):
- """Reset the cutoff widgets to 0"""
- self.ecutwfc_setter.value = 0
- self.ecutrho_setter.value = 0
- def _reset_traitlets(self):
- """Reset the traitlets to the initial state"""
- self.ecutwfc = 0
- self.ecutrho = 0
- self.pseudos = {}
+ self.container = ipw.VBox()
+
+ self.children = [
+ ipw.HTML("""
+
+
Accuracy and precision
+
+ """),
+ ipw.HBox(
+ children=[
+ ipw.HTML(
+ """
+
+ The exchange-correlation functional and pseudopotential
+ library is set by the protocol configured in the
+ "Workflow" tab. Here you can override the defaults if
+ desired.
+
+ """,
+ layout=ipw.Layout(max_width="60%"),
+ ),
+ ],
+ layout=ipw.Layout(height="50px", justify_content="space-between"),
+ ),
+ ipw.HBox(
+ [
+ ipw.VBox(
+ children=[
+ self.functional_prompt,
+ self.functional,
+ self.functional_help,
+ ],
+ layout=ipw.Layout(max_width="40%"),
+ ),
+ ipw.VBox(
+ children=[
+ self.family_prompt,
+ self.library,
+ self.family_help,
+ ],
+ layout=ipw.Layout(max_width="60%"),
+ ),
+ ]
+ ),
+ self.container,
+ self.cutoff_helper,
+ ipw.HBox(
+ children=[
+ self.ecutwfc,
+ self.ecutrho,
+ ],
+ ),
+ self._status_message,
+ ]
+
+ self.rendered = True
+
+ self.update()
+
+ def update(self):
+ if not self.updated:
+ self._update()
+ self.updated = True
- def _reset(self):
- """Reset the pseudo setting widgets according to the structure
- by default the pseudos are get from the pseudo family
- """
- if self.structure is None:
- self._reset_cutoff_widgets()
- self._reset_traitlets()
- self.pseudo_setting_widgets.children = ()
- self.pseudo_setter_helper.value = self._default_pseudo_setter_helper_text
+ def reset(self):
+ if not self._model.input_structure:
+ self._unsubscribe()
+ self._model.advanced.pseudos.reset()
+ self.updated = False
+
+ def _on_input_structure_change(self, _):
+ self._update(rebuild=True)
+
+ def _on_spin_orbit_change(self, _):
+ self._model.advanced.pseudos.update_library_options()
+
+ def _on_override_change(self, _):
+ self._toggle_setter_widgets()
+
+ def _on_family_parameters_change(self, _):
+ self._model.advanced.pseudos.update_family()
+
+ def _on_family_change(self, _):
+ self._update_family_link()
+ self._model.advanced.pseudos.update_default_pseudos()
+ self._model.advanced.pseudos.update_default_cutoffs()
+
+ def _update(self, rebuild=False):
+ self._unsubscribe()
+ self._show_loading()
+ self._model.advanced.pseudos.update()
+ self._build_setter_widgets(rebuild=rebuild)
+ self._toggle_setter_widgets()
+ self._model.advanced.pseudos.update_library_options()
+ self._update_family_link()
+
+ def _update_family_link(self):
+ if not self.rendered:
return
- if self.pseudo_family is None:
- # this happened from the beginning when the widget is initialized
- # but also for the case when pseudo family is not provided which
- # won't happened for the real use but may happen for the test
- # so we still generate the pseudo setting widgets
- kinds = self.structure.get_kind_names()
-
- # Reset the traitlets, so the interface is clear setup
- self.pseudo_setting_widgets.children = ()
- self._reset_traitlets()
-
- # loop over the kinds and create the pseudo setting widget
- # (initialized with the pseudo from the family)
- for kind in kinds:
- pseudo_upload_widget = PseudoUploadWidget(kind=kind)
+ library, accuracy = self._model.advanced.pseudos.library.split()
+ if library == "SSSP":
+ pseudo_family_link = (
+ f"https://www.materialscloud.org/discover/sssp/table/{accuracy}"
+ )
+ else:
+ pseudo_family_link = "http://www.pseudo-dojo.org/"
+
+ self.family_prompt.value = f"""
+
+ """
- # keep track of the changing of pseudo setting of each kind
- pseudo_upload_widget.observe(self._update_pseudos, ["pseudo"])
- self.pseudo_setting_widgets.children += (pseudo_upload_widget,)
+ def _show_loading(self):
+ if self.rendered:
+ self.setter_widget.children = [self.loading_message]
+ def _build_setter_widgets(self, rebuild=False):
+ """Build the pseudo setter widgets."""
+ if not self.rendered or len(self.setter_widget.children) > 1 and not rebuild:
return
- try:
- pseudo_family = self._get_pseudos_family(self.pseudo_family)
- except exceptions.NotExistent as exception:
- self._status_message.message = (
- f""" ERROR: {exception!s}
"""
+ children = []
+
+ if self._model.input_structure is None:
+ kinds = []
+ else:
+ kinds = self._model.input_structure.kinds
+
+ for index, kind in enumerate(kinds):
+ symbol = kind.name
+ upload_widget = PseudoUploadWidget(kind=symbol)
+ pseudo_link = ipw.link(
+ (self._model.advanced.pseudos, "dictionary"),
+ (upload_widget, "pseudo"),
+ [
+ lambda d, symbol=symbol: orm.load_node(d.get(symbol)),
+ lambda v, symbol=symbol: {
+ **self._model.advanced.pseudos.dictionary,
+ symbol: v.uuid,
+ },
+ ],
+ )
+ cutoffs_link = ipw.dlink(
+ (self._model.advanced.pseudos, "cutoffs"),
+ (upload_widget, "cutoffs"),
+ lambda c, i=index: [c[0][i], c[1][i]] if len(c[0]) > i else [0.0, 0.0],
+ )
+ upload_widget.render()
+
+ self.links.extend(
+ [
+ pseudo_link,
+ cutoffs_link,
+ *upload_widget.links,
+ ]
)
- return
- try:
- pseudos = pseudo_family.get_pseudos(structure=self.structure)
- # get cutoffs dict of all elements
- cutoffs = self._get_cutoffs(pseudo_family)
- except ValueError as exception:
- self._status_message.message = f""" ERROR: failed to obtain recommended cutoffs for pseudos `{pseudo_family}`: {exception}
"""
- return
+ children.append(upload_widget)
- # success get family and cutoffs, set the traitlets accordingly
- # set the recommended cutoffs
- self.pseudos = {kind: pseudo.uuid for kind, pseudo in pseudos.items()}
- self.set_pseudos(self.pseudos, cutoffs)
-
- def _get_pseudos_family(self, pseudo_family: str) -> orm.Group:
- """Get the pseudo family from the database."""
- try:
- pseudo_set = (PseudoDojoFamily, SsspFamily, CutoffsPseudoPotentialFamily)
- pseudo_family = (
- orm.QueryBuilder()
- .append(pseudo_set, filters={"label": pseudo_family})
- .one()[0]
- )
- except exceptions.NotExistent as exception:
- raise exceptions.NotExistent(
- f"required pseudo family `{pseudo_family}` is not installed. Please use `aiida-pseudo install` to"
- "install it."
- ) from exception
-
- return pseudo_family
-
- def _get_cutoffs(self, pseudo_family):
- """Get the cutoffs from the pseudo family."""
- from aiida_pseudo.common.units import U
-
- try:
- cutoffs = pseudo_family.get_cutoffs()
- except ValueError as exception:
- self._status_message.message = f""" ERROR: failed to obtain recommended cutoffs for pseudos `{pseudo_family}`: {exception}
"""
- return
+ self.setter_widget.children = children
- current_unit = pseudo_family.get_cutoffs_unit()
- for element, cutoff in cutoffs.items():
- cutoffs[element] = {
- k: U.Quantity(v, current_unit).to("Ry").to_tuple()[0]
- for k, v in cutoff.items()
- }
-
- return cutoffs
-
- def _create_pseudo_widget(self, kind):
- """The sigle line of pseudo setter widget"""
- return PseudoUploadWidget(kind=kind)
-
- @tl.observe("structure")
- def _structure_change(self, _):
- self._reset()
- self._update_pseudos()
-
- @tl.observe("pseudo_family")
- def _pseudo_family_change(self, _):
- self._reset()
- self._update_pseudos()
-
- def _update_pseudos(self, _=None):
- """Update the pseudos according to the pseudo setting widgets"""
- self._reset_cutoff_widgets()
- for w in self.pseudo_setting_widgets.children:
- if w.error_message is not None:
- self._status_message.message = w.error_message
- return
-
- if w.pseudo is not None:
- self.pseudos[w.kind] = w.pseudo.uuid
- self.pseudo_setter_helper.value = self._update_pseudo_setter_helper_text
-
- with self.hold_trait_notifications():
- self.ecutwfc_setter.value = max(self.ecutwfc, w.ecutwfc)
- self.ecutrho_setter.value = max(self.ecutrho, w.ecutrho)
-
- def set_pseudos(self, pseudos, cutoffs):
- # Reset the traitlets, so the interface is clear setup
- self.pseudo_setting_widgets.children = ()
- self._reset_traitlets()
-
- # loop over the kinds and create the pseudo setting widget
- # (initialized with the pseudo from the family)
- for kind in self.structure.kinds:
- element = kind.symbol
- pseudo = orm.load_node(pseudos.get(kind.name, None))
- _cutoffs = cutoffs.get(element, None) # cutoffs for each element
- pseudo_upload_widget = PseudoUploadWidget(
- kind=kind.name, pseudo=pseudo, cutoffs=_cutoffs
- )
+ def _toggle_setter_widgets(self):
+ if not self.rendered:
+ return
+ if self._model.advanced.override:
+ self.container.children = [
+ self.setter_widget_helper,
+ self.setter_widget,
+ ]
+ else:
+ self.container.children = []
- # keep track of the changing of pseudo setting of each kind
- pseudo_upload_widget.observe(
- self._update_pseudos, ["pseudo", "ecutwfc", "ecutrho"]
- )
- self.pseudo_setting_widgets.children += (pseudo_upload_widget,)
- self._update_pseudos()
+ def _unsubscribe(self):
+ for link in self.links:
+ link.unlink()
+ self.links.clear()
+# TODO implement/improve MVC in this widget
class PseudoUploadWidget(ipw.HBox):
"""Class that allows to upload pseudopotential from user's computer."""
- pseudo = tl.Instance(klass=UpfData, allow_none=True)
- kind = tl.Unicode()
- ecutwfc = tl.Float(allow_none=True)
- ecutrho = tl.Float(allow_none=True)
+ pseudo = tl.Instance(UpfData, allow_none=True)
+ cutoffs = tl.List(tl.Float(), [])
error_message = tl.Unicode(allow_none=True)
- cutoffs_message_template = """
- The recommened ecutwfc: {ecutwfc} Ry
- for ecutrho: {ecutrho} Ry
-
"""
-
- def __init__(
- self,
- kind: str = "",
- pseudo: UpfData | None = None,
- cutoffs: dict | None = None,
- **kwargs,
- ):
+ def __init__(self, kind, **kwargs):
+ super().__init__(
+ children=[LoadingWidget("Loading pseudopotential uploader")],
+ **kwargs,
+ )
+
self.kind = kind
+
+ self.rendered = False
+
+ def render(self):
+ if self.rendered:
+ return
+
self.file_upload = ipw.FileUpload(
description="Upload",
multiple=False,
)
- self.pseudo_text = ipw.Text(description=kind)
- self.file_upload.observe(self._on_file_upload, names="value")
+ self.pseudo_text = ipw.Text(description=self.kind)
+ self.file_upload.observe(self._on_file_upload, "value")
- self._cutoff_message = ipw.HTML(
- self.cutoffs_message_template.format(ecutwfc=0, ecutrho=0)
+ cutoffs_message_template = """
+
+ Recommended ecutwfc: {ecutwfc} Ry ecutrho: {ecutrho} Ry
+
+ """
+
+ self.cutoff_message = ipw.HTML()
+
+ pseudo_link = ipw.dlink(
+ (self, "pseudo"),
+ (self.pseudo_text, "value"),
+ lambda p: p.filename if p else "",
)
- if pseudo is not None:
- self.pseudo = pseudo
- self.pseudo_text.value = pseudo.filename
+ cutoff_link = ipw.dlink(
+ (self, "cutoffs"),
+ (self.cutoff_message, "value"),
+ lambda c: cutoffs_message_template.format(
+ ecutwfc=c[0] if len(c) else "not set",
+ ecutrho=c[1] if len(c) else "not set",
+ ),
+ )
+
+ self.links = [pseudo_link, cutoff_link]
self.error_message = None
- super().__init__(
- children=[
- self.pseudo_text,
- self.file_upload,
- self._cutoff_message,
- ],
- **kwargs,
- )
- # set the widget directly to trigger the traitlets set
- if cutoffs is not None:
- self.ecutwfc = cutoffs.get("cutoff_wfc", None)
- self.ecutrho = cutoffs.get("cutoff_rho", None)
- self._cutoff_message.value = self.cutoffs_message_template.format(
- ecutwfc=self.ecutwfc or "not set", ecutrho=self.ecutrho or "not set"
- )
+
+ self.children = [
+ self.pseudo_text,
+ self.file_upload,
+ self.cutoff_message,
+ ]
+
+ self.rendered = True
def _on_file_upload(self, change=None):
"""When file upload button is pressed."""
@@ -559,5 +433,4 @@ def _on_file_upload(self, change=None):
def _reset(self):
"""Reset the widget to the initial state."""
self.pseudo = None
- self.ecutrho = None
- self.ecutwfc = None
+ self.cutoffs = []
diff --git a/src/aiidalab_qe/app/configuration/smearing.py b/src/aiidalab_qe/app/configuration/smearing.py
new file mode 100644
index 000000000..55a517e49
--- /dev/null
+++ b/src/aiidalab_qe/app/configuration/smearing.py
@@ -0,0 +1,77 @@
+import ipywidgets as ipw
+
+from .model import ConfigurationModel
+
+
+class SmearingSettings(ipw.VBox):
+ def __init__(self, model: ConfigurationModel, **kwargs):
+ from aiidalab_qe.common.widgets import LoadingWidget
+
+ super().__init__(
+ layout={"justify_content": "space-between", **kwargs.get("layout", {})},
+ children=[LoadingWidget("Loading smearing settings widget")],
+ **kwargs,
+ )
+
+ self._model = model
+
+ self.rendered = False
+
+ def render(self):
+ if self.rendered:
+ return
+
+ self.smearing = ipw.Dropdown(
+ options=["cold", "gaussian", "fermi-dirac", "methfessel-paxton"],
+ description="Smearing type:",
+ disabled=False,
+ style={"description_width": "initial"},
+ )
+ ipw.link(
+ (self._model.advanced.smearing, "type"),
+ (self.smearing, "value"),
+ )
+ ipw.dlink(
+ (self._model.advanced, "override"),
+ (self.smearing, "disabled"),
+ lambda override: not override,
+ )
+
+ self.degauss = ipw.FloatText(
+ step=0.005,
+ description="Smearing width (Ry):",
+ disabled=False,
+ style={"description_width": "initial"},
+ )
+ ipw.link(
+ (self._model.advanced.smearing, "degauss"),
+ (self.degauss, "value"),
+ )
+ ipw.dlink(
+ (self._model.advanced, "override"),
+ (self.degauss, "disabled"),
+ lambda override: not override,
+ )
+
+ self.children = [
+ ipw.HTML("""
+
+ The smearing type and width is set by the chosen protocol.
+ Tick the box to override the default, not advised unless you've
+ mastered smearing effects (click
+ here for a discussion).
+
+ """),
+ ipw.HBox(
+ children=[
+ self.smearing,
+ self.degauss,
+ ]
+ ),
+ ]
+
+ self.rendered = True
+
+ def reset(self):
+ self._model.advanced.smearing.reset()
diff --git a/src/aiidalab_qe/app/configuration/workflow.py b/src/aiidalab_qe/app/configuration/workflow.py
index cba0df521..ea73f8f21 100644
--- a/src/aiidalab_qe/app/configuration/workflow.py
+++ b/src/aiidalab_qe/app/configuration/workflow.py
@@ -3,120 +3,148 @@
Authors: AiiDAlab team
"""
+import typing as t
+
import ipywidgets as ipw
-import traitlets as tl
-from aiida import orm
-from aiida_quantumespresso.common.types import RelaxType
-from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
from aiidalab_qe.app.utils import get_entry_items
-from aiidalab_qe.common.panel import Panel
+from aiidalab_qe.common.panel import SettingsModel, SettingsPanel
-class WorkChainSettings(Panel):
+class WorkChainSettings(SettingsPanel):
+ title = "Basic Settings"
identifier = "workchain"
- structure_title = ipw.HTML(
- """
-
Structure
"""
- )
- structure_help = ipw.HTML(
- """
- You have three options:
- (1) Structure as is: perform a self consistent calculation using the structure provided as input.
- (2) Atomic positions: perform a full relaxation of the internal atomic coordinates.
- (3) Full geometry: perform a full relaxation for both the internal atomic coordinates and the cell vectors.
"""
- )
- materials_help = ipw.HTML(
- """
- Below you can indicate both if the material should be treated as an insulator
- or a metal (if in doubt, choose "Metal"),
- and if it should be studied with magnetization/spin polarization,
- switch magnetism On or Off (On is at least twice more costly).
-
"""
- )
-
- properties_title = ipw.HTML(
- """
-
Properties
"""
- )
- protocol_title = ipw.HTML(
- """
-
Protocol
"""
- )
- protocol_help = ipw.HTML(
- """
- The "moderate" protocol represents a trade-off between
- accuracy and speed. Choose the "fast" protocol for a faster calculation
- with less precision and the "precise" protocol to aim at best accuracy (at the price of longer/costlier calculations).
"""
- )
-
- input_structure = tl.Instance(orm.StructureData, allow_none=True)
-
- def __init__(self, **kwargs):
+ def fetch_setting_entries(
+ self,
+ register_setting_callback: t.Callable[[str, SettingsPanel], None],
+ update_tabs_callback: t.Callable[[], None],
+ ):
+ self.property_children = [ipw.HTML("Select which properties to calculate:")]
+
+ outlines = get_entry_items("aiidalab_qe.properties", "outline")
+ models = get_entry_items("aiidalab_qe.properties", "model")
+ settings = get_entry_items("aiidalab_qe.properties", "setting")
+ for identifier in settings:
+ model: SettingsModel = models[identifier]()
+ self._config_model.add_model(identifier, model)
+
+ outline = outlines[identifier]()
+ info = ipw.HTML()
+ ipw.link(
+ (model, "include"),
+ (outline.include, "value"),
+ )
+
+ if identifier == "bands":
+ ipw.dlink(
+ (self._config_model, "has_pbc"),
+ (outline.include, "disabled"),
+ lambda periodic: not periodic,
+ )
+
+ def toggle_plugin(change, identifier=identifier, info=info):
+ if change["new"]:
+ info.value = (
+ f"Customize {identifier} settings in the panel above if needed"
+ )
+ else:
+ info.value = ""
+ update_tabs_callback()
+
+ model.observe(
+ toggle_plugin,
+ "include",
+ )
+
+ self.property_children.append(
+ ipw.HBox(
+ children=[
+ outline,
+ info,
+ ]
+ )
+ )
+
+ register_setting_callback(identifier, settings[identifier])
+
+ def update(self):
+ if self.updated:
+ return
+ self._model.update()
+ self.updated = True
+
+ def render(self):
+ if self.rendered:
+ return
+
# RelaxType: degrees of freedom in geometry optimization
+ self.relax_type_help = ipw.HTML()
+ ipw.dlink(
+ (self._model, "relax_type_help"),
+ (self.relax_type_help, "value"),
+ )
self.relax_type = ipw.ToggleButtons(
options=[
("Structure as is", "none"),
("Atomic positions", "positions"),
("Full geometry", "positions_cell"),
],
- value="positions_cell",
+ )
+ ipw.dlink(
+ (self._model, "relax_type_options"),
+ (self.relax_type, "options"),
+ )
+ ipw.link(
+ (self._model, "relax_type"),
+ (self.relax_type, "value"),
)
# SpinType: magnetic properties of material
self.spin_type = ipw.ToggleButtons(
options=[("Off", "none"), ("On", "collinear")],
- value=DEFAULT_PARAMETERS["workchain"]["spin_type"],
style={"description_width": "initial"},
)
+ ipw.link(
+ (self._model, "spin_type"),
+ (self.spin_type, "value"),
+ )
# ElectronicType: electronic properties of material
self.electronic_type = ipw.ToggleButtons(
options=[("Metal", "metal"), ("Insulator", "insulator")],
- value=DEFAULT_PARAMETERS["workchain"]["electronic_type"],
style={"description_width": "initial"},
)
+ ipw.link(
+ (self._model, "electronic_type"),
+ (self.electronic_type, "value"),
+ )
# Work chain protocol
- self.workchain_protocol = ipw.ToggleButtons(
+ self.protocol = ipw.ToggleButtons(
options=["fast", "moderate", "precise"],
- value="moderate",
)
- self.properties = {}
- self.reminder_info = {}
- self.property_children = [
- self.properties_title,
- ipw.HTML("Select which properties to calculate:"),
- ]
- entries = get_entry_items("aiidalab_qe.properties", "outline")
- setting_entries = get_entry_items("aiidalab_qe.properties", "setting")
- for name, entry_point in entries.items():
- self.properties[name] = entry_point()
- self.reminder_info[name] = ipw.HTML()
- self.property_children.append(
- ipw.HBox([self.properties[name], self.reminder_info[name]])
- )
-
- # observer change to update the reminder text
- def update_reminder_info(change, name=name):
- if change["new"]:
- self.reminder_info[
- name
- ].value = (
- f"""Customize {name} settings in the panel above if needed."""
- )
- else:
- self.reminder_info[name].value = ""
-
- if name in setting_entries:
- self.properties[name].run.observe(update_reminder_info, "value")
+ ipw.link(
+ (self._model, "protocol"),
+ (self.protocol, "value"),
+ )
self.children = [
- self.structure_title,
- self.structure_help,
+ ipw.HTML("""
+
+
Structure
+
+ """),
+ self.relax_type_help,
self.relax_type,
- self.materials_help,
+ ipw.HTML("""
+
+ Below you can indicate both if the material should be treated as an
+ insulator or a metal (if in doubt, choose "Metal"), and if it
+ should be studied with magnetization/spin polarization, switch
+ magnetism On or Off (On is at least twice more costly).
+
+ """),
ipw.HBox(
children=[
ipw.Label(
@@ -136,100 +164,25 @@ def update_reminder_info(change, name=name):
]
),
*self.property_children,
- self.protocol_title,
+ ipw.HTML("""
+
+
Protocol
+
+ """),
ipw.HTML("Select the protocol:", layout=ipw.Layout(flex="1 1 auto")),
- self.workchain_protocol,
- self.protocol_help,
+ self.protocol,
+ ipw.HTML("""
+
+ The "moderate" protocol represents a trade-off between accuracy and
+ speed. Choose the "fast" protocol for a faster calculation with
+ less precision and the "precise" protocol to aim at best accuracy
+ (at the price of longer/costlier calculations).
+
+ """),
]
- super().__init__(
- **kwargs,
- )
- @tl.observe("input_structure")
- def _on_input_structure_change(self, change):
- """Update the relax type options based on the input structure."""
- structure = change["new"]
- if structure is None or structure.pbc != (False, False, False):
- self.relax_type.options = [
- ("Structure as is", "none"),
- ("Atomic positions", "positions"),
- ("Full geometry", "positions_cell"),
- ]
- # Ensure the value is in the options
- if self.relax_type.value not in [
- option[1] for option in self.relax_type.options
- ]:
- self.relax_type.value = "positions_cell"
-
- self.properties["bands"].run.disabled = False
- elif structure.pbc == (False, False, False):
- self.relax_type.options = [
- ("Structure as is", "none"),
- ("Atomic positions", "positions"),
- ]
- # Ensure the value is in the options
- if self.relax_type.value not in [
- option[1] for option in self.relax_type.options
- ]:
- self.relax_type.value = "positions"
-
- self.properties["bands"].run.value = False
- self.properties["bands"].run.disabled = True
-
- def get_panel_value(self):
- # Work chain settings
- relax_type = self.relax_type.value
- electronic_type = self.electronic_type.value
- spin_type = self.spin_type.value
-
- protocol = self.workchain_protocol.value
-
- properties = []
-
- # add plugin specific settings
- run_bands = False
- run_pdos = False
- for name in self.properties:
- if self.properties[name].run.value:
- properties.append(name)
- if name == "bands":
- run_bands = True
- elif name == "pdos":
- run_bands = True
-
- if RelaxType(relax_type) is not RelaxType.NONE or not (run_bands or run_pdos):
- properties.append("relax")
- return {
- "protocol": protocol,
- "relax_type": relax_type,
- "properties": properties,
- "spin_type": spin_type,
- "electronic_type": electronic_type,
- }
-
- def set_panel_value(self, parameters):
- """Update the settings based on the given dict."""
- for key in [
- "relax_type",
- "spin_type",
- "electronic_type",
- ]:
- if key in parameters:
- getattr(self, key).value = parameters[key]
- if "protocol" in parameters:
- self.workchain_protocol.value = parameters["protocol"]
- properties = parameters.get("properties", [])
- for name in self.properties:
- if name in properties:
- self.properties[name].run.value = True
- else:
- self.properties[name].run.value = False
+ self.rendered = True
def reset(self):
- """Reset the panel to the default value."""
- self.input_structure = None
- for key in ["relax_type", "spin_type", "electronic_type"]:
- getattr(self, key).value = DEFAULT_PARAMETERS["workchain"][key]
- self.workchain_protocol.value = DEFAULT_PARAMETERS["workchain"]["protocol"]
- for key, p in self.properties.items():
- p.run.value = key in DEFAULT_PARAMETERS["workchain"]["properties"]
+ self._model.reset()
+ self.updated = False
diff --git a/src/aiidalab_qe/app/main.py b/src/aiidalab_qe/app/main.py
index 471cba3f0..c56359a59 100644
--- a/src/aiidalab_qe/app/main.py
+++ b/src/aiidalab_qe/app/main.py
@@ -8,10 +8,16 @@
from IPython.display import Javascript, display
from aiida.orm import load_node
+from aiida.orm.utils.serialize import deserialize_unsafe
from aiidalab_qe.app.configuration import ConfigureQeAppWorkChainStep
+from aiidalab_qe.app.configuration.model import ConfigurationModel
from aiidalab_qe.app.result import ViewQeAppWorkChainStatusAndResultsStep
+from aiidalab_qe.app.result.model import ResultsModel
from aiidalab_qe.app.structure import StructureSelectionStep
+from aiidalab_qe.app.structure.model import StructureModel
from aiidalab_qe.app.submission import SubmitQeAppWorkChainStep
+from aiidalab_qe.app.submission.model import SubmissionModel
+from aiidalab_qe.common.widgets import LoadingWidget
from aiidalab_widgets_base import WizardAppWidget, WizardAppWidgetStep
@@ -22,40 +28,50 @@ class App(ipw.VBox):
process = tl.Union([tl.Unicode(), tl.Int()], allow_none=True)
def __init__(self, qe_auto_setup=True):
+ # Initialize the models
+ self.struct_model = StructureModel()
+ self.config_model = ConfigurationModel()
+ self.submit_model = SubmissionModel()
+ self.results_model = ResultsModel()
+
# Create the application steps
- self.structure_step = StructureSelectionStep(auto_advance=True)
- self.structure_step.observe(self._observe_structure_selection, "structure")
- self.configure_step = ConfigureQeAppWorkChainStep(auto_advance=True)
+ self.structure_step = StructureSelectionStep(
+ model=self.struct_model,
+ auto_advance=True,
+ )
+ self.configure_step = ConfigureQeAppWorkChainStep(
+ model=self.config_model,
+ auto_advance=True,
+ )
self.submit_step = SubmitQeAppWorkChainStep(
+ model=self.submit_model,
auto_advance=True,
qe_auto_setup=qe_auto_setup,
)
- self.results_step = ViewQeAppWorkChainStatusAndResultsStep()
+ self.results_step = ViewQeAppWorkChainStatusAndResultsStep(
+ model=self.results_model,
+ )
- # Link the application steps
+ # Wizard step observations
ipw.dlink(
(self.structure_step, "state"),
(self.configure_step, "previous_step_state"),
)
- ipw.dlink(
- (self.structure_step, "confirmed_structure"),
- (self.submit_step, "input_structure"),
- )
- ipw.dlink(
- (self.structure_step, "confirmed_structure"),
- (self.configure_step, "input_structure"),
+ self.struct_model.observe(
+ self._on_structure_confirmation_change,
+ "confirmed",
)
ipw.dlink(
(self.configure_step, "state"),
(self.submit_step, "previous_step_state"),
)
- ipw.dlink(
- (self.configure_step, "configuration_parameters"),
- (self.submit_step, "input_parameters"),
+ self.config_model.observe(
+ self._on_configuration_confirmation_change,
+ "confirmed",
)
ipw.dlink(
- (self.submit_step, "process"),
- (self.results_step, "process"),
+ (self.submit_model, "process"),
+ (self.results_model, "process"),
transform=lambda node: node.uuid if node is not None else None,
)
@@ -68,93 +84,113 @@ def __init__(self, qe_auto_setup=True):
("Status & Results", self.results_step),
]
)
- # hide the header
- self._wizard_app_widget.children[0].layout.display = "none"
- self._wizard_app_widget.observe(self._observe_selected_index, "selected_index")
+ self._wizard_app_widget.observe(
+ self._on_step_change,
+ "selected_index",
+ )
+
+ # Hide the header
+ self._wizard_app_widget.children[0].layout.display = "none" # type: ignore
# Add a button to start a new calculation
- self.new_work_chains_button = ipw.Button(
- description="Start New Calculation",
- tooltip="Open a new page to start a separate calculation",
+ self.new_workchain_button = ipw.Button(
+ layout=ipw.Layout(width="auto"),
button_style="success",
icon="plus-circle",
- layout=ipw.Layout(width="30%"),
+ description="Start New Calculation",
+ tooltip="Open a new page to start a separate calculation",
)
- def on_button_click(_):
- display(Javascript("window.open('./qe.ipynb', '_blank')"))
+ self.new_workchain_button.on_click(self._on_new_workchain_button_click)
- self.new_work_chains_button.on_click(on_button_click)
+ self._process_loading_message = LoadingWidget(
+ message="Loading process",
+ layout=ipw.Layout(display="none"),
+ )
super().__init__(
children=[
- self.new_work_chains_button,
+ self.new_workchain_button,
+ self._process_loading_message,
self._wizard_app_widget,
]
)
+ self._wizard_app_widget.selected_index = None
+
+ self._update_blockers()
+
@property
def steps(self):
return self._wizard_app_widget.steps
- # Reset the confirmed_structure in case that a new structure is selected
- def _observe_structure_selection(self, change):
- with self.structure_step.hold_sync():
- if (
- self.structure_step.confirmed_structure is not None
- and self.structure_step.confirmed_structure != change["new"]
- ):
- self.structure_step.confirmed_structure = None
-
- def _observe_selected_index(self, change):
- """Check unsaved change in the step when leaving the step."""
- # no accordion tab is selected
- if not change["new"]:
- return
- new_idx = change["new"]
- # only when entering the submit step, check and udpate the blocker messages
- # steps[new_idx][0] is the title of the step
- if self.steps[new_idx][1] is not self.submit_step:
- return
- blockers = []
- # Loop over all steps before the submit step
- for title, step in self.steps[:new_idx]:
- # check if the step is saved
- if not step.is_saved():
- step.state = WizardAppWidgetStep.State.CONFIGURED
- blockers.append(
- f"Unsaved changes in the {title} step. Please save the changes before submitting."
- )
- self.submit_step.external_submission_blockers = blockers
-
@tl.observe("process")
- def _observe_process(self, change):
- from aiida.orm.utils.serialize import deserialize_unsafe
+ def _on_process_change(self, change):
+ self._update_from_process(change["new"])
+
+ def _on_new_workchain_button_click(self, _):
+ display(Javascript("window.open('./qe.ipynb', '_blank')"))
+
+ def _on_step_change(self, change):
+ if (step_index := change["new"]) is not None:
+ self._render_step(step_index)
- if change["old"] == change["new"]:
- return
- pk = change["new"]
+ def _on_structure_confirmation_change(self, _):
+ if self.struct_model.confirmed:
+ self.config_model.input_structure = self.struct_model.structure
+ else:
+ self.config_model.input_structure = None
+ self._update_blockers()
+
+ def _on_configuration_confirmation_change(self, _):
+ if self.config_model.confirmed:
+ self.submit_model.input_structure = self.struct_model.structure
+ self.submit_model.input_parameters = self.config_model.get_model_state()
+ else:
+ self.submit_model.input_structure = None
+ self.submit_model.input_parameters = {}
+ self._update_blockers()
+
+ def _render_step(self, step_index):
+ step = self.steps[step_index][1]
+ step.render()
+ if step is self.structure_step:
+ # HACK to fix the rendering of the ngl viewer
+ # Reason: If a process is loaded prior to the rendering of the
+ # structure step, the ngl viewer will assume a size of zero.
+ # This code will reset the size and render the selected structure.
+ self.structure_step.manager.viewer._viewer.handle_resize()
+ self.structure_step.manager.viewer._viewer.control.zoom(0)
+
+ def _update_blockers(self):
+ self.submit_model.external_submission_blockers = [
+ f"Unsaved changes in the {title} step. Please confirm the changes before submitting."
+ for title, step in self.steps[:2]
+ if not step.is_saved()
+ ]
+
+ def _update_from_process(self, pk):
if pk is None:
self._wizard_app_widget.reset()
self._wizard_app_widget.selected_index = 0
else:
+ self._show_process_loading_message()
process = load_node(pk)
- with self.structure_step.manager.hold_sync():
- with self.structure_step.hold_sync():
- self._wizard_app_widget.selected_index = 3
- self.structure_step.manager.viewer.structure = (
- process.inputs.structure.get_ase()
- )
- self.structure_step.structure = process.inputs.structure
- self.structure_step.confirm()
- self.submit_step.process = process
-
- # set ui_parameters
- # print out error message if yaml format ui_parameters is not reachable
- ui_parameters = process.base.extras.get("ui_parameters", {})
- if ui_parameters and isinstance(ui_parameters, str):
- ui_parameters = deserialize_unsafe(ui_parameters)
- self.configure_step.set_configuration_parameters(ui_parameters)
- self.configure_step.confirm()
- self.submit_step.set_submission_parameters(ui_parameters)
- self.submit_step.state = self.submit_step.State.SUCCESS
+ self._wizard_app_widget.selected_index = 3
+ self.struct_model.structure = process.inputs.structure
+ self.struct_model.confirmed = True
+ parameters = process.base.extras.get("ui_parameters", {})
+ if parameters and isinstance(parameters, str):
+ parameters = deserialize_unsafe(parameters)
+ self.config_model.set_model_state(parameters)
+ self.config_model.confirmed = True
+ self.submit_model.process = process
+ self.submit_model.set_model_state(parameters)
+ self.submit_step.state = WizardAppWidgetStep.State.SUCCESS
+ self._hide_process_loading_message()
+
+ def _show_process_loading_message(self):
+ self._process_loading_message.layout.display = "flex"
+
+ def _hide_process_loading_message(self):
+ self._process_loading_message.layout.display = "none"
diff --git a/src/aiidalab_qe/app/result/__init__.py b/src/aiidalab_qe/app/result/__init__.py
index a98bed779..ee6944c67 100644
--- a/src/aiidalab_qe/app/result/__init__.py
+++ b/src/aiidalab_qe/app/result/__init__.py
@@ -1,9 +1,7 @@
import ipywidgets as ipw
import traitlets as tl
-from aiida import orm
from aiida.engine import ProcessState
-from aiida.engine.processes import control
from aiidalab_widgets_base import (
AiidaNodeViewWidget,
ProcessMonitor,
@@ -11,6 +9,8 @@
WizardAppWidgetStep,
)
+from .model import ResultsModel
+
# trigger registration of the viewer widget:
from .workchain_viewer import WorkChainViewer # noqa: F401
@@ -20,12 +20,29 @@
class ViewQeAppWorkChainStatusAndResultsStep(ipw.VBox, WizardAppWidgetStep):
- process = tl.Unicode(allow_none=True)
+ def __init__(self, model: ResultsModel, **kwargs):
+ from aiidalab_qe.common.widgets import LoadingWidget
+
+ super().__init__(
+ children=[LoadingWidget("Loading results panel")],
+ **kwargs,
+ )
+
+ self._model = model
+ self._model.observe(
+ self._on_process_change,
+ "process",
+ )
+
+ self.rendered = False
+
+ def render(self):
+ if self.rendered:
+ return
- def __init__(self, **kwargs):
self.process_tree = ProcessNodesTreeWidget()
ipw.dlink(
- (self, "process"),
+ (self._model, "process"),
(self.process_tree, "value"),
)
@@ -35,9 +52,14 @@ def __init__(self, **kwargs):
(self.node_view, "node"),
transform=lambda nodes: nodes[0] if nodes else None,
)
- self.process_status = ipw.VBox(children=[self.process_tree, self.node_view])
- # Setup process monitor
+ self.process_status = ipw.VBox(
+ children=[
+ self.process_tree,
+ self.node_view,
+ ],
+ )
+
self.process_monitor = ProcessMonitor(
timeout=0.2,
callbacks=[
@@ -45,173 +67,144 @@ def __init__(self, **kwargs):
self._update_state,
],
)
- ipw.dlink((self, "process"), (self.process_monitor, "value"))
+ ipw.dlink(
+ (self._model, "process"),
+ (self.process_monitor, "value"),
+ )
self.kill_button = ipw.Button(
description="Kill workchain",
tooltip="Kill the below workchain.",
button_style="danger",
icon="stop",
- layout=ipw.Layout(width="120px", display="none", margin="0px 20px 0px 0px"),
+ layout=ipw.Layout(width="auto", display="none"),
+ )
+ ipw.dlink(
+ (self, "state"),
+ (self.kill_button, "disabled"),
+ lambda state: state is not self.State.ACTIVE,
)
- self.kill_button.on_click(self._on_click_kill_button)
+ self.kill_button.on_click(self._on_kill_button_click)
+
+ self.update_results_button = ipw.Button(
+ description="Update results",
+ tooltip="Trigger the update of the results.",
+ button_style="success",
+ icon="refresh",
+ layout=ipw.Layout(width="auto", display="block"),
+ )
+ self.update_results_button.on_click(self._on_update_results_button_click)
self.clean_scratch_button = ipw.Button(
description="Clean remote data",
tooltip="Clean the remote folders of the workchain.",
button_style="danger",
icon="trash",
- layout=ipw.Layout(width="150px", display="none", margin="0px 20px 0px 0px"),
+ layout=ipw.Layout(width="auto", display="none"),
)
- self.clean_scratch_button.on_click(self._on_click_clean_scratch_button)
- self.update_result_button = ipw.Button(
- description="Update results tabs",
- tooltip="Trigger the update of the results tabs.",
- button_style="success",
- icon="refresh",
- layout=ipw.Layout(
- width="150px", display="block", margin="0px 20px 0px 0px"
- ),
+ ipw.dlink(
+ (self._model, "process_remote_folder_is_clean"),
+ (self.clean_scratch_button, "disabled"),
)
- self.update_result_button.on_click(self._on_click_update_result_button)
+ self.clean_scratch_button.on_click(self._on_clean_scratch_button_click)
self.process_info = ipw.HTML()
-
- super().__init__(
- [
- self.process_info,
- ipw.HBox(
- children=[
- self.kill_button,
- self.update_result_button,
- self.clean_scratch_button,
- ]
- ),
- self.process_status,
- ],
- **kwargs,
+ ipw.dlink(
+ (self._model, "process_info"),
+ (self.process_info, "value"),
)
+ self.children = [
+ self.process_info,
+ ipw.HBox(
+ children=[
+ self.kill_button,
+ self.update_results_button,
+ self.clean_scratch_button,
+ ],
+ layout=ipw.Layout(margin="0 3px"),
+ ),
+ self.process_status,
+ ]
+
+ self.rendered = True
+
self._update_kill_button_layout()
+ self._update_clean_scratch_button_layout()
def can_reset(self):
- "Do not allow reset while process is running."
return self.state is not self.State.ACTIVE
def reset(self):
- self.process = None
-
- def _update_state(self):
- """Based on the process state, update the state of the step."""
- if self.process is None:
- self.state = self.State.INIT
- else:
- process = orm.load_node(self.process)
- process_state = process.process_state
- if process_state in (
- ProcessState.CREATED,
- ProcessState.RUNNING,
- ProcessState.WAITING,
- ):
- self.state = self.State.ACTIVE
- self.process_info.value = PROCESS_RUNNING
- elif (
- process_state in (ProcessState.EXCEPTED, ProcessState.KILLED)
- or process.is_failed
- ):
- self.state = self.State.FAIL
- self.process_info.value = PROCESS_EXCEPTED
- elif process.is_finished_ok:
- self.state = self.State.SUCCESS
- self.process_info.value = PROCESS_COMPLETED
- # trigger the update of kill and clean button.
- if self.state in [self.State.SUCCESS, self.State.FAIL]:
- self._update_kill_button_layout()
- self._update_clean_scratch_button_layout()
+ self._model.reset()
- def _update_kill_button_layout(self):
- """Update the layout of the kill button."""
- # If no process is selected, hide the button.
- if self.process is None or self.process == "":
- self.kill_button.layout.display = "none"
- else:
- process = orm.load_node(self.process)
- # If the process is terminated, hide the button.
- if process.is_terminated:
- self.kill_button.layout.display = "none"
- else:
- self.kill_button.layout.display = "block"
-
- # If the step is not activated, no point to click the button, so disable it.
- # Only enable it if the process is on (RUNNING, CREATED, WAITING).
- if self.state is self.State.ACTIVE:
- self.kill_button.disabled = False
- else:
- self.kill_button.disabled = True
+ @tl.observe("state")
+ def _on_state_change(self, _):
+ self._update_kill_button_layout()
- def _update_clean_scratch_button_layout(self):
- """Update the layout of the kill button."""
- # The button is hidden by default, but if we load a new process, we hide again.
- if not self.process:
- self.clean_scratch_button.layout.display = "none"
- else:
- process = orm.load_node(self.process)
- # If the process is terminated, show the button.
- if process.is_terminated:
- self.clean_scratch_button.layout.display = "block"
- else:
- self.clean_scratch_button.layout.display = "none"
-
- # If the scratch is already empty, we should deactivate the button.
- # not sure about the performance if descendants are several.
- cleaned_bool = []
- for called_descendant in process.called_descendants:
- if isinstance(called_descendant, orm.CalcJobNode):
- try:
- cleaned_bool.append(
- called_descendant.outputs.remote_folder.is_empty
- )
- except Exception:
- pass
- self.clean_scratch_button.disabled = all(cleaned_bool)
-
- def _on_click_kill_button(self, _=None):
- """callback for the kill button.
- First kill the process, then update the kill button layout.
- """
- workchain = [orm.load_node(self.process)]
- control.kill_processes(workchain)
-
- # update the kill button layout
+ def _on_process_change(self, _):
+ self._model.update()
+ self._update_state()
self._update_kill_button_layout()
+ self._update_clean_scratch_button_layout()
- def _on_click_clean_scratch_button(self, _=None):
- """callback for the clean scratch button.
- First clean the remote folders, then update the clean button layout.
- """
- process = orm.load_node(self.process)
+ def _on_kill_button_click(self, _):
+ self._model.kill_process()
+ self._update_kill_button_layout()
- for called_descendant in process.called_descendants:
- if isinstance(called_descendant, orm.CalcJobNode):
- try:
- called_descendant.outputs.remote_folder._clean()
- except Exception:
- pass
+ def _on_update_results_button_click(self, _):
+ self.node_view.node = None
+ self.node_view.node = self._model.process_node
- # update the kill button layout
+ def _on_clean_scratch_button_click(self, _):
+ self._model.clean_remote_data()
self._update_clean_scratch_button_layout()
- def _on_click_update_result_button(self, _=None):
- """Trigger the update of the results tabs."""
- # change the node to trigger the update of the view.
- self.node_view.node = None
- self.node_view.node = orm.load_node(self.process)
+ def _update_kill_button_layout(self):
+ if not self.rendered:
+ return
+ if (
+ not self._model.process
+ or self._model.process_node.is_finished
+ or self._model.process_node.is_excepted
+ or self.state
+ in (
+ self.State.SUCCESS,
+ self.State.FAIL,
+ )
+ ):
+ self.kill_button.layout.display = "none"
+ else:
+ self.kill_button.layout.display = "block"
- @tl.observe("process")
- def _observe_process(self, _):
- """Callback for when the process is changed."""
- # The order of the following calls matters,
- # as the self.state is updated in the _update_state method.
- self._update_state()
- self._update_kill_button_layout()
- self._update_clean_scratch_button_layout()
+ def _update_clean_scratch_button_layout(self):
+ if not self.rendered:
+ return
+ if self._model.process and self._model.process_node.is_terminated:
+ self.clean_scratch_button.layout.display = "block"
+ else:
+ self.clean_scratch_button.layout.display = "none"
+
+ def _update_state(self):
+ if not self._model.process:
+ self.state = self.State.INIT
+ elif self._model.process_node.process_state in (
+ ProcessState.CREATED,
+ ProcessState.RUNNING,
+ ProcessState.WAITING,
+ ):
+ self.state = self.State.ACTIVE
+ self._model.process_info = PROCESS_RUNNING
+ elif (
+ self._model.process_node.process_state
+ in (
+ ProcessState.EXCEPTED,
+ ProcessState.KILLED,
+ )
+ or self._model.process_node.is_failed
+ ):
+ self.state = self.State.FAIL
+ self._model.process_info = PROCESS_EXCEPTED
+ elif self._model.process_node.is_finished_ok:
+ self.state = self.State.SUCCESS
+ self._model.process_info = PROCESS_COMPLETED
diff --git a/src/aiidalab_qe/app/result/model.py b/src/aiidalab_qe/app/result/model.py
new file mode 100644
index 000000000..b30d5c0ab
--- /dev/null
+++ b/src/aiidalab_qe/app/result/model.py
@@ -0,0 +1,56 @@
+from __future__ import annotations
+
+import contextlib
+
+import traitlets as tl
+
+from aiida import orm
+from aiida.engine.processes import control
+
+
+class ResultsModel(tl.HasTraits):
+ process = tl.Unicode(allow_none=True)
+
+ process_info = tl.Unicode("")
+ process_remote_folder_is_clean = tl.Bool(False)
+
+ _process_node: orm.ProcessNode | None = None
+
+ @property
+ def process_node(self):
+ if self._process_node is None:
+ self._process_node = self._get_process_node()
+ return self._process_node
+
+ def update(self):
+ self._update_process_remote_folder_state()
+
+ def kill_process(self):
+ if process := self._get_process_node():
+ control.kill_processes(process)
+
+ def clean_remote_data(self):
+ if self.process_node is None:
+ return
+ for called_descendant in self.process_node.called_descendants:
+ if isinstance(called_descendant, orm.CalcJobNode):
+ with contextlib.suppress(Exception):
+ called_descendant.outputs.remote_folder._clean()
+ self.process_remote_folder_is_clean = True
+
+ def reset(self):
+ self.process = None
+ self.process_info = ""
+
+ def _get_process_node(self):
+ return orm.load_node(self.process) if self.process else None
+
+ def _update_process_remote_folder_state(self):
+ if self.process_node is None:
+ return
+ cleaned = []
+ for called_descendant in self.process_node.called_descendants:
+ if isinstance(called_descendant, orm.CalcJobNode):
+ with contextlib.suppress(Exception):
+ cleaned.append(called_descendant.outputs.remote_folder.is_empty)
+ self.process_remote_folder_is_clean = all(cleaned)
diff --git a/src/aiidalab_qe/app/result/workchain_viewer.py b/src/aiidalab_qe/app/result/workchain_viewer.py
index 48cf9f734..33fb0bd36 100644
--- a/src/aiidalab_qe/app/result/workchain_viewer.py
+++ b/src/aiidalab_qe/app/result/workchain_viewer.py
@@ -91,13 +91,6 @@ def toggle_camera():
children=[self.title, self.result_tabs],
**kwargs,
)
- # self.process_monitor = ProcessMonitor(
- # timeout=1.0,
- # on_sealed=[
- # self._update_view,
- # ],
- # )
- # ipw.dlink((self, "process_uuid"), (self.process_monitor, "value"))
@property
def node(self):
diff --git a/src/aiidalab_qe/app/static/styles/custom.css b/src/aiidalab_qe/app/static/styles/custom.css
index 0cf6e78a3..053cf8b92 100644
--- a/src/aiidalab_qe/app/static/styles/custom.css
+++ b/src/aiidalab_qe/app/static/styles/custom.css
@@ -38,9 +38,21 @@
margin-bottom: 0.5em;
}
-#loading {
- text-align: center;
+.loading {
+ margin: 0 auto;
+ padding: 5px;
font-size: large;
+ justify-content: center;
+}
+
+.warning {
+ color: red;
+ font-weight: bold;
+}
+
+.pseudo-text {
+ line-height: 140%;
+ padding: 5px 0;
}
footer {
diff --git a/src/aiidalab_qe/app/static/templates/about.jinja b/src/aiidalab_qe/app/static/templates/about.jinja
index 0b617a231..6203b2124 100644
--- a/src/aiidalab_qe/app/static/templates/about.jinja
+++ b/src/aiidalab_qe/app/static/templates/about.jinja
@@ -1,10 +1,9 @@
- The Quantum ESPRESSO app
- (or QE app for short) is a graphical front end for calculating materials properties using
- Quantum ESPRESSO (QE). Each property is calculated by workflows powered by the
- AiiDA engine, and maintained in the
- aiida-quantumespresso plugin and many other plugins developed by the community.
- for AiiDA.
+ The Quantum ESPRESSO app (or QE app for short) is a graphical front end for calculating materials properties using
+ Quantum ESPRESSO (QE).
+ Each property is calculated by workflows powered by the AiiDA
+ engine, and maintained in the aiida-quantumespresso plugin and many other plugins developed by the AiiDA community.
diff --git a/src/aiidalab_qe/app/static/templates/guide.jinja b/src/aiidalab_qe/app/static/templates/guide.jinja
index 14b10f914..d72d7beef 100644
--- a/src/aiidalab_qe/app/static/templates/guide.jinja
+++ b/src/aiidalab_qe/app/static/templates/guide.jinja
@@ -23,7 +23,11 @@
- Completed workflows can be selected at the top of the app.
+ Completed workflows can be viewed in the Job History section.
+
+
+
+ To start a new calculation in a separate tab, click the Start New Calculation button.
diff --git a/src/aiidalab_qe/app/structure/__init__.py b/src/aiidalab_qe/app/structure/__init__.py
index cb5940ccf..ae11052a4 100644
--- a/src/aiidalab_qe/app/structure/__init__.py
+++ b/src/aiidalab_qe/app/structure/__init__.py
@@ -6,16 +6,15 @@
import pathlib
import ipywidgets as ipw
-import traitlets as tl
-import aiida
+from aiida import orm
from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData
+from aiidalab_qe.app.structure.model import StructureModel
from aiidalab_qe.app.utils import get_entry_items
-from aiidalab_qe.common import AddingTagsEditor
+from aiidalab_qe.common import AddingTagsEditor, OptimadeWrapper
from aiidalab_widgets_base import (
BasicCellEditor,
BasicStructureEditor,
- OptimadeQueryWidget,
StructureBrowserWidget,
StructureExamplesWidget,
StructureManagerWidget,
@@ -47,51 +46,92 @@ class StructureSelectionStep(ipw.VBox, WizardAppWidgetStep):
structure importers and the structure editors can be extended by plugins.
"""
- structure = tl.Instance(aiida.orm.StructureData, allow_none=True)
- confirmed_structure = tl.Instance(aiida.orm.StructureData, allow_none=True)
+ def __init__(self, model: StructureModel, **kwargs):
+ from aiidalab_qe.common.widgets import LoadingWidget
+
+ super().__init__(
+ children=[LoadingWidget("Loading structure selection panel")],
+ **kwargs,
+ )
+
+ self._model = model
+ self._model.observe(
+ self._on_confirmation_change,
+ "confirmed",
+ )
+ self._model.observe(
+ self._on_structure_change,
+ "structure",
+ )
+
+ self.rendered = False
+
+ def render(self):
+ """docstring"""
+ if self.rendered:
+ return
- def __init__(self, description=None, **kwargs):
importers = [
StructureUploadWidget(title="Upload file"),
- OptimadeQueryWidget(embedded=False),
+ OptimadeWrapper(embedded=False),
StructureBrowserWidget(
title="AiiDA database",
query_types=(
- aiida.orm.StructureData,
- aiida.orm.CifData,
+ orm.StructureData,
+ orm.CifData,
HubbardStructureData,
),
),
StructureExamplesWidget(title="From Examples", examples=Examples),
]
- # add plugin specific structure importers
- entries = get_entry_items("aiidalab_qe.properties", "importer")
- importers.extend([entry_point() for entry_point in entries.values()])
- # add plugin specific structure editors
+
+ plugin_importers = get_entry_items("aiidalab_qe.properties", "importer")
+ importers.extend([importer() for importer in plugin_importers.values()])
+
editors = [
BasicCellEditor(title="Edit cell"),
BasicStructureEditor(title="Edit structure"),
AddingTagsEditor(title="Edit StructureData"),
]
- entries = get_entry_items("aiidalab_qe.properties", "editor")
- editors.extend([entry_point() for entry_point in entries.values()])
- #
+
+ plugin_editors = get_entry_items("aiidalab_qe.properties", "editor")
+ editors.extend([editor() for editor in plugin_editors.values()])
+
+ # HACK structure manager resets the structure node on initialization,
+ # causing the structure in the model (if exists) to reset. To avoid
+ # this issue, we store the structure in a variable and reassign it
+ # to the model after the initialization of the structure manager.
+ # TODO fix this issue in the structure manager and remove this hack!
+
+ structure = self._model.structure
+
self.manager = StructureManagerWidget(
importers=importers,
editors=editors,
node_class="StructureData",
storable=False,
- configuration_tabs=["Cell", "Selection", "Appearance", "Download"],
+ configuration_tabs=[
+ "Cell",
+ "Selection",
+ "Appearance",
+ "Download",
+ ],
+ )
+ ipw.dlink(
+ (self._model, "structure"),
+ (self.manager, "structure"),
+ lambda structure: structure.get_ase() if structure else None,
+ )
+ ipw.dlink(
+ (self.manager, "structure_node"),
+ (self._model, "structure"),
+ )
+ ipw.link(
+ (self._model, "manager_output"),
+ (self.manager.output, "value"),
)
- if description is None:
- description = ipw.HTML(
- """
-
Select a structure from one of the following sources and then click
- "Confirm" to go to the next step.
- """
- )
- self.description = description
+ self._model.structure = structure
self.structure_name_text = ipw.Text(
placeholder="[No structure selected]",
@@ -99,88 +139,71 @@ def __init__(self, description=None, **kwargs):
disabled=True,
layout=ipw.Layout(width="auto", flex="1 1 auto"),
)
+ ipw.dlink(
+ (self._model, "structure_name"),
+ (self.structure_name_text, "value"),
+ )
self.confirm_button = ipw.Button(
description="Confirm",
tooltip="Confirm the currently selected structure and go to the next step.",
button_style="success",
icon="check-circle",
- disabled=True,
layout=ipw.Layout(width="auto"),
)
+ ipw.dlink(
+ (self, "state"),
+ (self.confirm_button, "disabled"),
+ lambda state: state != self.State.CONFIGURED,
+ )
self.confirm_button.on_click(self.confirm)
- self.message_area = ipw.HTML()
-
- # Create directional link from the (read-only) 'structure_node' traitlet of the
- # structure manager to our 'structure' traitlet:
- ipw.dlink((self.manager, "structure_node"), (self, "structure"))
- super().__init__(
- children=[
- self.description,
- self.manager,
- self.structure_name_text,
- self.message_area,
- self.confirm_button,
- ],
- **kwargs,
+ self.message_area = ipw.HTML()
+ ipw.dlink(
+ (self._model, "message_area"),
+ (self.message_area, "value"),
)
- @tl.default("state")
- def _default_state(self):
- return self.State.INIT
+ self.children = [
+ ipw.HTML("""
+
+ Select a structure from one of the following sources and then
+ click "Confirm" to go to the next step.
+
+ """),
+ self.manager,
+ self.structure_name_text,
+ self.message_area,
+ self.confirm_button,
+ ]
- def _update_state(self):
- if self.structure is None:
- if self.confirmed_structure is None:
- self.state = self.State.READY
- else:
- self.state = self.State.SUCCESS
- else:
- if self.confirmed_structure is None:
- self.state = self.State.CONFIGURED
- else:
- self.state = self.State.SUCCESS
-
- @tl.observe("structure")
- def _observe_structure(self, change):
- structure = change["new"]
- with self.hold_trait_notifications():
- if structure is None:
- self.structure_name_text.value = ""
- self.message_area.value = ""
- else:
- self.structure_name_text.value = str(self.structure.get_formula())
- self._update_state()
-
- @tl.observe("confirmed_structure")
- def _observe_confirmed_structure(self, _):
- with self.hold_trait_notifications():
- self._update_state()
-
- @tl.observe("state")
- def _observe_state(self, change):
- with self.hold_trait_notifications():
- state = change["new"]
- self.confirm_button.disabled = state != self.State.CONFIGURED
- self.manager.disabled = state is self.State.SUCCESS
+ self.rendered = True
+
+ def is_saved(self):
+ return self._model.confirmed
def confirm(self, _=None):
self.manager.store_structure()
- self.confirmed_structure = self.structure
- self.message_area.value = ""
-
- def is_saved(self):
- """Check if the current structure is saved.
- That all changes are confirmed."""
- return self.confirmed_structure == self.structure
+ self._model.message_area = ""
+ self._model.confirmed = True
def can_reset(self):
- return self.confirmed_structure is not None
-
- def reset(self): # unconfirm
- """Reset the widget to its initial state."""
- self.confirmed_structure = None
- self.manager.structure = None
- self.manager.viewer.structure = None
- self.manager.output.value = ""
+ return self._model.confirmed
+
+ def reset(self):
+ self._model.reset()
+
+ def _on_structure_change(self, _):
+ self._model.update_widget_text()
+ self._update_state()
+
+ def _on_confirmation_change(self, _):
+ self._update_state()
+
+ def _update_state(self):
+ if self._model.confirmed:
+ self.state = self.State.SUCCESS
+ elif self._model.structure is None:
+ self.state = self.State.READY
+ else:
+ self.state = self.State.CONFIGURED
diff --git a/src/aiidalab_qe/app/structure/model.py b/src/aiidalab_qe/app/structure/model.py
new file mode 100644
index 000000000..191a60f4f
--- /dev/null
+++ b/src/aiidalab_qe/app/structure/model.py
@@ -0,0 +1,41 @@
+import traitlets as tl
+
+from aiida import orm
+
+
+class StructureModel(tl.HasTraits):
+ structure = tl.Instance(
+ orm.StructureData,
+ allow_none=True,
+ )
+ structure_name = tl.Unicode("")
+ manager_output = tl.Unicode("")
+ message_area = tl.Unicode("")
+ confirmed = tl.Bool(False)
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.observe(
+ self._unconfirm,
+ "structure",
+ )
+
+ def update_widget_text(self):
+ if self.structure is None:
+ self.structure_name = ""
+ self.message_area = ""
+ else:
+ self.manager_output = ""
+ self.structure_name = str(self.structure.get_formula())
+
+ def reset(self):
+ self.structure = None
+ self.structure_name = ""
+ self.manager_output = ""
+ self.message_area = ""
+
+ def _unconfirm(self, _):
+ self.confirmed = False
+
+
+struct_model = StructureModel()
diff --git a/src/aiidalab_qe/app/submission/__init__.py b/src/aiidalab_qe/app/submission/__init__.py
index de0272b63..f58559172 100644
--- a/src/aiidalab_qe/app/submission/__init__.py
+++ b/src/aiidalab_qe/app/submission/__init__.py
@@ -5,98 +5,127 @@
from __future__ import annotations
-import os
-
import ipywidgets as ipw
import traitlets as tl
-from aiida import orm
-from aiida.common import NotExistent
-from aiida.engine import ProcessBuilderNamespace, submit
from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
from aiidalab_qe.app.utils import get_entry_items
from aiidalab_qe.common.setup_codes import QESetupWidget
from aiidalab_qe.common.setup_pseudos import PseudosInstallWidget
-from aiidalab_qe.common.widgets import (
- PwCodeResourceSetupWidget,
- QEAppComputationalResourcesWidget,
-)
-from aiidalab_qe.workflows import QeAppWorkChain
+from aiidalab_qe.common.widgets import LoadingWidget, PwCodeResourceSetupWidget
from aiidalab_widgets_base import WizardAppWidgetStep
+from .code import CodeModel, PluginCodes
+from .model import SubmissionModel
+
+DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore
+
class SubmitQeAppWorkChainStep(ipw.VBox, WizardAppWidgetStep):
"""Step for submission of a bands workchain."""
- codes_title = ipw.HTML(
- """
-
Codes
"""
- )
- codes_help = ipw.HTML(
- """ Select the code to use for running the calculations. The codes
- on the local machine (localhost) are installed by default, but you can
- configure new ones on potentially more powerful machines by clicking on
- "Setup new code".
"""
- )
- process_label_help = ipw.HTML(
- """
-
Labeling Your Job
-
Label your job and provide a brief description. These details help identify the job later and make the search process easier. While optional, adding a description is recommended for better clarity.
-
"""
- )
-
- # This number provides a rough estimate for how many MPI tasks are needed
- # for a given structure.
- NUM_SITES_PER_MPI_TASK_DEFAULT = 6
-
- # Warn the user if they are trying to run calculations for a large
- # structure on localhost.
- RUN_ON_LOCALHOST_NUM_SITES_WARN_THRESHOLD = 10
- RUN_ON_LOCALHOST_VOLUME_WARN_THRESHOLD = 1000 # \AA^3
-
- # Put a limit on how many MPI tasks you want to run per k-pool by default
- MAX_MPI_PER_POOL = 20
-
- input_structure = tl.Instance(orm.StructureData, allow_none=True)
- process = tl.Instance(orm.WorkChainNode, allow_none=True)
previous_step_state = tl.UseEnum(WizardAppWidgetStep.State)
- input_parameters = tl.Dict()
- internal_submission_blockers = tl.List(tl.Unicode())
- external_submission_blockers = tl.List(tl.Unicode())
- def __init__(self, qe_auto_setup=True, **kwargs):
- self._submission_blocker_messages = ipw.HTML()
- self._submission_warning_messages = ipw.HTML()
+ def __init__(self, model: SubmissionModel, qe_auto_setup=True, **kwargs):
+ super().__init__(
+ children=[LoadingWidget("Loading workflow submission panel")],
+ **kwargs,
+ )
- self.pw_code = PwCodeResourceSetupWidget(
- description="pw.x:", default_calc_job_plugin="quantumespresso.pw"
+ self._model = model
+ self._model.observe(
+ self._on_input_structure_change,
+ "input_structure",
+ )
+ self._model.observe(
+ self._on_input_parameters_change,
+ "input_parameters",
+ )
+ self._model.observe(
+ self._on_process_change,
+ "process",
+ )
+ self._model.observe(
+ self._on_submission_blockers_change,
+ [
+ "internal_submission_blockers",
+ "external_submission_blockers",
+ ],
+ )
+ self._model.observe(
+ self._on_installation_change,
+ ["installing_sssp", "sssp_installed"],
+ )
+ self._model.observe(
+ self._on_sssp_installed,
+ "sssp_installed",
+ )
+ self._model.observe(
+ self._on_installation_change,
+ ["installing_qe", "qe_installed"],
+ )
+ self._model.observe(
+ self._on_qe_installed,
+ "qe_installed",
)
- self.pw_code.observe(self._update_state, "value")
+ # # TODO for testing only - remove in PR
+ # self._model.observe(
+ # lambda change: print(change["new"]),
+ # "input_parameters",
+ # )
+
+ self.qe_auto_setup = qe_auto_setup
+
+ plugin_codes: PluginCodes = get_entry_items("aiidalab_qe.properties", "code")
+ plugin_codes.update(
+ {
+ "dft": {
+ "pw": CodeModel(
+ description="pw.x:",
+ default_calc_job_plugin="quantumespresso.pw",
+ setup_widget_class=PwCodeResourceSetupWidget,
+ ),
+ },
+ }
+ )
+ for identifier, codes in plugin_codes.items():
+ for name, code in codes.items():
+ self._model.add_code(identifier, name, code)
+ code.observe(
+ self._on_code_activation_change,
+ "is_active",
+ )
+ code.observe(
+ self._on_code_selection_change,
+ "selected",
+ )
+
+ self.rendered = False
+
+ def render(self):
+ if self.rendered:
+ return
+
+ self.code_widgets_container = ipw.VBox()
- # add plugin's entry points
- self.codes = {"pw": self.pw_code}
- self.code_children = [
- self.codes_title,
- self.codes_help,
- self.pw_code,
- ]
- self.code_entries = get_entry_items("aiidalab_qe.properties", "code")
- for _, entry_point in self.code_entries.items():
- for name, code in entry_point.items():
- self.codes[name] = code
- code.observe(self._update_state, "value")
- self.code_children.append(self.codes[name])
- # set process label and description
self.process_label = ipw.Text(
- description="Label:", layout=ipw.Layout(width="auto", indent="0px")
+ description="Label:",
+ layout=ipw.Layout(width="auto", indent="0px"),
+ )
+ ipw.link(
+ (self._model, "process_label"),
+ (self.process_label, "value"),
)
self.process_description = ipw.Textarea(
- description="Description", layout=ipw.Layout(width="auto", indent="0px")
+ description="Description",
+ layout=ipw.Layout(width="auto", indent="0px"),
+ )
+ ipw.link(
+ (self._model, "process_description"),
+ (self.process_description, "value"),
)
- #
+
self.submit_button = ipw.Button(
description="Submit",
tooltip="Submit the calculation with the selected parameters.",
@@ -105,440 +134,212 @@ def __init__(self, qe_auto_setup=True, **kwargs):
layout=ipw.Layout(width="auto", flex="1 1 auto"),
disabled=True,
)
-
- self.submit_button.on_click(self._on_submit_button_clicked)
-
- # The SSSP installation status widget shows the installation status of
- # the SSSP pseudo potentials and triggers the installation in case that
- # they are not yet installed. The widget will remain in a "busy" state
- # in case that the installation was already triggered elsewhere, e.g.,
- # by the start up scripts. The submission is blocked while the
- # potentials are not yet installed.
- self.sssp_installation_status = PseudosInstallWidget(auto_start=qe_auto_setup)
- self.sssp_installation_status.observe(self._update_state, ["busy", "installed"])
- self.sssp_installation_status.observe(self._toggle_install_widgets, "installed")
-
- # The QE setup widget checks whether there are codes that match specific
- # expected labels (e.g. "pw-7.2@localhost") and triggers both the
- # installation of QE into a dedicated conda environment and the setup of
- # the codes in case that they are not already configured.
- self.qe_setup_status = QESetupWidget(auto_start=qe_auto_setup)
- self.qe_setup_status.observe(self._update_state, "busy")
- self.qe_setup_status.observe(self._toggle_install_widgets, "installed")
- self.qe_setup_status.observe(self._auto_select_code, "installed")
- self.ui_parameters = {}
-
- super().__init__(
- children=[
- *self.code_children,
- self.sssp_installation_status,
- self.qe_setup_status,
- self._submission_blocker_messages,
- self._submission_warning_messages,
- self.process_label_help,
- self.process_label,
- self.process_description,
- self.submit_button,
- ],
- **kwargs,
+ ipw.dlink(
+ (self, "state"),
+ (self.submit_button, "disabled"),
+ lambda state: state != self.State.CONFIGURED,
)
- # set default codes
- self.set_selected_codes(DEFAULT_PARAMETERS["codes"])
-
- # observe these two for the resource checking:
- self.pw_code.num_cpus.observe(self._check_resources, "value")
- self.pw_code.num_nodes.observe(self._check_resources, "value")
-
- @tl.observe("internal_submission_blockers", "external_submission_blockers")
- def _observe_submission_blockers(self, _change):
- """Observe the submission blockers and update the message area."""
- blockers = self.internal_submission_blockers + self.external_submission_blockers
- if any(blockers):
- fmt_list = "\n".join(f"{item}" for item in sorted(blockers))
- self._submission_blocker_messages.value = f"""
-
-
The submission is blocked, due to the following reason(s):
-
"""
- else:
- self._submission_blocker_messages.value = ""
-
- def _identify_submission_blockers(self):
- """Validate the resource inputs and identify blockers for the submission."""
- # Do not submit while any of the background setup processes are running.
- if self.qe_setup_status.busy or self.sssp_installation_status.busy:
- yield "Background setup processes must finish."
-
- # No pw code selected (this is ignored while the setup process is running).
- if self.pw_code.value is None and not self.qe_setup_status.busy:
- yield ("No pw code selected")
- # code related to the selected property is not installed
- properties = self.input_parameters.get("workchain", {}).get("properties", [])
- for identifer in properties:
- for name, code in self.code_entries.get(identifer, {}).items():
- if code.value is None:
- yield f"Calculating the {identifer} property requires code {name} to be set."
- # SSSP library not installed
- if not self.sssp_installation_status.installed:
- yield "The SSSP library is not installed."
-
- # check if the QEAppComputationalResourcesWidget is used
- for name, code in self.codes.items():
- # skip if the code is not displayed, convenient for the plugin developer
- if code.layout.display == "none":
- continue
- if not isinstance(code, QEAppComputationalResourcesWidget):
- yield (
- f"Error: hi, plugin developer, please use the QEAppComputationalResourcesWidget from aiidalab_qe.common.widgets for code {name}."
- )
-
- def _update_state(self, _=None):
- # If the previous step has failed, this should fail as well.
- if self.previous_step_state is self.State.FAIL:
- self.state = self.State.FAIL
- return
- # Do not interact with the user if they haven't successfully completed the previous step.
- elif self.previous_step_state is not self.State.SUCCESS:
- self.state = self.State.INIT
- return
+ self.submit_button.on_click(self._on_submission)
- # Process is already running.
- if self.process is not None:
- self.state = self.State.SUCCESS
- return
-
- blockers = list(self._identify_submission_blockers())
- if any(blockers):
- self.internal_submission_blockers = blockers
- self.state = self.State.READY
- return
-
- self.internal_submission_blockers = []
- self.state = self.state.CONFIGURED
-
- def _toggle_install_widgets(self, change):
- if change["new"]:
- self.children = [
- child for child in self.children if child is not change["owner"]
- ]
+ self.sssp_installation = PseudosInstallWidget()
+ ipw.dlink(
+ (self.sssp_installation, "busy"),
+ (self._model, "installing_sssp"),
+ )
+ ipw.dlink(
+ (self.sssp_installation, "installed"),
+ (self._model, "installing_sssp"),
+ lambda installed: not installed,
+ )
+ ipw.dlink(
+ (self.sssp_installation, "installed"),
+ (self._model, "sssp_installed"),
+ )
+ if self.qe_auto_setup:
+ self.sssp_installation.refresh()
- def _auto_select_code(self, change):
- if change["new"] and not change["old"]:
- self.set_selected_codes(DEFAULT_PARAMETERS["codes"])
+ self.qe_setup = QESetupWidget()
+ ipw.dlink(
+ (self.qe_setup, "busy"),
+ (self._model, "installing_qe"),
+ )
+ ipw.dlink(
+ (self.qe_setup, "installed"),
+ (self._model, "installing_qe"),
+ lambda installed: not installed,
+ )
+ ipw.dlink(
+ (self.qe_setup, "installed"),
+ (self._model, "qe_installed"),
+ )
+ if self.qe_auto_setup:
+ self.qe_setup.refresh()
- _ALERT_MESSAGE = """
- """
+ self.submission_blocker_messages = ipw.HTML()
+ ipw.dlink(
+ (self._model, "submission_blocker_messages"),
+ (self.submission_blocker_messages, "value"),
+ )
- def _show_alert_message(self, message, alert_class="info"):
- self._submission_warning_messages.value = self._ALERT_MESSAGE.format(
- alert_class=alert_class, message=message
+ self.submission_warning_messages = ipw.HTML()
+ ipw.dlink(
+ (self._model, "submission_warning_messages"),
+ (self.submission_warning_messages, "value"),
)
- @tl.observe("input_structure")
- def _check_resources(self, _change=None):
- """Check whether the currently selected resources will be sufficient and warn if not."""
- if not self.pw_code.value or not self.input_structure:
- return # No code selected or no structure, so nothing to do.
+ self.children = [
+ ipw.HTML("""
+
+
Codes
+
+ """),
+ ipw.HTML("""
+
+ Select the code to use for running the calculations. The codes on
+ the local machine (localhost) are installed by default, but you can
+ configure new ones on potentially more powerful machines by clicking
+ on "Setup new code".
+
+ """),
+ self.code_widgets_container,
+ self.sssp_installation,
+ self.qe_setup,
+ self.submission_blocker_messages,
+ self.submission_warning_messages,
+ ipw.HTML("""
+
+
Labeling Your Job
+
+ Label your job and provide a brief description. These details
+ help identify the job later and make the search process easier.
+ While optional, adding a description is recommended for better
+ clarity.
+
+
+ """),
+ self.process_label,
+ self.process_description,
+ self.submit_button,
+ ]
+
+ self.rendered = True
- num_cpus = self.pw_code.num_cpus.value * self.pw_code.num_nodes.value
- on_localhost = (
- orm.load_node(self.pw_code.value).computer.hostname == "localhost"
+ # Render and set up default PW code
+ pw_code = self._model.get_code("dft", "pw")
+ pw_code.activate()
+ pw_code_widget = pw_code.get_setup_widget()
+ pw_code_widget.num_cpus.observe(
+ self._on_pw_code_resource_change,
+ "value",
)
- num_sites = len(self.input_structure.sites)
- volume = self.input_structure.get_cell_volume()
- try:
- localhost_cpus = len(os.sched_getaffinity(0))
- except (
- Exception
- ): # fallback, in some OS os.sched_getaffinity(0) is not supported
- localhost_cpus = os.cpu_count() # however, not so realiable in containers.
-
- large_system = (
- num_sites > self.RUN_ON_LOCALHOST_NUM_SITES_WARN_THRESHOLD
- or volume > self.RUN_ON_LOCALHOST_VOLUME_WARN_THRESHOLD
+ pw_code_widget.num_nodes.observe(
+ self._on_pw_code_resource_change,
+ "value",
)
- estimated_CPUs = self._estimate_min_cpus(
- num_sites, volume
- ) # estimated number of CPUs for a run less than 12 hours.
-
- # List of possible suggestions for warnings:
- suggestions = {
- "more_resources": f"Increase the resources (total number of CPUs should be equal or more than {min(100,estimated_CPUs)}, if possible) ",
- "change_configuration": "Review the configuration (e.g. choosing fast protocol - this will affect precision) ",
- "go_remote": "Select a code that runs on a larger machine",
- "avoid_overloading": "Reduce the number of CPUs to avoid the overloading of the local machine ",
- }
-
- alert_message = ""
- if large_system and estimated_CPUs > num_cpus:
- # This part is in common between Warnings 1 (2): (not) on localhost, big system and few cpus
- warnings_1_2 = (
- f"⚠ Warning: The selected structure is large, with {num_sites} atoms "
- f"and a volume of {int(volume)} Å3, "
- "making it computationally demanding "
- "to run at the localhost. Consider the following: "
- if on_localhost
- else "to run in a reasonable amount of time. Consider the following: "
- )
-
- # Warning 1: on localhost, big system and few cpus
- if on_localhost:
- alert_message += (
- warnings_1_2
- + ""
- + suggestions["more_resources"]
- + suggestions["change_configuration"]
- + "
"
- )
- # Warning 2: not on localhost, big system and few cpus
- else:
- alert_message += (
- warnings_1_2
- + ""
- + suggestions["go_remote"]
- + suggestions["more_resources"]
- + suggestions["change_configuration"]
- + "
"
- )
- if on_localhost and num_cpus / localhost_cpus > 0.8:
- # Warning-3: on localhost, more than half of the available cpus
- alert_message += (
- "⚠ Warning: the selected pw.x code will run locally, but "
- f"the number of requested CPUs ({num_cpus}) is larger than the 80% of the available resources ({localhost_cpus}). "
- "Please be sure that your local "
- "environment has enough free CPUs for the calculation. Consider the following: "
- ""
- + suggestions["avoid_overloading"]
- + suggestions["go_remote"]
- + "
"
- )
-
- if not (on_localhost and num_cpus / localhost_cpus) > 0.8 and not (
- large_system and estimated_CPUs > num_cpus
- ):
- self._submission_warning_messages.value = ""
- else:
- self._show_alert_message(
- message=alert_message,
- alert_class="warning",
- )
+ # Render any other active codes
+ self._toggle_code(pw_code)
+ for _, code in self._model.get_codes(flat=True):
+ if code is not pw_code and code.is_active:
+ self._toggle_code(code)
- @tl.observe("state")
- def _observe_state(self, change):
+ def reset(self):
with self.hold_trait_notifications():
- self.submit_button.disabled = change["new"] != self.State.CONFIGURED
+ self._model.reset()
+ self._model.set_selected_codes()
- @tl.observe("previous_step_state", "input_parameters")
- def _observe_input_structure(self, _):
+ @tl.observe("previous_step_state")
+ def _on_previous_step_state_change(self, _):
self._update_state()
- self.update_codes_display()
- self._update_process_label()
- @tl.observe("process")
- def _observe_process(self, change):
+ def _on_input_structure_change(self, _):
+ self._model.check_resources()
+
+ def _on_input_parameters_change(self, _):
+ self._model.update_active_codes()
+ self._model.update_process_label()
+ self._model.update_submission_blockers()
+
+ def _on_process_change(self, _):
with self.hold_trait_notifications():
- process_node = change["new"]
- if process_node is not None:
- self.input_structure = process_node.inputs.structure
+ # TODO why here? Do we not populate traits earlier that would cover this?
+ if self._model.process is not None:
+ self._model.input_structure = self._model.process.inputs.structure
self._update_state()
- def _on_submit_button_clicked(self, _):
- self.submit_button.disabled = True
- self.submit()
-
- def get_selected_codes(self):
- """Get the codes selected in the GUI.
-
- return: A dict with the code names as keys and the code UUIDs as values.
- """
- codes = {
- key: code.parameters
- for key, code in self.codes.items()
- if code.layout.display != "none"
- }
- return codes
-
- def set_selected_codes(self, code_data):
- """Set the inputs in the GUI based on a set of codes."""
-
- # Codes
- def _get_code_uuid(code):
- if code is not None:
- try:
- return orm.load_code(code).uuid
- except NotExistent:
- return None
+ def _on_submission_blockers_change(self, _):
+ self._model.update_submission_blocker_message()
+ self._update_state()
- with self.hold_trait_notifications():
- for name, code in self.codes.items():
- if name not in code_data:
- continue
- # check if the code is installed and usable
- # note: if code is imported from another user, it is not usable and thus will not be
- # treated as an option in the ComputationalResourcesWidget.
- code_options = [
- o[1] for o in code.code_selection.code_select_dropdown.options
- ]
- if _get_code_uuid(code_data.get(name)["code"]) in code_options:
- # get code uuid from code label in case of using DEFAULT_PARAMETERS
- code_data.get(name)["code"] = _get_code_uuid(
- code_data.get(name)["code"]
- )
- code.parameters = code_data.get(name)
-
- def update_codes_display(self):
- """Hide code if no related property is selected."""
- # hide all codes except pw
- for name, code in self.codes.items():
- if name == "pw":
- continue
- code.layout.display = "none"
- properties = self.input_parameters.get("workchain", {}).get("properties", [])
- # show the code if the related property is selected.
- for identifer in properties:
- for code in self.code_entries.get(identifer, {}).values():
- code.layout.display = "block"
-
- def submit(self, _=None):
- """Submit the work chain with the current inputs."""
- from aiida.orm.utils.serialize import serialize
-
- builder = self._create_builder()
+ def _on_installation_change(self, _):
+ self._model.update_submission_blockers()
- with self.hold_trait_notifications():
- process = submit(builder)
-
- process.label = self.process_label.value
- process.description = self.process_description.value
- # since AiiDA data node may exist in the ui_parameters,
- # we serialize it to yaml
- process.base.extras.set("ui_parameters", serialize(self.ui_parameters))
- # store the workchain name in extras, this will help to filter the workchain in the future
- process.base.extras.set("workchain", self.ui_parameters["workchain"])
- process.base.extras.set("structure", self.input_structure.get_formula())
- self.process = process
+ def _on_qe_installed(self, _):
+ self._toggle_qe_installation_widget()
+
+ def _on_sssp_installed(self, _):
+ self._toggle_sssp_installation_widget()
+ def _on_code_activation_change(self, change):
+ self._toggle_code(change["owner"])
+
+ def _on_code_selection_change(self, _):
+ self._model.update_submission_blockers()
+
+ def _on_pw_code_resource_change(self, _):
+ self._model.check_resources()
+
+ def _on_submission(self, _):
+ self._model.submit()
self._update_state()
- def _update_process_label(self) -> dict:
- """Generate a label for the work chain based on the input parameters."""
- if not self.input_structure:
- return ""
- structure_label = (
- self.input_structure.label
- if len(self.input_structure.label) > 0
- else self.input_structure.get_formula()
- )
- workchain_data = self.input_parameters.get("workchain", {"properties": []})
- properties = [p for p in workchain_data["properties"] if p != "relax"]
- # relax_info
- relax_type = workchain_data.get("relax_type", "none")
- relax_info = "unrelaxed"
- if relax_type != "none":
- relax_info = (
- "relax: atoms+cell" if "cell" in relax_type else "relax: atoms only"
- )
- # protocol_info
- protocol_and_magnetic_info = f"{workchain_data['protocol']} protocol"
- # magnetic_info
- if workchain_data["spin_type"] != "none":
- protocol_and_magnetic_info += ", magnetic"
- # properties_info
- properties_info = ""
- if properties:
- properties_info = f"→ {', '.join(properties)}"
-
- label = f"{structure_label} [{relax_info}, {protocol_and_magnetic_info}] {properties_info}".strip()
- self.process_label.value = label
-
- def _create_builder(self) -> ProcessBuilderNamespace:
- """Create the builder for the `QeAppWorkChain` submit."""
- from copy import deepcopy
-
- self.ui_parameters = deepcopy(self.input_parameters)
- # add codes and resource info into ui_parameters
- submission_parameters = self.get_submission_parameters()
- self.ui_parameters.update(submission_parameters)
- builder = QeAppWorkChain.get_builder_from_protocol(
- structure=self.input_structure,
- parameters=deepcopy(self.ui_parameters),
- )
+ def _toggle_sssp_installation_widget(self):
+ sssp_installation_display = "none" if self._model.sssp_installed else "block"
+ self.sssp_installation.layout.display = sssp_installation_display
- self._update_builder(builder, submission_parameters["codes"])
-
- return builder
-
- def _update_builder(self, builder, codes):
- """Update the resources and parallelization of the ``relax`` builder."""
- # update resources
- builder.relax.base.pw.metadata.options.resources = {
- "num_machines": codes.get("pw")["nodes"],
- "num_mpiprocs_per_machine": codes.get("pw")["ntasks_per_node"],
- "num_cores_per_mpiproc": codes.get("pw")["cpus_per_task"],
- }
- builder.relax.base.pw.metadata.options["max_wallclock_seconds"] = codes.get(
- "pw"
- )["max_wallclock_seconds"]
- builder.relax.base.pw.parallelization = orm.Dict(
- dict=codes["pw"]["parallelization"]
- )
+ def _toggle_qe_installation_widget(self):
+ qe_installation_display = "none" if self._model.qe_installed else "block"
+ self.qe_setup.layout.display = qe_installation_display
- def _estimate_min_cpus(
- self, n, v, n0=9, v0=117, num_cpus0=4, t0=129.6, tmax=12 * 60 * 60, scf_cycles=5
- ):
- """
- Estimate the minimum number of CPUs required to complete a task within a given time limit.
- Parameters:
- n (int): The number of atoms in the system.
- v (float): The volume of the system.
- n0 (int, optional): Reference number of atoms. Default is 9.
- v0 (float, optional): Reference volume. Default is 117.
- num_cpus0 (int, optional): Reference number of CPUs. Default is 4.
- scf_cycles (int, optional): Reference number of SCF cycles in a relaxation. Default is 5.
-
- NB: Defaults (a part scf_cycles) are taken from a calculation done for SiO2. This is just a dummy
- and not well tested estimation, placeholder for a more rigourous one.
- """
- import numpy as np
-
- return int(
- np.ceil(
- scf_cycles * num_cpus0 * (n / n0) ** 3 * (v / v0) ** 1.5 * t0 / tmax
- )
+ def _toggle_code(self, code: CodeModel):
+ if not self.rendered:
+ return
+ if not code.is_rendered:
+ loading_message = LoadingWidget(f"Loading {code.name} code")
+ self.code_widgets_container.children += (loading_message,)
+ code_widget = code.get_setup_widget()
+ code_widget.layout.display = "block" if code.is_active else "none"
+ if not code.is_rendered:
+ self._render_code_widget(code, code_widget)
+
+ def _render_code_widget(self, code, code_widget):
+ ipw.dlink(
+ (code_widget.code_selection.code_select_dropdown, "options"),
+ (code, "options"),
)
+ ipw.dlink(
+ (code_widget, "value"),
+ (code, "selected"),
+ lambda value: value is not None,
+ )
+ code.observe(
+ lambda change: setattr(code_widget, "parameters", change["new"]),
+ "parameters",
+ )
+ code_widgets = self.code_widgets_container.children[:-1] # type: ignore
+ self.code_widgets_container.children = [*code_widgets, code_widget]
+ self._model.code_widgets[code.name] = code_widget
+ self._model.set_selected_codes() # TODO check logic
+ code.is_rendered = True
- def set_submission_parameters(self, parameters):
- # backward compatibility for v2023.11
- # which have a separate "resources" section for pw code
- if "resources" in parameters:
- parameters["codes"] = {
- key: {"code": value} for key, value in parameters["codes"].items()
- }
- parameters["codes"]["pw"]["nodes"] = parameters["resources"]["num_machines"]
- parameters["codes"]["pw"]["cpus"] = parameters["resources"][
- "num_mpiprocs_per_machine"
- ]
- parameters["codes"]["pw"]["parallelization"] = {
- "npool": parameters["resources"]["npools"]
- }
- self.set_selected_codes(parameters["codes"])
- # label and description are not stored in the parameters, but in the process directly
- if self.process:
- self.process_label.value = self.process.label
- self.process_description.value = self.process.description
-
- def get_submission_parameters(self):
- """Get the parameters for the submission step."""
- return {
- "codes": self.get_selected_codes(),
- }
-
- def reset(self):
- """Reset the widget to its initial state."""
- with self.hold_trait_notifications():
- self.process = None
- self.input_structure = None
- self.set_selected_codes(DEFAULT_PARAMETERS["codes"])
+ def _update_state(self, _=None):
+ if self.previous_step_state is self.State.FAIL:
+ self.state = self.State.FAIL
+ elif self.previous_step_state is not self.State.SUCCESS:
+ self.state = self.State.INIT
+ elif self._model.process is not None:
+ self.state = self.State.SUCCESS
+ elif self._model.is_blocked:
+ self.state = self.State.READY
+ else:
+ self.state = self.state.CONFIGURED
diff --git a/src/aiidalab_qe/app/submission/code/__init__.py b/src/aiidalab_qe/app/submission/code/__init__.py
new file mode 100644
index 000000000..117e94cdf
--- /dev/null
+++ b/src/aiidalab_qe/app/submission/code/__init__.py
@@ -0,0 +1,7 @@
+from .model import CodeModel, CodesDict, PluginCodes
+
+__all__ = [
+ "CodeModel",
+ "CodesDict",
+ "PluginCodes",
+]
diff --git a/src/aiidalab_qe/app/submission/code/model.py b/src/aiidalab_qe/app/submission/code/model.py
new file mode 100644
index 000000000..eb87141ce
--- /dev/null
+++ b/src/aiidalab_qe/app/submission/code/model.py
@@ -0,0 +1,51 @@
+import traitlets as tl
+
+from aiidalab_qe.common.widgets import QEAppComputationalResourcesWidget
+
+
+class CodeModel(tl.HasTraits):
+ is_active = tl.Bool(False)
+ selected = tl.Bool(False)
+ options = tl.List(
+ trait=tl.Tuple(tl.Unicode(), tl.Unicode()), # code option (label, uuid)
+ default_value=[],
+ )
+ parameters = tl.Dict()
+
+ def __init__(
+ self,
+ *,
+ name="pw",
+ description,
+ default_calc_job_plugin,
+ setup_widget_class=QEAppComputationalResourcesWidget,
+ ):
+ self.name = name
+ self.description = description
+ self.default_calc_job_plugin = default_calc_job_plugin
+ self.setup_widget_class = setup_widget_class
+ self.is_rendered = False
+ self.is_loaded = False
+
+ @property
+ def is_ready(self):
+ return self.is_loaded and self.is_active and self.selected
+
+ def activate(self):
+ self.is_active = True
+
+ def deactivate(self):
+ self.is_active = False
+
+ def get_setup_widget(self) -> QEAppComputationalResourcesWidget:
+ if not self.is_loaded:
+ self._setup_widget = self.setup_widget_class(
+ description=self.description,
+ default_calc_job_plugin=self.default_calc_job_plugin,
+ )
+ self.is_loaded = True
+ return self._setup_widget
+
+
+CodesDict = dict[str, CodeModel]
+PluginCodes = dict[str, CodesDict]
diff --git a/src/aiidalab_qe/app/submission/model.py b/src/aiidalab_qe/app/submission/model.py
new file mode 100644
index 000000000..2789e6565
--- /dev/null
+++ b/src/aiidalab_qe/app/submission/model.py
@@ -0,0 +1,424 @@
+from __future__ import annotations
+
+import os
+import typing as t
+from copy import deepcopy
+
+import traitlets as tl
+
+from aiida import orm
+from aiida.common import NotExistent
+from aiida.engine import ProcessBuilderNamespace
+from aiida.engine import submit as aiida_submit
+from aiida.orm.utils.serialize import serialize
+from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData
+from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
+from aiidalab_qe.common.widgets import QEAppComputationalResourcesWidget
+from aiidalab_qe.workflows import QeAppWorkChain
+
+from .code import CodeModel, CodesDict
+
+DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore
+
+
+class SubmissionModel(tl.HasTraits):
+ input_structure = tl.Union(
+ [
+ tl.Instance(orm.StructureData),
+ tl.Instance(HubbardStructureData),
+ ],
+ allow_none=True,
+ )
+ input_parameters = tl.Dict()
+
+ process = tl.Instance(orm.WorkChainNode, allow_none=True)
+ process_label = tl.Unicode("")
+ process_description = tl.Unicode("")
+
+ submission_blocker_messages = tl.Unicode("")
+ submission_warning_messages = tl.Unicode("")
+
+ installing_qe = tl.Bool(False)
+ installing_sssp = tl.Bool(False)
+ qe_installed = tl.Bool(allow_none=True)
+ sssp_installed = tl.Bool(allow_none=True)
+
+ codes = tl.Dict(
+ key_trait=tl.Unicode(), # plugin identifier
+ value_trait=tl.Dict( # plugin codes
+ key_trait=tl.Unicode(), # code name
+ value_trait=tl.Instance(CodeModel), # code metadata
+ ),
+ default_value={},
+ )
+
+ internal_submission_blockers = tl.List(tl.Unicode())
+ external_submission_blockers = tl.List(tl.Unicode())
+
+ code_widgets: dict[str, QEAppComputationalResourcesWidget] = {}
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self._RUN_ON_LOCALHOST_NUM_SITES_WARN_THRESHOLD = 10
+ self._RUN_ON_LOCALHOST_VOLUME_WARN_THRESHOLD = 1000 # \AA^3
+
+ self._ALERT_MESSAGE = """
+
+ """
+
+ @property
+ def is_blocked(self):
+ return any(
+ [
+ *self.internal_submission_blockers,
+ *self.external_submission_blockers,
+ ]
+ )
+
+ def submit(self):
+ parameters = self._get_submission_parameters()
+ builder = self._create_builder(parameters)
+
+ with self.hold_trait_notifications():
+ process = aiida_submit(builder)
+
+ process.label = self.process_label
+ process.description = self.process_description
+ # since AiiDA data node may exist in the ui_parameters,
+ # we serialize it to yaml
+ process.base.extras.set("ui_parameters", serialize(parameters))
+ # store the workchain name in extras, this will help to filter the workchain in the future
+ process.base.extras.set("workchain", parameters["workchain"]) # type: ignore
+ process.base.extras.set(
+ "structure",
+ self.input_structure.get_formula(),
+ )
+ self.process = process
+
+ def check_resources(self):
+ pw_code_model = self.get_code("dft", "pw")
+
+ if not self.input_structure or not pw_code_model.selected:
+ return # No code selected or no structure, so nothing to do
+
+ pw_code = pw_code_model.get_setup_widget()
+ num_cpus = pw_code.num_cpus.value * pw_code.num_nodes.value
+ on_localhost = orm.load_node(pw_code.value).computer.hostname == "localhost"
+ num_sites = len(self.input_structure.sites)
+ volume = self.input_structure.get_cell_volume()
+
+ try:
+ localhost_cpus = len(os.sched_getaffinity(0))
+ except Exception:
+ # Fallback, in some OS os.sched_getaffinity(0) is not supported
+ # However, not so reliable in containers
+ localhost_cpus = os.cpu_count()
+
+ large_system = (
+ num_sites > self._RUN_ON_LOCALHOST_NUM_SITES_WARN_THRESHOLD
+ or volume > self._RUN_ON_LOCALHOST_VOLUME_WARN_THRESHOLD
+ )
+
+ # Estimated number of CPUs for a run less than 12 hours.
+ estimated_CPUs = self._estimate_min_cpus(num_sites, volume)
+
+ # List of possible suggestions for warnings:
+ suggestions = {
+ "more_resources": f"Increase the resources (total number of CPUs should be equal or more than {min(100,estimated_CPUs)}, if possible) ",
+ "change_configuration": "Review the configuration (e.g. choosing fast protocol - this will affect precision) ",
+ "go_remote": "Select a code that runs on a larger machine",
+ "avoid_overloading": "Reduce the number of CPUs to avoid the overloading of the local machine ",
+ }
+
+ alert_message = ""
+ if large_system and estimated_CPUs > num_cpus:
+ # This part is in common between Warnings 1 (2):
+ # (not) on localhost, big system and few cpus
+ warnings_1_2 = (
+ f"⚠ Warning: The selected structure is large, with {num_sites} atoms "
+ f"and a volume of {int(volume)} Å3, "
+ "making it computationally demanding "
+ "to run at the localhost. Consider the following: "
+ if on_localhost
+ else "to run in a reasonable amount of time. Consider the following: "
+ )
+ # Warning 1: on localhost, big system and few cpus
+ alert_message += (
+ f"{warnings_1_2}"
+ + suggestions["more_resources"]
+ + suggestions["change_configuration"]
+ + "
"
+ if on_localhost
+ else f"{warnings_1_2}"
+ + suggestions["go_remote"]
+ + suggestions["more_resources"]
+ + suggestions["change_configuration"]
+ + "
"
+ )
+ if on_localhost and num_cpus / localhost_cpus > 0.8:
+ # Warning-3: on localhost, more than half of the available cpus
+ alert_message += (
+ "⚠ Warning: the selected pw.x code will run locally, but "
+ f"the number of requested CPUs ({num_cpus}) is larger than the 80% of the available resources ({localhost_cpus}). "
+ "Please be sure that your local "
+ "environment has enough free CPUs for the calculation. Consider the following: "
+ ""
+ + suggestions["avoid_overloading"]
+ + suggestions["go_remote"]
+ + "
"
+ )
+
+ self.submission_warning_messages = (
+ ""
+ if (on_localhost and num_cpus / localhost_cpus) <= 0.8
+ and (not large_system or estimated_CPUs <= num_cpus)
+ else self._ALERT_MESSAGE.format(
+ alert_class="warning",
+ message=alert_message,
+ )
+ )
+
+ def update_active_codes(self):
+ for name, code in self.get_codes(flat=True):
+ if name != "pw":
+ code.deactivate()
+ properties = self.get_properties()
+ for identifier, codes in self.get_codes():
+ if identifier in properties:
+ for code in codes.values():
+ code.activate()
+
+ def update_process_label(self):
+ if not self.input_structure:
+ self.process_label = ""
+ return
+ structure_label = (
+ self.input_structure.label
+ if len(self.input_structure.label) > 0
+ else self.input_structure.get_formula()
+ )
+ workchain_data = self.input_parameters.get(
+ "workchain",
+ {"properties": []},
+ )
+ properties = [p for p in workchain_data["properties"] if p != "relax"]
+ relax_type = workchain_data.get("relax_type", "none")
+ relax_info = "unrelaxed"
+ if relax_type != "none":
+ relax_info = (
+ "relax: atoms+cell" if "cell" in relax_type else "relax: atoms only"
+ )
+ protocol_and_magnetic_info = f"{workchain_data['protocol']} protocol"
+ if workchain_data["spin_type"] != "none":
+ protocol_and_magnetic_info += ", magnetic"
+ properties_info = f"→ {', '.join(properties)}" if properties else ""
+ label = f"{structure_label} [{relax_info}, {protocol_and_magnetic_info}] {properties_info}".strip()
+ self.process_label = label
+
+ def update_submission_blockers(self):
+ self.internal_submission_blockers = list(self._check_submission_blockers())
+
+ def update_submission_blocker_message(self):
+ blockers = self.internal_submission_blockers + self.external_submission_blockers
+ if any(blockers):
+ fmt_list = "\n".join(f"{item}" for item in sorted(blockers))
+ self.submission_blocker_messages = f"""
+
+
The submission is blocked due to the following reason(s):
+
+
+ """
+ else:
+ self.submission_blocker_messages = ""
+
+ def get_properties(self) -> list[str]:
+ return self.input_parameters.get("workchain", {}).get("properties", [])
+
+ def get_model_state(self) -> dict[str, dict[str, dict]]:
+ return {
+ "codes": self.get_selected_codes(),
+ }
+
+ def set_model_state(self, parameters):
+ if "resources" in parameters:
+ parameters["codes"] = {
+ key: {"code": value} for key, value in parameters["codes"].items()
+ }
+ parameters["codes"]["pw"]["nodes"] = parameters["resources"]["num_machines"]
+ parameters["codes"]["pw"]["cpus"] = parameters["resources"][
+ "num_mpiprocs_per_machine"
+ ]
+ parameters["codes"]["pw"]["parallelization"] = {
+ "npool": parameters["resources"]["npools"]
+ }
+ self.set_selected_codes(parameters["codes"])
+ if self.process:
+ self.process_label = self.process.label
+ self.process_description = self.process.description
+
+ def add_code(self, identifier, name, code):
+ code.name = name
+ if identifier not in self.codes:
+ self.codes[identifier] = {} # type: ignore
+ self.codes[identifier][name] = code # type: ignore
+
+ def get_code(self, identifier, name) -> CodeModel | None:
+ if identifier in self.codes and name in self.codes[identifier]: # type: ignore
+ return self.codes[identifier][name] # type: ignore
+
+ def get_codes(self, flat=False) -> t.Iterator[tuple[str, CodesDict | CodeModel]]:
+ if flat:
+ for codes in self.codes.values():
+ yield from codes.items()
+ else:
+ yield from self.codes.items()
+
+ def get_selected_codes(self) -> dict[str, dict]:
+ return {
+ name: code.parameters
+ for name, code in self.get_codes(flat=True)
+ if code.is_ready
+ } # type: ignore
+
+ def set_selected_codes(self, code_data=DEFAULT["codes"]):
+ def get_code_uuid(code):
+ if code is not None:
+ try:
+ return orm.load_code(code).uuid
+ except NotExistent:
+ return None
+
+ with self.hold_trait_notifications():
+ for name, code in self.get_codes(flat=True):
+ if name in code_data:
+ parameters = code_data[name]
+ code_uuid = get_code_uuid(parameters["code"])
+ if code_uuid in [opt[1] for opt in code.options]:
+ parameters["code"] = code_uuid
+ code.parameters = parameters
+
+ def reset(self):
+ with self.hold_trait_notifications():
+ self.input_structure = None
+ self.input_parameters = {}
+ self.process = None
+
+ def _create_builder(self, parameters) -> ProcessBuilderNamespace:
+ builder = QeAppWorkChain.get_builder_from_protocol(
+ structure=self.input_structure,
+ parameters=deepcopy(parameters), # TODO why deepcopy again?
+ )
+
+ codes = parameters["codes"]
+
+ builder.relax.base.pw.metadata.options.resources = {
+ "num_machines": codes.get("pw")["nodes"],
+ "num_mpiprocs_per_machine": codes.get("pw")["ntasks_per_node"],
+ "num_cores_per_mpiproc": codes.get("pw")["cpus_per_task"],
+ }
+ mws = codes.get("pw")["max_wallclock_seconds"]
+ builder.relax.base.pw.metadata.options["max_wallclock_seconds"] = mws
+ parallelization = codes["pw"]["parallelization"]
+ builder.relax.base.pw.parallelization = orm.Dict(dict=parallelization)
+
+ return builder
+
+ def _get_submission_parameters(self) -> dict:
+ submission_parameters = self.get_model_state()
+ for name, code_widget in self.code_widgets.items():
+ if name in submission_parameters["codes"]:
+ for key, value in code_widget.parameters.items():
+ if key != "code":
+ submission_parameters["codes"][name][key] = value
+ parameters = deepcopy(self.input_parameters)
+ parameters.update(submission_parameters)
+ return parameters # type: ignore
+
+ def _check_submission_blockers(self):
+ # Do not submit while any of the background setup processes are running.
+ if self.installing_qe or self.installing_sssp:
+ yield "Background setup processes must finish."
+
+ # SSSP library not installed
+ if not self.sssp_installed:
+ yield "The SSSP library is not installed."
+
+ # No pw code selected (this is ignored while the setup process is running).
+ pw_code = self.get_code(identifier="dft", name="pw")
+ if pw_code and not pw_code.selected and not self.installing_qe:
+ yield ("No pw code selected")
+
+ # code related to the selected property is not installed
+ properties = self.get_properties()
+ message = "Calculating the {property} property requires code {code} to be set."
+ for identifier, codes in self.get_codes():
+ if identifier in properties:
+ for code in codes.values():
+ if not code.is_ready:
+ yield message.format(property=identifier, code=code.description)
+
+ # check if the QEAppComputationalResourcesWidget is used
+ for name, code in self.get_codes(flat=True):
+ # skip if the code is not displayed, convenient for the plugin developer
+ if not code.is_ready:
+ continue
+ if not issubclass(
+ code.setup_widget_class, QEAppComputationalResourcesWidget
+ ):
+ yield (
+ f"Error: hi, plugin developer, please use the QEAppComputationalResourcesWidget from aiidalab_qe.common.widgets for code {name}."
+ )
+
+ def _estimate_min_cpus(
+ self,
+ n,
+ v,
+ n0=9,
+ v0=117,
+ num_cpus0=4,
+ t0=129.6,
+ tmax=12 * 60 * 60,
+ scf_cycles=5,
+ ):
+ """Estimate the minimum number of CPUs required to
+ complete a task within a given time limit.
+
+ Parameters
+ ----------
+ `n` : `int`
+ The number of atoms in the system.
+ `v` : `float`
+ The volume of the system.
+ `n0` : `int`, optional
+ Reference number of atoms. Default is 9.
+ `v0` : `float`, optional
+ Reference volume. Default is 117.
+ `num_cpus0` : `int`, optional
+ Reference number of CPUs. Default is 4.
+ `t0` : `float`, optional
+ Reference time. Default is 129.6.
+ `tmax` : `float`, optional
+ Maximum time limit. Default is 12 hours.
+ `scf_cycles` : `int`, optional
+ Reference number of SCF cycles in a relaxation. Default is 5.
+
+ Returns
+ -------
+ `int`
+ The estimated minimum number of CPUs required.
+ """
+ import numpy as np
+
+ return int(
+ np.ceil(
+ scf_cycles * num_cpus0 * (n / n0) ** 3 * (v / v0) ** 1.5 * t0 / tmax
+ )
+ )
diff --git a/src/aiidalab_qe/app/utils/__init__.py b/src/aiidalab_qe/app/utils/__init__.py
index a62d8c148..ca61694f6 100644
--- a/src/aiidalab_qe/app/utils/__init__.py
+++ b/src/aiidalab_qe/app/utils/__init__.py
@@ -15,12 +15,14 @@ def print_error(entry_point, e):
# load entry points
def get_entries(entry_point_name="aiidalab_qe.properties"):
- from importlib.metadata import entry_points
+ from importlib_metadata import entry_points
entries = {}
- for entry_point in entry_points().get(entry_point_name, []):
+ for entry_point in entry_points(group=entry_point_name):
try:
# Attempt to load the entry point
+ if entry_point.name in entries:
+ continue
loaded_entry_point = entry_point.load()
entries[entry_point.name] = loaded_entry_point
except Exception as e:
diff --git a/src/aiidalab_qe/app/utils/search_jobs.py b/src/aiidalab_qe/app/utils/search_jobs.py
index 4288630a8..7c76d0c1e 100644
--- a/src/aiidalab_qe/app/utils/search_jobs.py
+++ b/src/aiidalab_qe/app/utils/search_jobs.py
@@ -1,18 +1,20 @@
+import ipywidgets as ipw
+import pandas as pd
+from IPython.display import display
+
+from aiida.orm import QueryBuilder
+
+
class QueryInterface:
def __init__(self):
pass
def setup_table(self):
- import ipywidgets as ipw
-
self.df = self.load_data()
self.table = ipw.HTML()
self.setup_widgets()
def load_data(self):
- import pandas as pd
-
- from aiida.orm import QueryBuilder
from aiidalab_qe.workflows import QeAppWorkChain
projections = [
@@ -72,8 +74,6 @@ def load_data(self):
]
def setup_widgets(self):
- import ipywidgets as ipw
-
self.css_style = """