Skip to content

Commit

Permalink
Expose some GUI attrs as public
Browse files Browse the repository at this point in the history
  • Loading branch information
sdatkinson committed Sep 19, 2024
1 parent 625aa8a commit 38ad854
Showing 1 changed file with 50 additions and 46 deletions.
96 changes: 50 additions & 46 deletions nam/train/gui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _is_mac() -> bool:


@dataclass
class _AdvancedOptions(object):
class AdvancedOptions(object):
"""
:param architecture: Which architecture to use.
:param num_epochs: How many epochs to train for.
Expand Down Expand Up @@ -389,7 +389,13 @@ class _GUIWidgets(Enum):
UPDATE = "update"


class _GUI(object):
@dataclass
class Checkbox(object):
variable: tk.BooleanVar
check_button: tk.Checkbutton


class GUI(object):
def __init__(self):
self._root = tk.Tk()
self._root.title(f"NAM Trainer - v{__version__}")
Expand Down Expand Up @@ -457,7 +463,7 @@ def __init__(self):

# Advanced options for training
default_architecture = core.Architecture.STANDARD
self.advanced_options = _AdvancedOptions(
self.advanced_options = AdvancedOptions(
default_architecture,
_DEFAULT_NUM_EPOCHS,
_DEFAULT_DELAY,
Expand Down Expand Up @@ -526,11 +532,6 @@ def _get_additional_options_frame(self):
self._frame_checkboxes.pack(side=tk.LEFT)
row = 1

@dataclass
class Checkbox(object):
variable: tk.BooleanVar
check_button: tk.Checkbutton

def make_checkbox(
key: _CheckboxKeys, text: str, default_value: bool
) -> Checkbox:
Expand Down Expand Up @@ -567,7 +568,7 @@ def _open_advanced_options(self):
Open window for advanced options
"""

self._wait_while_func(lambda resume: _AdvancedOptionsGUI(resume, self))
self._wait_while_func(lambda resume: AdvancedOptionsGUI(resume, self))

def _open_metadata(self):
"""
Expand Down Expand Up @@ -991,16 +992,51 @@ def get(self):
return None


class _AdvancedOptionsGUI(object):
class AdvancedOptionsGUI(object):
"""
A window to hold advanced options (Architecture and number of epochs)
"""

def __init__(self, resume_main, parent: _GUI):
def __init__(self, resume_main, parent: GUI):
self._parent = parent
self._root = _TopLevelWithOk(self._apply, resume_main)
self._root = _TopLevelWithOk(self.apply, resume_main)
self._root.title("Advanced Options")

self.pack_options()

# "Ok": apply and destroy
self._frame_ok = tk.Frame(self._root)
self._frame_ok.pack()
self._button_ok = tk.Button(
self._frame_ok,
text="Ok",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
command=lambda: self._root.destroy(pressed_ok=True),
)
self._button_ok.pack()

def apply(self):
"""
Set values to parent and destroy this object
"""
self._parent.advanced_options.architecture = self._architecture.get()
epochs = self._epochs.get()
if epochs is not None:
self._parent.advanced_options.num_epochs = epochs
latency = self._latency.get()
# Value None is returned as "null" to disambiguate from non-set.
if latency is not None:
self._parent.advanced_options.latency = (
None if latency == "null" else latency
)
threshold_esr = self._threshold_esr.get()
if threshold_esr is not None:
self._parent.advanced_options.threshold_esr = (
None if threshold_esr == "null" else threshold_esr
)

def pack_options(self):
# Architecture: radio buttons
self._frame_architecture = tk.Frame(self._root)
self._frame_architecture.pack()
Expand Down Expand Up @@ -1043,44 +1079,12 @@ def __init__(self, resume_main, parent: _GUI):
type=_float_or_null,
)

# "Ok": apply and destroy
self._frame_ok = tk.Frame(self._root)
self._frame_ok.pack()
self._button_ok = tk.Button(
self._frame_ok,
text="Ok",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
command=lambda: self._root.destroy(pressed_ok=True),
)
self._button_ok.pack()

def _apply(self):
"""
Set values to parent and destroy this object
"""
self._parent.advanced_options.architecture = self._architecture.get()
epochs = self._epochs.get()
if epochs is not None:
self._parent.advanced_options.num_epochs = epochs
latency = self._latency.get()
# Value None is returned as "null" to disambiguate from non-set.
if latency is not None:
self._parent.advanced_options.latency = (
None if latency == "null" else latency
)
threshold_esr = self._threshold_esr.get()
if threshold_esr is not None:
self._parent.advanced_options.threshold_esr = (
None if threshold_esr == "null" else threshold_esr
)


class _UserMetadataGUI(object):
# Things that are auto-filled:
# Model date
# gain
def __init__(self, resume_main, parent: _GUI):
def __init__(self, resume_main, parent: GUI):
self._parent = parent
self._root = _TopLevelWithOk(self._apply, resume_main)
self._root.title("Metadata")
Expand Down Expand Up @@ -1184,7 +1188,7 @@ def _install_error():

def run():
if _install_is_valid:
_gui = _GUI()
_gui = GUI()
_gui.mainloop()
print("Shut down NAM trainer")
else:
Expand Down

0 comments on commit 38ad854

Please sign in to comment.