diff --git a/master/internal/rm/agentrm/agent_state.go b/master/internal/rm/agentrm/agent_state.go index f8060f9a4bb..68d7da0aed6 100644 --- a/master/internal/rm/agentrm/agent_state.go +++ b/master/internal/rm/agentrm/agent_state.go @@ -645,8 +645,13 @@ func (a *agentState) restoreContainersField() error { } func clearAgentStates(agentIds []agentID) error { - _, err := db.Bun().NewDelete().Model((*agentSnapshot)(nil)).Where("agent_id in (?)", agentIds).Exec(context.TODO()) - return fmt.Errorf("clearing agent states: %w", err) + if _, err := db.Bun().NewDelete().Model((*agentSnapshot)(nil)). + Where("agent_id in (?)", bun.In(agentIds)). + Exec(context.TODO()); err != nil { + return fmt.Errorf("clearing agent states: %w", err) + } + + return nil } func updateContainerState(c *cproto.Container) error { diff --git a/master/internal/rm/agentrm/agent_state_test.go b/master/internal/rm/agentrm/agent_state_test.go index c38e6e60d92..164e3523496 100644 --- a/master/internal/rm/agentrm/agent_state_test.go +++ b/master/internal/rm/agentrm/agent_state_test.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + "github.com/uptrace/bun" "github.com/determined-ai/determined/master/internal/db" "github.com/determined-ai/determined/master/internal/sproto" @@ -184,6 +185,28 @@ func TestAgentStatePersistence(t *testing.T) { require.False(t, exists) } +func TestClearAgentStates(t *testing.T) { + ctx := context.Background() + agentIDs := []agentID{agentID(uuid.NewString()), agentID(uuid.NewString())} + for _, agentID := range agentIDs { + _, err := db.Bun().NewInsert().Model(&agentSnapshot{ + AgentID: agentID, + UUID: uuid.NewString(), + ResourcePoolName: "rp-name", + Label: "label", + MaxZeroSlotContainers: 0, + }).Exec(ctx) + require.NoError(t, err) + } + + require.NoError(t, clearAgentStates(agentIDs)) + exists, err := db.Bun().NewSelect().Model(&agentSnapshot{}). + Where("agent_id IN (?)", bun.In(agentIDs)). + Exists(ctx) + require.NoError(t, err) + require.False(t, exists) +} + func Test_agentState_checkAgentStartedDevicesMatch(t *testing.T) { stableUUID := uuid.NewString() tests := []struct {