Skip to content

Commit

Permalink
refactor: Rework sum and count nodes to make use of generics (#757)
Browse files Browse the repository at this point in the history
* Refactor count node to make use of generics

* Refactor sum node to make use of generics
  • Loading branch information
AndrewSisley authored Aug 24, 2022
1 parent d59592f commit aa92a70
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 80 deletions.
70 changes: 25 additions & 45 deletions query/graphql/planner/count.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,62 +99,28 @@ func (n *countNode) Next() (bool, error) {
// v.Len will panic if v is not one of these types, we don't want it to panic
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
if source.Filter != nil {
var arrayCount int
var err error
switch array := property.(type) {
case []core.Doc:
for _, doc := range array {
passed, err := mapper.RunFilter(doc, source.Filter)
if err != nil {
return false, err
}
if passed {
count += 1
}
}
arrayCount, err = countItems(array, source.Filter)

case []bool:
for _, doc := range array {
passed, err := mapper.RunFilter(doc, source.Filter)
if err != nil {
return false, err
}
if passed {
count += 1
}
}
arrayCount, err = countItems(array, source.Filter)

case []int64:
for _, doc := range array {
passed, err := mapper.RunFilter(doc, source.Filter)
if err != nil {
return false, err
}
if passed {
count += 1
}
}
arrayCount, err = countItems(array, source.Filter)

case []float64:
for _, doc := range array {
passed, err := mapper.RunFilter(doc, source.Filter)
if err != nil {
return false, err
}
if passed {
count += 1
}
}
arrayCount, err = countItems(array, source.Filter)

case []string:
for _, doc := range array {
passed, err := mapper.RunFilter(doc, source.Filter)
if err != nil {
return false, err
}
if passed {
count += 1
}
}
arrayCount, err = countItems(array, source.Filter)
}
if err != nil {
return false, err
}
count += arrayCount
} else {
count = count + v.Len()
}
Expand All @@ -165,4 +131,18 @@ func (n *countNode) Next() (bool, error) {
return true, nil
}

func countItems[T any](items []T, filter *mapper.Filter) (int, error) {
count := 0
for _, item := range items {
passed, err := mapper.RunFilter(item, filter)
if err != nil {
return 0, err
}
if passed {
count += 1
}
}
return count, nil
}

func (n *countNode) SetPlan(p planNode) { n.plan = p }
70 changes: 35 additions & 35 deletions query/graphql/planner/sum.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,54 +196,39 @@ func (n *sumNode) Next() (bool, error) {

for _, source := range n.aggregateMapping {
child := n.currentValue.Fields[source.Index]
var collectionSum float64
var err error
switch childCollection := child.(type) {
case []core.Doc:
for _, childItem := range childCollection {
passed, err := mapper.RunFilter(childItem, source.Filter)
if err != nil {
return false, err
}
if !passed {
continue
}

collectionSum, err = sumItems(childCollection, source.Filter, func(childItem core.Doc) float64 {
childProperty := childItem.Fields[source.ChildTarget.Index]
switch v := childProperty.(type) {
case int:
sum += float64(v)
return float64(v)
case int64:
sum += float64(v)
return float64(v)
case uint64:
sum += float64(v)
return float64(v)
case float64:
sum += v
return v
default:
// do nothing, cannot be summed
// return nothing, cannot be summed
return 0
}
}
})
case []int64:
for _, childItem := range childCollection {
passed, err := mapper.RunFilter(childItem, source.Filter)
if err != nil {
return false, err
}
if !passed {
continue
}
sum += float64(childItem)
}
collectionSum, err = sumItems(childCollection, source.Filter, func(childItem int64) float64 {
return float64(childItem)
})
case []float64:
for _, childItem := range childCollection {
passed, err := mapper.RunFilter(childItem, source.Filter)
if err != nil {
return false, err
}
if !passed {
continue
}
sum += childItem
}
collectionSum, err = sumItems(childCollection, source.Filter, func(childItem float64) float64 {
return childItem
})
}
if err != nil {
return false, err
}
sum += collectionSum
}

var typedSum interface{}
Expand All @@ -257,4 +242,19 @@ func (n *sumNode) Next() (bool, error) {
return true, nil
}

func sumItems[T any](items []T, filter *mapper.Filter, toFloat func(T) float64) (float64, error) {
var sum float64 = 0
for _, item := range items {
passed, err := mapper.RunFilter(item, filter)
if err != nil {
return 0, err
}
if !passed {
continue
}
sum += toFloat(item)
}
return sum, nil
}

func (n *sumNode) SetPlan(p planNode) { n.plan = p }

0 comments on commit aa92a70

Please sign in to comment.