Skip to content

Commit

Permalink
feat: Logic of different modes for webhook (#9865)
Browse files Browse the repository at this point in the history
  • Loading branch information
gt2345 authored Aug 28, 2024
1 parent a773551 commit 54b6165
Show file tree
Hide file tree
Showing 16 changed files with 708 additions and 66 deletions.
131 changes: 126 additions & 5 deletions e2e_tests/tests/cluster/test_webhooks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import random
import time
import uuid

Expand Down Expand Up @@ -125,27 +126,147 @@ def test_log_pattern_send_webhook(should_match: bool) -> None:
),
)

workspace = bindings.post_PostWorkspace(
sess, body=bindings.v1PostWorkspaceRequest(name=f"webhook-test{random.random()}")
).workspace
project = bindings.post_PostProject(
sess,
body=bindings.v1PostProjectRequest(
name=f"webhook-test{random.random()}",
workspaceId=workspace.id,
),
workspaceId=workspace.id,
).project

specific_path = f"/test/path/here/{str(uuid.uuid4())}"
bindings.post_PostWebhook(
sess,
body=bindings.v1Webhook(
url=f"http://localhost:{port}{specific_path}",
webhookType=bindings.v1WebhookType.DEFAULT,
triggers=[webhook_trigger],
mode=bindings.v1WebhookMode.SPECIFIC,
name="specific-webhook",
workspaceId=workspace.id,
),
)

specific_path_unmatch = f"/test/path/here/{str(uuid.uuid4())}"
bindings.post_PostWebhook(
sess,
body=bindings.v1Webhook(
url=f"http://localhost:{port}{specific_path_unmatch}",
webhookType=bindings.v1WebhookType.DEFAULT,
triggers=[webhook_trigger],
mode=bindings.v1WebhookMode.SPECIFIC,
name=f"webhook-test{random.random()}",
workspaceId=1,
),
)

exp_id = exp.create_experiment(
sess,
conf.fixtures_path("no_op/single-medium-train-step.yaml"),
conf.fixtures_path("no_op"),
["--config", "hyperparameters.metrics_sigma=-1.0"],
[
"--config",
"hyperparameters.metrics_sigma=-1.0",
"--config",
"integrations.webhooks.webhook_name=['specific-webhook']",
"--project_id",
f"{project.id}",
],
)
exp.wait_for_experiment_state(sess, exp_id, bindings.experimentv1State.ERROR)

for _ in range(10):
for _ in range(8):
responses = server.return_responses()
if default_path in responses and slack_path in responses:
break
time.sleep(1)

responses = server.close_and_return_responses()
if should_match:
assert len(responses) >= 2
assert len(responses) >= 3
# Only need a spot check we get the default / slack responses.
# Further tested in integrations.
assert "TASK_LOG" in responses[default_path]
assert "This log matched the regex" in responses[slack_path]
assert "TASK_LOG" in responses[specific_path]
assert specific_path_unmatch not in responses
else:
assert default_path not in responses
assert slack_path not in responses
assert specific_path not in responses
assert specific_path_unmatch not in responses


@pytest.mark.e2e_cpu
def test_specific_webhook() -> None:
port1 = 5007
server1 = utils.WebhookServer(port1, allow_dupes=True)
port2 = 5008
server2 = utils.WebhookServer(port2, allow_dupes=True)
sess = api_utils.admin_session()

workspace = bindings.post_PostWorkspace(
sess, body=bindings.v1PostWorkspaceRequest(name=f"webhook-test{random.random()}")
).workspace
project = bindings.post_PostProject(
sess,
body=bindings.v1PostProjectRequest(
name=f"webhook-test{random.random()}",
workspaceId=workspace.id,
),
workspaceId=workspace.id,
).project

webhook_trigger = bindings.v1Trigger(
triggerType=bindings.v1TriggerType.EXPERIMENT_STATE_CHANGE,
condition={"state": "COMPLETED"},
)

