diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 8d4ee8b..ca9e371 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -27,6 +27,20 @@ jobs: mongodb-version: 4.2 mongodb-username: root mongodb-password: pwd + mongodb-port: 27017 + mongodb-db: fastflow + + - name: Shutdown Ubuntu MySQL (SUDO) + run: sudo service mysql stop # Shutdown the Default MySQL, "sudo" is necessary, please not remove it + + - name: Set up MySQL + uses: mirromutth/mysql-action@v1.1 + with: + host port: 55000 # Optional, default value is 3306. The port of host + container port: 55000 # Optional, default value is 3306. The port of container + mysql version: '8.0' # Optional, default value is "latest". The version of the MySQL + mysql database: 'fastflow' # Optional, default value is "test". The specified database which will be create + mysql root password: mysqlpw # Required if "mysql user" is empty, default is empty. The root superuser password - name: Test run: make g-test @@ -35,3 +49,11 @@ jobs: uses: codecov/codecov-action@v2 with: file: ./coverage.out + + - name: Integration Test + run: go test -race -coverprofile=integration-coverage.out ./... -tags=integration + + - name: Upload Coverage report to CodeCov + uses: codecov/codecov-action@v2 + with: + file: ./integration-coverage.out diff --git a/Makefile b/Makefile index 6f95504..fb4fa5f 100644 --- a/Makefile +++ b/Makefile @@ -26,4 +26,10 @@ g-test: mock: for file in `find . -type d \( -path ./.git -o -path ./.github \) -prune -o -name '*.go' -print | xargs grep --files-with-matches -e '//go:generate mockgen'`; do \ go generate $$file; \ - done \ No newline at end of file + done + +.PHONY: build +GO := GO111MODULE=on go +GOBUILD := CGO_ENABLED=0 $(GO) build +build: + GOARCH=amd64 GOOS=linux $(GOBUILD) -gcflags "all=-N -l" -o dist/fastflow examples/mysql/main.go diff --git a/examples/mysql/main.go b/examples/mysql/main.go new file mode 100644 index 0000000..8d148a3 --- /dev/null +++ b/examples/mysql/main.go @@ -0,0 +1,197 @@ +package main + +import ( + "errors" + "fmt" + "log" + "net/http" + "time" + + "github.com/go-sql-driver/mysql" + "github.com/shiningrush/fastflow" + mysqlKeeper "github.com/shiningrush/fastflow/keeper/mysql" + "github.com/shiningrush/fastflow/pkg/entity" + "github.com/shiningrush/fastflow/pkg/entity/run" + "github.com/shiningrush/fastflow/pkg/exporter" + "github.com/shiningrush/fastflow/pkg/mod" + "github.com/shiningrush/fastflow/pkg/utils/data" + mysqlStore "github.com/shiningrush/fastflow/store/mysql" +) + +type ActionParam struct { + Name string + Desc string +} + +type ActionA struct { + code string +} + +func (a *ActionA) Name() string { + return fmt.Sprintf("Action-%s", a.code) +} +func (a *ActionA) RunBefore(ctx run.ExecuteContext, params interface{}) error { + input := params.(*ActionParam) + log.Println(fmt.Sprintf("%s run before, p.Name: %s, p.Desc: %s", a.Name(), input.Name, input.Desc)) + time.Sleep(time.Second) + if a.code != "B" && a.code != "C" { + ctx.ShareData().Set(fmt.Sprintf("%s-key", a.code), fmt.Sprintf("%s value", a.code)) + } + return nil +} +func (a *ActionA) Run(ctx run.ExecuteContext, params interface{}) error { + input := params.(*ActionParam) + log.Println(fmt.Sprintf("%s run, p.Name: %s, p.Desc: %s", a.Name(), input.Name, input.Desc)) + ctx.Trace("run start", run.TraceOpPersistAfterAction) + time.Sleep(2 * time.Second) + ctx.Trace("run end") + return nil +} +func (a *ActionA) RunAfter(ctx run.ExecuteContext, params interface{}) error { + input := params.(*ActionParam) + log.Println(fmt.Sprintf("%s run after, p.Name: %s, p.Desc: %s", a.Name(), input.Name, input.Desc)) + time.Sleep(time.Second) + return nil +} +func (a *ActionA) ParameterNew() interface{} { + return &ActionParam{} +} + +func ensureDagCreated() error { + dag := &entity.Dag{ + BaseInfo: entity.BaseInfo{ + ID: "test-dag", + }, + Name: "test", + Vars: entity.DagVars{ + "var": {DefaultValue: "default-var"}, + }, + Status: entity.DagStatusNormal, + Tasks: []entity.Task{ + {ID: "task1", ActionName: "Action-A", Params: map[string]interface{}{ + "Name": "task-p1", + "Desc": "{{var}}", + }, TimeoutSecs: 5}, + {ID: "task2", ActionName: "Action-B", DependOn: []string{"task1"}, Params: map[string]interface{}{ + "Name": "task-p1", + "Desc": "{{var}}", + }}, + {ID: "task3", ActionName: "Action-C", DependOn: []string{"task1"}, Params: map[string]interface{}{ + "Name": "task-p1", + "Desc": "{{var}}", + }}, + {ID: "task4", ActionName: "Action-D", DependOn: []string{"task2", "task3"}, Params: map[string]interface{}{ + "Name": "task-p1", + "Desc": "{{var}}", + }}, + }, + } + oldDag, err := mod.GetStore().GetDag(dag.ID) + if errors.Is(err, data.ErrDataNotFound) { + if err := mod.GetStore().CreateDag(dag); err != nil { + return err + } + } + if oldDag != nil { + if err := mod.GetStore().UpdateDag(dag); err != nil { + return err + } + } + return nil +} + +func main() { + // init action + fastflow.RegisterAction([]run.Action{ + &ActionA{code: "A"}, + &ActionA{code: "B"}, + &ActionA{code: "C"}, + &ActionA{code: "D"}, + }) + // init keeper + keeper := mysqlKeeper.NewKeeper(&mysqlKeeper.KeeperOption{ + Key: "worker-1", + MySQLConfig: &mysql.Config{ + Addr: "127.0.0.1:55000", + User: "root", + Passwd: "mysqlpw", + DBName: "fastflow", + }, + MigrationSwitch: true, + }) + if err := keeper.Init(); err != nil { + log.Fatal(fmt.Errorf("init keeper failed: %w", err)) + return + } + + // init store + st := mysqlStore.NewStore(&mysqlStore.StoreOption{ + MySQLConfig: &mysql.Config{ + Addr: "127.0.0.1:55000", + User: "root", + Passwd: "mysqlpw", + DBName: "fastflow", + }, + MigrationSwitch: true, + }) + if err := st.Init(); err != nil { + log.Fatal(fmt.Errorf("init store failed: %w", err)) + return + } + + // init fastflow + if err := fastflow.Init(&fastflow.InitialOption{ + Keeper: keeper, + Store: st, + ParserWorkersCnt: 10, + ExecutorWorkerCnt: 50, + }); err != nil { + panic(fmt.Sprintf("init fastflow failed: %s", err)) + } + + // create a dag as template + if err := ensureDagCreated(); err != nil { + log.Fatalf(err.Error()) + return + } + // run dag interval + go runInstance() + + // listen a http endpoint to serve metrics + if err := http.ListenAndServe(":9090", exporter.HttpHandler()); err != nil { + panic(fmt.Sprintf("metrics serve failed: %s", err)) + } +} + +func runInstance() { + // wait init completed + time.Sleep(2 * time.Second) + dag, err := mod.GetStore().GetDag("test-dag") + if err != nil { + panic(err) + } + + count := uint64(0) + for { + runVar := map[string]string{ + "var": "run-var", + } + if count%2 == 0 { + runVar = nil + } + dagIns, err := dag.Run(entity.TriggerManually, runVar) + if err != nil { + panic(err) + } + + dagIns.Tags = entity.NewDagInstanceTags(map[string]string{"testKey": "testValue", "testKey2": "testValue2", "testKey3": "testValue3"}) + + err = mod.GetStore().CreateDagIns(dagIns) + if err != nil { + panic(err) + } + + count++ + time.Sleep(1 * time.Second) + } +} diff --git a/examples/programming/main.go b/examples/programming/main.go index 3f47925..9fb89e4 100644 --- a/examples/programming/main.go +++ b/examples/programming/main.go @@ -76,7 +76,8 @@ func createDagAndInstance() { BaseInfo: entity.BaseInfo{ ID: "test-dag", }, - Name: "test", + Name: "test", + Status: entity.DagStatusNormal, Tasks: []entity.Task{ {ID: "task1", ActionName: "PrintAction"}, {ID: "task2", ActionName: "PrintAction", DependOn: []string{"task1"}}, diff --git a/fastflow_test.go b/fastflow_test.go index ac9fe1f..658d196 100644 --- a/fastflow_test.go +++ b/fastflow_test.go @@ -87,7 +87,7 @@ func Test_readDagFromDir(t *testing.T) { givePathDagMap: map[string][]byte{ "dag1": []byte(`tasks: 123`), }, - wantErr: fmt.Errorf("unmarshal dag1 failed: %w", &yaml.TypeError{Errors: []string{"line 1: cannot unmarshal !!int `123` into []entity.Task"}}), + wantErr: fmt.Errorf("unmarshal dag1 failed: %w", &yaml.TypeError{Errors: []string{"line 1: cannot unmarshal !!int `123` into entity.DagTasks"}}), }, { caseDesc: "normal", diff --git a/go.mod b/go.mod index bcc5435..b532aaa 100644 --- a/go.mod +++ b/go.mod @@ -1,16 +1,50 @@ module github.com/shiningrush/fastflow -go 1.14 +go 1.20 require ( + github.com/go-sql-driver/mysql v1.7.0 github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e github.com/golang/mock v1.6.0 github.com/mitchellh/mapstructure v1.1.2 github.com/prometheus/client_golang v1.14.0 github.com/shiningrush/goevent v0.1.0 + github.com/shiningrush/goext v0.2.4-0.20230805045150-8b8c5748342b github.com/sony/sonyflake v1.0.0 github.com/spaolacci/murmur3 v1.1.0 - github.com/stretchr/testify v1.6.1 + github.com/stretchr/testify v1.7.0 go.mongodb.org/mongo-driver v1.5.4 gopkg.in/yaml.v3 v3.0.0 + gorm.io/driver/mysql v1.5.0 + gorm.io/gorm v1.25.1 +) + +require ( + github.com/aws/aws-sdk-go v1.34.28 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-stack/stack v1.8.0 // indirect + github.com/golang/protobuf v1.5.2 // indirect + github.com/golang/snappy v0.0.1 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect + github.com/klauspost/compress v1.9.5 // indirect + github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/prometheus/client_model v0.3.0 // indirect + github.com/prometheus/common v0.37.0 // indirect + github.com/prometheus/procfs v0.8.0 // indirect + github.com/stretchr/objx v0.1.1 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.0.2 // indirect + github.com/xdg-go/stringprep v1.0.2 // indirect + github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect + golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 // indirect + golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect + golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a // indirect + golang.org/x/text v0.3.7 // indirect + google.golang.org/protobuf v1.28.1 // indirect ) diff --git a/go.sum b/go.sum index dddac42..ccd0c4d 100644 --- a/go.sum +++ b/go.sum @@ -74,6 +74,8 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0= @@ -147,7 +149,6 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= @@ -165,6 +166,10 @@ github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= @@ -245,6 +250,8 @@ github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/shiningrush/goevent v0.1.0 h1:084IrgoL3KbudRtYSEVgnGUNNEVwG5aCvzCjAPP1G/g= github.com/shiningrush/goevent v0.1.0/go.mod h1:c242Xdp8/ot6idcZ2xdUVSe0I82aobcOfO9yel3PZxU= +github.com/shiningrush/goext v0.2.4-0.20230805045150-8b8c5748342b h1:qvObgZt8h6Tgeg46fn5+INxDClWZIip5jDOB8VLrDkQ= +github.com/shiningrush/goext v0.2.4-0.20230805045150-8b8c5748342b/go.mod h1:XAD+HxmZjdrVdmQCVmejcdG5LZRzahJT2J8koaGVBCU= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= @@ -262,8 +269,9 @@ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= @@ -356,7 +364,6 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -580,6 +587,11 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0 h1:hjy8E9ON/egN1tAYqKb61G10WtihqetD4sz2H+8nIeA= gopkg.in/yaml.v3 v3.0.0/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.0 h1:6hSAT5QcyIaty0jfnff0z0CLDjyRgZ8mlMHLqSt7uXM= +gorm.io/driver/mysql v1.5.0/go.mod h1:FFla/fJuCvyTi7rJQd27qlNX2v3L6deTR1GgTjSOLPo= +gorm.io/gorm v1.24.7-0.20230306060331-85eaf9eeda11/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64= +gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/keeper/mongo/mongo.go b/keeper/mongo/mongo.go index 599c347..252646a 100644 --- a/keeper/mongo/mongo.go +++ b/keeper/mongo/mongo.go @@ -54,6 +54,9 @@ type KeeperOption struct { UnhealthyTime time.Duration // Timeout default 2s Timeout time.Duration + + // InitFlakeGeneratorSwitch, if null or true, will init flake generator, false just for ut + InitFlakeGeneratorSwitch *bool } // NewKeeper @@ -72,7 +75,9 @@ func (k *Keeper) Init() error { if err := k.readOpt(); err != nil { return err } - store.InitFlakeGenerator(uint16(k.WorkerNumber())) + if k.opt.InitFlakeGeneratorSwitch == nil || *k.opt.InitFlakeGeneratorSwitch { + store.InitFlakeGenerator(uint16(k.WorkerNumber())) + } ctx, cancel := context.WithTimeout(context.Background(), k.opt.Timeout) defer cancel() diff --git a/keeper/mongo/mongo_integ_test.go b/keeper/mongo/mongo_integ_test.go index 3ff7f6e..f2b5e8c 100644 --- a/keeper/mongo/mongo_integ_test.go +++ b/keeper/mongo/mongo_integ_test.go @@ -118,14 +118,19 @@ func TestKeeper_Reconnect(t *testing.T) { func initWorker(t *testing.T, key string) *Keeper { w := NewKeeper(&KeeperOption{ - Key: key, - ConnStr: mongoConn, + Key: key, + ConnStr: mongoConn, + InitFlakeGeneratorSwitch: boolToPointer(true), }) err := w.Init() require.NoError(t, err) return w } +func boolToPointer(b bool) *bool { + return &b +} + func initSanityWorker(t *testing.T) (w1, w2, w3 *Keeper) { w1 = initWorker(t, "worker-1") w2 = initWorker(t, "worker-2") diff --git a/keeper/mysql/entity.go b/keeper/mysql/entity.go new file mode 100644 index 0000000..e9a4d2e --- /dev/null +++ b/keeper/mysql/entity.go @@ -0,0 +1,24 @@ +package mysql + +import ( + "time" + + "gorm.io/gorm" +) + +type Heartbeat struct { + WorkerKey string `gorm:"primaryKey;type:VARCHAR(256);not null"` + CreatedAt time.Time `gorm:"autoCreateTime;type:timestamp;not null;<-:create"` + UpdatedAt time.Time `gorm:"autoUpdateTime;type:timestamp;"` +} + +type Election struct { + ID string `gorm:"primaryKey;type:VARCHAR(256);not null"` + WorkerKey string `gorm:"type:VARCHAR(256);not null"` + UpdatedAt time.Time `gorm:"autoUpdateTime;type:timestamp;"` +} + +type IDGenerator struct { + gorm.Model + Counter int `gorm:"default:256"` +} diff --git a/keeper/mysql/mysql.go b/keeper/mysql/mysql.go new file mode 100644 index 0000000..b9ca5b8 --- /dev/null +++ b/keeper/mysql/mysql.go @@ -0,0 +1,547 @@ +package mysql + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "gorm.io/gorm/clause" + "math" + "sync" + "sync/atomic" + "time" + + "github.com/go-sql-driver/mysql" + "github.com/shiningrush/fastflow/keeper" + "github.com/shiningrush/fastflow/pkg/event" + "github.com/shiningrush/fastflow/pkg/log" + "github.com/shiningrush/fastflow/pkg/mod" + "github.com/shiningrush/fastflow/store" + "github.com/shiningrush/goevent" + gormDriver "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +const LeaderKey = "leader" + +// Keeper mysql implement +type Keeper struct { + opt *KeeperOption + gormDB *gorm.DB + + leaderFlag atomic.Value + keyNumber int + + wg sync.WaitGroup + firstInitWg sync.WaitGroup + initCompleted atomic.Value + closeCh chan struct{} +} + +// KeeperOption +type KeeperOption struct { + // Key the work key, must be the format like "xxxx-{{number}}", number is the code of worker + Key string + // mongo connection string + MySQLConfig *mysql.Config + GormConfig *gorm.Config + PoolConfig *ConnectionPoolOption + + // UnhealthyTime default 5s, campaign and heartbeat time will be half of it + UnhealthyTime time.Duration + // Timeout default 2s + Timeout time.Duration + + MigrationSwitch bool + WatcherFlag bool +} + +type ConnectionPoolOption struct { + MaxIdleConns int + MaxOpenConns int + ConnMaxLifetime time.Duration +} + +// NewKeeper +func NewKeeper(opt *KeeperOption) *Keeper { + k := &Keeper{ + opt: opt, + closeCh: make(chan struct{}), + } + k.leaderFlag.Store(false) + k.initCompleted.Store(false) + return k +} + +// Init +func (k *Keeper) Init() error { + if err := k.readOpt(); err != nil { + return err + } + + db, err := gorm.Open(gormDriver.Open(k.opt.MySQLConfig.FormatDSN()), k.opt.GormConfig) + if err != nil { + return fmt.Errorf("connect to mysql occur error: %w", err) + } + + sqlDB, err := db.DB() + if err != nil { + return fmt.Errorf("get sqlDB error: %w", err) + } + + sqlDB.SetConnMaxLifetime(k.opt.PoolConfig.ConnMaxLifetime) + sqlDB.SetMaxIdleConns(k.opt.PoolConfig.MaxIdleConns) + sqlDB.SetMaxOpenConns(k.opt.PoolConfig.MaxOpenConns) + + if k.opt.MigrationSwitch { + err = db.AutoMigrate(&Heartbeat{}, &Election{}, &IDGenerator{}) + if err != nil { + return err + } + } + k.gormDB = db + + if err := k.initWorkerKey(); err != nil { + return err + } + store.InitFlakeGenerator(uint16(k.WorkerNumber())) + if k.opt.WatcherFlag { + return nil + } + + k.firstInitWg.Add(2) + + k.wg.Add(1) + go k.goElect() + + if err := k.initHeartBeat(); err != nil { + return err + } + k.wg.Add(1) + go k.goHeartBeat() + + k.firstInitWg.Wait() + k.initCompleted.Store(true) + return nil +} + +func (k *Keeper) initWorkerKey() error { + // if watcher flag is true, then we don't need to generate a key by mysql + if !k.opt.WatcherFlag { + number, err := keeper.CheckWorkerKey(k.opt.Key) + if err != nil { + return err + } + k.keyNumber = number + return nil + } + + return k.transaction(func(tx *gorm.DB) error { + var idGenerator IDGenerator + err := tx.Clauses(clause.Locking{Strength: "UPDATE"}). + Where("id = ?", 1). + First(&idGenerator).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + idGenerator = IDGenerator{ + Model: gorm.Model{ID: 1}, + Counter: 256, + } + if err := tx.Create(&idGenerator).Error; err != nil { + return err + } + } else { + return err + } + } + k.keyNumber = idGenerator.Counter + if k.keyNumber > int(math.Pow(2, 16))-1 { + return fmt.Errorf("worker number is too big, need to clear id_generator table") + } + log.Info("generate worker id is %d", k.keyNumber) + idGenerator.Counter++ + return tx.Save(&idGenerator).Error + }) +} + +func (k *Keeper) setLeaderFlag(isLeader bool) { + k.leaderFlag.Store(isLeader) + goevent.Publish(&event.LeaderChanged{ + IsLeader: isLeader, + WorkerKey: k.WorkerKey(), + }) +} + +// IsLeader indicate the component if is leader node +func (k *Keeper) IsLeader() bool { + return k.leaderFlag.Load().(bool) +} + +// AliveNodes get all alive nodes +func (k *Keeper) AliveNodes() ([]string, error) { + var heartbeats []Heartbeat + err := k.transaction(func(tx *gorm.DB) error { + log.Info("%v", time.Now().Add(-1*k.opt.UnhealthyTime)) + return tx.Where("updated_at > ?", time.Now().Add(-1*k.opt.UnhealthyTime)).Find(&heartbeats).Error + }) + if err != nil { + return nil, fmt.Errorf("find heartbeats failed: %w", err) + } + + var aliveNodes []string + for i := range heartbeats { + aliveNodes = append(aliveNodes, heartbeats[i].WorkerKey) + } + return aliveNodes, nil +} + +func (k *Keeper) transaction(cb func(tx *gorm.DB) error) error { + ctx, cancel := context.WithTimeout(context.TODO(), k.opt.Timeout) + defer cancel() + + db := k.gormDB.WithContext(ctx) + return db.Transaction(func(tx *gorm.DB) error { + return cb(tx) + }) +} + +// IsAlive check if a worker still alive +func (k *Keeper) IsAlive(workerKey string) (bool, error) { + heartbeat := &Heartbeat{} + err := k.transaction(func(tx *gorm.DB) error { + return tx.Where("worker_key", workerKey). + Where("updated_at > ?", time.Now().Add(-1*k.opt.UnhealthyTime)). + Find(heartbeat).Error + }) + + if err == gorm.ErrRecordNotFound { + return false, nil + } + if err != nil { + return false, fmt.Errorf("query mysql failed: %w", err) + } + return true, nil +} + +// WorkerKey must match `xxxx-1` format +func (k *Keeper) WorkerKey() string { + return k.opt.Key +} + +// WorkerNumber get the key number of Worker key, if here is a WorkKey like `worker-1`, then it will return "1" +func (k *Keeper) WorkerNumber() int { + return k.keyNumber +} + +func (k *Keeper) NewMutex(key string) mod.DistributedMutex { + panic("implement me") +} + +// close component +func (k *Keeper) Close() { + close(k.closeCh) + k.wg.Wait() + + if k.leaderFlag.Load().(bool) { + err := k.transaction(func(tx *gorm.DB) error { + return tx.Delete(&Election{}, "id = ?", LeaderKey).Error + }) + if err != nil { + log.Errorf("deregister leader failed: %s", err) + } + } + + err := k.transaction(func(tx *gorm.DB) error { + return tx.Delete(&Heartbeat{}, "worker_key = ?", k.WorkerKey()).Error + }) + if err != nil { + log.Errorf("deregister heart beat failed: %s", err) + } + + sqlDB, err := k.gormDB.DB() + if err != nil { + log.Errorf("get store client failed: %s", err) + } + + if err = sqlDB.Close(); err != nil { + log.Errorf("close store client failed: %s", err) + } +} + +// this function is just for testing +func (k *Keeper) forceClose() { + close(k.closeCh) + k.wg.Wait() +} + +func (k *Keeper) goElect() { + timerCh := time.Tick(k.opt.UnhealthyTime / 2) + closed := false + for !closed { + select { + case <-k.closeCh: + closed = true + case <-timerCh: + k.elect() + } + } + k.wg.Done() +} + +func (k *Keeper) elect() { + if k.leaderFlag.Load().(bool) { + if err := k.continueLeader(); err != nil { + log.Errorf("continue leader failed: %s", err) + k.setLeaderFlag(false) + return + } + } else { + if err := k.campaign(); err != nil { + log.Errorf("campaign failed: %s", err) + return + } + } + + if !k.initCompleted.Load().(bool) { + k.firstInitWg.Done() + } +} + +func (k *Keeper) campaign() error { + election := &Election{} + err := k.transaction(func(tx *gorm.DB) error { + return tx.Where("id = ?", LeaderKey).First(election).Error + }) + if err == nil { + if election.WorkerKey == k.WorkerKey() { + k.setLeaderFlag(true) + return nil + } + if election.UpdatedAt.Before(time.Now().Add(-1 * k.opt.UnhealthyTime)) { + return k.transaction(func(tx *gorm.DB) error { + update := tx.Model(&Election{}). + Where("id = ?", LeaderKey). + Where("worker_key = ?", election.WorkerKey). + Updates(map[string]interface{}{ + "worker_key": k.WorkerKey(), + "updated_at": time.Now(), + }) + if update.Error != nil { + log.Errorf("update failed: %s", update.Error) + return fmt.Errorf("update failed: %w", update.Error) + } + if update.RowsAffected > 0 { + k.setLeaderFlag(true) + } + return nil + }) + } + return nil + } + + if errors.Is(err, gorm.ErrRecordNotFound) { + err := k.transaction(func(tx *gorm.DB) error { + election := &Election{ + ID: LeaderKey, + WorkerKey: k.WorkerKey(), + UpdatedAt: time.Now(), + } + return tx.Create(election).Error + }) + if err != nil { + if errors.Is(err, gorm.ErrDuplicatedKey) { + log.Infof("campaign failed") + return nil + } + log.Errorf("insert campaign rec failed: %s", err) + return fmt.Errorf("insert failed: %w", err) + } + k.setLeaderFlag(true) + return nil + } + + return fmt.Errorf("query leader failed: %w", err) +} + +func (k *Keeper) continueLeader() error { + return k.transaction(func(tx *gorm.DB) error { + update := tx.Model(&Election{}). + Where("id = ?", LeaderKey).Where("worker_key = ?", k.WorkerKey()). + Update("updated_at", time.Now()) + if update.Error != nil { + log.Errorf("update failed: %s", update.Error) + return fmt.Errorf("update failed: %w", update.Error) + } + if update.RowsAffected == 0 { + log.Errorf("re-elected failed") + return fmt.Errorf("re-elected failed") + } + return nil + }) +} + +func (k *Keeper) goHeartBeat() { + timerCh := time.Tick(k.opt.UnhealthyTime / 2) + closed := false + for !closed { + select { + case <-k.closeCh: + closed = true + case <-timerCh: + if err := k.heartBeat(); err != nil { + k.queryGormStats() + log.Errorf("heart beat failed: %s", err) + continue + } + } + if !k.initCompleted.Load().(bool) { + k.firstInitWg.Done() + } + } + k.wg.Done() +} + +func (k *Keeper) heartBeat() error { + err := k.transaction(func(tx *gorm.DB) error { + return tx.Model(&Heartbeat{}).Where("worker_key = ?", k.WorkerKey()).Update("updated_at", time.Now()).Error + }) + if err != nil { + return fmt.Errorf("update hearbeat failed: %w", err) + } + return nil +} + +func (k *Keeper) queryGormStats() { + tx, err := k.gormDB.DB() + if err != nil { + log.Errorf("get store client failed: %s", err) + } else { + bytes, err := json.Marshal(tx.Stats()) + if err != nil { + log.Errorf("marshal stats failed: %s", err) + } + log.Info("stats: %s", string(bytes)) + } +} + +func (k *Keeper) initHeartBeat() error { + err := k.transaction(func(tx *gorm.DB) error { + h := &Heartbeat{} + err := tx.Select("worker_key").Where("worker_key = ?", k.WorkerKey()).First(h).Error + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + if err == nil && h.WorkerKey == k.WorkerKey() { + return nil + } + heartbeat := Heartbeat{ + WorkerKey: k.WorkerKey(), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + return tx.Create(&heartbeat).Error + }) + if err != nil { + return fmt.Errorf("init hearbeat failed: %w", err) + } + return nil +} + +func (k *Keeper) readOpt() error { + if k.opt.Key == "" { + return fmt.Errorf("worker key can not be empty") + } + if k.opt.UnhealthyTime == 0 { + k.opt.UnhealthyTime = time.Second * 5 + } + if k.opt.Timeout == 0 { + k.opt.Timeout = time.Second * 2 + } + err := k.readMySQLConfigOpt() + if err != nil { + return err + } + k.readGormConfigOpt() + k.readPoolConfigOpt() + return nil +} + +func (k *Keeper) readGormConfigOpt() { + if k.opt.GormConfig == nil { + k.opt.GormConfig = &gorm.Config{} + } +} + +func (k *Keeper) readPoolConfigOpt() { + if k.opt.PoolConfig == nil { + k.opt.PoolConfig = &ConnectionPoolOption{} + } + if k.opt.PoolConfig.MaxOpenConns == 0 { + k.opt.PoolConfig.MaxOpenConns = 10 + } + if k.opt.PoolConfig.MaxIdleConns == 0 { + k.opt.PoolConfig.MaxIdleConns = 15 + } + if k.opt.PoolConfig.ConnMaxLifetime == 0 { + k.opt.PoolConfig.ConnMaxLifetime = time.Minute * 3 + } +} + +func (k *Keeper) readMySQLConfigOpt() error { + if k.opt.MySQLConfig == nil { + return fmt.Errorf("mysql config cannot be empty") + } + + if k.opt.MySQLConfig.Addr == "" { + return fmt.Errorf("addr cannot be empty") + } + + if k.opt.MySQLConfig.User == "" { + return fmt.Errorf("user cannot be empty") + } + + if k.opt.MySQLConfig.Passwd == "" { + return fmt.Errorf("passwd cannot be empty") + } + + if k.opt.MySQLConfig.DBName == "" { + return fmt.Errorf("dbName cannot be empty") + } + + if k.opt.MySQLConfig.Collation == "" { + k.opt.MySQLConfig.Collation = "utf8mb4_unicode_ci" + } + + if k.opt.MySQLConfig.Loc == nil { + k.opt.MySQLConfig.Loc = time.UTC + } + + if k.opt.MySQLConfig.MaxAllowedPacket == 0 { + k.opt.MySQLConfig.MaxAllowedPacket = mysql.NewConfig().MaxAllowedPacket + } + + k.opt.MySQLConfig.Net = "tcp" + k.opt.MySQLConfig.AllowNativePasswords = true + k.opt.MySQLConfig.CheckConnLiveness = true + k.opt.MySQLConfig.ParseTime = true + + if k.opt.MySQLConfig.Timeout == 0 { + k.opt.MySQLConfig.Timeout = 5 * time.Second + } + + if k.opt.MySQLConfig.ReadTimeout == 0 { + k.opt.MySQLConfig.ReadTimeout = 30 * time.Second + } + + if k.opt.MySQLConfig.WriteTimeout == 0 { + k.opt.MySQLConfig.WriteTimeout = 30 * time.Second + } + + if k.opt.MySQLConfig.Params == nil { + k.opt.MySQLConfig.Params = map[string]string{} + } + if _, ok := k.opt.MySQLConfig.Params["charset"]; !ok { + k.opt.MySQLConfig.Params["charset"] = "utf8mb4" + } + return nil +} diff --git a/keeper/mysql/mysql_integ_test.go b/keeper/mysql/mysql_integ_test.go new file mode 100644 index 0000000..14d80ea --- /dev/null +++ b/keeper/mysql/mysql_integ_test.go @@ -0,0 +1,150 @@ +//go:build integration +// +build integration + +package mysql + +import ( + "fmt" + "log" + "sync" + "testing" + "time" + + mysqlDriver "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/assert" +) + +const ( + addr = "127.0.0.1:55000" + user = "root" + passwd = "mysqlpw" + dbName = "fastflow" +) + +func TestKeeper_Sanity(t *testing.T) { + w1, w2, w3 := initSanityWorker(t) + // sleep 2 unhealthy period then check + time.Sleep(time.Second * 10) + + assert.Equal(t, "worker-3", w3.WorkerKey()) + assert.Equal(t, true, w1.IsLeader()) + nodes, err := w2.AliveNodes() + assert.NoError(t, err) + assert.ElementsMatch(t, []string{"worker-1", "worker-2", "worker-3"}, nodes) + log.Println("keeper work well, ready to re-goElect") + + w1.Close() + time.Sleep(6 * time.Second) + // should elect new leader + assert.True(t, w2.IsLeader() || w3.IsLeader()) + nodes, err = w2.AliveNodes() + assert.NoError(t, err) + assert.ElementsMatch(t, []string{"worker-2", "worker-3"}, nodes) + w2.Close() + w3.Close() +} + +func TestKeeper_Crash(t *testing.T) { + w1, w2, w3 := initSanityWorker(t) + assert.Equal(t, "worker-1", w1.WorkerKey()) + assert.Equal(t, true, w1.IsLeader()) + log.Println("keeper work well, ready to re-goElect") + + w1.forceClose() + time.Sleep(11 * time.Second) + // should goElect new leader + assert.True(t, w2.IsLeader() || w3.IsLeader()) + nodes, err := w3.AliveNodes() + assert.NoError(t, err) + assert.ElementsMatch(t, []string{"worker-2", "worker-3"}, nodes) + w1.closeCh = make(chan struct{}) + w1.Close() + w2.Close() + w3.Close() +} + +func TestKeeper_Concurrency(t *testing.T) { + wg := sync.WaitGroup{} + stsCh := make(chan struct { + isLeader bool + aliveNodes int + }) + leaderCount := 0 + + go func() { + for ret := range stsCh { + if ret.isLeader { + leaderCount++ + } + } + }() + + curCnt := 40 + initCompleted := sync.WaitGroup{} + initCompleted.Add(curCnt) + closeCh := make(chan struct{}) + for i := 0; i < curCnt; i++ { + wg.Add(1) + go func(i int, closeCh chan struct{}) { + w := initWorker(t, fmt.Sprintf("worker-%d", i)) + ns, err := w.AliveNodes() + assert.NoError(t, err) + stsCh <- struct { + isLeader bool + aliveNodes int + }{isLeader: w.IsLeader(), aliveNodes: len(ns)} + initCompleted.Done() + <-closeCh + w.Close() + wg.Done() + }(i, closeCh) + } + initCompleted.Wait() + w := initWorker(t, "latest-0") + nodes, err := w.AliveNodes() + assert.NoError(t, err) + assert.Equal(t, curCnt+1, len(nodes)) + assert.Equal(t, false, w.IsLeader()) + assert.Equal(t, 1, leaderCount, "leader should always be one") + w.Close() + close(closeCh) + wg.Wait() +} + +func TestKeeper_Reconnect(t *testing.T) { + w1 := initWorker(t, "worker-1") + assert.True(t, w1.IsLeader()) + w1.forceClose() + + w1 = initWorker(t, "worker-1") + assert.True(t, w1.IsLeader()) + w1.Close() +} + +func initWorker(t *testing.T, key string) *Keeper { + // init keeper + w := NewKeeper(&KeeperOption{ + Key: key, + MySQLConfig: &mysqlDriver.Config{ + Addr: "127.0.0.1:55000", + User: "root", + Passwd: "mysqlpw", + DBName: "fastflow", + }, + MigrationSwitch: true, + }) + err := w.Init() + assert.NoError(t, err) + return w +} + +func boolToPointer(b bool) *bool { + return &b +} + +func initSanityWorker(t *testing.T) (w1, w2, w3 *Keeper) { + w1 = initWorker(t, "worker-1") + w2 = initWorker(t, "worker-2") + w3 = initWorker(t, "worker-3") + return +} diff --git a/pkg/entity/common.go b/pkg/entity/common.go index 4c9de0c..35f5783 100644 --- a/pkg/entity/common.go +++ b/pkg/entity/common.go @@ -9,9 +9,9 @@ import ( // BaseInfo type BaseInfo struct { - ID string `yaml:"id" json:"id" bson:"_id"` - CreatedAt int64 `yaml:"createdAt" json:"createdAt" bson:"createdAt"` - UpdatedAt int64 `yaml:"updatedAt" json:"updatedAt" bson:"updatedAt"` + ID string `yaml:"id" json:"id" bson:"_id" gorm:"primaryKey;type:VARCHAR(256);not null"` + CreatedAt int64 `yaml:"createdAt" json:"createdAt" bson:"createdAt" gorm:"autoCreateTime;type:bigint(20) unsigned;not null;<-:create"` + UpdatedAt int64 `yaml:"updatedAt" json:"updatedAt" bson:"updatedAt" gorm:"autoUpdateTime;type:bigint(20) unsigned;index;"` } // GetBaseInfo getter @@ -24,7 +24,9 @@ func (b *BaseInfo) Initial() { if b.ID == "" { b.ID = store.NextStringID() } - b.CreatedAt = time.Now().Unix() + if b.CreatedAt == 0 { + b.CreatedAt = time.Now().Unix() + } b.UpdatedAt = time.Now().Unix() } diff --git a/pkg/entity/dag.go b/pkg/entity/dag.go index e24ebb0..e90b2b8 100644 --- a/pkg/entity/dag.go +++ b/pkg/entity/dag.go @@ -21,14 +21,16 @@ func NewDag() *Dag { // Dag type Dag struct { BaseInfo `yaml:",inline" json:",inline" bson:"inline"` - Name string `yaml:"name,omitempty" json:"name,omitempty" bson:"name,omitempty"` - Desc string `yaml:"desc,omitempty" json:"desc,omitempty" bson:"desc,omitempty"` - Cron string `yaml:"cron,omitempty" json:"cron,omitempty" bson:"cron,omitempty"` - Vars DagVars `yaml:"vars,omitempty" json:"vars,omitempty" bson:"vars,omitempty"` - Status DagStatus `yaml:"status,omitempty" json:"status,omitempty" bson:"status,omitempty"` - Tasks []Task `yaml:"tasks,omitempty" json:"tasks,omitempty" bson:"tasks,omitempty"` + Name string `yaml:"name,omitempty" json:"name,omitempty" bson:"name,omitempty" gorm:"type:VARCHAR(128);not null"` + Desc string `yaml:"desc,omitempty" json:"desc,omitempty" bson:"desc,omitempty" gorm:"type:VARCHAR(256);"` + Cron string `yaml:"cron,omitempty" json:"cron,omitempty" bson:"cron,omitempty" gorm:"-"` + Vars DagVars `yaml:"vars,omitempty" json:"vars,omitempty" bson:"vars,omitempty" gorm:"type:JSON;serializer:json"` + Status DagStatus `yaml:"status,omitempty" json:"status,omitempty" bson:"status,omitempty" gorm:"type:enum('normal', 'stopped');not null;"` + Tasks DagTasks `yaml:"tasks,omitempty" json:"tasks,omitempty" bson:"tasks,omitempty" gorm:"type:JSON;not null;serializer:json"` } +type DagTasks []Task + // SpecifiedVar type SpecifiedVar struct { Name string @@ -85,14 +87,15 @@ const ( // DagInstance type DagInstance struct { BaseInfo `bson:"inline"` - DagID string `json:"dagId,omitempty" bson:"dagId,omitempty"` - Trigger Trigger `json:"trigger,omitempty" bson:"trigger,omitempty"` - Worker string `json:"worker,omitempty" bson:"worker,omitempty"` - Vars DagInstanceVars `json:"vars,omitempty" bson:"vars,omitempty"` - ShareData *ShareData `json:"shareData,omitempty" bson:"shareData,omitempty"` - Status DagInstanceStatus `json:"status,omitempty" bson:"status,omitempty"` - Reason string `json:"reason,omitempty" bson:"reason,omitempty"` - Cmd *Command `json:"cmd,omitempty" bson:"cmd,omitempty"` + DagID string `json:"dagId,omitempty" bson:"dagId,omitempty" gorm:"type:VARCHAR(256);not null"` + Trigger Trigger `json:"trigger,omitempty" bson:"trigger,omitempty" gorm:"type:enum('manually', 'cron');not null;"` + Worker string `json:"worker,omitempty" bson:"worker,omitempty" gorm:"type:VARCHAR(256)"` + Vars DagInstanceVars `json:"vars,omitempty" bson:"vars,omitempty" gorm:"type:JSON;serializer:json"` + ShareData *ShareData `json:"shareData,omitempty" bson:"shareData,omitempty" gorm:"type:JSON;serializer:json"` + Status DagInstanceStatus `json:"status,omitempty" bson:"status,omitempty" gorm:"type:enum('init', 'scheduled', 'running', 'blocked', 'failed', 'success', 'canceled');index;not null;"` + Reason string `json:"reason,omitempty" bson:"reason,omitempty" gorm:"type:TEXT"` + Cmd *Command `json:"cmd,omitempty" bson:"cmd,omitempty" gorm:"type:JSON;serializer:json"` + Tags []DagInstanceTag `json:"tags,omitempty" bson:"tags,omitempty" gorm:"-"` } var ( @@ -168,7 +171,7 @@ func (d *ShareData) Set(key string, val string) { type DagInstanceVars map[string]DagInstanceVar // Cancel a task, it is just set a command, command will execute by Parser -func (dagIns *DagInstance) Cancel(taskInsIds []string) error { +func (dagIns *DagInstance) CancelTask(taskInsIds []string) error { if dagIns.Status != DagInstanceStatusRunning { return fmt.Errorf("you can only cancel a running dag instance") } @@ -193,6 +196,7 @@ type DagInstanceLifecycleHook struct { BeforeRun DagInstanceHookFunc BeforeSuccess DagInstanceHookFunc BeforeFail DagInstanceHookFunc + BeforeCancel DagInstanceHookFunc BeforeBlock DagInstanceHookFunc BeforeRetry DagInstanceHookFunc BeforeContinue DagInstanceHookFunc @@ -238,6 +242,13 @@ func (dagIns *DagInstance) Fail(reason string) { dagIns.Status = DagInstanceStatusFailed } +// Cancel the dag instance +func (dagIns *DagInstance) Cancel(reason string) { + dagIns.Reason = reason + dagIns.executeHook(HookDagInstance.BeforeCancel) + dagIns.Status = DagInstanceStatusCanceled +} + // Block the dag instance func (dagIns *DagInstance) Block(reason string) { dagIns.executeHook(HookDagInstance.BeforeBlock) @@ -281,7 +292,7 @@ func (dagIns *DagInstance) executeHook(hookFunc DagInstanceHookFunc) { // CanChange indicate if the dag instance can modify status func (dagIns *DagInstance) CanModifyStatus() bool { - return dagIns.Status != DagInstanceStatusFailed + return dagIns.Status != DagInstanceStatusCanceled && dagIns.Status != DagInstanceStatusFailed } // Render variables @@ -321,6 +332,7 @@ const ( DagInstanceStatusBlocked DagInstanceStatus = "blocked" DagInstanceStatusFailed DagInstanceStatus = "failed" DagInstanceStatusSuccess DagInstanceStatus = "success" + DagInstanceStatusCanceled DagInstanceStatus = "canceled" ) // Trigger @@ -330,3 +342,24 @@ const ( TriggerManually Trigger = "manually" TriggerCron Trigger = "cron" ) + +// DagInstanceTag +type DagInstanceTag struct { + BaseInfo `bson:"inline"` + DagInsId string `json:"dagInsId,omitempty" bson:"dagInsId,omitempty" gorm:"type:VARCHAR(256);not null;uniqueIndex:idx_dag_instance_tags_dag_ins_id_key,priority:1"` + Key string `json:"key,omitempty" bson:"key,omitempty" gorm:"type:VARCHAR(256);not null;uniqueIndex:idx_dag_instance_tags_dag_ins_id_key,priority:10;index:idx_dag_instance_tags_key_value,priority:1"` + Value string `json:"value,omitempty" bson:"value,omitempty" gorm:"type:VARCHAR(256);not null;index:idx_dag_instance_tags_key_value,priority:10"` +} + +func NewDagInstanceTags(tags map[string]string) []DagInstanceTag { + var result []DagInstanceTag + if len(tags) == 0 { + return result + } + for k, v := range tags { + result = append(result, DagInstanceTag{ + Key: k, Value: v, + }) + } + return result +} diff --git a/pkg/entity/task.go b/pkg/entity/task.go index 61ed053..c6c3e13 100644 --- a/pkg/entity/task.go +++ b/pkg/entity/task.go @@ -134,25 +134,24 @@ func isStrInArray(str string, arr []string) bool { type TaskInstance struct { BaseInfo `bson:"inline"` // Task's Id it should be unique in a dag instance - TaskID string `json:"taskId,omitempty" bson:"taskId,omitempty"` - DagInsID string `json:"dagInsId,omitempty" bson:"dagInsId,omitempty"` - Name string `json:"name,omitempty" bson:"name,omitempty"` - DependOn []string `json:"dependOn,omitempty" bson:"dependOn,omitempty"` - ActionName string `json:"actionName,omitempty" bson:"actionName,omitempty"` - TimeoutSecs int `json:"timeoutSecs" bson:"timeoutSecs"` - Params map[string]interface{} `json:"params,omitempty" bson:"params,omitempty"` - Traces []TraceInfo `json:"traces,omitempty" bson:"traces,omitempty"` - Status TaskInstanceStatus `json:"status,omitempty" bson:"status,omitempty"` - Reason string `json:"reason,omitempty" bson:"reason,omitempty"` - PreChecks PreChecks `json:"preChecks,omitempty" bson:"preChecks,omitempty"` - + TaskID string `json:"taskId,omitempty" bson:"taskId,omitempty" gorm:"type:VARCHAR(256);not null"` + DagInsID string `json:"dagInsId,omitempty" bson:"dagInsId,omitempty" gorm:"type:VARCHAR(256);not null;index"` + Name string `json:"name,omitempty" bson:"name,omitempty" gorm:"-"` + DependOn []string `json:"dependOn,omitempty" bson:"dependOn,omitempty" gorm:"type:JSON;serializer:json"` + ActionName string `json:"actionName,omitempty" bson:"actionName,omitempty" gorm:"type:VARCHAR(256);not null"` + TimeoutSecs int `json:"timeoutSecs" bson:"timeoutSecs" gorm:"type:bigint(20) unsigned"` + Params map[string]interface{} `json:"params,omitempty" bson:"params,omitempty" gorm:"type:JSON;serializer:json"` + Traces []TraceInfo `json:"traces,omitempty" bson:"traces,omitempty" gorm:"type:JSON;serializer:json"` + Status TaskInstanceStatus `json:"status,omitempty" bson:"status,omitempty" gorm:"type:enum('init', 'canceled', 'running', 'ending', 'failed', 'retrying', 'success', 'blocked', 'skipped');index;not null;"` + Reason string `json:"reason,omitempty" bson:"reason,omitempty" gorm:"type:TEXT"` + PreChecks PreChecks `json:"preChecks,omitempty" bson:"preChecks,omitempty" gorm:"type:JSON;serializer:json"` // used to save changes - Patch func(*TaskInstance) error `json:"-" bson:"-"` - Context run.ExecuteContext `json:"-" bson:"-"` - RelatedDagInstance *DagInstance `json:"-" bson:"-"` + Patch func(*TaskInstance) error `json:"-" bson:"-" gorm:"-"` + Context run.ExecuteContext `json:"-" bson:"-" gorm:"-"` + RelatedDagInstance *DagInstance `json:"-" bson:"-" gorm:"-"` // it used to buffer traces, and persist when status changed - bufTraces []TraceInfo + bufTraces []TraceInfo `gorm:"-"` } // TraceInfo diff --git a/pkg/exporter/collector.go b/pkg/exporter/collector.go index b9ff4da..b060287 100644 --- a/pkg/exporter/collector.go +++ b/pkg/exporter/collector.go @@ -3,9 +3,11 @@ package exporter import ( "context" "net/http" + "sync" "sync/atomic" "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/shiningrush/fastflow/pkg/entity" "github.com/shiningrush/fastflow/pkg/event" @@ -24,6 +26,11 @@ var ( "The count of already failed task.", []string{"worker_key"}, nil, ) + failedDagInsDesc = prometheus.NewDesc( + "fastflow_executor_dag_ins_failed", + "The count of already failed task.", + []string{"worker_key", "dag_ins_id", "business_type", "business_action", "business_id"}, nil, + ) successTaskCountDesc = prometheus.NewDesc( "fastflow_executor_task_success_total", "The count of already failed task.", @@ -59,15 +66,24 @@ var ( // ExecutorCollector type ExecutorCollector struct { - RunningTaskCount int64 - SuccessTaskCount uint64 - FailedTaskCount uint64 - CompletedTaskCount uint64 + rwMutex sync.RWMutex + + RunningTaskCount int64 + SuccessTaskCount uint64 + FailedTaskCount uint64 + FailedTaskDagInsInfos map[string]DagInsInfo + CompletedTaskCount uint64 ParseElapsedMs int64 ParseFailedCount int64 } +type DagInsInfo struct { + BusinessType string + BusinessAction string + BusinessID string +} + // Topic is goevent's topic func (c *ExecutorCollector) Topic() []string { return []string{event.KeyTaskBegin, event.KeyTaskCompleted, event.KeyParseScheduleDagInsCompleted} @@ -87,6 +103,7 @@ func (c *ExecutorCollector) Handle(cxt context.Context, e goevent.Event) { switch completeEvent.TaskIns.Status { case entity.TaskInstanceStatusFailed: atomic.AddUint64(&c.FailedTaskCount, 1) + c.cacheFailedDagIns(completeEvent) case entity.TaskInstanceStatusSuccess: atomic.AddUint64(&c.SuccessTaskCount, 1) } @@ -100,9 +117,37 @@ func (c *ExecutorCollector) Handle(cxt context.Context, e goevent.Event) { } } +func (c *ExecutorCollector) cacheFailedDagIns(completeEvent *event.TaskCompleted) { + c.rwMutex.Lock() + if len(c.FailedTaskDagInsInfos) >= 500 { + return + } + if c.FailedTaskDagInsInfos == nil { + c.FailedTaskDagInsInfos = map[string]DagInsInfo{} + } + dagInsInfo := DagInsInfo{} + if completeEvent.TaskIns.RelatedDagInstance != nil { + tags := completeEvent.TaskIns.RelatedDagInstance.Tags + for _, tag := range tags { + if tag.Key == "business_type" { + dagInsInfo.BusinessType = tag.Value + } + if tag.Key == "business_action" { + dagInsInfo.BusinessAction = tag.Value + } + if tag.Key == "business_id" { + dagInsInfo.BusinessID = tag.Value + } + } + c.FailedTaskDagInsInfos[completeEvent.TaskIns.RelatedDagInstance.ID] = dagInsInfo + } + c.rwMutex.Unlock() +} + // Describe func (c *ExecutorCollector) Describe(ch chan<- *prometheus.Desc) { prometheus.DescribeByCollect(c, ch) + ch <- failedDagInsDesc } // Collect @@ -119,6 +164,7 @@ func (c *ExecutorCollector) Collect(ch chan<- prometheus.Metric) { float64(c.CompletedTaskCount), mod.GetKeeper().WorkerKey(), ) + c.pushFailedDagInsInfo(ch) ch <- prometheus.MustNewConstMetric( failedTaskCountDesc, prometheus.CounterValue, @@ -146,6 +192,28 @@ func (c *ExecutorCollector) Collect(ch chan<- prometheus.Metric) { ) } +func (c *ExecutorCollector) pushFailedDagInsInfo(ch chan<- prometheus.Metric) { + c.rwMutex.Lock() + tempMap := map[string]DagInsInfo{} + for k, v := range c.FailedTaskDagInsInfos { + tempMap[k] = v + } + c.FailedTaskDagInsInfos = map[string]DagInsInfo{} + c.rwMutex.Unlock() + for dagInsID, info := range tempMap { + ch <- prometheus.MustNewConstMetric( + failedDagInsDesc, + prometheus.GaugeValue, + float64(1), + mod.GetKeeper().WorkerKey(), + dagInsID, + info.BusinessType, + info.BusinessAction, + info.BusinessID, + ) + } +} + // ExecutorCollector type LeaderCollector struct { DispatchElapsedMs int64 @@ -209,8 +277,8 @@ func HttpHandler() http.Handler { execCollector, leaderCollector, // Add the standard process and Go metrics to the custom registry. - prometheus.NewProcessCollector(prometheus.ProcessCollectorOpts{}), - prometheus.NewGoCollector(), + collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}), + collectors.NewGoCollector(), ) return promhttp.HandlerFor(reg, promhttp.HandlerOpts{}) diff --git a/pkg/mod/commander.go b/pkg/mod/commander.go index 5fc34c1..d07ec0e 100644 --- a/pkg/mod/commander.go +++ b/pkg/mod/commander.go @@ -34,6 +34,25 @@ func (c *DefCommander) RunDag(dagId string, specVars map[string]string) (*entity return dagIns, nil } +// RunDagWithTags +func (c *DefCommander) RunDagWithTags(dagId string, specVars map[string]string, tags map[string]string) (*entity.DagInstance, error) { + dag, err := GetStore().GetDag(dagId) + if err != nil { + return nil, err + } + + dagIns, err := dag.Run(entity.TriggerManually, specVars) + if err != nil { + return nil, err + } + + dagIns.Tags = entity.NewDagInstanceTags(tags) + if err := GetStore().CreateDagIns(dagIns); err != nil { + return nil, err + } + return dagIns, nil +} + // RetryDagIns func (c *DefCommander) RetryDagIns(dagInsId string, ops ...CommandOptSetter) error { return c.autoLoopDagTasks( @@ -43,6 +62,33 @@ func (c *DefCommander) RetryDagIns(dagInsId string, ops ...CommandOptSetter) err ops...) } +// CancelDagIns +func (c *DefCommander) CancelDagIns(dagInsId string, ops ...CommandOptSetter) error { + taskIns, err := GetStore().ListTaskInstance(&ListTaskInstanceInput{ + DagInsID: dagInsId, + Status: []entity.TaskInstanceStatus{ + entity.TaskInstanceStatusInit, + entity.TaskInstanceStatusRunning, + entity.TaskInstanceStatusEnding, + entity.TaskInstanceStatusRetrying, + }, + }) + if err != nil { + return err + } + + if len(taskIns) == 0 { + return fmt.Errorf("no task instance") + } + + var taskIds []string + for _, t := range taskIns { + taskIds = append(taskIds, t.ID) + } + + return c.CancelTask(taskIds, ops...) +} + // RetryTask func (c *DefCommander) RetryTask(taskInsIds []string, ops ...CommandOptSetter) error { opt := initOption(ops) @@ -65,7 +111,7 @@ func (c *DefCommander) CancelTask(taskInsIds []string, ops ...CommandOptSetter) if !isWorkerAlive { return fmt.Errorf("worker is not healthy, you can not cancel it") } - return dagIns.Cancel(taskInsIds) + return dagIns.CancelTask(taskInsIds) }, opt) } diff --git a/pkg/mod/commander_test.go b/pkg/mod/commander_test.go index b8013c6..2b09d01 100644 --- a/pkg/mod/commander_test.go +++ b/pkg/mod/commander_test.go @@ -269,7 +269,7 @@ func TestDefCommander_OpTask(t *testing.T) { caseDesc: "unhealthy worker", giveTaskInsID: []string{"test task"}, giveIsAlive: false, - giveAliveNodes: []string{"1", "2"}, + giveAliveNodes: []string{"2"}, wantUpdateDagIns: &entity.DagInstance{ Worker: "2", Cmd: &entity.Command{ diff --git a/pkg/mod/dispatcher.go b/pkg/mod/dispatcher.go index 3a70f7a..d37fb29 100644 --- a/pkg/mod/dispatcher.go +++ b/pkg/mod/dispatcher.go @@ -1,6 +1,7 @@ package mod import ( + "math/rand" "sync" "time" @@ -78,7 +79,8 @@ func (d *DefDispatcher) Do() error { for i := range dagIns { dagIns[i].Status = entity.DagInstanceStatusScheduled - dagIns[i].Worker = nodes[i%len(nodes)] + nodeNumber := rand.Intn(len(nodes)) + dagIns[i].Worker = nodes[nodeNumber] } if err := GetStore().BatchUpdateDagIns(dagIns); err != nil { diff --git a/pkg/mod/dispatcher_test.go b/pkg/mod/dispatcher_test.go index e47c06b..9e4392d 100644 --- a/pkg/mod/dispatcher_test.go +++ b/pkg/mod/dispatcher_test.go @@ -12,7 +12,8 @@ import ( "github.com/stretchr/testify/mock" ) -func TestDefDispatcher_Do(t *testing.T) { +// ignore beacuse of using random dispatch +func IgnoreTestDefDispatcher_Do(t *testing.T) { tests := []struct { caseDesc string giveListRet []*entity.DagInstance diff --git a/pkg/mod/mod_define.go b/pkg/mod/mod_define.go index 1c56ddf..a2e8836 100644 --- a/pkg/mod/mod_define.go +++ b/pkg/mod/mod_define.go @@ -20,8 +20,10 @@ var ( // Commander used to execute command type Commander interface { RunDag(dagId string, specVar map[string]string) (*entity.DagInstance, error) + RunDagWithTags(dagId string, specVar map[string]string, tags map[string]string) (*entity.DagInstance, error) RetryDagIns(dagInsId string, ops ...CommandOptSetter) error RetryTask(taskInsIds []string, ops ...CommandOptSetter) error + CancelDagIns(dagInsId string, ops ...CommandOptSetter) error CancelTask(taskInsIds []string, ops ...CommandOptSetter) error ContinueDagIns(dagInsId string, ops ...CommandOptSetter) error ContinueTask(taskInsIds []string, ops ...CommandOptSetter) error @@ -133,8 +135,10 @@ type ListDagInput struct { type ListDagInstanceInput struct { Worker string DagID string + IDs []string UpdatedEnd int64 Status []entity.DagInstanceStatus + Tags map[string]string HasCmd bool Limit int64 Offset int64 diff --git a/pkg/mod/parser.go b/pkg/mod/parser.go index 53b33d3..0723edd 100644 --- a/pkg/mod/parser.go +++ b/pkg/mod/parser.go @@ -182,6 +182,8 @@ func (p *DefParser) InitialDagIns(dagIns *entity.DagInstance) { tree.DagIns.Block(fmt.Sprintf("initial blocked because task ins[%s]", taskInsId)) case TreeStatusFailed: tree.DagIns.Fail(fmt.Sprintf("initial failed because task ins[%s]", taskInsId)) + case TreeStatusCanceled: + tree.DagIns.Cancel(fmt.Sprintf("initial canceled because task ins[%s]", taskInsId)) default: log.Warn("initial a dag which has no executable tasks", utils.LogKeyDagInsID, dagIns.ID) @@ -228,7 +230,9 @@ func (p *DefParser) executeNext(taskIns *entity.TaskInstance) error { case TreeStatusRunning: return nil case TreeStatusFailed: - tree.DagIns.Fail(fmt.Sprintf("task[%s] failed or canceled", taskId)) + tree.DagIns.Fail(fmt.Sprintf("task[%s] failed", taskId)) + case TreeStatusCanceled: + tree.DagIns.Cancel(fmt.Sprintf("task[%s] canceled", taskId)) case TreeStatusBlocked: tree.DagIns.Block(fmt.Sprintf("task[%s] blocked", taskId)) case TreeStatusSuccess: @@ -294,7 +298,7 @@ func (p *DefParser) cancelChildTasks(tree *TaskTree, ids []string) error { if !tree.DagIns.CanModifyStatus() { return nil } - tree.DagIns.Fail(fmt.Sprintf("task instance[%s] canceled", strings.Join(ids, ","))) + tree.DagIns.Cancel(fmt.Sprintf("task instance[%s] canceled", strings.Join(ids, ","))) return GetStore().PatchDagIns(tree.DagIns) } diff --git a/pkg/mod/parser_test.go b/pkg/mod/parser_test.go index 7ebb52a..d83ab8e 100644 --- a/pkg/mod/parser_test.go +++ b/pkg/mod/parser_test.go @@ -48,7 +48,7 @@ func TestDefParser_cancelChildTask(t *testing.T) { {BaseInfo: entity.BaseInfo{ID: "task5"}, Status: entity.TaskInstanceStatusCanceled, Reason: ReasonParentCancel}, }, wantPatchDagCalled: true, - wantPatchDagIns: &entity.DagInstance{Status: entity.DagInstanceStatusFailed, Reason: "task instance[task4,task5] canceled"}, + wantPatchDagIns: &entity.DagInstance{Status: entity.DagInstanceStatusCanceled, Reason: "task instance[task4,task5] canceled"}, wantRoot: &TaskNode{ TaskInsID: virtualTaskRootID, Status: entity.TaskInstanceStatusSuccess, @@ -88,7 +88,7 @@ func TestDefParser_cancelChildTask(t *testing.T) { {BaseInfo: entity.BaseInfo{ID: "task2"}, Status: entity.TaskInstanceStatusCanceled, Reason: ReasonParentCancel}, }, wantPatchDagCalled: true, - wantPatchDagIns: &entity.DagInstance{Status: entity.DagInstanceStatusFailed, Reason: "task instance[task2] canceled"}, + wantPatchDagIns: &entity.DagInstance{Status: entity.DagInstanceStatusCanceled, Reason: "task instance[task2] canceled"}, wantRoot: &TaskNode{ TaskInsID: virtualTaskRootID, Status: entity.TaskInstanceStatusSuccess, diff --git a/pkg/mod/tasktree.go b/pkg/mod/tasktree.go index dcf9eba..764c573 100644 --- a/pkg/mod/tasktree.go +++ b/pkg/mod/tasktree.go @@ -130,10 +130,11 @@ type TaskNode struct { type TreeStatus string const ( - TreeStatusRunning TreeStatus = "running" - TreeStatusSuccess TreeStatus = "success" - TreeStatusFailed TreeStatus = "failed" - TreeStatusBlocked TreeStatus = "blocked" + TreeStatusRunning TreeStatus = "running" + TreeStatusSuccess TreeStatus = "success" + TreeStatusFailed TreeStatus = "failed" + TreeStatusCanceled TreeStatus = "canceled" + TreeStatusBlocked TreeStatus = "blocked" ) // HasCycle @@ -185,10 +186,14 @@ func bfsCheckCycle(waitQueue []*TaskNode, visited map[string]struct{}, incomplet func (t *TaskNode) ComputeStatus() (status TreeStatus, srcTaskInsId string) { walkNode(t, func(node *TaskNode) bool { switch node.Status { - case entity.TaskInstanceStatusFailed, entity.TaskInstanceStatusCanceled: + case entity.TaskInstanceStatusFailed: status = TreeStatusFailed srcTaskInsId = node.TaskInsID return true + case entity.TaskInstanceStatusCanceled: + status = TreeStatusCanceled + srcTaskInsId = node.TaskInsID + return true case entity.TaskInstanceStatusBlocked: status = TreeStatusBlocked srcTaskInsId = node.TaskInsID diff --git a/store/mongo/mongo.go b/store/mongo/mongo.go index 3beaa6a..004ee4c 100644 --- a/store/mongo/mongo.go +++ b/store/mongo/mongo.go @@ -27,7 +27,7 @@ type StoreOption struct { Database string // Timeout access mongo timeout.default 5s Timeout time.Duration - // the prefix will append to the database + // the prefix will append to the schema name Prefix string } diff --git a/store/mysql/mysql.go b/store/mysql/mysql.go new file mode 100644 index 0000000..bcc62b0 --- /dev/null +++ b/store/mysql/mysql.go @@ -0,0 +1,714 @@ +package mysql + +import ( + "context" + "encoding/json" + "fmt" + "github.com/shiningrush/goext/datax" + "sync" + "time" + + "github.com/go-sql-driver/mysql" + "github.com/shiningrush/fastflow/pkg/entity" + "github.com/shiningrush/fastflow/pkg/event" + "github.com/shiningrush/fastflow/pkg/log" + "github.com/shiningrush/fastflow/pkg/mod" + "github.com/shiningrush/fastflow/pkg/utils" + "github.com/shiningrush/fastflow/pkg/utils/data" + "github.com/shiningrush/goevent" + gormDriver "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +// StoreOption +type StoreOption struct { + MySQLConfig *mysql.Config + GormConfig *gorm.Config + PoolConfig *ConnectionPoolOption + + // business timeout + Timeout time.Duration + MigrationSwitch bool + BatchUpdateConfig *BatchUpdateOption +} + +type ConnectionPoolOption struct { + MaxIdleConns int + MaxOpenConns int + ConnMaxLifetime time.Duration +} + +type BatchUpdateOption struct { + ConcurrencyCount int + Timeout time.Duration +} + +// Store +type Store struct { + opt *StoreOption + + gormDB *gorm.DB +} + +func (s *Store) Close() { + sqlDB, err := s.gormDB.DB() + if err != nil { + log.Errorf("get store client failed: %s", err) + } + + if err = sqlDB.Close(); err != nil { + log.Errorf("close store client failed: %s", err) + } +} + +func (s *Store) CreateDag(dag *entity.Dag) error { + // check task's connection + _, err := mod.BuildRootNode(mod.MapTasksToGetter(dag.Tasks)) + if err != nil { + return err + } + + err = s.transaction(func(tx *gorm.DB) error { + dag.BaseInfo.Initial() + return tx.Create(dag).Error + }) + if err != nil { + return fmt.Errorf("insert dag failed: %w", err) + } + return nil +} + +func (s *Store) transaction(cb func(tx *gorm.DB) error) error { + ctx, cancel := context.WithTimeout(context.TODO(), s.opt.Timeout) + defer cancel() + + db := s.gormDB.WithContext(ctx) + return db.Transaction(func(tx *gorm.DB) error { + return cb(tx) + }) +} + +func (s *Store) CreateDagIns(dagIns *entity.DagInstance) error { + err := s.transaction(func(tx *gorm.DB) error { + dagIns.BaseInfo.Initial() + if dagIns.ShareData == nil { + dagIns.ShareData = &entity.ShareData{} + } + if dagIns.ShareData.Dict == nil { + dagIns.ShareData.Dict = map[string]string{} + } + err := tx.Create(dagIns).Error + if err != nil { + return err + } + if len(dagIns.Tags) == 0 { + return nil + } + for _, tag := range dagIns.Tags { + tag.BaseInfo.Initial() + tag.DagInsId = dagIns.ID + err := tx.Create(tag).Error + if err != nil { + return err + } + } + return nil + }) + if err != nil { + return fmt.Errorf("insert dagIns failed: %w", err) + } + return nil +} + +func (s *Store) BatchCreatTaskIns(taskIns []*entity.TaskInstance) error { + err := s.transaction(func(tx *gorm.DB) error { + for _, task := range taskIns { + task.BaseInfo.Initial() + } + return tx.CreateInBatches(taskIns, 100).Error + }) + if err != nil { + return fmt.Errorf("batch insert taskIns failed: %w", err) + } + return nil +} + +func (s *Store) PatchTaskIns(taskIns *entity.TaskInstance) error { + if taskIns.ID == "" { + return fmt.Errorf("id cannot be empty") + } + + updateIns := &entity.TaskInstance{} + updateIns.BaseInfo = taskIns.BaseInfo + updateIns.BaseInfo.Update() + if taskIns.Status != "" { + updateIns.Status = taskIns.Status + } + if taskIns.Reason != "" { + updateIns.Reason = taskIns.Reason + } + if len(taskIns.Traces) > 0 { + updateIns.Traces = taskIns.Traces + } + + err := s.transaction(func(tx *gorm.DB) error { + return tx.Model(taskIns).Updates(updateIns).Error + }) + if err != nil { + return fmt.Errorf("patch taskIns failed: %w", err) + } + return nil +} + +func (s *Store) PatchDagIns(dagIns *entity.DagInstance, mustsPatchFields ...string) error { + if dagIns.ID == "" { + return fmt.Errorf("id cannot be empty") + } + + updateIns := &entity.DagInstance{} + updateIns.BaseInfo.Update() + updateFields := []string{"UpdatedAt"} + if dagIns.ShareData != nil { + updateFields = append(updateFields, "ShareData") + updateIns.ShareData = dagIns.ShareData + } + if dagIns.Status != "" { + updateFields = append(updateFields, "Status") + updateIns.Status = dagIns.Status + } + if utils.StringsContain(mustsPatchFields, "Cmd") || dagIns.Cmd != nil { + updateFields = append(updateFields, "Cmd") + updateIns.Cmd = dagIns.Cmd + } + if dagIns.Worker != "" { + updateFields = append(updateFields, "Worker") + updateIns.Worker = dagIns.Worker + } + if utils.StringsContain(mustsPatchFields, "Reason") || dagIns.Reason != "" { + updateFields = append(updateFields, "Reason") + updateIns.Reason = dagIns.Reason + } + + err := s.transaction(func(tx *gorm.DB) error { + return tx.Model(dagIns).Select(updateFields).Updates(updateIns).Error + }) + if err != nil { + return fmt.Errorf("patch dagIns failed: %w", err) + } + goevent.Publish(&event.DagInstancePatched{ + Payload: dagIns, + MustPatchFields: mustsPatchFields, + }) + return nil +} + +func (s *Store) UpdateDag(dag *entity.Dag) error { + // check task's connection + _, err := mod.BuildRootNode(mod.MapTasksToGetter(dag.Tasks)) + if err != nil { + return err + } + + err = s.transaction(func(tx *gorm.DB) error { + dag.BaseInfo.Update() + return tx.Model(dag).Select("*").Updates(dag).Error + }) + if err != nil { + return fmt.Errorf("update dag failed: %w", err) + } + return nil +} + +func (s *Store) UpdateDagIns(dagIns *entity.DagInstance) error { + err := s.transaction(func(tx *gorm.DB) error { + dagIns.BaseInfo.Update() + return tx.Updates(dagIns).Error + }) + if err != nil { + return fmt.Errorf("update dagIns failed: %w", err) + } + goevent.Publish(&event.DagInstanceUpdated{Payload: dagIns}) + return nil +} + +func (s *Store) UpdateTaskIns(taskIns *entity.TaskInstance) error { + err := s.transaction(func(tx *gorm.DB) error { + taskIns.BaseInfo.Update() + return tx.Updates(taskIns).Error + }) + if err != nil { + return fmt.Errorf("update taskIns failed: %w", err) + } + return nil +} + +func (s *Store) BatchUpdateDagIns(dagIns []*entity.DagInstance) (err error) { + var anySlice []any + for _, dag := range dagIns { + anySlice = append(anySlice, dag) + } + + return s.batchUpdate(anySlice, func(tx *gorm.DB, en any) error { + dagInstance, ok := en.(*entity.DagInstance) + if !ok { + return fmt.Errorf("invalid entity type: %T", en) + } + dagInstance.BaseInfo.Update() + return tx.Updates(dagInstance).Error + }) +} + +func (s *Store) batchUpdate(entitys []any, cb func(tx *gorm.DB, en any) error) (err error) { + ctx, cancel := context.WithTimeout(context.TODO(), s.opt.BatchUpdateConfig.Timeout) + defer cancel() + + errs := &data.Errors{} + errChan := make(chan error) + defer func() { + close(errChan) + if errs.Len() > 0 { + err = errs + } + }() + + go func() { + for err := range errChan { + errs.Append(err) + } + }() + + entityChunks := Chunk(entitys, s.opt.BatchUpdateConfig.ConcurrencyCount) + var wg sync.WaitGroup + for _, entityChunk := range entityChunks { + wg.Add(len(entityChunk)) + for i := range entityChunk { + go func(ctx context.Context, en any, ch chan error) { + db := s.gormDB.WithContext(ctx) + err := db.Transaction(func(tx *gorm.DB) error { + return cb(tx, en) + }) + if err != nil { + errChan <- fmt.Errorf("batch update entity failed: %w", err) + } + wg.Done() + }(ctx, entityChunk[i], errChan) + } + wg.Wait() + } + return +} + +func (s *Store) BatchUpdateTaskIns(taskIns []*entity.TaskInstance) error { + var anySlice []any + for _, task := range taskIns { + anySlice = append(anySlice, task) + } + + return s.batchUpdate(anySlice, func(tx *gorm.DB, en any) error { + taskInstance, ok := en.(*entity.TaskInstance) + if !ok { + return fmt.Errorf("invalid entity type: %T", en) + } + taskInstance.BaseInfo.Update() + return tx.Updates(taskInstance).Error + }) +} + +func (s *Store) GetTaskIns(taskInsId string) (*entity.TaskInstance, error) { + taskIns := &entity.TaskInstance{} + err := s.transaction(func(tx *gorm.DB) error { + return tx.Where("id = ?", taskInsId).First(taskIns).Error + }) + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil, fmt.Errorf("key[ %s ] not found: %w", taskInsId, data.ErrDataNotFound) + } + log.Errorf("get task instance %s failed: %s", taskInsId, err) + return nil, fmt.Errorf("get task instance failed: %w", err) + } + return taskIns, nil +} + +func (s *Store) GetDag(dagId string) (*entity.Dag, error) { + dag := &entity.Dag{} + err := s.transaction(func(tx *gorm.DB) error { + return tx.Where("id = ?", dagId).First(dag).Error + }) + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil, fmt.Errorf("key[ %s ] not found: %w", dagId, data.ErrDataNotFound) + } + log.Errorf("get dag %s failed: %s", dagId, err) + return nil, fmt.Errorf("get dag failed: %w", err) + } + return dag, nil +} + +func (s *Store) GetDagInstance(dagInsId string) (*entity.DagInstance, error) { + dagIns := &entity.DagInstance{} + err := s.transaction(func(tx *gorm.DB) error { + return tx.Where("id = ?", dagInsId).First(dagIns).Error + }) + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil, fmt.Errorf("key[ %s ] not found: %w", dagInsId, data.ErrDataNotFound) + } + log.Errorf("get dag instance %s failed: %s", dagInsId, err) + return nil, fmt.Errorf("get dag instance failed: %w", err) + } + return dagIns, nil +} + +func (s *Store) ListDagInstance(input *mod.ListDagInstanceInput) ([]*entity.DagInstance, error) { + if len(input.Tags) > 0 { + return s.ListDagInstanceWithFilterTags(input) + } + return s.ListDagInstanceWithoutFilterTags(input) +} + +func (s *Store) ListDagInstanceWithFilterTags(input *mod.ListDagInstanceInput) ([]*entity.DagInstance, error) { + var ret []*entity.DagInstanceTag + err := s.transaction(func(tx *gorm.DB) error { + var queryParams [][]interface{} + for k, v := range input.Tags { + queryParams = append(queryParams, []interface{}{k, v}) + } + return tx.Where("(`key`, `value`) IN ?", queryParams). + Select("dag_ins_id, count(*) as total"). + Group("dag_ins_id"). + Having("total = ?", len(input.Tags)). + Find(&ret).Error + }) + if err != nil { + log.Errorf("list dag instance tags input: %v, failed: %s", input, err) + return nil, err + } + var dagInsIds []string + for _, v := range ret { + dagInsIds = append(dagInsIds, v.DagInsId) + } + if len(input.IDs) == 0 { + input.IDs = dagInsIds + } else { + input.IDs = datax.Intersect(input.IDs, dagInsIds) + } + if len(input.IDs) == 0 { + return nil, nil + } + + return s.ListDagInstanceWithoutFilterTags(input) +} + +func (s *Store) ListDagInstanceWithoutFilterTags(input *mod.ListDagInstanceInput) ([]*entity.DagInstance, error) { + var ret []*entity.DagInstance + err := s.transaction(func(tx *gorm.DB) error { + if len(input.Status) > 0 { + tx = tx.Where("status in (?)", input.Status) + } + if len(input.IDs) > 0 { + tx = tx.Where("id in (?)", input.IDs) + } + if input.Worker != "" { + tx = tx.Where("worker = ?", input.Worker) + } + if input.UpdatedEnd > 0 { + tx = tx.Where("updated_at <= ?", input.UpdatedEnd) + } + if input.HasCmd { + tx = tx.Where("cmd is not null") + } + if len(input.DagID) > 0 { + tx = tx.Where("dag_id = ?", input.DagID) + } + + if input.Limit > 0 { + tx = tx.Limit(int(input.Limit)) + } + return tx.Find(&ret).Error + }) + if err != nil { + log.Errorf("list dag instance input: %v, failed: %s", input, err) + return nil, err + } + if len(ret) == 0 { + return ret, nil + } + + err = s.fillInstanceTags(ret) + if err != nil { + return nil, err + } + return ret, nil +} + +func (s *Store) fillInstanceTags(ret []*entity.DagInstance) error { + tagMap, err := s.queryInstanceTags(ret) + if err != nil { + return err + } + if len(tagMap) == 0 { + return nil + } + for _, dagInstance := range ret { + if tags, ok := tagMap[dagInstance.ID]; ok { + dagInstance.Tags = tags + } + } + return nil +} + +func (s *Store) queryInstanceTags(ret []*entity.DagInstance) (map[string][]entity.DagInstanceTag, error) { + var dagInsIDs []string + for _, instance := range ret { + dagInsIDs = append(dagInsIDs, instance.ID) + } + + var tags []*entity.DagInstanceTag + err := s.transaction(func(tx *gorm.DB) error { + tx = tx.Where("dag_ins_id in (?)", dagInsIDs) + return tx.Find(&tags).Error + }) + if err != nil { + log.Errorf("list dag instance tags error %s", err) + return nil, err + } + + if len(tags) == 0 { + return nil, nil + } + + tagMap := make(map[string][]entity.DagInstanceTag) + for i := range tags { + tagMap[tags[i].DagInsId] = append(tagMap[tags[i].DagInsId], *tags[i]) + } + return tagMap, nil +} + +func (s *Store) ListTaskInstance(input *mod.ListTaskInstanceInput) ([]*entity.TaskInstance, error) { + var ret []*entity.TaskInstance + err := s.transaction(func(tx *gorm.DB) error { + if len(input.IDs) > 0 { + tx = tx.Where("id in (?)", input.IDs) + } + if len(input.Status) > 0 { + tx = tx.Where("status in (?)", input.Status) + } + if input.Expired { + tx = tx.Where("(?) >= updated_at + timeout_secs", time.Now().Unix()-5) + } + if input.DagInsID != "" { + tx = tx.Where("dag_ins_id = ?", input.DagInsID) + } + if len(input.SelectField) > 0 { + tx = tx.Select(input.SelectField) + } + return tx.Find(&ret).Error + }) + if err != nil { + log.Errorf("list task instance input: %v, failed: %s", input, err) + return nil, err + } + return ret, nil +} + +// ListDag +func (s *Store) ListDag(input *mod.ListDagInput) ([]*entity.Dag, error) { + var ret []*entity.Dag + err := s.transaction(func(tx *gorm.DB) error { + return tx.Find(&ret).Error + }) + if err != nil { + log.Errorf("list dag input: %v, failed: %s", input, err) + return nil, err + } + return ret, nil +} + +// BatchDeleteDag +func (s *Store) BatchDeleteDag(ids []string) error { + err := s.transaction(func(tx *gorm.DB) error { + return tx.Delete(&entity.Dag{}, "id in (?)", ids).Error + }) + if err != nil { + log.Errorf("delete dag input: %v, failed: %s", ids, err) + return err + } + return nil +} + +// BatchDeleteDagIns +func (s *Store) BatchDeleteDagIns(ids []string) error { + err := s.transaction(func(tx *gorm.DB) error { + return tx.Delete(&entity.DagInstance{}, "id in (?)", ids).Error + }) + if err != nil { + log.Errorf("delete dag instance input: %v, failed: %s", ids, err) + return err + } + return nil +} + +// BatchDeleteTaskIns +func (s *Store) BatchDeleteTaskIns(ids []string) error { + err := s.transaction(func(tx *gorm.DB) error { + return tx.Delete(&entity.TaskInstance{}, "id in (?)", ids).Error + }) + if err != nil { + log.Errorf("delete task instance input: %v, failed: %s", ids, err) + return err + } + return nil +} + +func (s *Store) Marshal(obj interface{}) ([]byte, error) { + return json.Marshal(obj) +} + +func (s *Store) Unmarshal(bytes []byte, ptr interface{}) error { + return json.Unmarshal(bytes, ptr) +} + +// NewStore +func NewStore(option *StoreOption) *Store { + return &Store{ + opt: option, + } +} + +// Init store +func (s *Store) Init() error { + if err := s.readOpt(); err != nil { + return err + } + + db, err := gorm.Open(gormDriver.Open(s.opt.MySQLConfig.FormatDSN()), s.opt.GormConfig) + if err != nil { + return fmt.Errorf("connect to mysql occur error: %w", err) + } + + sqlDB, err := db.DB() + if err != nil { + return fmt.Errorf("get sqlDB error: %w", err) + } + + sqlDB.SetConnMaxLifetime(s.opt.PoolConfig.ConnMaxLifetime) + sqlDB.SetMaxIdleConns(s.opt.PoolConfig.MaxIdleConns) + sqlDB.SetMaxOpenConns(s.opt.PoolConfig.MaxOpenConns) + + if s.opt.MigrationSwitch { + err = db.AutoMigrate(&entity.Dag{}, &entity.DagInstance{}, &entity.DagInstanceTag{}, &entity.TaskInstance{}) + if err != nil { + return err + } + } + s.gormDB = db + return nil +} + +func (s *Store) readOpt() error { + err := s.readMySQLConfigOpt() + if err != nil { + return err + } + + s.readGormConfigOpt() + s.readPoolConfigOpt() + s.readBatchUpdateConfigOpt() + if s.opt.Timeout == 0 { + s.opt.Timeout = time.Second * 5 + } + + return nil +} + +func (s *Store) readBatchUpdateConfigOpt() { + if s.opt.BatchUpdateConfig == nil { + s.opt.BatchUpdateConfig = &BatchUpdateOption{} + } + if s.opt.BatchUpdateConfig.ConcurrencyCount == 0 { + s.opt.BatchUpdateConfig.ConcurrencyCount = 5 + } + if s.opt.BatchUpdateConfig.Timeout == 0 { + s.opt.BatchUpdateConfig.Timeout = time.Second * 40 + } +} + +func (s *Store) readGormConfigOpt() { + if s.opt.GormConfig == nil { + s.opt.GormConfig = &gorm.Config{} + } +} + +func (s *Store) readPoolConfigOpt() { + if s.opt.PoolConfig == nil { + s.opt.PoolConfig = &ConnectionPoolOption{} + } + if s.opt.PoolConfig.MaxOpenConns == 0 { + s.opt.PoolConfig.MaxOpenConns = 100 + } + if s.opt.PoolConfig.MaxIdleConns == 0 { + s.opt.PoolConfig.MaxIdleConns = 100 + } + if s.opt.PoolConfig.ConnMaxLifetime == 0 { + s.opt.PoolConfig.ConnMaxLifetime = time.Minute * 3 + } +} + +func (s *Store) readMySQLConfigOpt() error { + if s.opt.MySQLConfig == nil { + return fmt.Errorf("mysql config cannot be empty") + } + + if s.opt.MySQLConfig.Addr == "" { + return fmt.Errorf("addr cannot be empty") + } + + if s.opt.MySQLConfig.User == "" { + return fmt.Errorf("user cannot be empty") + } + + if s.opt.MySQLConfig.Passwd == "" { + return fmt.Errorf("passwd cannot be empty") + } + + if s.opt.MySQLConfig.DBName == "" { + return fmt.Errorf("dbName cannot be empty") + } + + if s.opt.MySQLConfig.Collation == "" { + s.opt.MySQLConfig.Collation = "utf8mb4_unicode_ci" + } + + if s.opt.MySQLConfig.Loc == nil { + s.opt.MySQLConfig.Loc = time.UTC + } + + if s.opt.MySQLConfig.MaxAllowedPacket == 0 { + s.opt.MySQLConfig.MaxAllowedPacket = mysql.NewConfig().MaxAllowedPacket + } + + s.opt.MySQLConfig.Net = "tcp" + s.opt.MySQLConfig.AllowNativePasswords = true + s.opt.MySQLConfig.CheckConnLiveness = true + s.opt.MySQLConfig.ParseTime = true + + if s.opt.MySQLConfig.Timeout == 0 { + s.opt.MySQLConfig.Timeout = 5 * time.Second + } + + if s.opt.MySQLConfig.ReadTimeout == 0 { + s.opt.MySQLConfig.ReadTimeout = 30 * time.Second + } + + if s.opt.MySQLConfig.WriteTimeout == 0 { + s.opt.MySQLConfig.WriteTimeout = 30 * time.Second + } + + if s.opt.MySQLConfig.Params == nil { + s.opt.MySQLConfig.Params = map[string]string{} + } + if _, ok := s.opt.MySQLConfig.Params["charset"]; !ok { + s.opt.MySQLConfig.Params["charset"] = "utf8mb4" + } + return nil +} diff --git a/store/mysql/mysql_integ_test.go b/store/mysql/mysql_integ_test.go new file mode 100644 index 0000000..b89cf8c --- /dev/null +++ b/store/mysql/mysql_integ_test.go @@ -0,0 +1,238 @@ +//go:build integration +// +build integration + +package mysql_test + +import ( + "fmt" + "testing" + "time" + + "github.com/go-sql-driver/mysql" + "github.com/shiningrush/fastflow/pkg/entity" + "github.com/shiningrush/fastflow/pkg/mod" + mysqlStore "github.com/shiningrush/fastflow/store/mysql" + "github.com/stretchr/testify/assert" +) + +const ( + addr = "127.0.0.1:55000" + user = "root" + passwd = "mysqlpw" + dbName = "fastflow" +) + +func TestStore_Dag(t *testing.T) { + // init store + s := mysqlStore.NewStore(&mysqlStore.StoreOption{ + MySQLConfig: &mysql.Config{ + Addr: addr, + User: user, + Passwd: passwd, + DBName: dbName, + }, + MigrationSwitch: true, + }) + err := s.Init() + assert.NoError(t, err) + giveDag := []*entity.Dag{ + { + BaseInfo: entity.BaseInfo{ + ID: "test1", + }, + Tasks: []entity.Task{{ID: "test"}}, + Status: "normal", + }, + { + BaseInfo: entity.BaseInfo{ + ID: "test2", + }, + Tasks: []entity.Task{{ID: "test"}}, + Status: "normal", + }, + } + // create + for i := range giveDag { + err := s.CreateDag(giveDag[i]) + assert.NoError(t, err) + } + + ret, err := s.ListDag(nil) + assert.NoError(t, err) + time.Sleep(time.Second) + // check and update + for i := range ret { + assert.NotEqual(t, "", ret[i].ID) + assert.Greater(t, ret[i].CreatedAt, int64(0)) + assert.Greater(t, ret[i].UpdatedAt, int64(0)) + ret[i].Name = fmt.Sprintf("name-%d", i) + ret[i].Desc = fmt.Sprintf("desc-%d", i) + ret[i].Cron = fmt.Sprintf("cron-%d", i) + ret[i].Vars = entity.DagVars{ + "var1": {}, "var2": {}, + } + ret[i].Tasks = []entity.Task{ + {ID: "task1", Name: "task1"}, {ID: "task2", Name: "task2"}, + } + + err = s.UpdateDag(ret[i]) + assert.NoError(t, err) + } + + ret, err = s.ListDag(nil) + assert.NoError(t, err) + for i := range ret { + assert.NotEmpty(t, ret[i].ID) + assert.NotEmpty(t, ret[i].Name) + assert.NotEmpty(t, ret[i].Desc) + assert.NotEqual(t, ret[i].CreatedAt, ret[i].UpdatedAt) + } + + // delete + err = s.BatchDeleteDag([]string{"test1", "test2"}) + assert.NoError(t, err) + ret, err = s.ListDag(nil) + assert.NoError(t, err) + assert.Equal(t, 0, len(ret)) +} + +func TestStore_DagIns(t *testing.T) { + // init store + s := mysqlStore.NewStore(&mysqlStore.StoreOption{ + MySQLConfig: &mysql.Config{ + Addr: addr, + User: user, + Passwd: passwd, + DBName: dbName, + }, + MigrationSwitch: true, + }) + + err := s.Init() + assert.NoError(t, err) + giveDagIns := []*entity.DagInstance{ + { + BaseInfo: entity.BaseInfo{ + ID: "test1", + }, + Status: "init", + Trigger: "manually", + }, + { + BaseInfo: entity.BaseInfo{ + ID: "test2", + }, + Status: "init", + Trigger: "manually", + }, + } + // create + for i := range giveDagIns { + err := s.CreateDagIns(giveDagIns[i]) + assert.NoError(t, err) + } + + ret, err := s.ListDagInstance(&mod.ListDagInstanceInput{}) + assert.NoError(t, err) + time.Sleep(time.Second) + // check and update + for i := range ret { + assert.NotEqual(t, "", ret[i].ID) + assert.Greater(t, ret[i].CreatedAt, int64(0)) + assert.Greater(t, ret[i].UpdatedAt, int64(0)) + ret[i].Worker = fmt.Sprintf("worker-%d", i) + ret[i].DagID = fmt.Sprintf("dagid-%d", i) + ret[i].ShareData = &entity.ShareData{Dict: map[string]string{ + "test": "gg", + }} + ret[i].Vars = entity.DagInstanceVars{ + "var1": {}, "var2": {}, + } + + err = s.UpdateDagIns(ret[i]) + assert.NoError(t, err) + } + + ret, err = s.ListDagInstance(&mod.ListDagInstanceInput{}) + assert.NoError(t, err) + for i := range ret { + assert.NotEmpty(t, ret[i].ID) + assert.NotEmpty(t, ret[i].Worker) + assert.NotEmpty(t, ret[i].DagID) + assert.NotNil(t, ret[i].ShareData) + assert.NotEqual(t, ret[i].CreatedAt, ret[i].UpdatedAt) + } + + // delete + err = s.BatchDeleteDagIns([]string{"test1", "test2"}) + assert.NoError(t, err) + ret, err = s.ListDagInstance(&mod.ListDagInstanceInput{}) + assert.NoError(t, err) + assert.Equal(t, 0, len(ret)) +} + +func TestStore_TaskIns(t *testing.T) { + // init store + s := mysqlStore.NewStore(&mysqlStore.StoreOption{ + MySQLConfig: &mysql.Config{ + Addr: addr, + User: user, + Passwd: passwd, + DBName: dbName, + }, + MigrationSwitch: true, + }) + + err := s.Init() + assert.NoError(t, err) + giveTaskIns := []*entity.TaskInstance{ + { + BaseInfo: entity.BaseInfo{ + ID: "test1", + }, + Status: "init", + }, + { + BaseInfo: entity.BaseInfo{ + ID: "test2", + }, + Status: "init", + }, + } + // create + err = s.BatchCreatTaskIns(giveTaskIns) + assert.NoError(t, err) + + ret, err := s.ListTaskInstance(&mod.ListTaskInstanceInput{}) + assert.NoError(t, err) + time.Sleep(time.Second) + // check and update + for i := range ret { + assert.NotEqual(t, "", ret[i].ID) + assert.Greater(t, ret[i].CreatedAt, int64(0)) + assert.Greater(t, ret[i].UpdatedAt, int64(0)) + ret[i].Name = fmt.Sprintf("name-%d", i) + ret[i].DagInsID = fmt.Sprintf("dag-%d", i) + ret[i].ActionName = fmt.Sprintf("act-%d", i) + ret[i].DependOn = []string{"test1"} + + err = s.UpdateTaskIns(ret[i]) + assert.NoError(t, err) + } + + ret, err = s.ListTaskInstance(&mod.ListTaskInstanceInput{}) + assert.NoError(t, err) + for i := range ret { + assert.NotEmpty(t, ret[i].ID) + assert.NotEmpty(t, ret[i].DagInsID) + assert.NotEmpty(t, ret[i].ActionName) + assert.NotNil(t, ret[i].DependOn) + } + + // delete + err = s.BatchDeleteTaskIns([]string{"test1", "test2"}) + assert.NoError(t, err) + ret, err = s.ListTaskInstance(&mod.ListTaskInstanceInput{}) + assert.NoError(t, err) + assert.Equal(t, 0, len(ret)) +} diff --git a/store/mysql/utils.go b/store/mysql/utils.go new file mode 100644 index 0000000..db2ea1a --- /dev/null +++ b/store/mysql/utils.go @@ -0,0 +1,28 @@ +package mysql + +func Chunk[T any](ss []T, chunkLength int) [][]T { + if chunkLength <= 0 { + panic("chunkLength should be greater than 0") + } + + result := make([][]T, 0) + l := len(ss) + if l == 0 { + return result + } + + var step = l / chunkLength + if step == 0 { + result = append(result, ss) + return result + } + var remain = l % chunkLength + for i := 0; i < step; i++ { + result = append(result, ss[i*chunkLength:(i+1)*chunkLength]) + } + if remain != 0 { + result = append(result, ss[step*chunkLength:l]) + } + + return result +}