diff --git a/lang/core/concat_func.go b/lang/core/concat_func.go index f13c8637c..023a7a712 100644 --- a/lang/core/concat_func.go +++ b/lang/core/concat_func.go @@ -30,6 +30,8 @@ package core import ( + "context" + "github.com/purpleidea/mgmt/lang/funcs" "github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/types" @@ -48,7 +50,7 @@ func init() { } // Concat concatenates two strings together. -func Concat(input []types.Value) (types.Value, error) { +func Concat(ctx context.Context, input []types.Value) (types.Value, error) { return &types.StrValue{ V: input[0].Str() + input[1].Str(), }, nil diff --git a/lang/core/convert/format_bool.go b/lang/core/convert/format_bool.go index 06b03af84..0bf445b45 100644 --- a/lang/core/convert/format_bool.go +++ b/lang/core/convert/format_bool.go @@ -30,6 +30,7 @@ package convert import ( + "context" "strconv" "github.com/purpleidea/mgmt/lang/funcs/simple" @@ -45,7 +46,7 @@ func init() { // FormatBool converts a boolean to a string representation that can be consumed // by ParseBool. This value will be `"true"` or `"false"`. -func FormatBool(input []types.Value) (types.Value, error) { +func FormatBool(ctx context.Context, input []types.Value) (types.Value, error) { return &types.StrValue{ V: strconv.FormatBool(input[0].Bool()), }, nil diff --git a/lang/core/convert/parse_bool.go b/lang/core/convert/parse_bool.go index 91b5e1349..fc490f2eb 100644 --- a/lang/core/convert/parse_bool.go +++ b/lang/core/convert/parse_bool.go @@ -30,6 +30,7 @@ package convert import ( + "context" "fmt" "strconv" @@ -48,7 +49,7 @@ func init() { // it an invalid value. Valid values match what is accepted by the golang // strconv.ParseBool function. It's recommended to use the strings `true` or // `false` if you are undecided about what string representation to choose. -func ParseBool(input []types.Value) (types.Value, error) { +func ParseBool(ctx context.Context, input []types.Value) (types.Value, error) { s := input[0].Str() b, err := strconv.ParseBool(s) if err != nil { diff --git a/lang/core/convert/to_float.go b/lang/core/convert/to_float.go index f51cc052a..959dcbf5a 100644 --- a/lang/core/convert/to_float.go +++ b/lang/core/convert/to_float.go @@ -30,6 +30,8 @@ package convert import ( + "context" + "github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/types" ) @@ -42,7 +44,7 @@ func init() { } // ToFloat converts an integer to a float. -func ToFloat(input []types.Value) (types.Value, error) { +func ToFloat(ctx context.Context, input []types.Value) (types.Value, error) { return &types.FloatValue{ V: float64(input[0].Int()), }, nil diff --git a/lang/core/convert/to_float_test.go b/lang/core/convert/to_float_test.go index 8311d97d1..887f5cebd 100644 --- a/lang/core/convert/to_float_test.go +++ b/lang/core/convert/to_float_test.go @@ -30,13 +30,14 @@ package convert import ( + "context" "testing" "github.com/purpleidea/mgmt/lang/types" ) func testToFloat(t *testing.T, input int64, expected float64) { - got, err := ToFloat([]types.Value{&types.IntValue{V: input}}) + got, err := ToFloat(context.Background(), []types.Value{&types.IntValue{V: input}}) if err != nil { t.Error(err) return diff --git a/lang/core/convert/to_int.go b/lang/core/convert/to_int.go index 3843443bf..6eaf67bc4 100644 --- a/lang/core/convert/to_int.go +++ b/lang/core/convert/to_int.go @@ -30,6 +30,8 @@ package convert import ( + "context" + "github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/types" ) @@ -42,7 +44,7 @@ func init() { } // ToInt converts a float to an integer. -func ToInt(input []types.Value) (types.Value, error) { +func ToInt(ctx context.Context, input []types.Value) (types.Value, error) { return &types.IntValue{ V: int64(input[0].Float()), }, nil diff --git a/lang/core/convert/to_int_test.go b/lang/core/convert/to_int_test.go index fd4899bcf..1814bca79 100644 --- a/lang/core/convert/to_int_test.go +++ b/lang/core/convert/to_int_test.go @@ -30,6 +30,7 @@ package convert import ( + "context" "testing" "github.com/purpleidea/mgmt/lang/types" @@ -37,7 +38,7 @@ import ( func testToInt(t *testing.T, input float64, expected int64) { - got, err := ToInt([]types.Value{&types.FloatValue{V: input}}) + got, err := ToInt(context.Background(), []types.Value{&types.FloatValue{V: input}}) if err != nil { t.Error(err) return diff --git a/lang/core/convert/to_str.go b/lang/core/convert/to_str.go index 40582a951..00b782b8f 100644 --- a/lang/core/convert/to_str.go +++ b/lang/core/convert/to_str.go @@ -30,6 +30,7 @@ package convert import ( + "context" "strconv" "github.com/purpleidea/mgmt/lang/funcs/simple" @@ -49,7 +50,7 @@ func init() { } // IntToStr converts an integer to a string. -func IntToStr(input []types.Value) (types.Value, error) { +func IntToStr(ctx context.Context, input []types.Value) (types.Value, error) { return &types.StrValue{ V: strconv.Itoa(int(input[0].Int())), }, nil diff --git a/lang/core/datetime/format_func.go b/lang/core/datetime/format_func.go index 7fbd4a44d..6db469859 100644 --- a/lang/core/datetime/format_func.go +++ b/lang/core/datetime/format_func.go @@ -30,6 +30,7 @@ package coredatetime import ( + "context" "fmt" "time" @@ -48,7 +49,7 @@ func init() { // has to be defined like specified by the golang "time" package. The time is // the number of seconds since the epoch, and matches what comes from our Now // function. Golang documentation: https://golang.org/pkg/time/#Time.Format -func Format(input []types.Value) (types.Value, error) { +func Format(ctx context.Context, input []types.Value) (types.Value, error) { epochDelta := input[0].Int() if epochDelta < 0 { return nil, fmt.Errorf("epoch delta must be positive") diff --git a/lang/core/datetime/format_func_test.go b/lang/core/datetime/format_func_test.go index 3c28e811e..d3e63d35f 100644 --- a/lang/core/datetime/format_func_test.go +++ b/lang/core/datetime/format_func_test.go @@ -32,6 +32,7 @@ package coredatetime import ( + "context" "testing" "github.com/purpleidea/mgmt/lang/types" @@ -41,7 +42,7 @@ func TestFormat(t *testing.T) { inputVal := &types.IntValue{V: 1443158163} inputFormat := &types.StrValue{V: "2006"} - val, err := Format([]types.Value{inputVal, inputFormat}) + val, err := Format(context.Background(), []types.Value{inputVal, inputFormat}) if err != nil { t.Error(err) } diff --git a/lang/core/datetime/hour_func.go b/lang/core/datetime/hour_func.go index b09d50b95..7beeb4f84 100644 --- a/lang/core/datetime/hour_func.go +++ b/lang/core/datetime/hour_func.go @@ -30,6 +30,7 @@ package coredatetime import ( + "context" "fmt" "time" @@ -47,7 +48,7 @@ func init() { // Hour returns the hour of the day corresponding to the input time. The time is // the number of seconds since the epoch, and matches what comes from our Now // function. -func Hour(input []types.Value) (types.Value, error) { +func Hour(ctx context.Context, input []types.Value) (types.Value, error) { epochDelta := input[0].Int() if epochDelta < 0 { return nil, fmt.Errorf("epoch delta must be positive") diff --git a/lang/core/datetime/print_func.go b/lang/core/datetime/print_func.go index 4dc6ea380..e1d3c61d8 100644 --- a/lang/core/datetime/print_func.go +++ b/lang/core/datetime/print_func.go @@ -30,6 +30,7 @@ package coredatetime import ( + "context" "fmt" "time" @@ -41,7 +42,7 @@ func init() { // FIXME: consider renaming this to printf, and add in a format string? simple.ModuleRegister(ModuleName, "print", &types.FuncValue{ T: types.NewType("func(a int) str"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { epochDelta := input[0].Int() if epochDelta < 0 { return nil, fmt.Errorf("epoch delta must be positive") diff --git a/lang/core/datetime/weekday_func.go b/lang/core/datetime/weekday_func.go index bddafe132..62e905e6d 100644 --- a/lang/core/datetime/weekday_func.go +++ b/lang/core/datetime/weekday_func.go @@ -30,6 +30,7 @@ package coredatetime import ( + "context" "fmt" "strings" "time" @@ -48,7 +49,7 @@ func init() { // Weekday returns the lowercased day of the week corresponding to the input // time. The time is the number of seconds since the epoch, and matches what // comes from our Now function. -func Weekday(input []types.Value) (types.Value, error) { +func Weekday(ctx context.Context, input []types.Value) (types.Value, error) { epochDelta := input[0].Int() if epochDelta < 0 { return nil, fmt.Errorf("epoch delta must be positive") diff --git a/lang/core/embedded/provisioner/provisioner.go b/lang/core/embedded/provisioner/provisioner.go index 51acb4ce1..a6b2adefe 100644 --- a/lang/core/embedded/provisioner/provisioner.go +++ b/lang/core/embedded/provisioner/provisioner.go @@ -418,11 +418,14 @@ func (obj *provisioner) Register(moduleName string) error { // Build a few separately... simple.ModuleRegister(moduleName, "cli_password", &types.FuncValue{ T: types.NewType("func() str"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { if obj.localArgs == nil { // programming error return nil, fmt.Errorf("could not convert/access our struct") } + + // TODO: plumb through the password lookup here instead? + //localArgs := *obj.localArgs // optional return &types.StrValue{ V: obj.password, diff --git a/lang/core/example/answer_func.go b/lang/core/example/answer_func.go index 0d61e88d0..29a41e99b 100644 --- a/lang/core/example/answer_func.go +++ b/lang/core/example/answer_func.go @@ -30,6 +30,8 @@ package coreexample import ( + "context" + "github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/types" ) @@ -40,7 +42,7 @@ const Answer = 42 func init() { simple.ModuleRegister(ModuleName, "answer", &types.FuncValue{ T: types.NewType("func() int"), - V: func([]types.Value) (types.Value, error) { + V: func(context.Context, []types.Value) (types.Value, error) { return &types.IntValue{V: Answer}, nil }, }) diff --git a/lang/core/example/errorbool_func.go b/lang/core/example/errorbool_func.go index fee2da9f3..7b5277d86 100644 --- a/lang/core/example/errorbool_func.go +++ b/lang/core/example/errorbool_func.go @@ -30,6 +30,7 @@ package coreexample import ( + "context" "fmt" "github.com/purpleidea/mgmt/lang/funcs/simple" @@ -39,7 +40,7 @@ import ( func init() { simple.ModuleRegister(ModuleName, "errorbool", &types.FuncValue{ T: types.NewType("func(a bool) str"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { if input[0].Bool() { return nil, fmt.Errorf("we errored on request") } diff --git a/lang/core/example/int2str_func.go b/lang/core/example/int2str_func.go index 6ffb5172f..cf672374a 100644 --- a/lang/core/example/int2str_func.go +++ b/lang/core/example/int2str_func.go @@ -30,6 +30,7 @@ package coreexample import ( + "context" "fmt" "github.com/purpleidea/mgmt/lang/funcs/simple" @@ -39,7 +40,7 @@ import ( func init() { simple.ModuleRegister(ModuleName, "int2str", &types.FuncValue{ T: types.NewType("func(a int) str"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.StrValue{ V: fmt.Sprintf("%d", input[0].Int()), }, nil diff --git a/lang/core/example/nested/hello_func.go b/lang/core/example/nested/hello_func.go index f6a64a9cb..3559f219a 100644 --- a/lang/core/example/nested/hello_func.go +++ b/lang/core/example/nested/hello_func.go @@ -30,6 +30,8 @@ package corenested import ( + "context" + coreexample "github.com/purpleidea/mgmt/lang/core/example" "github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/types" @@ -43,7 +45,7 @@ func init() { } // Hello returns some string. This is just to test nesting. -func Hello(input []types.Value) (types.Value, error) { +func Hello(ctx context.Context, input []types.Value) (types.Value, error) { return &types.StrValue{ V: "Hello!", }, nil diff --git a/lang/core/example/plus_func.go b/lang/core/example/plus_func.go index e044944f1..76dcbc335 100644 --- a/lang/core/example/plus_func.go +++ b/lang/core/example/plus_func.go @@ -30,6 +30,8 @@ package coreexample import ( + "context" + "github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/types" ) @@ -42,7 +44,7 @@ func init() { } // Plus returns y + z. -func Plus(input []types.Value) (types.Value, error) { +func Plus(ctx context.Context, input []types.Value) (types.Value, error) { y, z := input[0].Str(), input[1].Str() return &types.StrValue{ V: y + z, diff --git a/lang/core/example/str2int_func.go b/lang/core/example/str2int_func.go index 8ed0f2536..1326ff5a9 100644 --- a/lang/core/example/str2int_func.go +++ b/lang/core/example/str2int_func.go @@ -30,6 +30,7 @@ package coreexample import ( + "context" "strconv" "github.com/purpleidea/mgmt/lang/funcs/simple" @@ -39,7 +40,7 @@ import ( func init() { simple.ModuleRegister(ModuleName, "str2int", &types.FuncValue{ T: types.NewType("func(a str) int"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { var i int64 if val, err := strconv.ParseInt(input[0].Str(), 10, 64); err == nil { i = val diff --git a/lang/core/iter/map_func.go b/lang/core/iter/map_func.go index d4fd3e3f8..3145452b5 100644 --- a/lang/core/iter/map_func.go +++ b/lang/core/iter/map_func.go @@ -766,7 +766,7 @@ func (obj *MapFunc) replaceSubGraph(subgraphInput interfaces.Func) error { outputListFunc := structs.SimpleFnToDirectFunc( "mapOutputList", &types.FuncValue{ - V: func(args []types.Value) (types.Value, error) { + V: func(_ context.Context, args []types.Value) (types.Value, error) { listValue := &types.ListValue{ V: args, T: obj.outputListType, @@ -788,7 +788,7 @@ func (obj *MapFunc) replaceSubGraph(subgraphInput interfaces.Func) error { inputElemFunc := structs.SimpleFnToDirectFunc( fmt.Sprintf("mapInputElem[%d]", i), &types.FuncValue{ - V: func(args []types.Value) (types.Value, error) { + V: func(_ context.Context, args []types.Value) (types.Value, error) { if len(args) != 1 { return nil, fmt.Errorf("inputElemFunc: expected a single argument") } diff --git a/lang/core/len_func.go b/lang/core/len_func.go index 60d334a32..fb769a6eb 100644 --- a/lang/core/len_func.go +++ b/lang/core/len_func.go @@ -30,6 +30,7 @@ package core import ( + "context" "fmt" "github.com/purpleidea/mgmt/lang/funcs/simplepoly" @@ -56,7 +57,7 @@ func init() { // Len returns the number of elements in a list or the number of key pairs in a // map. It can operate on either of these types. -func Len(input []types.Value) (types.Value, error) { +func Len(ctx context.Context, input []types.Value) (types.Value, error) { var length int switch k := input[0].Type().Kind; k { case types.KindStr: diff --git a/lang/core/math/fortytwo_func.go b/lang/core/math/fortytwo_func.go index 65252ea0c..f6a70e26d 100644 --- a/lang/core/math/fortytwo_func.go +++ b/lang/core/math/fortytwo_func.go @@ -30,6 +30,7 @@ package coremath import ( + "context" "fmt" "github.com/purpleidea/mgmt/lang/funcs/simplepoly" @@ -57,8 +58,8 @@ func init() { // in a sig field, like how we demonstrate in the implementation of FortyTwo. If // the API doesn't change, then this is an example of how to build this as a // wrapper. -func fortyTwo(sig *types.Type) func([]types.Value) (types.Value, error) { - return func(input []types.Value) (types.Value, error) { +func fortyTwo(sig *types.Type) func(context.Context, []types.Value) (types.Value, error) { + return func(ctx context.Context, input []types.Value) (types.Value, error) { return FortyTwo(sig, input) } } diff --git a/lang/core/math/minus1_func.go b/lang/core/math/minus1_func.go index a1310afde..ee42518b7 100644 --- a/lang/core/math/minus1_func.go +++ b/lang/core/math/minus1_func.go @@ -30,6 +30,8 @@ package coremath import ( + "context" + "github.com/purpleidea/mgmt/lang/funcs/simple" "github.com/purpleidea/mgmt/lang/types" ) @@ -42,7 +44,7 @@ func init() { } // Minus1 takes an int and subtracts one from it. -func Minus1(input []types.Value) (types.Value, error) { +func Minus1(ctx context.Context, input []types.Value) (types.Value, error) { // TODO: check for overflow return &types.IntValue{ V: input[0].Int() - 1, diff --git a/lang/core/math/mod_func.go b/lang/core/math/mod_func.go index 3e4588c00..002a67931 100644 --- a/lang/core/math/mod_func.go +++ b/lang/core/math/mod_func.go @@ -30,6 +30,7 @@ package coremath import ( + "context" "fmt" "math" @@ -54,7 +55,7 @@ func init() { // both of KindInt or both of KindFloat, and it will return the same kind. If // you pass in a divisor of zero, this will error, eg: mod(x, 0) = NaN. // TODO: consider returning zero instead of erroring? -func Mod(input []types.Value) (types.Value, error) { +func Mod(ctx context.Context, input []types.Value) (types.Value, error) { var x, y float64 var float bool k := input[0].Type().Kind diff --git a/lang/core/math/pow_func.go b/lang/core/math/pow_func.go index 939d73260..30e51b307 100644 --- a/lang/core/math/pow_func.go +++ b/lang/core/math/pow_func.go @@ -30,6 +30,7 @@ package coremath import ( + "context" "fmt" "math" @@ -45,7 +46,7 @@ func init() { } // Pow returns x ^ y, the base-x exponential of y. -func Pow(input []types.Value) (types.Value, error) { +func Pow(ctx context.Context, input []types.Value) (types.Value, error) { x, y := input[0].Float(), input[1].Float() // FIXME: check for overflow z := math.Pow(x, y) diff --git a/lang/core/math/sqrt_func.go b/lang/core/math/sqrt_func.go index aae53b5a1..32a1bd883 100644 --- a/lang/core/math/sqrt_func.go +++ b/lang/core/math/sqrt_func.go @@ -30,6 +30,7 @@ package coremath import ( + "context" "fmt" "math" @@ -45,7 +46,7 @@ func init() { } // Sqrt returns sqrt(x), the square root of x. -func Sqrt(input []types.Value) (types.Value, error) { +func Sqrt(ctx context.Context, input []types.Value) (types.Value, error) { x := input[0].Float() y := math.Sqrt(x) if math.IsNaN(y) { diff --git a/lang/core/math/sqrt_func_test.go b/lang/core/math/sqrt_func_test.go index bc35cd119..0f55eac04 100644 --- a/lang/core/math/sqrt_func_test.go +++ b/lang/core/math/sqrt_func_test.go @@ -30,6 +30,7 @@ package coremath import ( + "context" "fmt" "math" "testing" @@ -40,7 +41,7 @@ import ( func testSqrtSuccess(input, sqrt float64) error { inputVal := &types.FloatValue{V: input} - val, err := Sqrt([]types.Value{inputVal}) + val, err := Sqrt(context.Background(), []types.Value{inputVal}) if err != nil { return err } @@ -52,7 +53,7 @@ func testSqrtSuccess(input, sqrt float64) error { func testSqrtError(input float64) error { inputVal := &types.FloatValue{V: input} - _, err := Sqrt([]types.Value{inputVal}) + _, err := Sqrt(context.Background(), []types.Value{inputVal}) if err == nil { return fmt.Errorf("expected error for input %f, got nil", input) } diff --git a/lang/core/net/cidr_to_ip_func.go b/lang/core/net/cidr_to_ip_func.go index 59ffe631d..8606419fd 100644 --- a/lang/core/net/cidr_to_ip_func.go +++ b/lang/core/net/cidr_to_ip_func.go @@ -30,6 +30,7 @@ package corenet import ( + "context" "net" "strings" @@ -45,7 +46,7 @@ func init() { } // CidrToIP returns the IP from a CIDR address -func CidrToIP(input []types.Value) (types.Value, error) { +func CidrToIP(ctx context.Context, input []types.Value) (types.Value, error) { cidr := input[0].Str() ip, _, err := net.ParseCIDR(strings.TrimSpace(cidr)) if err != nil { diff --git a/lang/core/net/cidr_to_ip_func_test.go b/lang/core/net/cidr_to_ip_func_test.go index 11fcc4319..f3c24c970 100644 --- a/lang/core/net/cidr_to_ip_func_test.go +++ b/lang/core/net/cidr_to_ip_func_test.go @@ -30,6 +30,7 @@ package corenet import ( + "context" "fmt" "testing" @@ -61,7 +62,7 @@ func TestCidrToIP(t *testing.T) { for _, ts := range cidrtests { test := ts t.Run(test.name, func(t *testing.T) { - output, err := CidrToIP([]types.Value{&types.StrValue{V: test.input}}) + output, err := CidrToIP(context.Background(), []types.Value{&types.StrValue{V: test.input}}) expectedStr := &types.StrValue{V: test.expected} if test.err != nil && err.Error() != test.err.Error() { diff --git a/lang/core/net/macfmt_func.go b/lang/core/net/macfmt_func.go index 4674fbd35..a775e89a4 100644 --- a/lang/core/net/macfmt_func.go +++ b/lang/core/net/macfmt_func.go @@ -30,6 +30,7 @@ package corenet import ( + "context" "fmt" "net" "strings" @@ -51,7 +52,7 @@ func init() { // MacFmt takes a MAC address with hyphens and converts it to a format with // colons. -func MacFmt(input []types.Value) (types.Value, error) { +func MacFmt(ctx context.Context, input []types.Value) (types.Value, error) { mac := input[0].Str() // Check if the MAC address is valid. @@ -70,7 +71,7 @@ func MacFmt(input []types.Value) (types.Value, error) { // OldMacFmt takes a MAC address with colons and converts it to a format with // hyphens. This is the old deprecated style that nobody likes. -func OldMacFmt(input []types.Value) (types.Value, error) { +func OldMacFmt(ctx context.Context, input []types.Value) (types.Value, error) { mac := input[0].Str() // Check if the MAC address is valid. diff --git a/lang/core/net/macfmt_func_test.go b/lang/core/net/macfmt_func_test.go index e93aa3d17..ed20c5cae 100644 --- a/lang/core/net/macfmt_func_test.go +++ b/lang/core/net/macfmt_func_test.go @@ -30,6 +30,7 @@ package corenet import ( + "context" "testing" "github.com/purpleidea/mgmt/lang/types" @@ -51,7 +52,7 @@ func TestMacFmt(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - m, err := MacFmt([]types.Value{&types.StrValue{V: tt.in}}) + m, err := MacFmt(context.Background(), []types.Value{&types.StrValue{V: tt.in}}) if (err != nil) != tt.wantErr { t.Errorf("func MacFmt() error = %v, wantErr %v", err, tt.wantErr) return @@ -81,7 +82,7 @@ func TestOldMacFmt(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - m, err := OldMacFmt([]types.Value{&types.StrValue{V: tt.in}}) + m, err := OldMacFmt(context.Background(), []types.Value{&types.StrValue{V: tt.in}}) if (err != nil) != tt.wantErr { t.Errorf("func MacFmt() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/lang/core/os/args_func.go b/lang/core/os/args_func.go index 023afbf91..3d8ee1ea6 100644 --- a/lang/core/os/args_func.go +++ b/lang/core/os/args_func.go @@ -30,6 +30,7 @@ package coreos import ( + "context" "os" "github.com/purpleidea/mgmt/lang/funcs/simple" @@ -47,7 +48,7 @@ func init() { // return different values depending on how this is deployed, so don't expect a // result on your deploy client to behave the same as a server receiving code. // FIXME: Sanitize any command-line secrets we might pass in by cli. -func Args([]types.Value) (types.Value, error) { +func Args(context.Context, []types.Value) (types.Value, error) { values := []types.Value{} for _, s := range os.Args { values = append(values, &types.StrValue{V: s}) diff --git a/lang/core/os/distro_func.go b/lang/core/os/distro_func.go index a7ac79517..2af702a34 100644 --- a/lang/core/os/distro_func.go +++ b/lang/core/os/distro_func.go @@ -30,6 +30,7 @@ package coreos import ( + "context" "fmt" "strings" @@ -51,7 +52,7 @@ func init() { // ParseDistroUID parses a distro UID into its component values. If it cannot // parse correctly, all the struct fields have the zero values. // NOTE: The UID pattern is subject to change. -func ParseDistroUID(input []types.Value) (types.Value, error) { +func ParseDistroUID(ctx context.Context, input []types.Value) (types.Value, error) { fn := func(distro, version, arch string) (types.Value, error) { st := types.NewStruct(types.NewType(structDistroUID)) if err := st.Set("distro", &types.StrValue{V: distro}); err != nil { diff --git a/lang/core/os/family_func.go b/lang/core/os/family_func.go index 82ef7af5d..bd53e1a70 100644 --- a/lang/core/os/family_func.go +++ b/lang/core/os/family_func.go @@ -30,6 +30,7 @@ package coreos import ( + "context" "os" "github.com/purpleidea/mgmt/lang/funcs/simple" @@ -54,8 +55,9 @@ func init() { // IsDebian detects if the os family is debian. // TODO: Detect OS changes. -func IsDebian(input []types.Value) (types.Value, error) { +func IsDebian(ctx context.Context, input []types.Value) (types.Value, error) { exists := true + // TODO: use ctx around io operations _, err := os.Stat("/etc/debian_version") if os.IsNotExist(err) { exists = false @@ -67,8 +69,9 @@ func IsDebian(input []types.Value) (types.Value, error) { // IsRedHat detects if the os family is redhat. // TODO: Detect OS changes. -func IsRedHat(input []types.Value) (types.Value, error) { +func IsRedHat(ctx context.Context, input []types.Value) (types.Value, error) { exists := true + // TODO: use ctx around io operations _, err := os.Stat("/etc/redhat-release") if os.IsNotExist(err) { exists = false @@ -80,8 +83,9 @@ func IsRedHat(input []types.Value) (types.Value, error) { // IsArchLinux detects if the os family is archlinux. // TODO: Detect OS changes. -func IsArchLinux(input []types.Value) (types.Value, error) { +func IsArchLinux(ctx context.Context, input []types.Value) (types.Value, error) { exists := true + // TODO: use ctx around io operations _, err := os.Stat("/etc/arch-release") if os.IsNotExist(err) { exists = false diff --git a/lang/core/panic_func.go b/lang/core/panic_func.go index 82c88bb5a..9719ff716 100644 --- a/lang/core/panic_func.go +++ b/lang/core/panic_func.go @@ -30,6 +30,7 @@ package core import ( + "context" "fmt" "github.com/purpleidea/mgmt/lang/funcs/simplepoly" @@ -52,7 +53,7 @@ func init() { // Panic returns an error when it receives a non-empty string or a true boolean. // The error should cause the function engine to shutdown. If there's no error, // it returns false. -func Panic(input []types.Value) (types.Value, error) { +func Panic(ctx context.Context, input []types.Value) (types.Value, error) { switch k := input[0].Type().Kind; k { case types.KindBool: if input[0].Bool() { diff --git a/lang/core/regexp/match_func.go b/lang/core/regexp/match_func.go index fa32619b8..39c026cd9 100644 --- a/lang/core/regexp/match_func.go +++ b/lang/core/regexp/match_func.go @@ -30,6 +30,7 @@ package coreregexp import ( + "context" "regexp" "github.com/purpleidea/mgmt/lang/funcs/simple" @@ -45,7 +46,7 @@ func init() { } // Match matches whether a string matches the regexp pattern. -func Match(input []types.Value) (types.Value, error) { +func Match(ctx context.Context, input []types.Value) (types.Value, error) { pattern := input[0].Str() s := input[1].Str() diff --git a/lang/core/regexp/match_func_test.go b/lang/core/regexp/match_func_test.go index 7cb4f3701..41e8951f6 100644 --- a/lang/core/regexp/match_func_test.go +++ b/lang/core/regexp/match_func_test.go @@ -30,6 +30,7 @@ package coreregexp import ( + "context" "testing" "github.com/purpleidea/mgmt/lang/types" @@ -76,7 +77,7 @@ func TestMatch0(t *testing.T) { for i, x := range values { pattern := &types.StrValue{V: x.pattern} s := &types.StrValue{V: x.s} - val, err := Match([]types.Value{pattern, s}) + val, err := Match(context.Background(), []types.Value{pattern, s}) if err != nil { t.Errorf("test index %d failed with: %+v", i, err) } diff --git a/lang/core/strings/split_func.go b/lang/core/strings/split_func.go index ac556e0a0..6409e76da 100644 --- a/lang/core/strings/split_func.go +++ b/lang/core/strings/split_func.go @@ -30,6 +30,7 @@ package corestrings import ( + "context" "strings" "github.com/purpleidea/mgmt/lang/funcs/simple" @@ -45,7 +46,7 @@ func init() { // Split splits the input string using the separator and returns the segments as // a list. -func Split(input []types.Value) (types.Value, error) { +func Split(ctx context.Context, input []types.Value) (types.Value, error) { str, sep := input[0].Str(), input[1].Str() segments := strings.Split(str, sep) diff --git a/lang/core/strings/split_func_test.go b/lang/core/strings/split_func_test.go index beef6098f..6102eba51 100644 --- a/lang/core/strings/split_func_test.go +++ b/lang/core/strings/split_func_test.go @@ -30,6 +30,7 @@ package corestrings import ( + "context" "fmt" "testing" @@ -40,7 +41,7 @@ import ( func testSplit(input, sep string, output []string) error { inputVal, sepVal := &types.StrValue{V: input}, &types.StrValue{V: sep} - val, err := Split([]types.Value{inputVal, sepVal}) + val, err := Split(context.Background(), []types.Value{inputVal, sepVal}) if err != nil { return err } diff --git a/lang/core/strings/to_lower_func.go b/lang/core/strings/to_lower_func.go index 888169325..c1895365f 100644 --- a/lang/core/strings/to_lower_func.go +++ b/lang/core/strings/to_lower_func.go @@ -30,6 +30,7 @@ package corestrings import ( + "context" "strings" "github.com/purpleidea/mgmt/lang/funcs/simple" @@ -44,7 +45,7 @@ func init() { } // ToLower turns a string to lowercase. -func ToLower(input []types.Value) (types.Value, error) { +func ToLower(ctx context.Context, input []types.Value) (types.Value, error) { return &types.StrValue{ V: strings.ToLower(input[0].Str()), }, nil diff --git a/lang/core/strings/to_lower_func_test.go b/lang/core/strings/to_lower_func_test.go index cc3eef049..764090b84 100644 --- a/lang/core/strings/to_lower_func_test.go +++ b/lang/core/strings/to_lower_func_test.go @@ -30,6 +30,7 @@ package corestrings import ( + "context" "testing" "github.com/purpleidea/mgmt/lang/types" @@ -37,7 +38,7 @@ import ( func testToLower(t *testing.T, input, expected string) { inputStr := &types.StrValue{V: input} - value, err := ToLower([]types.Value{inputStr}) + value, err := ToLower(context.Background(), []types.Value{inputStr}) if err != nil { t.Error(err) return diff --git a/lang/core/sys/env_func.go b/lang/core/sys/env_func.go index 367a824ee..949f46815 100644 --- a/lang/core/sys/env_func.go +++ b/lang/core/sys/env_func.go @@ -30,6 +30,7 @@ package coresys import ( + "context" "os" "strings" @@ -58,7 +59,7 @@ func init() { // GetEnv gets environment variable by name or returns empty string if non // existing. -func GetEnv(input []types.Value) (types.Value, error) { +func GetEnv(ctx context.Context, input []types.Value) (types.Value, error) { return &types.StrValue{ V: os.Getenv(input[0].Str()), }, nil @@ -66,7 +67,7 @@ func GetEnv(input []types.Value) (types.Value, error) { // DefaultEnv gets environment variable by name or returns default if non // existing. -func DefaultEnv(input []types.Value) (types.Value, error) { +func DefaultEnv(ctx context.Context, input []types.Value) (types.Value, error) { value, exists := os.LookupEnv(input[0].Str()) if !exists { value = input[1].Str() @@ -77,7 +78,7 @@ func DefaultEnv(input []types.Value) (types.Value, error) { } // HasEnv returns true if environment variable exists. -func HasEnv(input []types.Value) (types.Value, error) { +func HasEnv(ctx context.Context, input []types.Value) (types.Value, error) { _, exists := os.LookupEnv(input[0].Str()) return &types.BoolValue{ V: exists, @@ -85,7 +86,7 @@ func HasEnv(input []types.Value) (types.Value, error) { } // Env returns a map of all keys and their values. -func Env(input []types.Value) (types.Value, error) { +func Env(ctx context.Context, input []types.Value) (types.Value, error) { environ := make(map[types.Value]types.Value) for _, keyval := range os.Environ() { if i := strings.IndexRune(keyval, '='); i != -1 { diff --git a/lang/core/template_func.go b/lang/core/template_func.go index 1edf8668f..8bced27c9 100644 --- a/lang/core/template_func.go +++ b/lang/core/template_func.go @@ -390,7 +390,7 @@ func (obj *TemplateFunc) Init(init *interfaces.Init) error { } // run runs a template and returns the result. -func (obj *TemplateFunc) run(templateText string, vars types.Value) (string, error) { +func (obj *TemplateFunc) run(ctx context.Context, templateText string, vars types.Value) (string, error) { // see: https://golang.org/pkg/text/template/#FuncMap for more info // note: we can override any other functions by adding them here... funcMap := map[string]interface{}{ @@ -425,7 +425,7 @@ func (obj *TemplateFunc) run(templateText string, vars types.Value) (string, err // parameter types. Functions meant to apply to arguments of // arbitrary type can use parameters of type interface{} or of // type reflect.Value. - f, err := wrap(name, fn) // wrap it so that it meets API expectations + f, err := wrap(ctx, name, fn) // wrap it so that it meets API expectations if err != nil { obj.init.Logf("warning, skipping function named: `%s`, err: %v", name, err) continue @@ -538,7 +538,7 @@ func (obj *TemplateFunc) Stream(ctx context.Context) error { vars = nil } - result, err := obj.run(tmpl, vars) + result, err := obj.run(ctx, tmpl, vars) if err != nil { return err // no errwrap needed b/c helper func } @@ -585,7 +585,7 @@ func safename(name string) string { // function API with what is expected from the reflection API. It returns a // version that includes the optional second error return value so that our // functions can return errors without causing a panic. -func wrap(name string, fn *types.FuncValue) (_ interface{}, reterr error) { +func wrap(ctx context.Context, name string, fn *types.FuncValue) (_ interface{}, reterr error) { defer func() { // catch unhandled panics if r := recover(); r != nil { @@ -633,8 +633,8 @@ func wrap(name string, fn *types.FuncValue) (_ interface{}, reterr error) { innerArgs = append(innerArgs, v) } - result, err := fn.Call(innerArgs) // call it - if err != nil { // function errored :( + result, err := fn.Call(ctx, innerArgs) // call it + if err != nil { // function errored :( // errwrap is a better way to report errors, if allowed! r := reflect.ValueOf(errwrap.Wrapf(err, "function `%s` errored", name)) if !r.Type().ConvertibleTo(errorType) { // for fun! diff --git a/lang/core/test/oneinstance_fact.go b/lang/core/test/oneinstance_fact.go index e1c57b806..f4e71166d 100644 --- a/lang/core/test/oneinstance_fact.go +++ b/lang/core/test/oneinstance_fact.go @@ -123,7 +123,7 @@ func init() { simple.ModuleRegister(ModuleName, OneInstanceBFuncName, &types.FuncValue{ T: types.NewType("func() str"), - V: func([]types.Value) (types.Value, error) { + V: func(context.Context, []types.Value) (types.Value, error) { oneInstanceBMutex.Lock() if oneInstanceBFlag { panic("should not get called twice") @@ -135,7 +135,7 @@ func init() { }) simple.ModuleRegister(ModuleName, OneInstanceDFuncName, &types.FuncValue{ T: types.NewType("func() str"), - V: func([]types.Value) (types.Value, error) { + V: func(context.Context, []types.Value) (types.Value, error) { oneInstanceDMutex.Lock() if oneInstanceDFlag { panic("should not get called twice") @@ -147,7 +147,7 @@ func init() { }) simple.ModuleRegister(ModuleName, OneInstanceFFuncName, &types.FuncValue{ T: types.NewType("func() str"), - V: func([]types.Value) (types.Value, error) { + V: func(context.Context, []types.Value) (types.Value, error) { oneInstanceFMutex.Lock() if oneInstanceFFlag { panic("should not get called twice") @@ -159,7 +159,7 @@ func init() { }) simple.ModuleRegister(ModuleName, OneInstanceHFuncName, &types.FuncValue{ T: types.NewType("func() str"), - V: func([]types.Value) (types.Value, error) { + V: func(context.Context, []types.Value) (types.Value, error) { oneInstanceHMutex.Lock() if oneInstanceHFlag { panic("should not get called twice") diff --git a/lang/funcs/funcgen/fixtures/func_base.tpl b/lang/funcs/funcgen/fixtures/func_base.tpl index 8dc4afd0c..aa587b1bd 100644 --- a/lang/funcs/funcgen/fixtures/func_base.tpl +++ b/lang/funcs/funcgen/fixtures/func_base.tpl @@ -30,6 +30,7 @@ package core import ( + "context" "testpkg" "github.com/purpleidea/mgmt/lang/funcs/funcgen/util" @@ -65,25 +66,25 @@ func init() { } -func TestpkgAllKind(input []types.Value) (types.Value, error) { +func TestpkgAllKind(ctx context.Context, input []types.Value) (types.Value, error) { return &types.FloatValue{ V: testpkg.AllKind(input[0].Int(), input[1].Str()), }, nil } -func TestpkgToUpper(input []types.Value) (types.Value, error) { +func TestpkgToUpper(ctx context.Context, input []types.Value) (types.Value, error) { return &types.StrValue{ V: testpkg.ToUpper(input[0].Str()), }, nil } -func TestpkgMax(input []types.Value) (types.Value, error) { +func TestpkgMax(ctx context.Context, input []types.Value) (types.Value, error) { return &types.FloatValue{ V: testpkg.Max(input[0].Float(), input[1].Float()), }, nil } -func TestpkgWithError(input []types.Value) (types.Value, error) { +func TestpkgWithError(ctx context.Context, input []types.Value) (types.Value, error) { v, err := testpkg.WithError(input[0].Str()) if err != nil { return nil, err @@ -93,13 +94,13 @@ func TestpkgWithError(input []types.Value) (types.Value, error) { }, nil } -func TestpkgWithInt(input []types.Value) (types.Value, error) { +func TestpkgWithInt(ctx context.Context, input []types.Value) (types.Value, error) { return &types.StrValue{ V: testpkg.WithInt(input[0].Float(), int(input[1].Int()), input[2].Int(), int(input[3].Int()), int(input[4].Int()), input[5].Bool(), input[6].Str()), }, nil } -func TestpkgSuperByte(input []types.Value) (types.Value, error) { +func TestpkgSuperByte(ctx context.Context, input []types.Value) (types.Value, error) { return &types.StrValue{ V: string(testpkg.SuperByte([]byte(input[0].Str()), input[1].Str())), }, nil diff --git a/lang/funcs/funcgen/templates/generated_funcs.go.tpl b/lang/funcs/funcgen/templates/generated_funcs.go.tpl index 2c5572c64..e2c018316 100644 --- a/lang/funcs/funcgen/templates/generated_funcs.go.tpl +++ b/lang/funcs/funcgen/templates/generated_funcs.go.tpl @@ -30,6 +30,7 @@ package core import ( + "context" {{ range $i, $func := .Packages }} {{ if not (eq .Alias "") }}{{.Alias}} {{end}}"{{.Name}}" {{ end }} "github.com/purpleidea/mgmt/lang/funcs/funcgen/util" @@ -45,7 +46,7 @@ func init() { {{ end }} } {{ range $i, $func := .Functions }} -{{$func.Help}}func {{$func.InternalName}}(input []types.Value) (types.Value, error) { +{{$func.Help}}func {{$func.InternalName}}(ctx context.Context, input []types.Value) (types.Value, error) { {{- if $func.Errorful }} v, err := {{ if not (eq $func.GolangPackage.Alias "") }}{{$func.GolangPackage.Alias}}{{else}}{{$func.GolangPackage.Name}}{{end}}.{{$func.GolangFunc}}({{$func.MakeGolangArgs}}) if err != nil { diff --git a/lang/funcs/operator_func.go b/lang/funcs/operator_func.go index fa38d21d9..76eaa61db 100644 --- a/lang/funcs/operator_func.go +++ b/lang/funcs/operator_func.go @@ -55,7 +55,7 @@ func init() { // concatenation RegisterOperator("+", &types.FuncValue{ T: types.NewType("func(a str, b str) str"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.StrValue{ V: input[0].Str() + input[1].Str(), }, nil @@ -64,7 +64,7 @@ func init() { // addition RegisterOperator("+", &types.FuncValue{ T: types.NewType("func(a int, b int) int"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { //if l := len(input); l != 2 { // return nil, fmt.Errorf("expected two inputs, got: %d", l) //} @@ -77,7 +77,7 @@ func init() { // floating-point addition RegisterOperator("+", &types.FuncValue{ T: types.NewType("func(a float, b float) float"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.FloatValue{ V: input[0].Float() + input[1].Float(), }, nil @@ -87,7 +87,7 @@ func init() { // subtraction RegisterOperator("-", &types.FuncValue{ T: types.NewType("func(a int, b int) int"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.IntValue{ V: input[0].Int() - input[1].Int(), }, nil @@ -96,7 +96,7 @@ func init() { // floating-point subtraction RegisterOperator("-", &types.FuncValue{ T: types.NewType("func(a float, b float) float"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.FloatValue{ V: input[0].Float() - input[1].Float(), }, nil @@ -106,7 +106,7 @@ func init() { // multiplication RegisterOperator("*", &types.FuncValue{ T: types.NewType("func(a int, b int) int"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { // FIXME: check for overflow? return &types.IntValue{ V: input[0].Int() * input[1].Int(), @@ -116,7 +116,7 @@ func init() { // floating-point multiplication RegisterOperator("*", &types.FuncValue{ T: types.NewType("func(a float, b float) float"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.FloatValue{ V: input[0].Float() * input[1].Float(), }, nil @@ -127,7 +127,7 @@ func init() { // division RegisterOperator("/", &types.FuncValue{ T: types.NewType("func(a int, b int) float"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { divisor := input[1].Int() if divisor == 0 { return nil, fmt.Errorf("can't divide by zero") @@ -140,7 +140,7 @@ func init() { // floating-point division RegisterOperator("/", &types.FuncValue{ T: types.NewType("func(a float, b float) float"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { divisor := input[1].Float() if divisor == 0.0 { return nil, fmt.Errorf("can't divide by zero") @@ -154,7 +154,7 @@ func init() { // string equality RegisterOperator("==", &types.FuncValue{ T: types.NewType("func(a str, b str) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.BoolValue{ V: input[0].Str() == input[1].Str(), }, nil @@ -163,7 +163,7 @@ func init() { // bool equality RegisterOperator("==", &types.FuncValue{ T: types.NewType("func(a bool, b bool) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.BoolValue{ V: input[0].Bool() == input[1].Bool(), }, nil @@ -172,7 +172,7 @@ func init() { // int equality RegisterOperator("==", &types.FuncValue{ T: types.NewType("func(a int, b int) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.BoolValue{ V: input[0].Int() == input[1].Int(), }, nil @@ -181,7 +181,7 @@ func init() { // floating-point equality RegisterOperator("==", &types.FuncValue{ T: types.NewType("func(a float, b float) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { // TODO: should we do an epsilon check? return &types.BoolValue{ V: input[0].Float() == input[1].Float(), @@ -192,7 +192,7 @@ func init() { // string in-equality RegisterOperator("!=", &types.FuncValue{ T: types.NewType("func(a str, b str) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.BoolValue{ V: input[0].Str() != input[1].Str(), }, nil @@ -201,7 +201,7 @@ func init() { // bool in-equality RegisterOperator("!=", &types.FuncValue{ T: types.NewType("func(a bool, b bool) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.BoolValue{ V: input[0].Bool() != input[1].Bool(), }, nil @@ -210,7 +210,7 @@ func init() { // int in-equality RegisterOperator("!=", &types.FuncValue{ T: types.NewType("func(a int, b int) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.BoolValue{ V: input[0].Int() != input[1].Int(), }, nil @@ -219,7 +219,7 @@ func init() { // floating-point in-equality RegisterOperator("!=", &types.FuncValue{ T: types.NewType("func(a float, b float) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { // TODO: should we do an epsilon check? return &types.BoolValue{ V: input[0].Float() != input[1].Float(), @@ -230,7 +230,7 @@ func init() { // less-than RegisterOperator("<", &types.FuncValue{ T: types.NewType("func(a int, b int) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.BoolValue{ V: input[0].Int() < input[1].Int(), }, nil @@ -239,7 +239,7 @@ func init() { // floating-point less-than RegisterOperator("<", &types.FuncValue{ T: types.NewType("func(a float, b float) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { // TODO: should we do an epsilon check? return &types.BoolValue{ V: input[0].Float() < input[1].Float(), @@ -249,7 +249,7 @@ func init() { // greater-than RegisterOperator(">", &types.FuncValue{ T: types.NewType("func(a int, b int) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.BoolValue{ V: input[0].Int() > input[1].Int(), }, nil @@ -258,7 +258,7 @@ func init() { // floating-point greater-than RegisterOperator(">", &types.FuncValue{ T: types.NewType("func(a float, b float) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { // TODO: should we do an epsilon check? return &types.BoolValue{ V: input[0].Float() > input[1].Float(), @@ -268,7 +268,7 @@ func init() { // less-than-equal RegisterOperator("<=", &types.FuncValue{ T: types.NewType("func(a int, b int) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.BoolValue{ V: input[0].Int() <= input[1].Int(), }, nil @@ -277,7 +277,7 @@ func init() { // floating-point less-than-equal RegisterOperator("<=", &types.FuncValue{ T: types.NewType("func(a float, b float) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { // TODO: should we do an epsilon check? return &types.BoolValue{ V: input[0].Float() <= input[1].Float(), @@ -287,7 +287,7 @@ func init() { // greater-than-equal RegisterOperator(">=", &types.FuncValue{ T: types.NewType("func(a int, b int) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.BoolValue{ V: input[0].Int() >= input[1].Int(), }, nil @@ -296,7 +296,7 @@ func init() { // floating-point greater-than-equal RegisterOperator(">=", &types.FuncValue{ T: types.NewType("func(a float, b float) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { // TODO: should we do an epsilon check? return &types.BoolValue{ V: input[0].Float() >= input[1].Float(), @@ -309,7 +309,7 @@ func init() { // short-circuit operators, and does it matter? RegisterOperator("and", &types.FuncValue{ T: types.NewType("func(a bool, b bool) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.BoolValue{ V: input[0].Bool() && input[1].Bool(), }, nil @@ -318,7 +318,7 @@ func init() { // logical or RegisterOperator("or", &types.FuncValue{ T: types.NewType("func(a bool, b bool) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.BoolValue{ V: input[0].Bool() || input[1].Bool(), }, nil @@ -328,7 +328,7 @@ func init() { // logical not (unary operator) RegisterOperator("not", &types.FuncValue{ T: types.NewType("func(a bool) bool"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.BoolValue{ V: !input[0].Bool(), }, nil @@ -338,7 +338,7 @@ func init() { // pi operator (this is an easter egg to demo a zero arg operator) RegisterOperator("π", &types.FuncValue{ T: types.NewType("func() float"), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { return &types.FloatValue{ V: math.Pi, }, nil @@ -938,7 +938,7 @@ func (obj *OperatorFunc) Stream(ctx context.Context) error { lastOp = op var result types.Value - result, err := fn.Call(args) // run the function + result, err := fn.Call(ctx, args) // run the function if err != nil { return errwrap.Wrapf(err, "problem running function") } diff --git a/lang/funcs/simple/simple.go b/lang/funcs/simple/simple.go index 5d84ccdb4..27a9ad807 100644 --- a/lang/funcs/simple/simple.go +++ b/lang/funcs/simple/simple.go @@ -173,7 +173,7 @@ func (obj *WrappedFunc) Stream(ctx context.Context) error { values = append(values, x) } - result, err := obj.Fn.Call(values) // (Value, error) + result, err := obj.Fn.Call(ctx, values) // (Value, error) if err != nil { return errwrap.Wrapf(err, "simple function errored") } @@ -244,7 +244,7 @@ func StructRegister(moduleName string, args interface{}) error { ModuleRegister(moduleName, name, &types.FuncValue{ T: types.NewType(fmt.Sprintf("func() %s", typed.String())), - V: func(input []types.Value) (types.Value, error) { + V: func(ctx context.Context, input []types.Value) (types.Value, error) { //if args == nil { // // programming error // return nil, fmt.Errorf("could not convert/access our struct") diff --git a/lang/funcs/simplepoly/simplepoly.go b/lang/funcs/simplepoly/simplepoly.go index 6ab4061f3..9ff6eec72 100644 --- a/lang/funcs/simplepoly/simplepoly.go +++ b/lang/funcs/simplepoly/simplepoly.go @@ -602,7 +602,7 @@ func (obj *WrappedFunc) Stream(ctx context.Context) error { if obj.init.Debug { obj.init.Logf("Calling function with: %+v", values) } - result, err := obj.fn.Call(values) // (Value, error) + result, err := obj.fn.Call(ctx, values) // (Value, error) if err != nil { if obj.init.Debug { obj.init.Logf("Function returned error: %+v", err) diff --git a/lang/types/value.go b/lang/types/value.go index ed079e85f..1e4a48ae4 100644 --- a/lang/types/value.go +++ b/lang/types/value.go @@ -30,6 +30,7 @@ package types import ( + "context" "errors" "fmt" "net" @@ -230,7 +231,7 @@ func ValueOf(v reflect.Value) (Value, error) { return nil, fmt.Errorf("cannot only represent functions with one output value") } - f := func(args []Value) (Value, error) { + f := func(ctx context.Context, args []Value) (Value, error) { in := []reflect.Value{} for _, x := range args { // TODO: should we build this method instead? @@ -239,6 +240,7 @@ func ValueOf(v reflect.Value) (Value, error) { in = append(in, v) } + // FIXME: can we pass in ctx ? // FIXME: can we trap panic's ? out := value.Call(in) // []reflect.Value if len(out) != 1 { // TODO: panic, b/c already checked in TypeOf? @@ -1207,7 +1209,7 @@ func (obj *StructValue) Lookup(k string) (value Value, exists bool) { // Func nodes. type FuncValue struct { Base - V func([]Value) (Value, error) + V func(context.Context, []Value) (Value, error) T *Type // contains ordered field types, arg names are a bonus part } @@ -1217,7 +1219,7 @@ func NewFunc(t *Type) *FuncValue { if t.Kind != KindFunc { return nil // sanity check } - v := func([]Value) (Value, error) { + v := func(context.Context, []Value) (Value, error) { // You were not supposed to call the temporary function, you // were supposed to replace it with a real implementation! return nil, fmt.Errorf("nil function") @@ -1301,7 +1303,7 @@ func (obj *FuncValue) Value() interface{} { // Call runs the function value and returns its result. It returns an error if // something goes wrong during execution, and panic's if you call this with // inappropriate input types, or if it returns an inappropriate output type. -func (obj *FuncValue) Call(args []Value) (Value, error) { +func (obj *FuncValue) Call(ctx context.Context, args []Value) (Value, error) { // cmp input args type to obj.T length := len(obj.T.Ord) if length != len(args) { @@ -1313,7 +1315,7 @@ func (obj *FuncValue) Call(args []Value) (Value, error) { } } - result, err := obj.V(args) // call it + result, err := obj.V(ctx, args) // call it if result == nil { if err == nil { return nil, fmt.Errorf("function returned nil result")