Skip to content

Commit

Permalink
Handle the count(uid) subgraph correctly, fixes #4038 (#4122)
Browse files Browse the repository at this point in the history
  • Loading branch information
mangalaman93 authored Oct 10, 2019
1 parent 7e74f43 commit 9864961
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 56 deletions.
23 changes: 9 additions & 14 deletions gql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ type GraphQuery struct {
// If gq.fragment is nonempty, then it is a fragment reference / spread.
fragment string

// Indicates whether count of uids is requested as a child node. If there
// is an alias, then UidCountAlias will be set (otherwise it will be the
// empty string).
UidCount bool
UidCountAlias string

// True for blocks that don't have a starting function and hence no starting nodes. They are
// used to aggregate and get variables defined in another block.
IsEmpty bool
Expand Down Expand Up @@ -2849,15 +2843,16 @@ func godeep(it *lex.ItemIterator, gq *GraphQuery) error {
}

count = notSeen
gq.UidCount = true
gq.Var = varName
if alias != "" {
gq.UidCountAlias = alias
// This is a count(uid) node.
// Reset the alias here after assigning to UidCountAlias, so that siblings
// of this node don't get it.
alias = ""
child := &GraphQuery{
Attr: "uid",
Alias: alias,
Var: varName,
IsCount: true,
IsInternal: true,
}
gq.Children = append(gq.Children, child)
alias = ""

it.Next()
it.Next()
}
Expand Down
60 changes: 30 additions & 30 deletions query/outputnode.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ type outputNode interface {
SetUID(uid uint64, attr string)
IsEmpty() bool

addCountAtRoot(*SubGraph)
addGroupby(*SubGraph, *groupResults, string)
addAggregations(*SubGraph) error
}
Expand Down Expand Up @@ -386,18 +385,6 @@ func (fj *fastJsonNode) addGroupby(sg *SubGraph, res *groupResults, fname string
fj.AddListChild(fname, g)
}

func (fj *fastJsonNode) addCountAtRoot(sg *SubGraph) {
c := types.ValueForType(types.IntID)
c.Value = int64(len(sg.DestUIDs.Uids))
n1 := fj.New(sg.Params.Alias)
field := sg.Params.UidCountAlias
if field == "" {
field = "count"
}
n1.AddValue(field, c)
fj.AddListChild(sg.Params.Alias, n1)
}

func (fj *fastJsonNode) addAggregations(sg *SubGraph) error {
for _, child := range sg.Children {
aggVal, ok := child.Params.UidToVal[0]
Expand All @@ -423,6 +410,32 @@ func (fj *fastJsonNode) addAggregations(sg *SubGraph) error {
return nil
}

func handleCountUIDNodes(sg *SubGraph, n outputNode, count int) bool {
addedNewChild := false
fieldName := sg.fieldName()
for _, child := range sg.Children {
uidCount := child.Attr == "uid" && child.Params.DoCount && child.IsInternal()
normWithoutAlias := child.Params.Alias == "" && child.Params.Normalize
if uidCount && !normWithoutAlias {
addedNewChild = true

c := types.ValueForType(types.IntID)
c.Value = int64(count)

field := child.Params.Alias
if field == "" {
field = "count"
}

fjChild := n.New(fieldName)
fjChild.AddValue(field, c)
n.AddListChild(fieldName, fjChild)
}
}

return addedNewChild
}

func processNodeUids(fj *fastJsonNode, sg *SubGraph) error {
var seedNode *fastJsonNode
if sg.Params.IsEmpty {
Expand All @@ -434,12 +447,7 @@ func processNodeUids(fj *fastJsonNode, sg *SubGraph) error {
return nil
}

hasChild := false
if sg.Params.UidCount && !(sg.Params.UidCountAlias == "" && sg.Params.Normalize) {
hasChild = true
fj.addCountAtRoot(sg)
}

hasChild := handleCountUIDNodes(sg, fj, len(sg.DestUIDs.Uids))
if sg.Params.IsGroupBy {
if len(sg.GroupbyRes) == 0 {
return errors.Errorf("Expected GroupbyRes to have length > 0.")
Expand Down Expand Up @@ -715,17 +723,9 @@ func (sg *SubGraph) preTraverse(uid uint64, dst outputNode) error {
}
}
}
if pc.Params.UidCount && !(pc.Params.UidCountAlias == "" && pc.Params.Normalize) {
uc := dst.New(fieldName)
c := types.ValueForType(types.IntID)
c.Value = int64(len(ul.Uids))
alias := pc.Params.UidCountAlias
if alias == "" {
alias = "count"
}
uc.AddValue(alias, c)
dst.AddListChild(fieldName, uc)
}

// add value for count(uid) nodes if any.
_ = handleCountUIDNodes(pc, dst, len(ul.Uids))
} else {
if pc.Params.Alias == "" && len(pc.Params.Langs) > 0 {
fieldName += "@"
Expand Down
20 changes: 8 additions & 12 deletions query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,6 @@ type params struct {
// GroupbyAttrs holds the list of attributes to group by.
GroupbyAttrs []gql.GroupByAttr

// UidCount is true when "count(uid)" is used.
UidCount bool
// UidCountAlias holds the alias of the variable used to hold the results of a "count(uid)"
// request, if any.
UidCountAlias string
// ParentIds is a stack that is maintained and passed down to children.
ParentIds []uint64
// IsEmpty is true if the subgraph doesn't have any SrcUids or DestUids.
Expand Down Expand Up @@ -532,8 +527,6 @@ func treeCopy(gq *gql.GraphQuery, sg *SubGraph) error {
GroupbyAttrs: gchild.GroupbyAttrs,
IsGroupBy: gchild.IsGroupby,
IsInternal: gchild.IsInternal,
UidCount: gchild.UidCount,
UidCountAlias: gchild.UidCountAlias,
}

if gchild.IsCount {
Expand Down Expand Up @@ -758,8 +751,6 @@ func newGraph(ctx context.Context, gq *gql.GraphQuery) (*SubGraph, error) {
Var: gq.Var,
GroupbyAttrs: gq.GroupbyAttrs,
IsGroupBy: gq.IsGroupby,
UidCount: gq.UidCount,
UidCountAlias: gq.UidCountAlias,
}

for argk := range gq.Args {
Expand Down Expand Up @@ -1203,8 +1194,11 @@ func (sg *SubGraph) valueVarAggregation(doneVars map[string]varValue, path []*Su
srcMap := doneVars[srcVar.Name]
// The value var can be empty. No need to check for nil.
sg.Params.UidToVal = srcMap.Vals
} else if sg.Attr == "uid" && sg.Params.DoCount {
// This is the count(uid) case.
// We will do the computation later while constructing the result.
} else {
return errors.Errorf("Unhandled pb.node %v with parent %v", sg.Attr, parent.Attr)
return errors.Errorf("Unhandled pb.node <%v> with parent <%v>", sg.Attr, parent.Attr)
}

return nil
Expand Down Expand Up @@ -1378,7 +1372,7 @@ func (sg *SubGraph) populateUidValVar(doneVars map[string]varValue, sgPath []*Su
}
doneVars[sg.Params.Var].Vals[uid] = val
}
} else if sg.Params.UidCount {
} else if sg.Params.DoCount && sg.Attr == "uid" && sg.IsInternal() {
// 2. This is the case where count(uid) is requested in the query and stored as variable.
// In this case there is just one value which is stored corresponding to the uid
// math.MaxUint64 which isn't entirely correct as there could be an actual uid with that
Expand All @@ -1389,9 +1383,11 @@ func (sg *SubGraph) populateUidValVar(doneVars map[string]varValue, sgPath []*Su
strList: sg.valueMatrix,
}

// Because we are counting the number of UIDs in parent
// we use the length of SrcUIDs instead of DestUIDs.
val := types.Val{
Tid: types.IntID,
Value: int64(len(sg.DestUIDs.Uids)),
Value: int64(len(sg.SrcUIDs.Uids)),
}
doneVars[sg.Params.Var].Vals[math.MaxUint64] = val
} else if len(sg.DestUIDs.Uids) != 0 || (sg.Attr == "uid" && sg.SrcUIDs != nil) {
Expand Down
Loading

0 comments on commit 9864961

Please sign in to comment.