Skip to content

Commit

Permalink
Glorious refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
edan-bainglass committed Sep 20, 2024
1 parent a201ea1 commit 9b6d792
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 239 deletions.
182 changes: 52 additions & 130 deletions src/aiidalab_qe/app/configuration/advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,7 @@
"""

import ipywidgets as ipw
import numpy as np
import traitlets as tl

from aiida import orm
from aiida_quantumespresso.calculations.functions.create_kpoints_from_distance import (
create_kpoints_from_distance,
)
from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain

from aiidalab_qe.common.panel import Panel
from aiidalab_qe.setup.pseudos import PseudoFamily

Expand All @@ -26,9 +18,6 @@
class AdvancedSettings(Panel):
identifier = "advanced"

protocol = tl.Unicode(allow_none=True)
input_structure = tl.Instance(orm.StructureData, allow_none=True)

def __init__(self, **kwargs):
from aiidalab_qe.common.widgets import LoadingWidget

Expand Down Expand Up @@ -63,7 +52,7 @@ def render(self):
layout=ipw.Layout(max_width="20px"),
)
ipw.link(
(model, "clean_workdir"),
(model.advanced, "clean_workdir"),
(self.clean_workdir, "value"),
)
# Override setting widget
Expand All @@ -73,7 +62,7 @@ def render(self):
layout=ipw.Layout(max_width="10%"),
)
ipw.link(
(model, "override"),
(model.advanced, "override"),
(self.override, "value"),
)
ipw.dlink(
Expand All @@ -95,15 +84,23 @@ def render(self):
style={"description_width": "initial"},
)
ipw.link(
(model, "kpoints_distance"),
(model.advanced, "kpoints_distance"),
(self.kpoints_distance, "value"),
)
self.mesh_grid = ipw.HTML()
ipw.dlink(
(self.override, "value"),
(self.kpoints_distance, "disabled"),
lambda override: not override,
)
self.kpoints_distance.observe(
self._on_kpoints_distance_change,
"value",
)
self.mesh_grid = ipw.HTML()
ipw.dlink(
(model.advanced, "mesh_grid"),
(self.mesh_grid, "value"),
)

# Hubbard setting widget
self.hubbard.render()
Expand All @@ -118,7 +115,7 @@ def render(self):
style={"description_width": "initial"},
)
ipw.link(
(model, "total_charge"),
(model.advanced, "total_charge"),
(self.total_charge, "value"),
)
ipw.dlink(
Expand All @@ -142,7 +139,7 @@ def render(self):
style={"description_width": "initial"},
)
ipw.link(
(model, "van_der_waals"),
(model.advanced, "van_der_waals"),
(self.van_der_waals, "value"),
)
ipw.dlink(
Expand All @@ -163,11 +160,11 @@ def render(self):
style={"description_width": "initial"},
)
ipw.link(
(model, "scf_conv_thr"),
(model.advanced, "scf_conv_thr"),
(self.scf_conv_thr, "value"),
)
ipw.dlink(
(model, "scf_conv_thr_step"),
(model.advanced, "scf_conv_thr_step"),
(self.scf_conv_thr, "step"),
)
ipw.dlink(
Expand All @@ -183,11 +180,11 @@ def render(self):
style={"description_width": "initial"},
)
ipw.link(
(model, "forc_conv_thr"),
(model.advanced, "forc_conv_thr"),
(self.forc_conv_thr, "value"),
)
ipw.dlink(
(model, "forc_conv_thr_step"),
(model.advanced, "forc_conv_thr_step"),
(self.forc_conv_thr, "step"),
)
ipw.dlink(
Expand All @@ -203,11 +200,11 @@ def render(self):
style={"description_width": "initial"},
)
ipw.link(
(model, "etot_conv_thr"),
(model.advanced, "etot_conv_thr"),
(self.etot_conv_thr, "value"),
)
ipw.dlink(
(model, "etot_conv_thr_step"),
(model.advanced, "etot_conv_thr_step"),
(self.etot_conv_thr, "step"),
)
ipw.dlink(
Expand All @@ -226,7 +223,7 @@ def render(self):
style={"description_width": "initial"},
)
ipw.link(
(model, "spin_orbit"),
(model.advanced, "spin_orbit"),
(self.spin_orbit, "value"),
)
ipw.dlink(
Expand All @@ -237,8 +234,6 @@ def render(self):

self.pseudos.render()

self.kpoints_distance.observe(self._display_mesh, "value")

self.children = [
ipw.HTML("""
<div style="padding-top: 0px; padding-bottom: 10px">
Expand Down Expand Up @@ -307,33 +302,16 @@ def render(self):
self.pseudos,
]

with self.hold_trait_notifications():
ipw.dlink(
(model, "input_structure"),
(self, "input_structure"),
)
ipw.dlink(
(model, "protocol"),
(self, "protocol"),
)

