Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add a weights and biases logger #89

Merged
merged 1 commit into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,4 @@ results/
outputs/
multirun/
.neptune
.wandb
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ The current code in Stoix was initially **largely** taken and subsequently adapt

2. **Hydra Config System:** Leverage the Hydra configuration system for efficient and consistent management of experiments, network architectures, and environments. Hydra facilitates the easy addition of new hyperparameters and supports multi-runs and Optuna hyperparameter optimization. No more need to create large bash scripts to run a series of experiments with differing hyperparameters, network architectures or environments.

3. **Advanced Logging:** Stoix features advanced and configurable logging, ready for output to the terminal, TensorBoard, and other ML tracking dashboards. It also supports logging experiments in JSON format ready for statistical tests and generating RLiable plots (see the plotting notebook). This enables statistically confident comparisons of algorithms natively.
3. **Advanced Logging:** Stoix features advanced and configurable logging, ready for output to the terminal, TensorBoard, and other ML tracking dashboards (WandB and Neptune). It also supports logging experiments in JSON format ready for statistical tests and generating RLiable plots (see the plotting notebook). This enables statistically confident comparisons of algorithms natively.

Stoix currently offers the following building blocks for Single-Agent RL research:

Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ rlax
tdqm
tensorboard_logger
tensorflow_probability
wandb
xminigrid @ git+https://github.com/corl-team/xland-minigrid.git@main
9 changes: 5 additions & 4 deletions stoix/configs/logger/base_logger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ use_console: True # Whether to log to stdout.
use_tb: False # Whether to use tensorboard logging.
use_json: False # Whether to log marl-eval style to json files.
use_neptune: False # Whether to log to neptune.ai.
use_wandb: False # Whether to log to wandb.ai.

# --- Other logger kwargs ---
kwargs:
neptune_project: ~ # Project name in neptune.ai
neptune_tag: [stoix]
detailed_neptune_logging: False # having mean/std/min/max can clutter neptune so we make it optional
project: ~ # Project name in neptune.ai or wandb.ai.
tags: [stoix] # Tags to add to the experiment.
detailed_logging: False # having mean/std/min/max can clutter neptune/wandb so we make it optional
json_path: ~ # If set, json files will be logged to a set path so that multiple experiments can
# write to the same json file for easy downstream aggregation and plotting with marl-eval.
upload_json_data: False # Whether JSON file data should be uploaded to Neptune for downstream
upload_json_data: False # Whether JSON file data should be uploaded to Neptune/WandB for downstream
# aggregation and plotting of data from multiple experiments. Note that when uploading JSON files,
# `json_path` must be unset to ensure that uploaded json files don't continue getting larger
# over time. Setting both will break.
Expand Down
57 changes: 53 additions & 4 deletions stoix/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax
import neptune
import numpy as np
import wandb
from colorama import Fore, Style
from jax.typing import ArrayLike
from marl_eval.json_tools import JsonLogger as MarlEvalJsonLogger
Expand Down Expand Up @@ -134,13 +135,13 @@ class NeptuneLogger(BaseLogger):
"""Logger for neptune.ai."""

def __init__(self, cfg: DictConfig, unique_token: str) -> None:
tags = list(cfg.logger.kwargs.neptune_tag)
project = cfg.logger.kwargs.neptune_project
tags = list(cfg.logger.kwargs.tags)
project = cfg.logger.kwargs.project

self.logger = neptune.init_run(project=project, tags=tags)

self.logger["config"] = stringify_unsupported(cfg)
self.detailed_logging = cfg.logger.kwargs.detailed_neptune_logging
self.detailed_logging = cfg.logger.kwargs.detailed_logging

# Store json path for uploading json data to Neptune.
json_exp_path = get_logger_path(cfg, "json")
Expand Down Expand Up @@ -176,6 +177,52 @@ def _zip_and_upload_json(self) -> None:
self.logger[f"metrics/metrics_{self.unique_token}"].upload(zip_file_path)


class WandBLogger(BaseLogger):
"""Logger for wandb.ai."""

def __init__(self, cfg: DictConfig, unique_token: str) -> None:
tags = list(cfg.logger.kwargs.tags)
project = cfg.logger.kwargs.project

wandb.init(project=project, tags=tags, config=stringify_unsupported(cfg))

self.detailed_logging = cfg.logger.kwargs.detailed_logging

# Store json path for uploading json data to Neptune.
json_exp_path = get_logger_path(cfg, "json")
self.json_file_path = os.path.join(
cfg.logger.base_exp_path, f"{json_exp_path}/{unique_token}/metrics.json"
)
self.unique_token = unique_token
self.upload_json_data = cfg.logger.kwargs.upload_json_data

def log_stat(self, key: str, value: float, step: int, eval_step: int, event: LogEvent) -> None:
# Main metric if it's the mean of a list of metrics (ends with '/mean')
# or it's a single metric doesn't contain a '/'.
is_main_metric = "/" not in key or key.endswith("/mean")
# If we're not detailed logging (logging everything) then make sure it's a main metric.
if not self.detailed_logging and not is_main_metric:
return

data_to_log = {f"{event.value}/{key}": value}
wandb.log(data_to_log, step=step)

def stop(self) -> None:
if self.upload_json_data:
self._zip_and_upload_json()
wandb.finish()

def _zip_and_upload_json(self) -> None:
# Create the zip file path by replacing '.json' with '.zip'
zip_file_path = self.json_file_path.rsplit(".json", 1)[0] + ".zip"

# Create a zip file containing the specified JSON file
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
zipf.write(self.json_file_path)

wandb.save(zip_file_path)


class TensorboardLogger(BaseLogger):
"""Logger for tensorboard"""

Expand Down Expand Up @@ -290,7 +337,7 @@ def _make_multi_logger(cfg: DictConfig) -> BaseLogger:
unique_token = datetime.now().strftime("%Y%m%d%H%M%S")

if (
cfg.logger.use_neptune
(cfg.logger.use_neptune or cfg.logger.use_wandb)
and cfg.logger.use_json
and cfg.logger.kwargs.upload_json_data
and cfg.logger.kwargs.json_path
Expand All @@ -305,6 +352,8 @@ def _make_multi_logger(cfg: DictConfig) -> BaseLogger:

if cfg.logger.use_neptune:
loggers.append(NeptuneLogger(cfg, unique_token))
if cfg.logger.use_wandb:
loggers.append(WandBLogger(cfg, unique_token))
if cfg.logger.use_tb:
loggers.append(TensorboardLogger(cfg, unique_token))
if cfg.logger.use_json:
Expand Down
Loading