webhook_1 = bindings.v1Webhook(
url=f"http://localhost:{port1}",
webhookType=bindings.v1WebhookType.SLACK,
triggers=[webhook_trigger],
mode=bindings.v1WebhookMode.SPECIFIC,
name=f"webhook_1{random.random()}",
workspaceId=1,
)

webhook_2 = bindings.v1Webhook(
url=f"http://localhost:{port2}",
webhookType=bindings.v1WebhookType.SLACK,
triggers=[webhook_trigger],
mode=bindings.v1WebhookMode.SPECIFIC,
name="webhook_2",
workspaceId=workspace.id,
)

webhook_res_1 = bindings.post_PostWebhook(sess, body=webhook_1).webhook
assert webhook_res_1.url == webhook_1.url
webhook_res_2 = bindings.post_PostWebhook(sess, body=webhook_2).webhook
assert webhook_res_2.url == webhook_2.url

experiment_id = exp.create_experiment(
sess,
conf.fixtures_path("no_op/single-one-short-step.yaml"),
conf.fixtures_path("no_op"),
[
"--project_id",
f"{project.id}",
"--config",
f"integrations.webhooks.webhook_id=[{webhook_res_1.id},{webhook_res_2.id}]",
],
)

exp.wait_for_experiment_state(
sess,
experiment_id,
bindings.experimentv1State.COMPLETED,
max_wait_secs=conf.DEFAULT_MAX_WAIT_SECS,
)

