Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ci setup #22

Merged
merged 1 commit into from
Apr 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 134 additions & 134 deletions api/api.pb.go

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions api/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package api;
service Manager {
rpc CreateStudy(CreateStudyRequest) returns (CreateStudyReply);
rpc StopStudy(StopStudyRequest) returns (StopStudyReply);
rpc GetStudys(GetStudysRequest) returns (GetStudysReply);
rpc GetStudies(GetStudiesRequest) returns (GetStudiesReply);
rpc SuggestTrials(SuggestTrialsRequest) returns (SuggestTrialsReply);
rpc CompleteTrial(CompleteTrialRequest) returns (CompleteTrialReply);
rpc ShouldTrialStop(ShouldTrialStopRequest) returns (ShouldTrialStopReply);
Expand Down Expand Up @@ -145,7 +145,7 @@ message StopStudyRequest {
message StopStudyReply {
}

message GetStudysRequest {
message GetStudiesRequest {
}

message StudyInfo {
Expand All @@ -156,7 +156,7 @@ message StudyInfo {
int32 completed_trial_num = 5;
}

message GetStudysReply {
message GetStudiesReply {
repeated StudyInfo study_infos= 1;
}

Expand Down
4 changes: 2 additions & 2 deletions cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ func (m *ManagerAPI) Stopstudy(conn *grpc.ClientConn, args []string) {

func (m *ManagerAPI) Getstudies(conn *grpc.ClientConn, args []string) {
c := pb.NewManagerClient(conn)
req := &pb.GetStudysRequest{}
r, err := c.GetStudys(context.Background(), req)
req := &pb.GetStudiesRequest{}
r, err := c.GetStudies(context.Background(), req)
if err != nil {
log.Fatalf("GetStudy failed: %v", err)
}
Expand Down
44 changes: 32 additions & 12 deletions db/interface_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package db

import (
"database/sql"
"fmt"
"github.com/golang/protobuf/jsonpb"
"math/rand"
Expand All @@ -15,17 +14,27 @@ import (
api "github.com/kubeflow/hp-tuning/api"

_ "github.com/go-sql-driver/mysql"
"gopkg.in/DATA-DOG/go-sqlmock.v1"
)

var db_interface VizierDBInterface
var mock sqlmock.Sqlmock

func TestMain(m *testing.M) {
db, err := sql.Open("mysql", "root:test123@tcp(localhost:3306)/vizier")
// db, err := sql.Open("mysql", "root:test123@tcp(localhost:3306)/vizier")
db, sm, err := sqlmock.New()
mock = sm
if err != nil {
fmt.Printf("error opening db: %v\n", err)
os.Exit(1)
}
//mock.ExpectBegin()
db_interface = NewWithSqlConn(db)
mock.ExpectExec("CREATE TABLE IF NOT EXISTS studies").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS study_permissions").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS trials").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS trial_logs").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS workers").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
db_interface.DB_Init()

os.Exit(m.Run())
Expand All @@ -40,13 +49,18 @@ func TestGetStudyConfig(t *testing.T) {
t.Errorf("err %v", err)
}

mock.ExpectExec("INSERT INTO studies VALUES").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
id, err := db_interface.CreateStudy(&in)
if err != nil {
t.Fatalf("CreateStudy error %v", err)
t.Errorf("CreateStudy error %v", err)
}
// mock.ExpectExec("SELECT * FROM studies WHERE id").WithArgs(id).WillReturnRows(sqlmock.NewRows())
mock.ExpectQuery("SELECT").WillReturnRows(
sqlmock.NewRows([]string{"id", "name", "owner", "optimization_type", "optimization_goal", "parameter_configs", "suggest_algo", "autostop_algo", "study_task_name", "suggestion_parameters", "tags", "objective_value_name", "metrics", "image", "command", "gpu", "scheduler", "mount", "pull_secret"}).
AddRow("abc", "test", "admin", 1, 0.99, "{}", "random", "test", "", "", "", "", "", "", "", 1, "", "", ""))
study, err := db_interface.GetStudyConfig(id)
if err != nil {
t.Fatalf("GetStudyConfig failed: %v", err)
t.Errorf("GetStudyConfig failed: %v", err)
}
fmt.Printf("%v", study)
// TODO: check study data
Expand All @@ -58,19 +72,25 @@ func TestCreateStudyIdGeneration(t *testing.T) {

var ids []string
for i := 0; i < 4; i++ {
rand.Seed(1)
rand.Seed(int64(i))
mock.ExpectExec("INSERT INTO studies VALUES").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
id, err := db_interface.CreateStudy(&in)
if i < 3 {
if err != nil {
t.Errorf("CreateStudy error %v", err)
}
ids = append(ids, id)
} else if err == nil {
t.Fatal("Expected error but succeeded")
if err != nil {
t.Errorf("CreateStudy error %v", err)
}
ids = append(ids, id)
t.Logf("id gen %d %s %v\n", i, id, err)
}
encountered := map[string]bool{}
for i := 0; i < len(ids); i++ {
if !encountered[ids[i]] {
encountered[ids[i]] = true
} else {
t.Fatalf("Study ID duplicated %v", ids)
}
}
for _, id := range ids {
mock.ExpectExec("DELETE").WillReturnResult(sqlmock.NewResult(1, 1))
err := db_interface.DeleteStudy(id)
if err != nil {
t.Errorf("DeleteStudy error %v", err)
Expand Down
89 changes: 47 additions & 42 deletions manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,43 +98,45 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study
for {
select {
case <-tm.C:
err := s.wIF.CheckRunningTrials(study_id, conf.ObjectiveValueName, conf.Metrics)
if err != nil {
return err
}
r, err := s.SuggestTrials(context.Background(), &pb.SuggestTrialsRequest{StudyId: study_id, SuggestAlgorithm: conf.SuggestAlgorithm, Configs: conf})
if err != nil {
log.Printf("SuggestTrials failed %v", err)
return err
}
if r.Completed {
log.Printf("Study %v completed.", study_id)
//s.saveResult(study_id)
return nil
} else if len(r.Trials) > 0 {
for _, trial := range r.Trials {
trial.Status = pb.TrialState_PENDING
trial.StudyId = study_id
err = dbIf.CreateTrial(trial)
if err != nil {
log.Printf("CreateTrial failed %v", err)
return err
}
if conf.SuggestAlgorithm != "" {
err := s.wIF.CheckRunningTrials(study_id, conf.ObjectiveValueName, conf.Metrics)
if err != nil {
return err
}
err = s.wIF.SpawnWorkers(r.Trials, study_id)
r, err := s.SuggestTrials(context.Background(), &pb.SuggestTrialsRequest{StudyId: study_id, SuggestAlgorithm: conf.SuggestAlgorithm, Configs: conf})
if err != nil {
log.Printf("SpawnWorkers failed %v", err)
log.Printf("SuggestTrials failed %v", err)
return err
}
for _, t := range r.Trials {
err = tbif.SpawnTensorBoard(study_id, t.TrialId, k8s_namespace, conf.Mount)
if r.Completed {
log.Printf("Study %v completed.", study_id)
//s.saveResult(study_id)
return nil
} else if len(r.Trials) > 0 {
for _, trial := range r.Trials {
trial.Status = pb.TrialState_PENDING
trial.StudyId = study_id
err = dbIf.CreateTrial(trial)
if err != nil {
log.Printf("CreateTrial failed %v", err)
return err
}
}
err = s.wIF.SpawnWorkers(r.Trials, study_id)
if err != nil {
log.Printf("SpawnTB failed %v", err)
log.Printf("SpawnWorkers failed %v", err)
return err
}
for _, t := range r.Trials {
err = tbif.SpawnTensorBoard(study_id, t.TrialId, k8s_namespace, conf.Mount)
if err != nil {
log.Printf("SpawnTB failed %v", err)
return err
}
}
}
tm.Reset(1 * time.Second)
}
tm.Reset(1 * time.Second)
case <-sCh.stopCh:
log.Printf("Study %v is stopped.", study_id)
for _, t := range s.wIF.GetRunningTrials(study_id) {
Expand All @@ -154,18 +156,21 @@ func (s *server) CreateStudy(ctx context.Context, in *pb.CreateStudyRequest) (*p
}

study_id, err := dbIf.CreateStudy(in.StudyConfig)

_, err = s.InitializeSuggestService(
ctx,
&pb.InitializeSuggestServiceRequest{
StudyId: study_id,
SuggestAlgorithm: in.StudyConfig.SuggestAlgorithm,
SuggestionParameters: in.StudyConfig.SuggestionParameters,
Configs: in.StudyConfig,
},
)
if err != nil {
return &pb.CreateStudyReply{}, err
if in.StudyConfig.SuggestAlgorithm != "" {
_, err = s.InitializeSuggestService(
ctx,
&pb.InitializeSuggestServiceRequest{
StudyId: study_id,
SuggestAlgorithm: in.StudyConfig.SuggestAlgorithm,
SuggestionParameters: in.StudyConfig.SuggestionParameters,
Configs: in.StudyConfig,
},
)
if err != nil {
return &pb.CreateStudyReply{}, err
}
} else {
log.Printf("Suggestion Algorithm is not set.")
}
sCh := studyCh{stopCh: make(chan bool), addMetricsCh: make(chan string)}
go s.trialIteration(in.StudyConfig, study_id, sCh)
Expand Down Expand Up @@ -206,7 +211,7 @@ func spawn_worker(study_task string, params string) error {
return err
}

func (s *server) GetStudys(ctx context.Context, in *pb.GetStudysRequest) (*pb.GetStudysReply, error) {
func (s *server) GetStudies(ctx context.Context, in *pb.GetStudiesRequest) (*pb.GetStudiesReply, error) {
ss := make([]*pb.StudyInfo, len(s.StudyChList))
i := 0
for sid := range s.StudyChList {
Expand All @@ -220,7 +225,7 @@ func (s *server) GetStudys(ctx context.Context, in *pb.GetStudysRequest) (*pb.Ge
}
i++
}
return &pb.GetStudysReply{StudyInfos: ss}, nil
return &pb.GetStudiesReply{StudyInfos: ss}, nil
}

func (s *server) InitializeSuggestService(ctx context.Context, in *pb.InitializeSuggestServiceRequest) (*pb.InitializeSuggestServiceReply, error) {
Expand Down
113 changes: 113 additions & 0 deletions manager/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package main

import (
"context"
"github.com/golang/mock/gomock"
api "github.com/kubeflow/hp-tuning/api"
//"github.com/kubeflow/hp-tuning/mock/mock_api"
"github.com/kubeflow/hp-tuning/mock/mock_db"
"github.com/kubeflow/hp-tuning/mock/mock_worker_interface"
"testing"
)

func TestCreateStudy(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockDB := mock_db.NewMockVizierDBInterface(ctrl)
mockWif := mock_worker_interface.NewMockWorkerInterface(ctrl)
sid := "teststudy"
sc := &api.StudyConfig{
Name: "test",
Owner: "admin",
OptimizationType: 1,
ObjectiveValueName: "obj_name",
Gpu: 1,
}
dbIf = mockDB
mockDB.EXPECT().CreateStudy(
sc,
).Return(sid, nil)
s := &server{wIF: mockWif, StudyChList: make(map[string]studyCh)}
req := &api.CreateStudyRequest{StudyConfig: sc}
ret, err := s.CreateStudy(context.Background(), req)
if err != nil {
t.Fatalf("CreateStudy Error %v", err)
}
if ret.StudyId != sid {
t.Fatalf("Study ID expect "+sid+", get %s", ret.StudyId)
}
if len(s.StudyChList) != 1 {
t.Fatalf("Study register failed. Registered number is %d", len(s.StudyChList))
} else {
_, ok := s.StudyChList[sid]
if !ok {
t.Fatalf("Study %s is failed to register.", sid)
}
}
}
func TestGetStudies(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockDB := mock_db.NewMockVizierDBInterface(ctrl)
mockWif := mock_worker_interface.NewMockWorkerInterface(ctrl)
sid := []string{"teststudy1", "teststudy2"}
s := &server{wIF: mockWif, StudyChList: map[string]studyCh{sid[0]: studyCh{}, sid[1]: studyCh{}}}
dbIf = mockDB

sc := []*api.StudyConfig{
&api.StudyConfig{
Name: "test1",
Owner: "admin",
OptimizationType: 1,
ObjectiveValueName: "obj_name1",
Gpu: 1,
},
&api.StudyConfig{
Name: "test2",
Owner: "admin",
OptimizationType: 1,
ObjectiveValueName: "obj_name2",
},
}
rts := []int32{10, 20}
cts := []int32{5, 1}
for i := range sid {
mockDB.EXPECT().GetStudyConfig(sid[i]).Return(sc[i], nil)
mockWif.EXPECT().GetRunningTrials(sid[i]).Return(make([]*api.Trial, rts[i]))
mockWif.EXPECT().GetCompletedTrials(sid[i]).Return(make([]*api.Trial, cts[i]))
}

req := &api.GetStudiesRequest{}
ret, err := s.GetStudies(context.Background(), req)
if err != nil {
t.Fatalf("CreateStudy Error %v", err)
}
if len(ret.StudyInfos) != len(sid) {
t.Fatalf("Study Info number %d, expected%d", len(ret.StudyInfos), len(sid))
} else {
var j int
for i := range sid {
switch ret.StudyInfos[i].StudyId {
case sid[0]:
j = 0
case sid[1]:
j = 1
default:
t.Fatalf("GetStudy Error Study ID %s is not expected", ret.StudyInfos[j].StudyId)
}
if ret.StudyInfos[i].Name != sc[j].Name {
t.Fatalf("GetStudy Error Name %s expected %s", ret.StudyInfos[i].Name, sc[j].Name)
}
if ret.StudyInfos[i].Owner != sc[j].Owner {
t.Fatalf("GetStudy Error Owner %s expected %s", ret.StudyInfos[i].Owner, sc[j].Owner)
}
if ret.StudyInfos[i].RunningTrialNum != rts[j] {
t.Fatalf("GetStudy Error RunningTrialNum %d expected %d", ret.StudyInfos[i].RunningTrialNum, rts[j])
}
if ret.StudyInfos[i].CompletedTrialNum != cts[j] {
t.Fatalf("GetStudy Error CompletedTrialNum %d expected %d", ret.StudyInfos[i].CompletedTrialNum, cts[j])
}

}
}
}
Loading