From 9199843947bc3a0fa415dba50ba2221850e0fbad Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Thu, 19 Sep 2024 21:18:08 +0000 Subject: [PATCH] fix(firestore): Add UTF-8 validation (#10881) --- firestore/client.go | 4 ++-- firestore/client_test.go | 6 +++++ firestore/docref.go | 32 +++++++++++++++++++-------- firestore/docref_test.go | 47 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 11 deletions(-) diff --git a/firestore/client.go b/firestore/client.go index 28144677644e..5f5e186a0067 100644 --- a/firestore/client.go +++ b/firestore/client.go @@ -256,8 +256,8 @@ func (c *Client) getAll(ctx context.Context, docRefs []*DocumentRef, tid []byte, var docNames []string docIndices := map[string][]int{} // doc name to positions in docRefs for i, dr := range docRefs { - if dr == nil { - return nil, errNilDocRef + if err := dr.isValid(); err != nil { + return nil, err } docNames = append(docNames, dr.Path) docIndices[dr.Path] = append(docIndices[dr.Path], i) diff --git a/firestore/client_test.go b/firestore/client_test.go index 9dd889290c73..f710aaea07bc 100644 --- a/firestore/client_test.go +++ b/firestore/client_test.go @@ -351,6 +351,12 @@ func TestGetAllErrors(t *testing.T) { if _, err := c.GetAll(ctx, []*DocumentRef{c.Doc("C/a")}); err == nil { t.Error("got nil, want error") } + + // Invalid UTF-8 characters + srv.reset() + if _, gotErr := c.GetAll(ctx, []*DocumentRef{c.Doc("C/Mayag\xcfez")}); !errorsMatch(gotErr, errInvalidUtf8DocRef) { + t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef) + } } func TestClient_WithReadOptions(t *testing.T) { diff --git a/firestore/docref.go b/firestore/docref.go index 822316316c98..97302868b10e 100644 --- a/firestore/docref.go +++ b/firestore/docref.go @@ -21,6 +21,7 @@ import ( "io" "reflect" "sort" + "unicode/utf8" vkit "cloud.google.com/go/firestore/apiv1" pb "cloud.google.com/go/firestore/apiv1/firestorepb" @@ -31,7 +32,10 @@ import ( "google.golang.org/protobuf/proto" ) -var errNilDocRef = errors.New("firestore: nil DocumentRef") +var ( + errNilDocRef = errors.New("firestore: nil DocumentRef") + errInvalidUtf8DocRef = errors.New("firestore: ID in DocumentRef contains invalid UTF-8 characters") +) // A DocumentRef is a reference to a Firestore document. type DocumentRef struct { @@ -63,6 +67,16 @@ func newDocRef(parent *CollectionRef, id string) *DocumentRef { } } +func (d *DocumentRef) isValid() error { + if d == nil { + return errNilDocRef + } + if !utf8.ValidString(d.ID) { + return errInvalidUtf8DocRef + } + return nil +} + // Collection returns a reference to sub-collection of this document. func (d *DocumentRef) Collection(id string) *CollectionRef { return newCollRefWithParent(d.Parent.c, d, id) @@ -79,8 +93,8 @@ func (d *DocumentRef) Get(ctx context.Context) (_ *DocumentSnapshot, err error) ctx = trace.StartSpan(ctx, "cloud.google.com/go/firestore.DocumentRef.Get") defer func() { trace.EndSpan(ctx, err) }() - if d == nil { - return nil, errNilDocRef + if err := d.isValid(); err != nil { + return nil, err } docsnaps, err := d.Parent.c.getAll(ctx, []*DocumentRef{d}, nil, d.readSettings) @@ -147,8 +161,8 @@ func (d *DocumentRef) Create(ctx context.Context, data interface{}) (_ *WriteRes } func (d *DocumentRef) newCreateWrites(data interface{}) ([]*pb.Write, error) { - if d == nil { - return nil, errNilDocRef + if err := d.isValid(); err != nil { + return nil, err } doc, transforms, err := toProtoDocument(data) if err != nil { @@ -179,8 +193,8 @@ func (d *DocumentRef) Set(ctx context.Context, data interface{}, opts ...SetOpti } func (d *DocumentRef) newSetWrites(data interface{}, opts []SetOption) ([]*pb.Write, error) { - if d == nil { - return nil, errNilDocRef + if err := d.isValid(); err != nil { + return nil, err } if data == nil { return nil, errors.New("firestore: nil document contents") @@ -259,8 +273,8 @@ func (d *DocumentRef) Delete(ctx context.Context, preconds ...Precondition) (_ * } func (d *DocumentRef) newDeleteWrites(preconds []Precondition) ([]*pb.Write, error) { - if d == nil { - return nil, errNilDocRef + if err := d.isValid(); err != nil { + return nil, err } pc, err := processPreconditionsForDelete(preconds) if err != nil { diff --git a/firestore/docref_test.go b/firestore/docref_test.go index f09578bbaf1b..415fa37c969b 100644 --- a/firestore/docref_test.go +++ b/firestore/docref_test.go @@ -89,6 +89,17 @@ func TestDocGet(t *testing.T) { if status.Code(err) != codes.NotFound { t.Errorf("got %v, want NotFound", err) } + + // Invalid UTF-8 characters + if _, gotErr := c.Collection("C").Doc("Mayag\xcfez").Get(ctx); !errorsMatch(gotErr, errInvalidUtf8DocRef) { + t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef) + } + + // nil DocRef + var nilDocRef *DocumentRef + if _, gotErr := nilDocRef.Get(ctx); !errorsMatch(gotErr, errNilDocRef) { + t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef) + } } func TestDocSet(t *testing.T) { @@ -133,6 +144,18 @@ func TestDocSet(t *testing.T) { if err == nil { t.Errorf("got nil, want error") } + + // Invalid UTF-8 characters + if _, gotErr := c.Collection("C").Doc("Mayag\xcfez"). + Set(ctx, data, Merge([]string{"*", "~"})); !errorsMatch(gotErr, errInvalidUtf8DocRef) { + t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef) + } + + // nil DocRef + var nilDocRef *DocumentRef + if _, gotErr := nilDocRef.Set(ctx, data, Merge([]string{"*", "~"})); !errorsMatch(gotErr, errNilDocRef) { + t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef) + } } func TestDocCreate(t *testing.T) { @@ -175,6 +198,18 @@ func TestDocCreate(t *testing.T) { if err != nil { t.Fatal(err) } + + // Invalid UTF-8 characters + if _, gotErr := c.Collection("C").Doc("Mayag\xcfez"). + Create(ctx, &create{}); !errorsMatch(gotErr, errInvalidUtf8DocRef) { + t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef) + } + + // nil DocRef + var nilDocRef *DocumentRef + if _, gotErr := nilDocRef.Create(ctx, &create{}); !errorsMatch(gotErr, errNilDocRef) { + t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef) + } } func TestDocDelete(t *testing.T) { @@ -199,6 +234,18 @@ func TestDocDelete(t *testing.T) { if !testEqual(wr, &WriteResult{}) { t.Errorf("got %+v, want %+v", wr, writeResultForSet) } + + // Invalid UTF-8 characters + if _, gotErr := c.Collection("C").Doc("Mayag\xcfez"). + Delete(ctx); !errorsMatch(gotErr, errInvalidUtf8DocRef) { + t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef) + } + + // nil DocRef + var nilDocRef *DocumentRef + if _, gotErr := nilDocRef.Delete(ctx); !errorsMatch(gotErr, errNilDocRef) { + t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef) + } } var (