responses = server1.close_and_return_responses()
assert len(responses) == 0
responses = server2.close_and_return_responses()
assert len(responses) == 1
2 changes: 1 addition & 1 deletion master/internal/api_experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -2349,7 +2349,7 @@ func (a *apiServer) GetModelDef(
func (a *apiServer) GetTaskContextDirectory(
ctx context.Context, req *apiv1.GetTaskContextDirectoryRequest,
) (*apiv1.GetTaskContextDirectoryResponse, error) {
if err := a.canDoActionsOnTask(ctx, model.TaskID(req.TaskId),
if _, _, err := a.canDoActionsOnTask(ctx, model.TaskID(req.TaskId),
experiment.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions master/internal/api_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
func (a *apiServer) GetTask(
ctx context.Context, req *apiv1.GetTaskRequest,
) (resp *apiv1.GetTaskResponse, err error) {
if err := a.canDoActionsOnTask(ctx, model.TaskID(req.TaskId),
if _, _, err := a.canDoActionsOnTask(ctx, model.TaskID(req.TaskId),
expauth.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil {
return nil, err
}
Expand All @@ -34,7 +34,7 @@ func (a *apiServer) GetTask(
func (a *apiServer) GetGenericTaskConfig(
ctx context.Context, req *apiv1.GetGenericTaskConfigRequest,
) (resp *apiv1.GetGenericTaskConfigResponse, err error) {
if err := a.canDoActionsOnTask(ctx, model.TaskID(req.TaskId)); err != nil {
if _, _, err := a.canDoActionsOnTask(ctx, model.TaskID(req.TaskId)); err != nil {
return nil, err
}

Expand Down
61 changes: 36 additions & 25 deletions master/internal/api_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/determined-ai/determined/master/internal/task"
"github.com/determined-ai/determined/master/internal/webhooks"
"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/ptrs"
"github.com/determined-ai/determined/proto/pkg/apiv1"
"github.com/determined-ai/determined/proto/pkg/taskv1"
)
Expand Down Expand Up @@ -59,65 +60,74 @@ func expFromTaskID(
return true, exp, nil
}

func canAccessNTSCTask(ctx context.Context, curUser model.User, taskID model.TaskID) (bool, error) {
func canAccessNTSCTask(
ctx context.Context, curUser model.User, taskID model.TaskID,
) (bool, model.AccessScopeID, error) {
spec, err := command.IdentifyTask(ctx, taskID)
if errors.Is(err, db.ErrNotFound) {
// Non NTSC case like checkpointGC case or the task just does not exist.
// TODO(nick) eventually control access to checkpointGC.
return true, nil
return true, spec.WorkspaceID, nil
} else if err != nil {
return false, err
return false, spec.WorkspaceID, err
}
err = command.AuthZProvider.Get().CanGetNSC(
ctx, curUser, spec.WorkspaceID)
return !authz.IsPermissionDenied(err), err
return !authz.IsPermissionDenied(err), spec.WorkspaceID, err
}

func (a *apiServer) canDoActionsOnTask(
ctx context.Context, taskID model.TaskID,
actions ...func(context.Context, model.User, *model.Experiment) error,
) error {
) (*model.AccessScopeID, *int, error) {
errTaskNotFound := api.NotFoundErrs("task", fmt.Sprint(taskID), true)
t, err := db.TaskByID(ctx, taskID)
if errors.Is(err, db.ErrNotFound) {
return errTaskNotFound
return nil, nil, errTaskNotFound
} else if err != nil {
return err
return nil, nil, err
}

curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return err
return nil, nil, err
}

switch t.TaskType {
case model.TaskTypeTrial:
isExp, exp, err := expFromTaskID(ctx, taskID)
if !isExp {
return fmt.Errorf("error we failed to look up an experiment "+
return nil, nil, fmt.Errorf("error we failed to look up an experiment "+
"from taskID %s when we think it is a trial task", taskID)
}
if err != nil {
return err
return nil, nil, err
}

if err = expauth.AuthZProvider.Get().CanGetExperiment(ctx, *curUser, exp); err != nil {
return authz.SubIfUnauthorized(err, errTaskNotFound)
return nil, nil, authz.SubIfUnauthorized(err, errTaskNotFound)
}
for _, action := range actions {
if err = action(ctx, *curUser, exp); err != nil {
return status.Error(codes.PermissionDenied, err.Error())
return nil, nil, status.Error(codes.PermissionDenied, err.Error())
}
}
workspaceID, err := expauth.GetWorkspaceFromExperiment(ctx, exp)
if err != nil {
return nil, nil, err
}
return ptrs.Ptr(model.AccessScopeID(workspaceID)), ptrs.Ptr(exp.ID), nil
default: // NTSC case + checkpointGC.
if ok, err := canAccessNTSCTask(ctx, *curUser, taskID); err != nil {
ok, workspaceID, err := canAccessNTSCTask(ctx, *curUser, taskID)
if err != nil {
if !ok || authz.IsPermissionDenied(err) {
return errTaskNotFound
return nil, nil, errTaskNotFound
}
return err
return nil, nil, err
}
// When error is nil, workspaceID is guaranteed not nil
return &workspaceID, nil, nil
}
return nil
}

func (a *apiServer) canGetTaskAcceleration(ctx context.Context, taskID string) error {
Expand All @@ -131,7 +141,7 @@ func (a *apiServer) canGetTaskAcceleration(ctx context.Context, taskID string) e
}
if !isExp {
var ok bool
if ok, err = canAccessNTSCTask(ctx, *curUser, model.TaskID(taskID)); err != nil {
if ok, _, err = canAccessNTSCTask(ctx, *curUser, model.TaskID(taskID)); err != nil {
return err
} else if !ok {
return api.NotFoundErrs("task", taskID, true)
Expand Down Expand Up @@ -163,7 +173,7 @@ func (a *apiServer) canGetAllocation(ctx context.Context, allocationID string) e
}
if !isExp {
var ok bool
if ok, err = canAccessNTSCTask(ctx, *curUser, taskID); err != nil {
if ok, _, err = canAccessNTSCTask(ctx, *curUser, taskID); err != nil {
return err
} else if !ok {
return api.NotFoundErrs("allocation", allocationID, true)
Expand Down Expand Up @@ -196,7 +206,7 @@ func (a *apiServer) canEditAllocation(ctx context.Context, allocationID string)
}
if !isExp {
var ok bool
if ok, err = canAccessNTSCTask(ctx, *curUser, taskID); err != nil {
if ok, _, err = canAccessNTSCTask(ctx, *curUser, taskID); err != nil {
return err
} else if !ok {
return api.NotFoundErrs("allocation", allocationID, true)
Expand Down Expand Up @@ -460,8 +470,9 @@ func (a *apiServer) PostTaskLogs(
}
taskID := req.Logs[0].TaskId

if err := a.canDoActionsOnTask(ctx, model.TaskID(taskID),
expauth.AuthZProvider.Get().CanEditExperiment); err != nil {
workspaceID, expID, err := a.canDoActionsOnTask(ctx, model.TaskID(taskID),
expauth.AuthZProvider.Get().CanEditExperiment)
if err != nil {
return nil, err
}

Expand All @@ -486,7 +497,7 @@ func (a *apiServer) PostTaskLogs(
return nil, fmt.Errorf("adding task logs to task log backend: %w", err)
}

switch err := webhooks.ScanLogs(ctx, logs); {
switch err := webhooks.ScanLogs(ctx, logs, *workspaceID, expID); {
case err != nil && errors.Is(err, context.Canceled):
return nil, err
case err != nil:
Expand Down Expand Up @@ -580,7 +591,7 @@ func (a *apiServer) GetTasks(
}

if !isExp {
_, err = canAccessNTSCTask(ctx, *curUser, summary[allocationID].TaskID)
_, _, err = canAccessNTSCTask(ctx, *curUser, summary[allocationID].TaskID)
} else {
err = expauth.AuthZProvider.Get().CanGetExperiment(ctx, *curUser, exp)
}
Expand Down Expand Up @@ -612,7 +623,7 @@ func (a *apiServer) taskLogs(
var timeSinceLastAuth time.Time
fetch := func(r api.BatchRequest) (api.Batch, error) {
if time.Since(timeSinceLastAuth) >= recheckAuthPeriod {
if err = a.canDoActionsOnTask(ctx, taskID,
if _, _, err = a.canDoActionsOnTask(ctx, taskID,
expauth.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil {
return nil, err
}
Expand Down Expand Up @@ -735,7 +746,7 @@ func (a *apiServer) TaskLogsFields(
var timeSinceLastAuth time.Time
fetch := func(lr api.BatchRequest) (api.Batch, error) {
if time.Since(timeSinceLastAuth) >= recheckAuthPeriod {
if err := a.canDoActionsOnTask(resp.Context(), taskID,
if _, _, err := a.canDoActionsOnTask(resp.Context(), taskID,
expauth.AuthZProvider.Get().CanGetExperimentArtifacts); err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion master/internal/api_trials.go
Original file line number Diff line number Diff line change
Expand Up @@ -1547,7 +1547,7 @@ func (a *apiServer) ReportTrialValidationMetrics(
func (a *apiServer) ReportCheckpoint(
ctx context.Context, req *apiv1.ReportCheckpointRequest,
) (*apiv1.ReportCheckpointResponse, error) {
if err := a.canDoActionsOnTask(ctx, model.TaskID(req.Checkpoint.TaskId),
if _, _, err := a.canDoActionsOnTask(ctx, model.TaskID(req.Checkpoint.TaskId),
experiment.AuthZProvider.Get().CanEditExperiment); err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion master/internal/core_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (m *Master) getTasks(c echo.Context) (interface{}, error) {
}

if !isExp {
_, err = canAccessNTSCTask(ctx, curUser, summary[allocationID].TaskID)
_, _, err = canAccessNTSCTask(ctx, curUser, summary[allocationID].TaskID)
} else {
err = expauth.AuthZProvider.Get().CanGetExperiment(ctx, curUser, exp)
}
Expand Down
4 changes: 4 additions & 0 deletions master/internal/db/postgres_test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ type MockExperimentParams struct {
ProjectID *int
ExternalExperimentID *string
State *model.State
Integrations *expconf.IntegrationsConfigV0
}

// RequireMockExperimentParams returns a mock experiment with various parameters.
Expand Down Expand Up @@ -379,6 +380,9 @@ func RequireMockExperimentParams(
}
}
}
if p.Integrations != nil {
notDefaulted.RawIntegrations = p.Integrations
}

cfg := schemas.WithDefaults(notDefaulted)

Expand Down
Loading

0 comments on commit 54b6165

Please sign in to comment.