Skip to content

Commit

Permalink
add test script and argo files
Browse files Browse the repository at this point in the history
Signed-off-by: YujiOshima <[email protected]>
  • Loading branch information
YujiOshima committed Apr 9, 2018
1 parent cfbfbc1 commit 02d3d17
Show file tree
Hide file tree
Showing 37 changed files with 228,250 additions and 257 deletions.
268 changes: 134 additions & 134 deletions api/api.pb.go

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions api/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package api;
service Manager {
rpc CreateStudy(CreateStudyRequest) returns (CreateStudyReply);
rpc StopStudy(StopStudyRequest) returns (StopStudyReply);
rpc GetStudys(GetStudysRequest) returns (GetStudysReply);
rpc GetStudies(GetStudiesRequest) returns (GetStudiesReply);
rpc SuggestTrials(SuggestTrialsRequest) returns (SuggestTrialsReply);
rpc CompleteTrial(CompleteTrialRequest) returns (CompleteTrialReply);
rpc ShouldTrialStop(ShouldTrialStopRequest) returns (ShouldTrialStopReply);
Expand Down Expand Up @@ -145,7 +145,7 @@ message StopStudyRequest {
message StopStudyReply {
}

message GetStudysRequest {
message GetStudiesRequest {
}

message StudyInfo {
Expand All @@ -156,7 +156,7 @@ message StudyInfo {
int32 completed_trial_num = 5;
}

message GetStudysReply {
message GetStudiesReply {
repeated StudyInfo study_infos= 1;
}

Expand Down
4 changes: 2 additions & 2 deletions cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ func (m *ManagerAPI) Stopstudy(conn *grpc.ClientConn, args []string) {

func (m *ManagerAPI) Getstudies(conn *grpc.ClientConn, args []string) {
c := pb.NewManagerClient(conn)
req := &pb.GetStudysRequest{}
r, err := c.GetStudys(context.Background(), req)
req := &pb.GetStudiesRequest{}
r, err := c.GetStudies(context.Background(), req)
if err != nil {
log.Fatalf("GetStudy failed: %v", err)
}
Expand Down
44 changes: 32 additions & 12 deletions db/interface_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package db

import (
"database/sql"
"fmt"
"github.com/golang/protobuf/jsonpb"
"math/rand"
Expand All @@ -15,17 +14,27 @@ import (
api "github.com/kubeflow/hp-tuning/api"

_ "github.com/go-sql-driver/mysql"
"gopkg.in/DATA-DOG/go-sqlmock.v1"
)

var db_interface VizierDBInterface
var mock sqlmock.Sqlmock

func TestMain(m *testing.M) {
db, err := sql.Open("mysql", "root:test123@tcp(localhost:3306)/vizier")
// db, err := sql.Open("mysql", "root:test123@tcp(localhost:3306)/vizier")
db, sm, err := sqlmock.New()
mock = sm
if err != nil {
fmt.Printf("error opening db: %v\n", err)
os.Exit(1)
}
//mock.ExpectBegin()
db_interface = NewWithSqlConn(db)
mock.ExpectExec("CREATE TABLE IF NOT EXISTS studies").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS study_permissions").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS trials").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS trial_logs").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS workers").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
db_interface.DB_Init()

os.Exit(m.Run())
Expand All @@ -40,13 +49,18 @@ func TestGetStudyConfig(t *testing.T) {
t.Errorf("err %v", err)
}

mock.ExpectExec("INSERT INTO studies VALUES").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
id, err := db_interface.CreateStudy(&in)
if err != nil {
t.Fatalf("CreateStudy error %v", err)
t.Errorf("CreateStudy error %v", err)
}
// mock.ExpectExec("SELECT * FROM studies WHERE id").WithArgs(id).WillReturnRows(sqlmock.NewRows())
mock.ExpectQuery("SELECT").WillReturnRows(
sqlmock.NewRows([]string{"id", "name", "owner", "optimization_type", "optimization_goal", "parameter_configs", "suggest_algo", "autostop_algo", "study_task_name", "suggestion_parameters", "tags", "objective_value_name", "metrics", "image", "command", "gpu", "scheduler", "mount", "pull_secret"}).
AddRow("abc", "test", "admin", 1, 0.99, "{}", "random", "test", "", "", "", "", "", "", "", 1, "", "", ""))
study, err := db_interface.GetStudyConfig(id)
if err != nil {
t.Fatalf("GetStudyConfig failed: %v", err)
t.Errorf("GetStudyConfig failed: %v", err)
}
fmt.Printf("%v", study)
// TODO: check study data
Expand All @@ -58,19 +72,25 @@ func TestCreateStudyIdGeneration(t *testing.T) {

var ids []string
for i := 0; i < 4; i++ {
rand.Seed(1)
rand.Seed(int64(i))
mock.ExpectExec("INSERT INTO studies VALUES").WithArgs().WillReturnResult(sqlmock.NewResult(1, 1))
id, err := db_interface.CreateStudy(&in)
if i < 3 {
if err != nil {
t.Errorf("CreateStudy error %v", err)
}
ids = append(ids, id)
} else if err == nil {
t.Fatal("Expected error but succeeded")
if err != nil {
t.Errorf("CreateStudy error %v", err)
}
ids = append(ids, id)
t.Logf("id gen %d %s %v\n", i, id, err)
}
encountered := map[string]bool{}
for i := 0; i < len(ids); i++ {
if !encountered[ids[i]] {
encountered[ids[i]] = true
} else {
t.Fatalf("Study ID duplicated %v", ids)
}
}
for _, id := range ids {
mock.ExpectExec("DELETE").WillReturnResult(sqlmock.NewResult(1, 1))
err := db_interface.DeleteStudy(id)
if err != nil {
t.Errorf("DeleteStudy error %v", err)
Expand Down
89 changes: 47 additions & 42 deletions manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,43 +98,45 @@ func (s *server) trialIteration(conf *pb.StudyConfig, study_id string, sCh study
for {
select {
case <-tm.C:
err := s.wIF.CheckRunningTrials(study_id, conf.ObjectiveValueName, conf.Metrics)
if err != nil {
return err
}
r, err := s.SuggestTrials(context.Background(), &pb.SuggestTrialsRequest{StudyId: study_id, SuggestAlgorithm: conf.SuggestAlgorithm, Configs: conf})
if err != nil {
log.Printf("SuggestTrials failed %v", err)
return err
}
if r.Completed {
log.Printf("Study %v completed.", study_id)
//s.saveResult(study_id)
return nil
} else if len(r.Trials) > 0 {
for _, trial := range r.Trials {
trial.Status = pb.TrialState_PENDING
trial.StudyId = study_id
err = dbIf.CreateTrial(trial)
if err != nil {
log.Printf("CreateTrial failed %v", err)
return err
}
if conf.SuggestAlgorithm != "" {
err := s.wIF.CheckRunningTrials(study_id, conf.ObjectiveValueName, conf.Metrics)
if err != nil {
return err
}
err = s.wIF.SpawnWorkers(r.Trials, study_id)
r, err := s.SuggestTrials(context.Background(), &pb.SuggestTrialsRequest{StudyId: study_id, SuggestAlgorithm: conf.SuggestAlgorithm, Configs: conf})
if err != nil {
log.Printf("SpawnWorkers failed %v", err)
log.Printf("SuggestTrials failed %v", err)
return err
}
for _, t := range r.Trials {
err = tbif.SpawnTensorBoard(study_id, t.TrialId, k8s_namespace, conf.Mount)
if r.Completed {
log.Printf("Study %v completed.", study_id)
//s.saveResult(study_id)
return nil
} else if len(r.Trials) > 0 {
for _, trial := range r.Trials {
trial.Status = pb.TrialState_PENDING
trial.StudyId = study_id
err = dbIf.CreateTrial(trial)
if err != nil {
log.Printf("CreateTrial failed %v", err)
return err
}
}
err = s.wIF.SpawnWorkers(r.Trials, study_id)
if err != nil {
log.Printf("SpawnTB failed %v", err)
log.Printf("SpawnWorkers failed %v", err)
return err
}
for _, t := range r.Trials {
err = tbif.SpawnTensorBoard(study_id, t.TrialId, k8s_namespace, conf.Mount)
if err != nil {
log.Printf("SpawnTB failed %v", err)
return err
}
}
}
tm.Reset(1 * time.Second)
}
tm.Reset(1 * time.Second)
case <-sCh.stopCh:
log.Printf("Study %v is stopped.", study_id)
for _, t := range s.wIF.GetRunningTrials(study_id) {
Expand All @@ -154,18 +156,21 @@ func (s *server) CreateStudy(ctx context.Context, in *pb.CreateStudyRequest) (*p
}

study_id, err := dbIf.CreateStudy(in.StudyConfig)

_, err = s.InitializeSuggestService(
ctx,
&pb.InitializeSuggestServiceRequest{
StudyId: study_id,
SuggestAlgorithm: in.StudyConfig.SuggestAlgorithm,
SuggestionParameters: in.StudyConfig.SuggestionParameters,
Configs: in.StudyConfig,
},
)
if err != nil {
return &pb.CreateStudyReply{}, err
if in.StudyConfig.SuggestAlgorithm != "" {
_, err = s.InitializeSuggestService(
ctx,
&pb.InitializeSuggestServiceRequest{
StudyId: study_id,
SuggestAlgorithm: in.StudyConfig.SuggestAlgorithm,
SuggestionParameters: in.StudyConfig.SuggestionParameters,
Configs: in.StudyConfig,
},
)
if err != nil {
return &pb.CreateStudyReply{}, err
}
} else {
log.Printf("Suggestion Algorithm is not set.")
}
sCh := studyCh{stopCh: make(chan bool), addMetricsCh: make(chan string)}
go s.trialIteration(in.StudyConfig, study_id, sCh)
Expand Down Expand Up @@ -206,7 +211,7 @@ func spawn_worker(study_task string, params string) error {
return err
}

func (s *server) GetStudys(ctx context.Context, in *pb.GetStudysRequest) (*pb.GetStudysReply, error) {
func (s *server) GetStudies(ctx context.Context, in *pb.GetStudiesRequest) (*pb.GetStudiesReply, error) {
ss := make([]*pb.StudyInfo, len(s.StudyChList))
i := 0
for sid := range s.StudyChList {
Expand All @@ -220,7 +225,7 @@ func (s *server) GetStudys(ctx context.Context, in *pb.GetStudysRequest) (*pb.Ge
}
i++
}
return &pb.GetStudysReply{StudyInfos: ss}, nil
return &pb.GetStudiesReply{StudyInfos: ss}, nil
}

func (s *server) InitializeSuggestService(ctx context.Context, in *pb.InitializeSuggestServiceRequest) (*pb.InitializeSuggestServiceReply, error) {
Expand Down
113 changes: 113 additions & 0 deletions manager/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package main

import (
"context"
"github.com/golang/mock/gomock"
api "github.com/kubeflow/hp-tuning/api"
//"github.com/kubeflow/hp-tuning/mock/mock_api"
"github.com/kubeflow/hp-tuning/mock/mock_db"
"github.com/kubeflow/hp-tuning/mock/mock_worker_interface"
"testing"
)

func TestCreateStudy(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockDB := mock_db.NewMockVizierDBInterface(ctrl)
mockWif := mock_worker_interface.NewMockWorkerInterface(ctrl)
sid := "teststudy"
sc := &api.StudyConfig{
Name: "test",
Owner: "admin",
OptimizationType: 1,
ObjectiveValueName: "obj_name",
Gpu: 1,
}
dbIf = mockDB
mockDB.EXPECT().CreateStudy(
sc,
).Return(sid, nil)
s := &server{wIF: mockWif, StudyChList: make(map[string]studyCh)}
req := &api.CreateStudyRequest{StudyConfig: sc}
ret, err := s.CreateStudy(context.Background(), req)
if err != nil {
t.Fatalf("CreateStudy Error %v", err)
}
if ret.StudyId != sid {
t.Fatalf("Study ID expect "+sid+", get %s", ret.StudyId)
}
if len(s.StudyChList) != 1 {
t.Fatalf("Study register failed. Registered number is %d", len(s.StudyChList))
} else {
_, ok := s.StudyChList[sid]
if !ok {
t.Fatalf("Study %s is failed to register.", sid)
}
}
}
func TestGetStudies(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockDB := mock_db.NewMockVizierDBInterface(ctrl)
mockWif := mock_worker_interface.NewMockWorkerInterface(ctrl)
sid := []string{"teststudy1", "teststudy2"}
s := &server{wIF: mockWif, StudyChList: map[string]studyCh{sid[0]: studyCh{}, sid[1]: studyCh{}}}
dbIf = mockDB

sc := []*api.StudyConfig{
&api.StudyConfig{
Name: "test1",
Owner: "admin",
OptimizationType: 1,
ObjectiveValueName: "obj_name1",
Gpu: 1,
},
&api.StudyConfig{
Name: "test2",
Owner: "admin",
OptimizationType: 1,
ObjectiveValueName: "obj_name2",
},
}
rts := []int32{10, 20}
cts := []int32{5, 1}
for i := range sid {
mockDB.EXPECT().GetStudyConfig(sid[i]).Return(sc[i], nil)
mockWif.EXPECT().GetRunningTrials(sid[i]).Return(make([]*api.Trial, rts[i]))
mockWif.EXPECT().GetCompletedTrials(sid[i]).Return(make([]*api.Trial, cts[i]))
}

req := &api.GetStudiesRequest{}
ret, err := s.GetStudies(context.Background(), req)
if err != nil {
t.Fatalf("CreateStudy Error %v", err)
}
if len(ret.StudyInfos) != len(sid) {
t.Fatalf("Study Info number %d, expected%d", len(ret.StudyInfos), len(sid))
} else {
var j int
for i := range sid {
switch ret.StudyInfos[i].StudyId {
case sid[0]:
j = 0
case sid[1]:
j = 1
default:
t.Fatalf("GetStudy Error Study ID %s is not expected", ret.StudyInfos[j].StudyId)
}
if ret.StudyInfos[i].Name != sc[j].Name {
t.Fatalf("GetStudy Error Name %s expected %s", ret.StudyInfos[i].Name, sc[j].Name)
}
if ret.StudyInfos[i].Owner != sc[j].Owner {
t.Fatalf("GetStudy Error Owner %s expected %s", ret.StudyInfos[i].Owner, sc[j].Owner)
}
if ret.StudyInfos[i].RunningTrialNum != rts[j] {
t.Fatalf("GetStudy Error RunningTrialNum %d expected %d", ret.StudyInfos[i].RunningTrialNum, rts[j])
}
if ret.StudyInfos[i].CompletedTrialNum != cts[j] {
t.Fatalf("GetStudy Error CompletedTrialNum %d expected %d", ret.StudyInfos[i].CompletedTrialNum, cts[j])
}

}
}
}
Loading

0 comments on commit 02d3d17

Please sign in to comment.