self.rendered = True

def reset(self):
"""Reset the widget and the traitlets"""

with self.hold_trait_notifications():
model.protocol = model.traits()["protocol"].default_value
model.total_charge = model.traits()["total_charge"].default_value
model.van_der_waals = model.traits()["van_der_waals"].default_value
model.spin_orbit = model.traits()["spin_orbit"].default_value
model.override = False
model.advanced.reset()
self.smearing.reset()
self.hubbard.reset()
self.magnetization.reset()
self.pseudos.reset()
if model.input_structure is None:
model.mesh_grid = " "
model.update_from_protocol()

def get_panel_value(self):
# create the the initial_magnetic_moments as None (Default)
Expand All @@ -343,20 +321,20 @@ def get_panel_value(self):
"pw": {
"parameters": {
"SYSTEM": {
"tot_charge": model.total_charge,
"tot_charge": model.advanced.total_charge,
},
"CONTROL": {
"forc_conv_thr": model.forc_conv_thr,
"etot_conv_thr": model.etot_conv_thr,
"forc_conv_thr": model.advanced.forc_conv_thr,
"etot_conv_thr": model.advanced.etot_conv_thr,
},
"ELECTRONS": {
"conv_thr": model.scf_conv_thr,
"conv_thr": model.advanced.scf_conv_thr,
},
}
},
"clean_workdir": model.clean_workdir,
"clean_workdir": model.advanced.clean_workdir,
"pseudo_family": model.pseudos.family, # TODO check this
"kpoints_distance": model.kpoints_distance,
"kpoints_distance": model.advanced.kpoints_distance,
}

if model.hubbard.activate:
Expand All @@ -366,46 +344,48 @@ def get_panel_value(self):
{"starting_ns_eigenvalue": model.hubbard.eigenvalues}
)

if model.pseudos.dict:
parameters["pw"]["pseudos"] = model.pseudos.dict
if model.pseudos.dictionary:
parameters["pw"]["pseudos"] = model.pseudos.dictionary
parameters["pw"]["parameters"]["SYSTEM"]["ecutwfc"] = model.pseudos.ecutwfc
parameters["pw"]["parameters"]["SYSTEM"]["ecutrho"] = model.pseudos.ecutrho

if model.van_der_waals in ["none", "ts-vdw"]:
parameters["pw"]["parameters"]["SYSTEM"]["vdw_corr"] = model.van_der_waals
if model.advanced.van_der_waals in ["none", "ts-vdw"]:
parameters["pw"]["parameters"]["SYSTEM"]["vdw_corr"] = (
model.advanced.van_der_waals
)
else:
parameters["pw"]["parameters"]["SYSTEM"]["vdw_corr"] = "dft-d3"
parameters["pw"]["parameters"]["SYSTEM"]["dftd3_version"] = (
self.dftd3_version[model.van_der_waals]
self.dftd3_version[model.advanced.van_der_waals]
)

# there are two choose, use link or parent
if model.spin_type == "collinear":
if model.basic.spin_type == "collinear":
parameters["initial_magnetic_moments"] = model.magnetization.moments
if model.electronic_type == "metal":
if model.basic.electronic_type == "metal":
# smearing type setting
parameters["pw"]["parameters"]["SYSTEM"]["smearing"] = model.smearing.type
# smearing degauss setting
parameters["pw"]["parameters"]["SYSTEM"]["degauss"] = model.smearing.degauss

# Set tot_magnetization for collinear simulations.
if model.spin_type == "collinear":
if model.basic.spin_type == "collinear":
# Conditions for metallic systems. Select the magnetization type and set the value if override is True
if model.electronic_type == "metal" and model.override:
if model.basic.electronic_type == "metal" and model.advanced.override:
if model.magnetization.type == "tot_magnetization":
parameters["pw"]["parameters"]["SYSTEM"]["tot_magnetization"] = (
model.magnetization.total
)
else:
parameters["initial_magnetic_moments"] = model.magnetization.moments
# Conditions for insulator systems. Default value is 0.0
elif model.electronic_type == "insulator":
elif model.basic.electronic_type == "insulator":
parameters["pw"]["parameters"]["SYSTEM"]["tot_magnetization"] = (
model.magnetization.total
)

