Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change class names #45

Merged
merged 4 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions abics/applications/latgas_abinitio_interface/aenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def from_XSF(input_string):
return s


class aenetSolver(SolverBase):
class AenetSolver(SolverBase):
"""
This class defines the aenet solver.
"""
Expand All @@ -138,10 +138,10 @@ def __init__(self, path_to_solver, ignore_species=None, run_scheme="subprocess")
path_to_solver : str
Path to the solver.
"""
super(aenetSolver, self).__init__(path_to_solver)
super(AenetSolver, self).__init__(path_to_solver)
self.path_to_solver = path_to_solver
self.input = aenetSolver.Input(ignore_species, run_scheme)
self.output = aenetSolver.Output()
self.input = AenetSolver.Input(ignore_species, run_scheme)
self.output = AenetSolver.Output()

def name(self):
return "aenet"
Expand Down
61 changes: 28 additions & 33 deletions abics/applications/latgas_abinitio_interface/default_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

from abics.mc import observer_base
from ...util import expand_path
from typing import Tuple

import os
import numpy as np

from abics import __version__
from abics.util import expand_path
from abics.mc import ObserverBase, MCAlgorithm

class default_observer(observer_base):

class DefaultObserver(ObserverBase):
"""
Default observer.

Expand All @@ -30,6 +33,7 @@ class default_observer(observer_base):
minE : float
Minimum of energy
"""

def __init__(self, comm, Lreload=False):
"""

Expand All @@ -40,7 +44,7 @@ def __init__(self, comm, Lreload=False):
Lreload: bool
Reload or not
"""
super(default_observer, self).__init__()
super(DefaultObserver, self).__init__()
self.minE = 100000.0
myrank = comm.Get_rank()
if Lreload:
Expand All @@ -49,42 +53,24 @@ def __init__(self, comm, Lreload=False):
with open(os.path.join(str(myrank), "obs.dat"), "r") as f:
self.lprintcount = int(f.readlines()[-1].split()[0]) + 1

def logfunc(self, calc_state):
"""

Parameters
----------
calc_state: MCalgo
Object of Monte Carlo algorithm
Returns
-------
calc_state.energy : float
Minimum energy
"""
def logfunc(self, calc_state: MCAlgorithm) -> Tuple[float]:
if calc_state.energy < self.minE:
self.minE = calc_state.energy
with open("minEfi.dat", "a") as f:
f.write(str(self.minE) + "\n")
calc_state.config.structure_norel.to(fmt="POSCAR", filename="minE.vasp")
return calc_state.energy
return (calc_state.energy,)

def writefile(self, calc_state):
"""

Parameters
----------
calc_state: MCalgo
Object of Monte Carlo algorithm

"""
def writefile(self, calc_state: MCAlgorithm) -> None:
calc_state.config.structure.to(
fmt="POSCAR", filename="structure." + str(self.lprintcount) + ".vasp"
)
calc_state.config.structure_norel.to(
fmt="POSCAR", filename="structure_norel." + str(self.lprintcount) + ".vasp"
)

class ensemble_error_observer(default_observer):

class EnsembleErrorObserver(DefaultObserver):
def __init__(self, comm, energy_calculators, Lreload=False):
"""

