Skip to content

Commit

Permalink
fix(firestore): Add UTF-8 validation (#10881)
Browse files Browse the repository at this point in the history
  • Loading branch information
bhshkh authored Sep 19, 2024
1 parent 9ae039a commit 9199843
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 11 deletions.
4 changes: 2 additions & 2 deletions firestore/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,8 @@ func (c *Client) getAll(ctx context.Context, docRefs []*DocumentRef, tid []byte,
var docNames []string
docIndices := map[string][]int{} // doc name to positions in docRefs
for i, dr := range docRefs {
if dr == nil {
return nil, errNilDocRef
if err := dr.isValid(); err != nil {
return nil, err
}
docNames = append(docNames, dr.Path)
docIndices[dr.Path] = append(docIndices[dr.Path], i)
Expand Down
6 changes: 6 additions & 0 deletions firestore/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,12 @@ func TestGetAllErrors(t *testing.T) {
if _, err := c.GetAll(ctx, []*DocumentRef{c.Doc("C/a")}); err == nil {
t.Error("got nil, want error")
}

// Invalid UTF-8 characters
srv.reset()
if _, gotErr := c.GetAll(ctx, []*DocumentRef{c.Doc("C/Mayag\xcfez")}); !errorsMatch(gotErr, errInvalidUtf8DocRef) {
t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef)
}
}

func TestClient_WithReadOptions(t *testing.T) {
Expand Down
32 changes: 23 additions & 9 deletions firestore/docref.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"io"
"reflect"
"sort"
"unicode/utf8"

vkit "cloud.google.com/go/firestore/apiv1"
pb "cloud.google.com/go/firestore/apiv1/firestorepb"
Expand All @@ -31,7 +32,10 @@ import (
"google.golang.org/protobuf/proto"
)

var errNilDocRef = errors.New("firestore: nil DocumentRef")
var (
errNilDocRef = errors.New("firestore: nil DocumentRef")
errInvalidUtf8DocRef = errors.New("firestore: ID in DocumentRef contains invalid UTF-8 characters")
)

// A DocumentRef is a reference to a Firestore document.
type DocumentRef struct {
Expand Down Expand Up @@ -63,6 +67,16 @@ func newDocRef(parent *CollectionRef, id string) *DocumentRef {
}
}

func (d *DocumentRef) isValid() error {
if d == nil {
return errNilDocRef
}
if !utf8.ValidString(d.ID) {
return errInvalidUtf8DocRef
}
return nil
}

// Collection returns a reference to sub-collection of this document.
func (d *DocumentRef) Collection(id string) *CollectionRef {
return newCollRefWithParent(d.Parent.c, d, id)
Expand All @@ -79,8 +93,8 @@ func (d *DocumentRef) Get(ctx context.Context) (_ *DocumentSnapshot, err error)
ctx = trace.StartSpan(ctx, "cloud.google.com/go/firestore.DocumentRef.Get")
defer func() { trace.EndSpan(ctx, err) }()

if d == nil {
return nil, errNilDocRef
if err := d.isValid(); err != nil {
return nil, err
}

docsnaps, err := d.Parent.c.getAll(ctx, []*DocumentRef{d}, nil, d.readSettings)
Expand Down Expand Up @@ -147,8 +161,8 @@ func (d *DocumentRef) Create(ctx context.Context, data interface{}) (_ *WriteRes
}

func (d *DocumentRef) newCreateWrites(data interface{}) ([]*pb.Write, error) {
if d == nil {
return nil, errNilDocRef
if err := d.isValid(); err != nil {
return nil, err
}
doc, transforms, err := toProtoDocument(data)
if err != nil {
Expand Down Expand Up @@ -179,8 +193,8 @@ func (d *DocumentRef) Set(ctx context.Context, data interface{}, opts ...SetOpti
}

func (d *DocumentRef) newSetWrites(data interface{}, opts []SetOption) ([]*pb.Write, error) {
if d == nil {
return nil, errNilDocRef
if err := d.isValid(); err != nil {
return nil, err
}
if data == nil {
return nil, errors.New("firestore: nil document contents")
Expand Down Expand Up @@ -259,8 +273,8 @@ func (d *DocumentRef) Delete(ctx context.Context, preconds ...Precondition) (_ *
}

func (d *DocumentRef) newDeleteWrites(preconds []Precondition) ([]*pb.Write, error) {
if d == nil {
return nil, errNilDocRef
if err := d.isValid(); err != nil {
return nil, err
}
pc, err := processPreconditionsForDelete(preconds)
if err != nil {
Expand Down
47 changes: 47 additions & 0 deletions firestore/docref_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ func TestDocGet(t *testing.T) {
if status.Code(err) != codes.NotFound {
t.Errorf("got %v, want NotFound", err)
}

// Invalid UTF-8 characters
if _, gotErr := c.Collection("C").Doc("Mayag\xcfez").Get(ctx); !errorsMatch(gotErr, errInvalidUtf8DocRef) {
t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef)
}

// nil DocRef
var nilDocRef *DocumentRef
if _, gotErr := nilDocRef.Get(ctx); !errorsMatch(gotErr, errNilDocRef) {
t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef)
}
}

func TestDocSet(t *testing.T) {
Expand Down Expand Up @@ -133,6 +144,18 @@ func TestDocSet(t *testing.T) {
if err == nil {
t.Errorf("got nil, want error")
}

// Invalid UTF-8 characters
if _, gotErr := c.Collection("C").Doc("Mayag\xcfez").
Set(ctx, data, Merge([]string{"*", "~"})); !errorsMatch(gotErr, errInvalidUtf8DocRef) {
t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef)
}

// nil DocRef
var nilDocRef *DocumentRef
if _, gotErr := nilDocRef.Set(ctx, data, Merge([]string{"*", "~"})); !errorsMatch(gotErr, errNilDocRef) {
t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef)
}
}

func TestDocCreate(t *testing.T) {
Expand Down Expand Up @@ -175,6 +198,18 @@ func TestDocCreate(t *testing.T) {
if err != nil {
t.Fatal(err)
}

// Invalid UTF-8 characters
if _, gotErr := c.Collection("C").Doc("Mayag\xcfez").
Create(ctx, &create{}); !errorsMatch(gotErr, errInvalidUtf8DocRef) {
t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef)
}

// nil DocRef
var nilDocRef *DocumentRef
if _, gotErr := nilDocRef.Create(ctx, &create{}); !errorsMatch(gotErr, errNilDocRef) {
t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef)
}
}

func TestDocDelete(t *testing.T) {
Expand All @@ -199,6 +234,18 @@ func TestDocDelete(t *testing.T) {
if !testEqual(wr, &WriteResult{}) {
t.Errorf("got %+v, want %+v", wr, writeResultForSet)
}

// Invalid UTF-8 characters
if _, gotErr := c.Collection("C").Doc("Mayag\xcfez").
Delete(ctx); !errorsMatch(gotErr, errInvalidUtf8DocRef) {
t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef)
}

// nil DocRef
var nilDocRef *DocumentRef
if _, gotErr := nilDocRef.Delete(ctx); !errorsMatch(gotErr, errNilDocRef) {
t.Errorf("got: %v, want: %v", gotErr, errInvalidUtf8DocRef)
}
}

var (
Expand Down

0 comments on commit 9199843

Please sign in to comment.