# Spin-Orbit calculation
if model.spin_orbit == "soc":
if model.advanced.spin_orbit == "soc":
parameters["pw"]["parameters"]["SYSTEM"]["lspinorb"] = True
parameters["pw"]["parameters"]["SYSTEM"]["noncolin"] = True
parameters["pw"]["parameters"]["SYSTEM"]["nspin"] = 4
Expand All @@ -426,37 +406,37 @@ def set_panel_value(self, parameters):
model.pseudos.ecutwfc = parameters["pw"]["parameters"]["SYSTEM"]["ecutwfc"]
model.pseudos.ecutrho = parameters["pw"]["parameters"]["SYSTEM"]["ecutrho"]
#
model.kpoints_distance = parameters.get("kpoints_distance", 0.15)
model.advanced.kpoints_distance = parameters.get("kpoints_distance", 0.15)
if parameters.get("pw") is not None:
system = parameters["pw"]["parameters"]["SYSTEM"]
if "degauss" in system:
model.smearing.degauss = system["degauss"]
if "smearing" in system:
model.smearing.type = system["smearing"]
model.total_charge = parameters["pw"]["parameters"]["SYSTEM"].get(
model.advanced.total_charge = parameters["pw"]["parameters"]["SYSTEM"].get(
"tot_charge", 0
)
model.spin_orbit = "soc" if "lspinorb" in system else "wo_soc"
model.advanced.spin_orbit = "soc" if "lspinorb" in system else "wo_soc"
# van der waals correction
model.van_der_waals = self.dftd3_version.get(
model.advanced.van_der_waals = self.dftd3_version.get(
system.get("dftd3_version"),
parameters["pw"]["parameters"]["SYSTEM"].get("vdw_corr", "none"),
)

# convergence threshold setting
model.forc_conv_thr = (
model.advanced.forc_conv_thr = (
parameters.get("pw", {})
.get("parameters", {})
.get("CONTROL", {})
.get("forc_conv_thr", 0.0)
)
model.etot_conv_thr = (
model.advanced.etot_conv_thr = (
parameters.get("pw", {})
.get("parameters", {})
.get("CONTROL", {})
.get("etot_conv_thr", 0.0)
)
model.scf_conv_thr = (
model.advanced.scf_conv_thr = (
parameters.get("pw", {})
.get("parameters", {})
.get("ELECTRONS", {})
Expand Down Expand Up @@ -488,67 +468,9 @@ def set_panel_value(self, parameters):
model.hubbard.eigenvalues_label = True
model.hubbard.eigenvalues = starting_ns_eigenvalue

@tl.observe("protocol")
def _on_protocol_change(self, _=None):
self._update_model()

@tl.observe("input_structure")
def _on_input_structure_change(self, _):
if model.input_structure:
self._update_model()
self._display_mesh()
if isinstance(model.input_structure, HubbardStructureData):
model.override = True

def _on_override_change(self, change):
if not change["new"]:
self.reset()

def _update_model(self):
"""Update the model from the given protocol."""

def set_value_and_step(attribute, value):
"""Sets the value and step size.
Parameters:
attribute (str):
The attribute whose values are to be set (e.g., self.etot_conv_thr).
value (float):
The numerical value to set.
"""
setattr(model, attribute, value)
if value != 0:
order_of_magnitude = np.floor(np.log10(abs(value)))
setattr(model, f"{attribute}_step", 10 ** (order_of_magnitude - 1))
else:
setattr(model, f"{attribute}_step", 0.1)

parameters = PwBaseWorkChain.get_protocol_inputs(model.protocol)

model.kpoints_distance = parameters["kpoints_distance"]

num_atoms = len(model.input_structure.sites) if model.input_structure else 1

etot_value = num_atoms * parameters["meta_parameters"]["etot_conv_thr_per_atom"]
set_value_and_step("etot_conv_thr", etot_value)

scf_value = num_atoms * parameters["meta_parameters"]["conv_thr_per_atom"]
set_value_and_step("scf_conv_thr", scf_value)

forc_value = parameters["pw"]["parameters"]["CONTROL"]["forc_conv_thr"]
set_value_and_step("forc_conv_thr", forc_value)

def _display_mesh(self, _=None):
if model.input_structure is None:
return
if model.kpoints_distance > 0:
# To avoid creating an aiida node every time we change the kpoints_distance,
# we use the function itself instead of the decorated calcfunction.
mesh = create_kpoints_from_distance.process_class._func(
model.input_structure,
orm.Float(model.kpoints_distance),
orm.Bool(False),
)
model.mesh_grid = "Mesh " + str(mesh.get_kpoints_mesh()[0])
else:
model.mesh_grid = "Please select a number higher than 0.0"
def _on_kpoints_distance_change(self, change):
model.advanced.update_kpoints_mesh(model.input_structure, change["new"])
2 changes: 1 addition & 1 deletion src/aiidalab_qe/app/configuration/hubbard.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def render(self):
(self.activate_hubbard, "value"),
)
ipw.dlink(
(model, "override"),
(model.advanced, "override"),
(self.activate_hubbard, "disabled"),
lambda override: not override,
)
Expand Down
Loading

0 comments on commit 9b6d792

Please sign in to comment.