Expand All @@ -97,11 +83,11 @@ def __init__(self, comm, energy_calculators, Lreload=False):
Lreload: bool
Reload or not
"""
super(ensemble_error_observer, self).__init__(comm, Lreload)
super(EnsembleErrorObserver, self).__init__(comm, Lreload)
self.calculators = energy_calculators
self.comm = comm

def logfunc(self, calc_state):
def logfunc(self, calc_state: MCAlgorithm):
if calc_state.energy < self.minE:
self.minE = calc_state.energy
with open("minEfi.dat", "a") as f:
Expand All @@ -110,17 +96,21 @@ def logfunc(self, calc_state):
energies = [calc_state.energy]
npar = self.comm.Get_size()
if npar > 1:
assert(npar == len(self.calculators))
assert npar == len(self.calculators)
myrank = self.comm.Get_rank()
energy, _ = self.calculators[myrank].submit(calc_state.config.structure, os.path.join(os.getcwd(),"ensemble{}".format(myrank)))
energy, _ = self.calculators[myrank].submit(
calc_state.config.structure,
os.path.join(os.getcwd(), "ensemble{}".format(myrank)),
)
energies_tmp = self.comm.allgather(energy)
std = np.std(energies_tmp, ddof=1)

else:
energies_tmp = []
for i, calculator in enumerate(self.calculators):
energy, _ = calculator.submit(
calc_state.config.structure, os.path.join(os.getcwd(), "ensemble{}".format(i))
calc_state.config.structure,
os.path.join(os.getcwd(), "ensemble{}".format(i)),
)
energies_tmp.append(energy)
std = np.std(energies_tmp, ddof=1)
Expand All @@ -129,7 +119,6 @@ def logfunc(self, calc_state):
return np.asarray(energies)



class EnsembleParams:
@classmethod
def from_dict(cls, d):
Expand Down Expand Up @@ -186,3 +175,9 @@ def from_toml(cls, f):
import toml

return cls.from_dict(toml.load(f))


# For backward compatibility
if __version__ < "3":
default_observer = DefaultObserver
ensemble_error_observer = EnsembleErrorObserver
34 changes: 24 additions & 10 deletions abics/applications/latgas_abinitio_interface/defect.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,26 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

from pymatgen.core import Lattice
import typing
from typing import Any, Dict, List, MutableMapping

from .model_setup import config, defect_sublattice, base_structure
from pymatgen.core import Lattice, Structure

from .model_setup import Config, DefectSublattice, base_structure

from abics.exception import InputError
from abics.util import read_matrix


class DFTConfigParams:
def __init__(self, dconfig):

lat: Lattice
supercell: List[int]
base_structure: Structure
defect_sublattices: List[DefectSublattice]
num_defects: List[Dict[str, int]]

def __init__(self, dconfig: MutableMapping[str, Any]) -> None:
"""
Get information from dictionary

Expand All @@ -37,12 +47,16 @@ def __init__(self, dconfig):
dconfig = dconfig["config"]

if "unitcell" not in dconfig:
raise InputError('"unitcell" is not found in the "config" section.')
raise InputError('"config.unitcell" is not found in the "config" section.')
self.lat = Lattice(read_matrix(dconfig["unitcell"]))

self.supercell = dconfig.get("supercell", [1, 1, 1])
self.supercell = typing.cast(List[int], dconfig.get("supercell", [1, 1, 1]))
if not isinstance(self.supercell, list):
raise InputError('"config.supercell" should be a list of integers')

constraint_module = dconfig.get("constraint_module", False)
if not isinstance(constraint_module, bool):
raise InputError('"config.constraint_module" should be true or false')
if constraint_module:
import os, sys

Expand All @@ -60,19 +74,19 @@ def __init__(self, dconfig):
if "defect_structure" not in dconfig:
raise InputError('"defect_structure" is not found in the "config" section.')
self.defect_sublattices = [
defect_sublattice.from_dict(ds) for ds in dconfig["defect_structure"]
DefectSublattice.from_dict(ds) for ds in dconfig["defect_structure"]
]
self.num_defects = [
{g["name"]: g["num"] for g in ds["groups"]}
for ds in dconfig["defect_structure"]
]

@classmethod
def from_dict(cls, d):
def from_dict(cls, d: MutableMapping) -> "DFTConfigParams":
return cls(d)

@classmethod
def from_toml(cls, f):
def from_toml(cls, f: str) -> "DFTConfigParams":
"""

Get information from toml file.
Expand All @@ -92,7 +106,7 @@ def from_toml(cls, f):
return cls(toml.load(f))


def defect_config(cparam: DFTConfigParams):
def defect_config(cparam: DFTConfigParams) -> Config:
"""
Get configuration information

Expand All @@ -106,7 +120,7 @@ def defect_config(cparam: DFTConfigParams):
spinel_config: config object
spinel configure object
"""
spinel_config = config(
spinel_config = Config(
cparam.base_structure,
cparam.defect_sublattices,
cparam.num_defects,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def map2perflat(
perf_st = perf_st.copy()
N = perf_st.num_sites
seldyn = np.array(
perf_st.site_properties.get("seldyn", np.ones((N, 3))), dtype=np.float
perf_st.site_properties.get("seldyn", np.ones((N, 3))), dtype=np.float64
)
perf_st.replace_species(vac_spaceholder)
mapping = naive_mapping(st, perf_st)
Expand Down
Loading