Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix modelsave #52

Merged
merged 4 commits into from
Apr 19, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dlk/dlkmanager/learningTask.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,12 @@ func (lt *learningTask) run() {
}

case <-time.After(1 * time.Second):
lt.pollJobs()
err := lt.checkPodStatus(podState)
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
}
lt.pollJobs()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the motivation of swapping pollJobs() and checkPodStatus()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If pollJobs is called before checkPodStatus, when logs of pod updated and the pod completed after pollJobs was called, the last logs won't be collected by dlk.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understood, thanks!


if lt.nrCompletedWorkers == lt.ltc.NrWorker {
state = ltStateCompleted
Expand Down
90 changes: 51 additions & 39 deletions manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,54 @@ type server struct {
StudyChList map[string]studyCh
}

func (s *server) saveCompletedModels(studyId string, conf *pb.StudyConfig) error {
ret, err := s.GetSavedModels(context.Background(), &pb.GetSavedModelsRequest{StudyName: conf.Name})
if err != nil {
log.Printf("GetSavedModels Err %v", err)
return err
}
ts, err := dbIf.GetTrialList(studyId)
if err != nil {
log.Printf("GetTrials Err %v", err)
return err
}
for _, t := range ts {
tid := t.TrialId
tst, err := dbIf.GetTrialStatus(tid)
if err != nil {
log.Printf("GetTrialStatus Err %v", err)
continue
}
isin := false
if tst == pb.TrialState_COMPLETED {
for _, m := range ret.Models {
if m.TrialId == tid {
isin = true
break
}
}
if !isin {
met := make([]*pb.Metrics, len(conf.Metrics))
for i, mn := range conf.Metrics {
l, _ := dbIf.GetTrialLogs(tid, &kdb.GetTrialLogOpts{Name: mn})
met[i] = &pb.Metrics{Name: mn, Value: l[len(l)-1].Value}
}
t, _ := dbIf.GetTrial(tid)
s.SaveModel(context.Background(), &pb.SaveModelRequest{
Model: &pb.ModelInfo{
StudyName: conf.Name,
TrialId: tid,
Parameters: t.ParameterSet,
Metrics: met,
},
})
log.Printf("Trial %v in Study %v is saved", tid, conf.Name)
}
}
}
return nil
}

func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh studyCh) error {
defer delete(s.StudyChList, study_id)
defer s.wIF.CleanWorkers(study_id)
Expand Down Expand Up @@ -85,7 +133,7 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study
}
if r.Completed {
log.Printf("Study %v completed.", study_id)
return nil
return s.saveCompletedModels(study_id, conf)
} else if len(r.Trials) > 0 {
for _, trial := range r.Trials {
trial.Status = pb.TrialState_PENDING
Expand All @@ -112,43 +160,7 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study
tm.Reset(1 * time.Second)
}
case <-strtm.C:
ret, err := s.GetSavedModels(context.Background(), &pb.GetSavedModelsRequest{StudyName: conf.Name})
if err != nil {
log.Printf("GetSavedModels Err %v", err)
}
ts, err := dbIf.GetTrialList(study_id)
if err != nil {
log.Printf("GetTrials Err %v", err)
}
for _, t := range ts {
tid := t.TrialId
tst, err := dbIf.GetTrialStatus(tid)
if err != nil {
log.Printf("GetTrialStatus Err %v", err)
continue
}
if tst == pb.TrialState_COMPLETED {
for _, m := range ret.Models {
if m.TrialId == tid {
met := make([]*pb.Metrics, len(conf.Metrics))
for i, mn := range conf.Metrics {
l, _ := dbIf.GetTrialLogs(tid, &kdb.GetTrialLogOpts{Name: mn})
met[i] = &pb.Metrics{Name: mn, Value: l[len(l)-1].Value}
}
t, _ := dbIf.GetTrial(tid)
s.SaveModel(context.Background(), &pb.SaveModelRequest{
Model: &pb.ModelInfo{
StudyName: conf.Name,
TrialId: tid,
Parameters: t.ParameterSet,
Metrics: met,
},
})
break
}
}
}
}
s.saveCompletedModels(study_id, conf)
strtm.Reset(defaultSaveInterval * time.Second)

case <-estm.C:
Expand All @@ -168,7 +180,7 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study
for _, t := range s.wIF.GetRunningTrials(study_id) {
t.Status = pb.TrialState_KILLED
}
return nil
return s.saveCompletedModels(study_id, conf)
case m := <-sCh.addMetricsCh:
conf.Metrics = append(conf.Metrics, m)
}
Expand Down