Skip to content

Commit

Permalink
refactor: rename asha to sha (#733)
Browse files Browse the repository at this point in the history
* refactor: rename asha to sha

* Apply suggestions from code review

Co-authored-by: Danny Zhu <[email protected]>

Co-authored-by: Danny Zhu <[email protected]>
  • Loading branch information
liamcli and dzhu authored Jun 17, 2020
1 parent c75e693 commit e85c8bc
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 45 deletions.
2 changes: 1 addition & 1 deletion master/pkg/model/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func DefaultExperimentConfig() ExperimentConfig {
Hyperparameters: make(map[string]Hyperparameter),
Searcher: SearcherConfig{
SmallerIsBetter: true,
AsyncHalvingConfig: &AsyncHalvingConfig{
SyncHalvingConfig: &SyncHalvingConfig{
SmallerIsBetter: true,
Divisor: 4,
TrainStragglers: true,
Expand Down
6 changes: 3 additions & 3 deletions master/pkg/model/searcher_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type SearcherConfig struct {
SingleConfig *SingleConfig `union:"name,single" json:"-"`
RandomConfig *RandomConfig `union:"name,random" json:"-"`
GridConfig *GridConfig `union:"name,grid" json:"-"`
AsyncHalvingConfig *AsyncHalvingConfig `union:"name,async_halving" json:"-"`
SyncHalvingConfig *SyncHalvingConfig `union:"name,sync_halving" json:"-"`
AdaptiveConfig *AdaptiveConfig `union:"name,adaptive" json:"-"`
AdaptiveSimpleConfig *AdaptiveSimpleConfig `union:"name,adaptive_simple" json:"-"`
PBTConfig *PBTConfig `union:"name,pbt" json:"-"`
Expand Down Expand Up @@ -81,8 +81,8 @@ func (g GridConfig) Validate() []error {
}
}

// AsyncHalvingConfig configures asynchronous successive halving.
type AsyncHalvingConfig struct {
// SyncHalvingConfig configures synchronous successive halving.
type SyncHalvingConfig struct {
Metric string `json:"metric"`
SmallerIsBetter bool `json:"smaller_is_better"`
NumRungs int `json:"num_rungs"`
Expand Down
4 changes: 2 additions & 2 deletions master/pkg/searcher/adaptive.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func newAdaptiveSearch(config model.AdaptiveConfig) SearchMethod {

methods := make([]SearchMethod, 0, len(brackets))
for _, numRungs := range brackets {
c := model.AsyncHalvingConfig{
c := model.SyncHalvingConfig{
Metric: config.Metric,
SmallerIsBetter: config.SmallerIsBetter,
NumRungs: numRungs,
Expand All @@ -27,7 +27,7 @@ func newAdaptiveSearch(config model.AdaptiveConfig) SearchMethod {
Divisor: config.Divisor,
TrainStragglers: config.TrainStragglers,
}
methods = append(methods, newAsyncHalvingSearch(c))
methods = append(methods, newSyncHalvingSearch(c))
}

return newTournamentSearch(methods...)
Expand Down
18 changes: 9 additions & 9 deletions master/pkg/searcher/adaptive_simple.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func newAdaptiveSimpleSearch(config model.AdaptiveSimpleConfig) SearchMethod {

methods := make([]SearchMethod, 0, len(brackets))
for i, numRungs := range brackets {
c := model.AsyncHalvingConfig{
c := model.SyncHalvingConfig{
Metric: config.Metric,
SmallerIsBetter: config.SmallerIsBetter,
TargetTrialSteps: config.MaxSteps,
Expand All @@ -30,13 +30,13 @@ func newAdaptiveSimpleSearch(config model.AdaptiveSimpleConfig) SearchMethod {
TrainStragglers: true,
}
numTrials := max(maxTrials(config.MaxTrials, len(brackets), i), 1)
methods = append(methods, newAsyncHalvingSimpleSearch(c, numTrials))
methods = append(methods, newSyncHalvingSimpleSearch(c, numTrials))
}

return newTournamentSearch(methods...)
}

func newAsyncHalvingSimpleSearch(config model.AsyncHalvingConfig, trials int) SearchMethod {
func newSyncHalvingSimpleSearch(config model.SyncHalvingConfig, trials int) SearchMethod {
rungs := make([]*rung, 0, config.NumRungs)
expectedSteps := 0
expectedWorkloads := 0
Expand All @@ -63,11 +63,11 @@ func newAsyncHalvingSimpleSearch(config model.AsyncHalvingConfig, trials int) Se
)
}
config.StepBudget = expectedSteps
return &asyncHalvingSearch{
AsyncHalvingConfig: config,
rungs: rungs,
trialRungs: make(map[RequestID]int),
earlyExitTrials: make(map[RequestID]bool),
expectedWorkloads: expectedWorkloads,
return &syncHalvingSearch{
SyncHalvingConfig: config,
rungs: rungs,
trialRungs: make(map[RequestID]int),
earlyExitTrials: make(map[RequestID]bool),
expectedWorkloads: expectedWorkloads,
}
}
4 changes: 2 additions & 2 deletions master/pkg/searcher/search_method.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ func NewSearchMethod(c model.SearcherConfig) SearchMethod {
return newRandomSearch(*c.RandomConfig)
case c.GridConfig != nil:
return newGridSearch(*c.GridConfig)
case c.AsyncHalvingConfig != nil:
return newAsyncHalvingSearch(*c.AsyncHalvingConfig)
case c.SyncHalvingConfig != nil:
return newSyncHalvingSearch(*c.SyncHalvingConfig)
case c.AdaptiveConfig != nil:
return newAdaptiveSearch(*c.AdaptiveConfig)
case c.AdaptiveSimpleConfig != nil:
Expand Down
40 changes: 20 additions & 20 deletions master/pkg/searcher/asha.go → master/pkg/searcher/sha.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (
"github.com/determined-ai/determined/master/pkg/model"
)

// asyncHalvingSearch implements a search using the asynchronous successive halving algorithm
// (ASHA). Technically, this is closer to SHA than ASHA as the promotions are synchronous.
type asyncHalvingSearch struct {
// syncHalvingSearch implements a search using the synchronous successive halving algorithm
// (SHA).
type syncHalvingSearch struct {
defaultSearchMethod
model.AsyncHalvingConfig
model.SyncHalvingConfig

rungs []*rung
trialRungs map[RequestID]int
Expand All @@ -23,9 +23,9 @@ type asyncHalvingSearch struct {
trialsCompleted int
}

const ashaExitedMetricValue = math.MaxFloat64
const shaExitedMetricValue = math.MaxFloat64

func newAsyncHalvingSearch(config model.AsyncHalvingConfig) SearchMethod {
func newSyncHalvingSearch(config model.SyncHalvingConfig) SearchMethod {
rungs := make([]*rung, 0, config.NumRungs)
expectedSteps := 0
for id := 0; id < config.NumRungs; id++ {
Expand Down Expand Up @@ -61,12 +61,12 @@ func newAsyncHalvingSearch(config model.AsyncHalvingConfig) SearchMethod {
}
}

return &asyncHalvingSearch{
AsyncHalvingConfig: config,
rungs: rungs,
trialRungs: make(map[RequestID]int),
earlyExitTrials: make(map[RequestID]bool),
expectedWorkloads: expectedWorkloads,
return &syncHalvingSearch{
SyncHalvingConfig: config,
rungs: rungs,
trialRungs: make(map[RequestID]int),
earlyExitTrials: make(map[RequestID]bool),
expectedWorkloads: expectedWorkloads,
}
}

Expand Down Expand Up @@ -111,7 +111,7 @@ func (r *rung) promotions(requestID RequestID, metric float64) []RequestID {
}
}

func (s *asyncHalvingSearch) initialOperations(ctx context) ([]Operation, error) {
func (s *syncHalvingSearch) initialOperations(ctx context) ([]Operation, error) {
var ops []Operation
for trial := 0; trial < s.rungs[0].startTrials; trial++ {
create := NewCreate(
Expand All @@ -122,13 +122,13 @@ func (s *asyncHalvingSearch) initialOperations(ctx context) ([]Operation, error)
return ops, nil
}

func (s *asyncHalvingSearch) trainCompleted(
func (s *syncHalvingSearch) trainCompleted(
ctx context, requestID RequestID, message Workload,
) ([]Operation, error) {
return nil, nil
}

func (s *asyncHalvingSearch) validationCompleted(
func (s *syncHalvingSearch) validationCompleted(
ctx context, requestID RequestID, message Workload, metrics ValidationMetrics,
) ([]Operation, error) {
// Extract the relevant metric as a float.
Expand All @@ -143,7 +143,7 @@ func (s *asyncHalvingSearch) validationCompleted(
return s.promoteTrials(ctx, requestID, message, metric)
}

func (s *asyncHalvingSearch) promoteTrials(
func (s *syncHalvingSearch) promoteTrials(
ctx context, requestID RequestID, message Workload, metric float64,
) ([]Operation, error) {
rungIndex := s.trialRungs[requestID]
Expand Down Expand Up @@ -183,7 +183,7 @@ func (s *asyncHalvingSearch) promoteTrials(
//
// 2) We are bounded on the depth of this recursive stack by
// the number of rungs. We default this to max out at 5.
_, err := s.promoteTrials(ctx, promotionID, wkld, ashaExitedMetricValue)
_, err := s.promoteTrials(ctx, promotionID, wkld, shaExitedMetricValue)
return nil, err
}
}
Expand All @@ -204,15 +204,15 @@ func (s *asyncHalvingSearch) promoteTrials(
return ops, nil
}

func (s *asyncHalvingSearch) progress(workloadsCompleted int) float64 {
func (s *syncHalvingSearch) progress(workloadsCompleted int) float64 {
return math.Min(1, float64(workloadsCompleted)/float64(s.expectedWorkloads))
}

func (s *asyncHalvingSearch) trialExitedEarly(
func (s *syncHalvingSearch) trialExitedEarly(
ctx context, requestID RequestID, message Workload,
) ([]Operation, error) {
s.earlyExitTrials[requestID] = true
return s.promoteTrials(ctx, requestID, message, ashaExitedMetricValue)
return s.promoteTrials(ctx, requestID, message, shaExitedMetricValue)
}

func max(initial int, values ...int) int {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"github.com/determined-ai/determined/master/pkg/model"
)

func TestASHASearcher(t *testing.T) {
actual := model.AsyncHalvingConfig{
func TestSHASearcher(t *testing.T) {
actual := model.SyncHalvingConfig{
Metric: defaultMetric,
NumRungs: 4,
TargetTrialSteps: 800,
Expand All @@ -22,10 +22,10 @@ func TestASHASearcher(t *testing.T) {
toKinds("12S 1V 38S 1V"),
toKinds("12S 1V 38S 1V 150S 1V 600S 1V"),
}
checkSimulation(t, newAsyncHalvingSearch(actual), nil, ConstantValidation, expected)
checkSimulation(t, newSyncHalvingSearch(actual), nil, ConstantValidation, expected)
}

func TestASHASearchMethod(t *testing.T) {
func TestSHASearchMethod(t *testing.T) {
testCases := []valueSimulationTestCase{
{
name: "smaller is better",
Expand All @@ -43,7 +43,7 @@ func TestASHASearchMethod(t *testing.T) {
newConstantPredefinedTrial(0.11, 12, []int{12}, nil),
},
config: model.SearcherConfig{
AsyncHalvingConfig: &model.AsyncHalvingConfig{
SyncHalvingConfig: &model.SyncHalvingConfig{
Metric: "error",
NumRungs: 4,
SmallerIsBetter: true,
Expand All @@ -70,7 +70,7 @@ func TestASHASearchMethod(t *testing.T) {
newEarlyExitPredefinedTrial(0.11, 11, nil, nil),
},
config: model.SearcherConfig{
AsyncHalvingConfig: &model.AsyncHalvingConfig{
SyncHalvingConfig: &model.SyncHalvingConfig{
Metric: "error",
NumRungs: 4,
SmallerIsBetter: true,
Expand All @@ -97,7 +97,7 @@ func TestASHASearchMethod(t *testing.T) {
newConstantPredefinedTrial(0.01, 12, []int{12}, nil),
},
config: model.SearcherConfig{
AsyncHalvingConfig: &model.AsyncHalvingConfig{
SyncHalvingConfig: &model.SyncHalvingConfig{
Metric: "error",
NumRungs: 4,
SmallerIsBetter: false,
Expand All @@ -124,7 +124,7 @@ func TestASHASearchMethod(t *testing.T) {
newEarlyExitPredefinedTrial(0.01, 11, nil, nil),
},
config: model.SearcherConfig{
AsyncHalvingConfig: &model.AsyncHalvingConfig{
SyncHalvingConfig: &model.SyncHalvingConfig{
Metric: "error",
NumRungs: 4,
SmallerIsBetter: false,
Expand Down

0 comments on commit e85c8bc

Please sign in to comment.