diff --git a/internal/app/subsystems/aio/store/postgres/postgres.go b/internal/app/subsystems/aio/store/postgres/postgres.go index 773e7476..f2130ee3 100644 --- a/internal/app/subsystems/aio/store/postgres/postgres.go +++ b/internal/app/subsystems/aio/store/postgres/postgres.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "net/url" + "strings" "time" "github.com/resonatehq/resonate/internal/aio" @@ -92,11 +93,12 @@ const ( promises WHERE ($1::int IS NULL OR sort_id < $1) AND - (state & $2 != 0) + state & $2 != 0 AND + id LIKE $3 ORDER BY sort_id DESC LIMIT - $3` + $4` PROMISE_INSERT_STATEMENT = ` INSERT INTO promises @@ -533,6 +535,12 @@ func (w *PostgresStoreWorker) readPromise(tx *sql.Tx, cmd *types.ReadPromiseComm } func (w *PostgresStoreWorker) searchPromises(tx *sql.Tx, cmd *types.SearchPromisesCommand) (*types.Result, error) { + util.Assert(cmd.Q != "", "query cannot be empty") + util.Assert(cmd.States != nil, "states cannot be empty") + + // convert query + query := strings.ReplaceAll(cmd.Q, "*", "%") + // convert list of state to bit mask mask := 0 for _, state := range cmd.States { @@ -540,7 +548,7 @@ func (w *PostgresStoreWorker) searchPromises(tx *sql.Tx, cmd *types.SearchPromis } // select - rows, err := tx.Query(PROMISE_SEARCH_STATEMENT, cmd.SortId, mask, cmd.Limit) + rows, err := tx.Query(PROMISE_SEARCH_STATEMENT, cmd.SortId, mask, query, cmd.Limit) if err != nil { return nil, err } diff --git a/internal/app/subsystems/aio/store/sqlite/sqlite.go b/internal/app/subsystems/aio/store/sqlite/sqlite.go index 9e06b852..2b16065d 100644 --- a/internal/app/subsystems/aio/store/sqlite/sqlite.go +++ b/internal/app/subsystems/aio/store/sqlite/sqlite.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/json" "os" + "strings" "time" "github.com/resonatehq/resonate/internal/aio" @@ -84,7 +85,8 @@ const ( promises WHERE (? IS NULL OR sort_id < ?) AND - (state & ? != 0) + state & ? != 0 AND + id LIKE ? ORDER BY sort_id DESC LIMIT @@ -508,6 +510,12 @@ func (w *SqliteStoreWorker) readPromise(tx *sql.Tx, cmd *types.ReadPromiseComman } func (w *SqliteStoreWorker) searchPromises(tx *sql.Tx, cmd *types.SearchPromisesCommand) (*types.Result, error) { + util.Assert(cmd.Q != "", "query cannot be empty") + util.Assert(cmd.States != nil, "states cannot be empty") + + // convert query + query := strings.ReplaceAll(cmd.Q, "*", "%") + // convert list of state to bit mask mask := 0 for _, state := range cmd.States { @@ -515,7 +523,7 @@ func (w *SqliteStoreWorker) searchPromises(tx *sql.Tx, cmd *types.SearchPromises } // select - rows, err := tx.Query(PROMISE_SEARCH_STATEMENT, cmd.SortId, cmd.SortId, mask, cmd.Limit) + rows, err := tx.Query(PROMISE_SEARCH_STATEMENT, cmd.SortId, cmd.SortId, mask, query, cmd.Limit) if err != nil { return nil, err } diff --git a/internal/app/subsystems/aio/store/test/util.go b/internal/app/subsystems/aio/store/test/util.go index b0f82293..05a1d272 100644 --- a/internal/app/subsystems/aio/store/test/util.go +++ b/internal/app/subsystems/aio/store/test/util.go @@ -1072,7 +1072,247 @@ var TestCases = []*testCase{ }, }, { - name: "SearchPromises", + name: "SearchPromisesById", + commands: []*types.Command{ + { + Kind: types.StoreCreatePromise, + CreatePromise: &types.CreatePromiseCommand{ + Id: "foo.a", + Timeout: 2, + Param: promise.Value{ + Headers: map[string]string{}, + Data: []byte{}, + }, + Tags: map[string]string{}, + CreatedOn: 1, + }, + }, + { + Kind: types.StoreCreatePromise, + CreatePromise: &types.CreatePromiseCommand{ + Id: "foo.b", + Timeout: 2, + Param: promise.Value{ + Headers: map[string]string{}, + Data: []byte{}, + }, + Tags: map[string]string{}, + CreatedOn: 1, + }, + }, + { + Kind: types.StoreCreatePromise, + CreatePromise: &types.CreatePromiseCommand{ + Id: "a.bar", + Timeout: 2, + Param: promise.Value{ + Headers: map[string]string{}, + Data: []byte{}, + }, + Tags: map[string]string{}, + CreatedOn: 1, + }, + }, + { + Kind: types.StoreCreatePromise, + CreatePromise: &types.CreatePromiseCommand{ + Id: "b.bar", + Timeout: 2, + Param: promise.Value{ + Headers: map[string]string{}, + Data: []byte{}, + }, + Tags: map[string]string{}, + CreatedOn: 1, + }, + }, + { + Kind: types.StoreSearchPromises, + SearchPromises: &types.SearchPromisesCommand{ + Q: "foo.*", + States: []promise.State{ + promise.Pending, + }, + Limit: 2, + }, + }, + { + Kind: types.StoreSearchPromises, + SearchPromises: &types.SearchPromisesCommand{ + Q: "*.bar", + States: []promise.State{ + promise.Pending, + }, + Limit: 2, + }, + }, + { + Kind: types.StoreSearchPromises, + SearchPromises: &types.SearchPromisesCommand{ + Q: "*", + States: []promise.State{ + promise.Pending, + }, + Limit: 2, + }, + }, + { + Kind: types.StoreSearchPromises, + SearchPromises: &types.SearchPromisesCommand{ + Q: "*", + States: []promise.State{ + promise.Pending, + }, + Limit: 2, + SortId: int64ToPointer(3), + }, + }, + }, + expected: []*types.Result{ + { + Kind: types.StoreCreatePromise, + CreatePromise: &types.AlterPromisesResult{ + RowsAffected: 1, + }, + }, + { + Kind: types.StoreCreatePromise, + CreatePromise: &types.AlterPromisesResult{ + RowsAffected: 1, + }, + }, + { + Kind: types.StoreCreatePromise, + CreatePromise: &types.AlterPromisesResult{ + RowsAffected: 1, + }, + }, + { + Kind: types.StoreCreatePromise, + CreatePromise: &types.AlterPromisesResult{ + RowsAffected: 1, + }, + }, + { + Kind: types.StoreSearchPromises, + SearchPromises: &types.QueryPromisesResult{ + RowsReturned: 2, + LastSortId: 1, + Records: []*promise.PromiseRecord{ + { + Id: "foo.b", + State: 1, + ParamHeaders: []byte("{}"), + ParamData: []byte{}, + Timeout: 2, + CreatedOn: int64ToPointer(1), + Tags: []byte("{}"), + SortId: 2, + }, + { + Id: "foo.a", + State: 1, + ParamHeaders: []byte("{}"), + ParamData: []byte{}, + Timeout: 2, + CreatedOn: int64ToPointer(1), + Tags: []byte("{}"), + SortId: 1, + }, + }, + }, + }, + { + Kind: types.StoreSearchPromises, + SearchPromises: &types.QueryPromisesResult{ + RowsReturned: 2, + LastSortId: 3, + Records: []*promise.PromiseRecord{ + { + Id: "b.bar", + State: 1, + ParamHeaders: []byte("{}"), + ParamData: []byte{}, + Timeout: 2, + CreatedOn: int64ToPointer(1), + Tags: []byte("{}"), + SortId: 4, + }, + { + Id: "a.bar", + State: 1, + ParamHeaders: []byte("{}"), + ParamData: []byte{}, + Timeout: 2, + CreatedOn: int64ToPointer(1), + Tags: []byte("{}"), + SortId: 3, + }, + }, + }, + }, + { + Kind: types.StoreSearchPromises, + SearchPromises: &types.QueryPromisesResult{ + RowsReturned: 2, + LastSortId: 3, + Records: []*promise.PromiseRecord{ + { + Id: "b.bar", + State: 1, + ParamHeaders: []byte("{}"), + ParamData: []byte{}, + Timeout: 2, + CreatedOn: int64ToPointer(1), + Tags: []byte("{}"), + SortId: 4, + }, + { + Id: "a.bar", + State: 1, + ParamHeaders: []byte("{}"), + ParamData: []byte{}, + Timeout: 2, + CreatedOn: int64ToPointer(1), + Tags: []byte("{}"), + SortId: 3, + }, + }, + }, + }, + { + Kind: types.StoreSearchPromises, + SearchPromises: &types.QueryPromisesResult{ + RowsReturned: 2, + LastSortId: 1, + Records: []*promise.PromiseRecord{ + { + Id: "foo.b", + State: 1, + ParamHeaders: []byte("{}"), + ParamData: []byte{}, + Timeout: 2, + CreatedOn: int64ToPointer(1), + Tags: []byte("{}"), + SortId: 2, + }, + { + Id: "foo.a", + State: 1, + ParamHeaders: []byte("{}"), + ParamData: []byte{}, + Timeout: 2, + CreatedOn: int64ToPointer(1), + Tags: []byte("{}"), + SortId: 1, + }, + }, + }, + }, + }, + }, + { + name: "SearchPromisesByState", commands: []*types.Command{ { Kind: types.StoreCreatePromise, diff --git a/test/dst/generator.go b/test/dst/generator.go index 4b61e6f7..d21be826 100644 --- a/test/dst/generator.go +++ b/test/dst/generator.go @@ -125,6 +125,14 @@ func (g *Generator) GenerateSearchPromises(r *rand.Rand, t int64) *types.Request limit := r.Intn(10) states := []promise.State{} + var query string + switch r.Intn(2) { + case 0: + query = fmt.Sprintf("*%d", r.Intn(10)) + default: + query = fmt.Sprintf("%d*", r.Intn(10)) + } + for i := 0; i < r.Intn(5); i++ { switch r.Intn(5) { case 0: @@ -143,7 +151,7 @@ func (g *Generator) GenerateSearchPromises(r *rand.Rand, t int64) *types.Request return &types.Request{ Kind: types.SearchPromises, SearchPromises: &types.SearchPromisesRequest{ - Q: "*", + Q: query, States: states, Limit: limit, }, diff --git a/test/dst/model.go b/test/dst/model.go index b73c6ff8..e9214f52 100644 --- a/test/dst/model.go +++ b/test/dst/model.go @@ -2,6 +2,8 @@ package dst import ( "fmt" + "regexp" + "strings" "github.com/resonatehq/resonate/internal/kernel/types" "github.com/resonatehq/resonate/pkg/promise" @@ -59,11 +61,6 @@ func (m *Model) ValidateReadPromise(req *types.Request, res *types.Response) err } func (m *Model) ValidateSearchPromises(req *types.Request, res *types.Response) error { - states := map[promise.State]bool{} - for _, state := range req.SearchPromises.States { - states[state] = true - } - if res.SearchPromises.Cursor != nil { m.addCursor(&types.Request{ Kind: types.SearchPromises, @@ -71,7 +68,18 @@ func (m *Model) ValidateSearchPromises(req *types.Request, res *types.Response) }) } + regex := regexp.MustCompile(fmt.Sprintf("^%s$", strings.ReplaceAll(req.SearchPromises.Q, "*", ".*"))) + + states := map[promise.State]bool{} + for _, state := range req.SearchPromises.States { + states[state] = true + } + for _, p := range res.SearchPromises.Promises { + if !regex.MatchString(p.Id) { + return fmt.Errorf("promise id '%s' does not match search query '%s'", p.Id, req.SearchPromises.Q) + } + if _, ok := states[p.State]; !ok { return fmt.Errorf("unexpected state %s, searched for %s", p.State, req.SearchPromises.States) }