Skip to content

Commit

Permalink
[FEATURE] Automatically check for updates (#467)
Browse files Browse the repository at this point in the history
* Button for update when available

* Rearrange methods, add shut down print statement

* Cleanup imports

* Function for getting current version

* Handle when there's no internet
  • Loading branch information
sdatkinson authored Sep 19, 2024
1 parent 26fdad7 commit 625aa8a
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 6 deletions.
13 changes: 13 additions & 0 deletions nam/train/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,20 @@
Version utility
"""

from .._version import __version__


class Version:
def __init__(self, major: int, minor: int, patch: int):
self.major = major
self.minor = minor
self.patch = patch

@classmethod
def from_string(cls, s: str):
major, minor, patch = [int(x) for x in s.split(".")]
return cls(major, minor, patch)

def __eq__(self, other) -> bool:
return (
self.major == other.major
Expand All @@ -21,6 +28,8 @@ def __eq__(self, other) -> bool:
)

def __lt__(self, other) -> bool:
if self == other:
return False
if self.major != other.major:
return self.major < other.major
if self.minor != other.minor:
Expand All @@ -33,3 +42,7 @@ def __str__(self) -> str:


PROTEUS_VERSION = Version(4, 0, 0)


def get_current_version() -> Version:
return Version.from_string(__version__)
101 changes: 98 additions & 3 deletions nam/train/gui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
"""

import re
import requests
import tkinter as tk
import subprocess
import sys
import webbrowser
from dataclasses import dataclass
from enum import Enum
from functools import partial
from pathlib import Path
from tkinter import filedialog
from typing import Callable, Dict, Optional, Sequence
from typing import Callable, Dict, NamedTuple, Optional, Sequence

try: # 3rd-party and 1st-party imports
import torch
Expand All @@ -33,7 +35,7 @@
# Ok private access here--this is technically allowed access
from nam.train import metadata
from nam.train._names import INPUT_BASENAMES, LATEST_VERSION
from nam.train.metadata import TRAINING_KEY
from nam.train._version import Version, get_current_version

_install_is_valid = True
_HAVE_ACCELERATOR = torch.cuda.is_available() or torch.backends.mps.is_available()
Expand Down Expand Up @@ -384,6 +386,7 @@ class _GUIWidgets(Enum):
METADATA = "metadata"
ADVANCED_OPTIONS = "advanced_options"
TRAIN = "train"
UPDATE = "update"


class _GUI(object):
Expand Down Expand Up @@ -446,7 +449,9 @@ def __init__(self):
# Last frames: avdanced options & train in the SE corner:
self._frame_advanced_options = tk.Frame(self._root)
self._frame_train = tk.Frame(self._root)
# Pack train first so that it's on bottom.
self._frame_update = tk.Frame(self._root)
# Pack must be in reverse order
self._frame_update.pack(side=tk.BOTTOM, anchor="e")
self._frame_train.pack(side=tk.BOTTOM, anchor="e")
self._frame_advanced_options.pack(side=tk.BOTTOM, anchor="e")

Expand Down Expand Up @@ -481,6 +486,8 @@ def __init__(self):
)
self._widgets[_GUIWidgets.TRAIN].pack()

self._pack_update_button_if_update_is_available()

self._check_button_states()

def get_mrstft_fit(self) -> bool:
Expand Down Expand Up @@ -569,6 +576,93 @@ def _open_metadata(self):

self._wait_while_func(lambda resume: _UserMetadataGUI(resume, self))

def _pack_update_button(self, version_from: Version, version_to: Version):
"""
Pack a button that a user can click to update
"""

def update_nam():
result = subprocess.run(
[
f"{sys.executable}",
"-m",
"pip",
"install",
"--upgrade",
"neural-amp-modeler",
]
)
if result.returncode == 0:
self._wait_while_func(
(lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)),
"Update complete! Restart NAM for changes to take effect.",
)
else:
self._wait_while_func(
(lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)),
"Update failed! See logs.",
)

self._widgets[_GUIWidgets.UPDATE] = tk.Button(
self._frame_update,
text=f"Update ({str(version_from)} -> {str(version_to)})",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
command=update_nam,
)
self._widgets[_GUIWidgets.UPDATE].pack()

def _pack_update_button_if_update_is_available(self):
class UpdateInfo(NamedTuple):
available: bool
current_version: Version
new_version: Optional[Version]

def get_info() -> UpdateInfo:
# TODO error handling
url = f"https://api.github.com/repos/sdatkinson/neural-amp-modeler/releases"
current_version = get_current_version()
try:
response = requests.get(url)
except requests.exceptions.ConnectionError:
print("WARNING: Failed to reach the server to check for updates")
return UpdateInfo(
available=False, current_version=current_version, new_version=None
)
if response.status_code != 200:
print(f"Failed to fetch releases. Status code: {response.status_code}")
return UpdateInfo(
available=False, current_version=current_version, new_version=None
)
else:
releases = response.json()
latest_version = None
if releases:
for release in releases:
tag = release["tag_name"]
if not tag.startswith("v"):
print(f"Found invalid version {tag}")
else:
this_version = Version.from_string(tag[1:])
if latest_version is None or this_version > latest_version:
latest_version = this_version
else:
print("No releases found for this repository.")
update_available = (
latest_version is not None and latest_version > current_version
)
return UpdateInfo(
available=update_available,
current_version=current_version,
new_version=latest_version,
)

update_info = get_info()
if update_info.available:
self._pack_update_button(
update_info.current_version, update_info.new_version
)

def _resume(self):
self._set_all_widget_states_to(tk.NORMAL)
self._check_button_states()
Expand Down Expand Up @@ -1092,6 +1186,7 @@ def run():
if _install_is_valid:
_gui = _GUI()
_gui.mainloop()
print("Shut down NAM trainer")
else:
_install_error()

Expand Down
3 changes: 0 additions & 3 deletions tests/test_nam/test_train/test_gui/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
# Created Date: Friday May 24th 2024
# Author: Steven Atkinson ([email protected])

import importlib
import os
import tkinter as tk
from pathlib import Path

import pytest

Expand Down

0 comments on commit 625aa8a

Please sign in to comment.