Skip to content

Commit

Permalink
fix: return error if invalid UUID is supplied to ids filter (#4116)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-jonas committed Sep 25, 2024
1 parent 1146599 commit 98140f2
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 17 deletions.
30 changes: 20 additions & 10 deletions identity/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"strings"
"time"

"github.com/gofrs/uuid"

"github.com/ory/x/crdbx"
"github.com/ory/x/pagination/keysetpagination"

Expand Down Expand Up @@ -193,6 +195,7 @@ type listIdentitiesParameters struct {
// default: errorGeneric
func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
includeCredentials := r.URL.Query()["include_credential"]
var err error
var declassify []CredentialsType
for _, v := range includeCredentials {
tc, ok := ParseCredentialsType(v)
Expand All @@ -204,17 +207,24 @@ func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Para
}
}

var (
err error
params = ListIdentityParameters{
Expand: ExpandDefault,
IdsFilter: r.URL.Query()["ids"],
CredentialsIdentifier: r.URL.Query().Get("credentials_identifier"),
CredentialsIdentifierSimilar: r.URL.Query().Get("preview_credentials_identifier_similar"),
ConsistencyLevel: crdbx.ConsistencyLevelFromRequest(r),
DeclassifyCredentials: declassify,
var idsFilter []uuid.UUID
for _, v := range r.URL.Query()["ids"] {
id, err := uuid.FromString(v)
if err != nil {
h.r.Writer().WriteError(w, r, errors.WithStack(herodot.ErrBadRequest.WithReasonf("Invalid UUID value `%s` for parameter `ids`.", v)))
return
}
)
idsFilter = append(idsFilter, id)
}

params := ListIdentityParameters{
Expand: ExpandDefault,
IdsFilter: idsFilter,
CredentialsIdentifier: r.URL.Query().Get("credentials_identifier"),
CredentialsIdentifierSimilar: r.URL.Query().Get("preview_credentials_identifier_similar"),
ConsistencyLevel: crdbx.ConsistencyLevelFromRequest(r),
DeclassifyCredentials: declassify,
}
if params.CredentialsIdentifier != "" && params.CredentialsIdentifierSimilar != "" {
h.r.Writer().WriteError(w, r, herodot.ErrBadRequest.WithReason("Cannot pass both credentials_identifier and preview_credentials_identifier_similar."))
return
Expand Down
9 changes: 7 additions & 2 deletions identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,18 +372,23 @@ func TestHandler(t *testing.T) {
require.Equal(t, len(ids), identitiesAmount)
})

t.Run("case= list few identities", func(t *testing.T) {
t.Run("case=list few identities", func(t *testing.T) {
url := "/identities?ids=" + ids[0].String()
for i := 1; i < listAmount; i++ {
url += "&ids=" + ids[i].String()
}
res := get(t, adminTS, url, 200)
res := get(t, adminTS, url, http.StatusOK)

identities := res.Array()
require.Equal(t, len(identities), listAmount)
})
})

t.Run("case=malformed ids should return an error", func(t *testing.T) {
res := get(t, adminTS, "/identities?ids=not-a-uuid", http.StatusBadRequest)
assert.Contains(t, res.Get("error.reason").String(), "Invalid UUID value `not-a-uuid` for parameter `ids`.", "%s", res.Raw)
})

t.Run("suite=create and update", func(t *testing.T) {
var i identity.Identity
createOidcIdentity := func(t *testing.T, identifier, accessToken, refreshToken, idToken string, encrypt bool) string {
Expand Down
2 changes: 1 addition & 1 deletion identity/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
type (
ListIdentityParameters struct {
Expand Expandables
IdsFilter []string
IdsFilter []uuid.UUID
CredentialsIdentifier string
CredentialsIdentifierSimilar string
DeclassifyCredentials []CredentialsType
Expand Down
5 changes: 1 addition & 4 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -676,10 +676,7 @@ func TestPool(ctx context.Context, p persistence.Persister, m *identity.Manager,
})

t.Run("list some using ids filter", func(t *testing.T) {
var filterIds []string
for _, id := range createdIDs[:2] {
filterIds = append(filterIds, id.String())
}
filterIds := createdIDs[:2]

is, _, err := p.ListIdentities(ctx, identity.ListIdentityParameters{Expand: identity.ExpandDefault, IdsFilter: filterIds})
require.NoError(t, err)
Expand Down

0 comments on commit 98140f2

Please sign in to comment.