Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agent): cleaning up models that fail to load #5857

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions scheduler/pkg/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ func NewClient(
drainerService interfaces.DependencyServiceInterface,
metrics metrics.AgentMetricsHandler,
) *Client {

opts := []grpc.CallOption{
grpc.MaxCallSendMsgSize(math.MaxInt32),
grpc.MaxCallRecvMsgSize(math.MaxInt32),
Expand Down Expand Up @@ -283,7 +282,7 @@ func (c *Client) WaitReadySubServices(isStartup bool) error {
logger.WithError(err).Errorf("Rclone not ready")
}

//TODO make retry configurable
// TODO make retry configurable
err := backoff.RetryNotify(c.ModelRepository.Ready, backoffWithMax, logFailure)
if err != nil {
logger.WithError(err).Error("Failed to wait for model repository to be ready")
Expand Down Expand Up @@ -439,7 +438,7 @@ func (c *Client) StartService() error {
AvailableMemoryBytes: c.stateManager.GetAvailableMemoryBytesWithOverCommit(),
},
grpc_retry.WithMax(1),
) //TODO make configurable
) // TODO make configurable
if err != nil {
return err
}
Expand Down Expand Up @@ -560,7 +559,7 @@ func (c *Client) getArtifactConfig(request *agent.ModelOperationMessage) ([]byte

func (c *Client) LoadModel(request *agent.ModelOperationMessage) error {
if request == nil || request.ModelVersion == nil {
return fmt.Errorf("Empty request received for load model")
return fmt.Errorf("empty request received for load model")
}

logger := c.logger.WithField("func", "LoadModel")
Expand Down Expand Up @@ -591,6 +590,7 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage) error {
)
if err != nil {
c.sendModelEventError(modelName, modelVersion, agent.ModelEventMessage_LOAD_FAILED, err)
c.cleanup(modelWithVersion)
return err
}
logger.Infof("Chose path %s for model %s:%d", *chosenVersionPath, modelName, modelVersion)
Expand All @@ -606,6 +606,7 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage) error {
}
if err := backoffWithMaxNumRetry(loaderFn, c.settings.maxLoadRetryCount, c.settings.maxLoadElapsedTime, logger); err != nil {
c.sendModelEventError(modelName, modelVersion, agent.ModelEventMessage_LOAD_FAILED, err)
c.cleanup(modelWithVersion)
return err
}

Expand Down Expand Up @@ -672,6 +673,16 @@ func (c *Client) UnloadModel(request *agent.ModelOperationMessage) error {
return c.sendAgentEvent(modelName, modelVersion, agent.ModelEventMessage_UNLOADED)
}

func (c *Client) cleanup(modelWithVersion string) {
logger := c.logger.WithField("func", "cleanup")
err := c.ModelRepository.RemoveModelVersion(modelWithVersion)
if err != nil {
logger.Errorf("could not remove model %s - %v", modelWithVersion, err)
return
}
logger.Infof("removed model %s", modelWithVersion)
}

func (c *Client) sendModelEventError(
modelName string,
modelVersion uint32,
Expand Down
47 changes: 30 additions & 17 deletions scheduler/pkg/agent/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,26 @@ type mockAgentV2Server struct {
}

type FakeModelRepository struct {
err error
err error
modelRemovals int
modelDownloads int
}

func (f FakeModelRepository) RemoveModelVersion(modelName string) error {
func (f *FakeModelRepository) RemoveModelVersion(modelName string) error {
f.modelRemovals++
return nil
}

func (f FakeModelRepository) DownloadModelVersion(modelName string, version uint32, modelSpec *pbs.ModelSpec, config []byte) (*string, error) {
func (f *FakeModelRepository) DownloadModelVersion(modelName string, version uint32, modelSpec *pbs.ModelSpec, config []byte) (*string, error) {
f.modelDownloads++
if f.err != nil {
return nil, f.err
}
path := "path"
return &path, nil
}

func (f FakeModelRepository) Ready() error {
func (f *FakeModelRepository) Ready() error {
return f.err
}

Expand Down Expand Up @@ -197,7 +201,7 @@ func TestClientCreate(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
v2Client := createTestV2Client(addVerionToModels(test.models, 0), test.v2Status)
httpmock.ActivateNonDefault(v2Client.(*testing_utils.V2RestClientForTest).HttpClient)
modelRepository := FakeModelRepository{err: test.modelRepoErr}
modelRepository := &FakeModelRepository{err: test.modelRepoErr}
rpHTTP := FakeDependencyService{err: nil}
rpGRPC := FakeDependencyService{err: nil}
agentDebug := FakeDependencyService{err: nil}
Expand Down Expand Up @@ -271,7 +275,8 @@ func TestLoadModel(t *testing.T) {
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
expectedAvailableMemory: 500,
v2Status: 200,
success: true}, // Success
success: true,
}, // Success
{
name: "simple - autoscaling enabled",
models: []string{"iris"},
Expand All @@ -291,7 +296,8 @@ func TestLoadModel(t *testing.T) {
expectedAvailableMemory: 500,
v2Status: 200,
success: true,
autoscalingEnabled: true}, // Success
autoscalingEnabled: true,
}, // Success
{
name: "V2Fail",
models: []string{"iris"},
Expand All @@ -309,7 +315,8 @@ func TestLoadModel(t *testing.T) {
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
expectedAvailableMemory: 1000,
v2Status: 400,
success: false}, // Fail as V2 fail
success: false,
}, // Fail as V2 fail
{
name: "MemoryAvailableFail",
models: []string{"iris"},
Expand All @@ -327,7 +334,8 @@ func TestLoadModel(t *testing.T) {
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
expectedAvailableMemory: 1000,
v2Status: 200,
success: false}, // Fail due to too much memory required
success: false,
}, // Fail due to too much memory required
}

for tidx, test := range tests {
Expand All @@ -337,7 +345,7 @@ func TestLoadModel(t *testing.T) {
// Set up dependencies
v2Client := createTestV2Client(addVerionToModels(test.models, 0), test.v2Status)
httpmock.ActivateNonDefault(v2Client.(*testing_utils.V2RestClientForTest).HttpClient)
modelRepository := FakeModelRepository{err: test.modelRepoErr}
modelRepository := &FakeModelRepository{err: test.modelRepoErr}
rpHTTP := FakeDependencyService{err: nil}
rpGRPC := FakeDependencyService{err: nil}
agentDebug := FakeDependencyService{err: nil}
Expand Down Expand Up @@ -392,6 +400,7 @@ func TestLoadModel(t *testing.T) {
g.Expect(mockAgentV2Server.loadedEvents).To(Equal(1))
g.Expect(mockAgentV2Server.loadFailedEvents).To(Equal(0))
g.Expect(client.stateManager.GetAvailableMemoryBytes()).To(Equal(test.expectedAvailableMemory))
g.Expect(modelRepository.modelRemovals).To(Equal(0))
loadedVersions := client.stateManager.modelVersions.getVersionsForAllModels()
// we have only one version in the test
g.Expect(proto.Clone(loadedVersions[0])).To(Equal(proto.Clone(test.op.ModelVersion)))
Expand All @@ -414,6 +423,7 @@ func TestLoadModel(t *testing.T) {
g.Expect(mockAgentV2Server.loadedEvents).To(Equal(0))
g.Expect(mockAgentV2Server.loadFailedEvents).To(Equal(1))
g.Expect(client.stateManager.GetAvailableMemoryBytes()).To(Equal(test.expectedAvailableMemory))
g.Expect(modelRepository.modelRemovals).To(Equal(1))
}
client.Stop()
httpmock.DeactivateAndReset()
Expand Down Expand Up @@ -505,7 +515,7 @@ parameters:
t.Logf("Test #%d", tidx)
v2Client := createTestV2Client(addVerionToModels(test.models, 0), test.v2Status)
httpmock.ActivateNonDefault(v2Client.(*testing_utils.V2RestClientForTest).HttpClient)
modelRepository := FakeModelRepository{}
modelRepository := &FakeModelRepository{}
rpHTTP := FakeDependencyService{err: nil}
rpGRPC := FakeDependencyService{err: nil}
agentDebug := FakeDependencyService{err: nil}
Expand Down Expand Up @@ -540,9 +550,11 @@ parameters:
g.Expect(mockAgentV2Server.loadedEvents).To(Equal(1))
g.Expect(mockAgentV2Server.loadFailedEvents).To(Equal(0))
g.Expect(client.stateManager.GetAvailableMemoryBytes()).To(Equal(test.expectedAvailableMemory))
g.Expect(modelRepository.modelRemovals).To(Equal(0))
} else {
g.Expect(err).ToNot(BeNil())
g.Expect(mockAgentV2Server.loadedEvents).To(Equal(0))
g.Expect(modelRepository.modelRemovals).To(Equal(1))
}
client.Stop()
httpmock.DeactivateAndReset()
Expand Down Expand Up @@ -596,7 +608,8 @@ func TestUnloadModel(t *testing.T) {
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
expectedAvailableMemory: 1000,
v2Status: 200,
success: true}, // Success
success: true,
}, // Success
{
name: "UnknownModel - unload ok",
models: []string{"iris"},
Expand Down Expand Up @@ -625,15 +638,16 @@ func TestUnloadModel(t *testing.T) {
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
expectedAvailableMemory: 500,
v2Status: 200,
success: true},
success: true,
},
}

for tidx, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Logf("Test #%d", tidx)
v2Client := createTestV2Client(addVerionToModels(test.models, 0), test.v2Status)
httpmock.ActivateNonDefault(v2Client.(*testing_utils.V2RestClientForTest).HttpClient)
modelRepository := FakeModelRepository{}
modelRepository := &FakeModelRepository{}
rpHTTP := FakeDependencyService{err: nil}
rpGRPC := FakeDependencyService{err: nil}
agentDebug := FakeDependencyService{err: nil}
Expand Down Expand Up @@ -705,7 +719,7 @@ func TestClientClose(t *testing.T) {
v2Client := createTestV2Client(nil, 200)
httpmock.ActivateNonDefault(v2Client.(*testing_utils.V2RestClientForTest).HttpClient)
defer httpmock.DeactivateAndReset()
modelRepository := FakeModelRepository{}
modelRepository := &FakeModelRepository{}
rpHTTP := FakeDependencyService{err: nil}
rpGRPC := FakeDependencyService{err: nil}
agentDebug := FakeDependencyService{err: nil}
Expand Down Expand Up @@ -782,7 +796,6 @@ func TestAgentStopOnSubServicesFailure(t *testing.T) {
maxTimeAfterStart := 1 * time.Millisecond
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {

mockMLServer := &testing_utils.MockGRPCMLServer{}
backEndGRPCPort, err := testing_utils2.GetFreePortForTest()
if err != nil {
Expand All @@ -798,7 +811,7 @@ func TestAgentStopOnSubServicesFailure(t *testing.T) {
v2Client := oip.NewV2Client(
oip.GetV2ConfigWithDefaults("", backEndGRPCPort), log.New())

modelRepository := FakeModelRepository{}
modelRepository := &FakeModelRepository{}
rpHTTP := FakeDependencyService{err: nil}
rpGRPC := FakeDependencyService{err: nil}
agentDebug := FakeDependencyService{err: nil}
Expand Down
Loading