Skip to content

Commit

Permalink
fix(tasks): update ticks on loss plot axes
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <[email protected]>
  • Loading branch information
cameronraysmith committed Aug 15, 2024
1 parent eb0b9f9 commit 350499f
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/pyrovelocity/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
from anndata._core.anndata import AnnData
from beartype import beartype
from matplotlib.ticker import ScalarFormatter, SymmetricalLogLocator
from mlflow import MlflowClient
from numpy import ndarray
from scvi.model._utils import parse_device_args
Expand Down Expand Up @@ -416,6 +417,13 @@ def set_loss_plot_axes(ax):
ax.set_xlabel("Epochs")
ax.set_ylabel("-ELBO")

locator = SymmetricalLogLocator(base=10, linthresh=1, subs=[1.0])
ax.yaxis.set_major_locator(locator)

formatter = ScalarFormatter()
formatter.set_scientific(False)
ax.yaxis.set_major_formatter(formatter)


def log_run_info(r: mlflow.entities.run.Run) -> None:
tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
Expand Down

0 comments on commit 350499f

Please sign in to comment.