Skip to content

Commit

Permalink
Adds callable option to get secrets for the W&B plugin (flyteorg#2449)
Browse files Browse the repository at this point in the history
* Adds callable option to get secrets

Signed-off-by: Thomas J. Fan <[email protected]>

* DOC Improve docstring

Signed-off-by: Thomas J. Fan <[email protected]>

* Use wandb.login instead of environment variable

Signed-off-by: Thomas J. Fan <[email protected]>

---------

Signed-off-by: Thomas J. Fan <[email protected]>
Signed-off-by: bugra.gedik <[email protected]>
  • Loading branch information
thomasjpfan authored and bugra.gedik committed Jul 3, 2024
1 parent eb86b4b commit 853a498
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 13 deletions.
24 changes: 18 additions & 6 deletions plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Callable, Optional
from typing import Callable, Optional, Union

import wandb
from flytekit import Secret
Expand All @@ -21,19 +21,22 @@ def __init__(
task_function: Optional[Callable] = None,
project: Optional[str] = None,
entity: Optional[str] = None,
secret: Optional[Secret] = None,
secret: Optional[Union[Secret, Callable]] = None,
id: Optional[str] = None,
host: str = "https://wandb.ai",
api_host: str = "https://api.wandb.ai",
**init_kwargs: dict,
):
"""Weights and Biases plugin.
Args:
task_function (function, optional): The user function to be decorated. Defaults to None.
project (str): The name of the project where you're sending the new run. (Required)
entity (str): An entity is a username or team name where you're sending runs. (Required)
secret (Secret): Secret with your `WANDB_API_KEY`. (Required)
secret (Secret or Callable): Secret with your `WANDB_API_KEY` or a callable that returns the API key.
The callable takes no arguments and returns a string. (Required)
id (str, optional): A unique id for this wandb run.
host (str, optional): URL to your wandb service. The default is "https://wandb.ai".
api_host (str, optional): URL to your API Host, The default is "https://api.wandb.ai".
**init_kwargs (dict): The rest of the arguments are passed directly to `wandb.init`. Please see
[the `wandb.init` docs](https://docs.wandb.ai/ref/python/init) for details.
"""
Expand All @@ -50,6 +53,7 @@ def __init__(
self.init_kwargs = init_kwargs
self.secret = secret
self.host = host
self.api_host = api_host

# All kwargs need to be passed up so that the function wrapping works for both
# `@wandb_init` and `@wandb_init(...)`
Expand All @@ -60,6 +64,7 @@ def __init__(
secret=secret,
id=id,
host=host,
api_host=api_host,
**init_kwargs,
)

Expand All @@ -72,9 +77,16 @@ def execute(self, *args, **kwargs):
# will generate it's own id.
wand_id = self.id
else:
# Set secret for remote execution
secrets = ctx.user_space_params.secrets
os.environ["WANDB_API_KEY"] = secrets.get(key=self.secret.key, group=self.secret.group)
if isinstance(self.secret, Secret):
# Set secret for remote execution
secrets = ctx.user_space_params.secrets
wandb_api_key = secrets.get(key=self.secret.key, group=self.secret.group)
else:
# Get API key with callable
wandb_api_key = self.secret()

wandb.login(key=wandb_api_key, host=self.api_host)

if self.id is None:
# The HOSTNAME is set to {.executionName}-{.nodeID}-{.taskRetryAttempt}
# If HOSTNAME is not defined, use the execution name as a fallback
Expand Down
44 changes: 37 additions & 7 deletions plugins/flytekit-wandb/tests/test_wandb_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@
from flytekitplugins.wandb import wandb_init
from flytekitplugins.wandb.tracking import WANDB_CUSTOM_TYPE_VALUE, WANDB_EXECUTION_TYPE_VALUE

from flytekit import task
from flytekit import Secret, task

secret = Secret(key="abc", group="xyz")


@pytest.mark.parametrize("id", [None, "abc123"])
def test_wandb_extra_config(id):
wandb_decorator = wandb_init(
project="abc",
entity="xyz",
secret_key="my-secret-key",
secret=secret,
id=id,
host="https://my_org.wandb.org",
)

assert wandb_decorator.secret is secret
extra_config = wandb_decorator.get_extra_config()

if id is None:
Expand All @@ -29,7 +32,7 @@ def test_wandb_extra_config(id):


@task
@wandb_init(project="abc", entity="xyz", secret_key="my-secret-key", secret_group="my-secret-group", tags=["my_tag"])
@wandb_init(project="abc", entity="xyz", secret=secret, tags=["my_tag"])
def train_model():
pass

Expand All @@ -42,7 +45,7 @@ def test_local_execution(wandb_mock):


@task
@wandb_init(project="abc", entity="xyz", secret_key="my-secret-key", tags=["my_tag"], id="1234")
@wandb_init(project="abc", entity="xyz", secret=secret, tags=["my_tag"], id="1234")
def train_model_with_id():
pass

Expand Down Expand Up @@ -71,8 +74,8 @@ def test_non_local_execution(wandb_mock, manager_mock, os_mock):
train_model()

wandb_mock.init.assert_called_with(project="abc", entity="xyz", id="my_execution_id", tags=["my_tag"])
ctx_mock.user_space_params.secrets.get.assert_called_with(key="my-secret-key", group="my-secret-group")
assert os_mock.environ["WANDB_API_KEY"] == "this_is_the_secret"
ctx_mock.user_space_params.secrets.get.assert_called_with(key="abc", group="xyz")
wandb_mock.login.assert_called_with(key="this_is_the_secret", host="https://api.wandb.ai")


def test_errors():
Expand All @@ -82,5 +85,32 @@ def test_errors():
with pytest.raises(ValueError, match="entity must be set"):
wandb_init(project="abc")

with pytest.raises(ValueError, match="secret_key must be set"):
with pytest.raises(ValueError, match="secret must be set"):
wandb_init(project="abc", entity="xyz")


def get_secret():
return "my-wandb-api-key"


@task
@wandb_init(project="my_project", entity="my_entity", secret=get_secret, tags=["my_tag"], id="1234")
def train_model_with_id_callable_secret():
pass


@patch("flytekitplugins.wandb.tracking.os")
@patch("flytekitplugins.wandb.tracking.FlyteContextManager")
@patch("flytekitplugins.wandb.tracking.wandb")
def test_secret_callable_remote(wandb_mock, manager_mock, os_mock):
# Pretend that the execution is remote
ctx_mock = Mock()
ctx_mock.execution_state.is_local_execution.return_value = False

manager_mock.current_context.return_value = ctx_mock
os_mock.environ = {}

train_model_with_id_callable_secret()

wandb_mock.init.assert_called_with(project="my_project", entity="my_entity", id="1234", tags=["my_tag"])
wandb_mock.login.assert_called_with(key=get_secret(), host="https://api.wandb.ai")

0 comments on commit 853a498

Please sign in to comment.