Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a check if config key is present in yaml file before fetching value #975

Merged
merged 4 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 58 additions & 16 deletions cmd/config-utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cmd

import (
"github.com/spf13/viper"
"razor/core"
"razor/core/types"
"razor/utils"
"strings"
Expand Down Expand Up @@ -69,10 +70,15 @@ func (*UtilsStruct) GetConfigData() (types.Configurations, error) {
func (*UtilsStruct) GetProvider() (string, error) {
provider, err := flagSetUtils.GetRootStringProvider()
if err != nil {
return "", err
return core.DefaultProvider, err
}
if provider == "" {
provider = viper.GetString("provider")
if viper.IsSet("provider") {
provider = viper.GetString("provider")
} else {
provider = core.DefaultProvider
log.Debug("Provider is not set, taking its default value ", provider)
}
}
if !strings.HasPrefix(provider, "https") {
log.Warn("You are not using a secure RPC URL. Switch to an https URL instead to be safe.")
Expand All @@ -84,10 +90,15 @@ func (*UtilsStruct) GetProvider() (string, error) {
func (*UtilsStruct) GetMultiplier() (float32, error) {
gasMultiplier, err := flagSetUtils.GetRootFloat32GasMultiplier()
if err != nil {
return 1, err
return float32(core.DefaultGasMultiplier), err
}
if gasMultiplier == -1 {
gasMultiplier = float32(viper.GetFloat64("gasmultiplier"))
if viper.IsSet("gasmultiplier") {
gasMultiplier = float32(viper.GetFloat64("gasmultiplier"))
} else {
gasMultiplier = float32(core.DefaultGasMultiplier)
log.Debug("GasMultiplier is not set, taking its default value ", gasMultiplier)
}
}
return gasMultiplier, nil
}
Expand All @@ -96,10 +107,15 @@ func (*UtilsStruct) GetMultiplier() (float32, error) {
func (*UtilsStruct) GetBufferPercent() (int32, error) {
bufferPercent, err := flagSetUtils.GetRootInt32Buffer()
if err != nil {
return 30, err
return int32(core.DefaultBufferPercent), err
}
if bufferPercent == 0 {
bufferPercent = viper.GetInt32("buffer")
if viper.IsSet("buffer") {
bufferPercent = viper.GetInt32("buffer")
} else {
bufferPercent = int32(core.DefaultBufferPercent)
log.Debug("BufferPercent is not set, taking its default value ", bufferPercent)
}
}
return bufferPercent, nil
}
Expand All @@ -108,10 +124,15 @@ func (*UtilsStruct) GetBufferPercent() (int32, error) {
func (*UtilsStruct) GetWaitTime() (int32, error) {
waitTime, err := flagSetUtils.GetRootInt32Wait()
if err != nil {
return 3, err
return int32(core.DefaultWaitTime), err
}
if waitTime == -1 {
waitTime = viper.GetInt32("wait")
if viper.IsSet("wait") {
waitTime = viper.GetInt32("wait")
} else {
waitTime = int32(core.DefaultWaitTime)
log.Debug("WaitTime is not set, taking its default value ", waitTime)
}
}
return waitTime, nil
}
Expand All @@ -120,10 +141,16 @@ func (*UtilsStruct) GetWaitTime() (int32, error) {
func (*UtilsStruct) GetGasPrice() (int32, error) {
gasPrice, err := flagSetUtils.GetRootInt32GasPrice()
if err != nil {
return 0, err
return int32(core.DefaultGasPrice), err
}
if gasPrice == -1 {
gasPrice = viper.GetInt32("gasprice")
if viper.IsSet("gasprice") {
gasPrice = viper.GetInt32("gasprice")
} else {
gasPrice = int32(core.DefaultGasPrice)
log.Debug("GasPrice is not set, taking its default value ", gasPrice)

}
}
return gasPrice, nil
}
Expand All @@ -132,10 +159,15 @@ func (*UtilsStruct) GetGasPrice() (int32, error) {
func (*UtilsStruct) GetLogLevel() (string, error) {
logLevel, err := flagSetUtils.GetRootStringLogLevel()
if err != nil {
return "", err
return core.DefaultLogLevel, err
}
if logLevel == "" {
logLevel = viper.GetString("logLevel")
if viper.IsSet("logLevel") {
logLevel = viper.GetString("logLevel")
} else {
logLevel = core.DefaultLogLevel
log.Debug("LogLevel is not set, taking its default value ", logLevel)
}
}
return logLevel, nil
}
Expand All @@ -144,10 +176,15 @@ func (*UtilsStruct) GetLogLevel() (string, error) {
func (*UtilsStruct) GetGasLimit() (float32, error) {
gasLimit, err := flagSetUtils.GetRootFloat32GasLimit()
if err != nil {
return -1, err
return float32(core.DefaultGasLimit), err
}
if gasLimit == -1 {
gasLimit = float32(viper.GetFloat64("gasLimit"))
if viper.IsSet("gasLimit") {
gasLimit = float32(viper.GetFloat64("gasLimit"))
} else {
gasLimit = float32(core.DefaultGasLimit)
log.Debug("GasLimit is not set, taking its default value ", gasLimit)
}
}
return gasLimit, nil
}
Expand All @@ -156,10 +193,15 @@ func (*UtilsStruct) GetGasLimit() (float32, error) {
func (*UtilsStruct) GetRPCTimeout() (int64, error) {
rpcTimeout, err := flagSetUtils.GetRootInt64RPCTimeout()
if err != nil {
return 10, err
return int64(core.DefaultRPCTimeout), err
}
if rpcTimeout == 0 {
rpcTimeout = viper.GetInt64("rpcTimeout")
if viper.IsSet("rpcTimeout") {
rpcTimeout = viper.GetInt64("rpcTimeout")
} else {
rpcTimeout = int64(core.DefaultRPCTimeout)
log.Debug("RPCTimeout is not set, taking its default value ", rpcTimeout)
}
}
return rpcTimeout, nil
}
26 changes: 13 additions & 13 deletions cmd/config-utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,15 @@ func TestGetBufferPercent(t *testing.T) {
args: args{
bufferPercent: 0,
},
want: 0,
want: 20,
wantErr: nil,
},
{
name: "Test 3: When there is an error in getting bufferPercent",
args: args{
bufferPercentErr: errors.New("bufferPercent error"),
},
want: 30,
want: 20,
wantErr: errors.New("bufferPercent error"),
},
}
Expand Down Expand Up @@ -250,15 +250,15 @@ func TestGetGasLimit(t *testing.T) {
args: args{
gasLimit: -1,
},
want: 0,
want: 2,
wantErr: nil,
},
{
name: "Test 3: When there is an error in getting gasLimit",
args: args{
gasLimitErr: errors.New("gasLimit error"),
},
want: -1,
want: 2,
wantErr: errors.New("gasLimit error"),
},
}
Expand Down Expand Up @@ -311,15 +311,15 @@ func TestGetGasPrice(t *testing.T) {
args: args{
gasPrice: -1,
},
want: 0,
want: 1,
wantErr: nil,
},
{
name: "Test 3: When there is an error in getting gasPrice",
args: args{
gasPriceErr: errors.New("gasPrice error"),
},
want: 0,
want: 1,
wantErr: errors.New("gasPrice error"),
},
}
Expand Down Expand Up @@ -433,7 +433,7 @@ func TestGetMultiplier(t *testing.T) {
args: args{
gasMultiplier: -1,
},
want: 0,
want: 1,
wantErr: nil,
},
{
Expand Down Expand Up @@ -502,15 +502,15 @@ func TestGetProvider(t *testing.T) {
args: args{
providerErr: errors.New("provider error"),
},
want: "",
want: "http://127.0.0.1:8545",
wantErr: errors.New("provider error"),
},
{
name: "Test 2: When provider is nil",
name: "Test 4: When provider is nil",
args: args{
provider: "",
},
want: "",
want: "http://127.0.0.1:8545",
wantErr: nil,
},
}
Expand Down Expand Up @@ -563,15 +563,15 @@ func TestGetWaitTime(t *testing.T) {
args: args{
waitTime: -1,
},
want: 0,
want: 1,
wantErr: nil,
},
{
name: "Test 3: When there is an error in getting waitTime",
args: args{
waitTimeErr: errors.New("waitTime error"),
},
want: 3,
want: 1,
wantErr: errors.New("waitTime error"),
},
}
Expand Down Expand Up @@ -623,7 +623,7 @@ func TestGetRPCTimeout(t *testing.T) {
args: args{
rpcTimeout: 0,
},
want: 0,
want: 10,
wantErr: nil,
},
{
Expand Down
9 changes: 9 additions & 0 deletions core/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,12 @@ var StateLength = uint64(EpochLength / NumberOfStates)
var MaxRetries uint = 8
var NilHash = common.Hash{0x00}
var BlockCompletionTimeout = 30

var DefaultProvider = "http://127.0.0.1:8545"
var DefaultGasMultiplier = 1.0
var DefaultBufferPercent = 20
var DefaultGasPrice = 1
var DefaultWaitTime = 1
var DefaultGasLimit = 2
var DefaultRPCTimeout = 10
var DefaultLogLevel = ""