Skip to content

Commit

Permalink
add unit test for function (#708) (#733)
Browse files Browse the repository at this point in the history
* add unit test for function (#708)

* formatter import
  • Loading branch information
csynineyang authored Aug 6, 2023
1 parent 33f2907 commit 3ea47a2
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 0 deletions.
30 changes: 30 additions & 0 deletions pkg/dataset/filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"database/sql"
"fmt"
"io"
"strings"
"testing"
)

Expand Down Expand Up @@ -75,4 +76,33 @@ func TestFilter(t *testing.T) {

t.Logf("id=%v, name=%v, gender=%v\n", dest[0], dest[1], dest[2])
}

for i := int64(10); i < 100; i++ {
root.Rows = append(root.Rows, rows.NewTextVirtualRow(fields, []proto.Value{
proto.NewValueInt64(i),
proto.NewValueString(fmt.Sprintf("fake-name-%d", i)),
proto.NewValueInt64(i & 1), // 0=female,1=male
}))
}

preFiltered := Pipe(root, FilterPrefix(func(row proto.Row) bool {
dest := make([]proto.Value, len(fields))
_ = row.Scan(dest)
var fkname sql.NullString
_ = fkname.Scan(dest[1])
assert.True(t, fkname.Valid)
return strings.HasPrefix(fkname.String, "fake-name-1")
}, "fake"))

for {
next, err := preFiltered.Next()
if err == io.EOF {
break
}
assert.NoError(t, err)

dest := make([]proto.Value, len(fields))
_ = next.Scan(dest)
t.Logf("id=%v, name=%v, gender=%v\n", dest[0], dest[1], dest[2])
}
}
127 changes: 127 additions & 0 deletions pkg/dataset/reduce_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package dataset

import (
"fmt"
"testing"
)

import (
"github.com/stretchr/testify/assert"
)

import (
consts "github.com/arana-db/arana/pkg/constants/mysql"
"github.com/arana-db/arana/pkg/mysql"
vrows "github.com/arana-db/arana/pkg/mysql/rows"
"github.com/arana-db/arana/pkg/proto"
"github.com/arana-db/arana/pkg/reduce"
)

func TestReduce(t *testing.T) {
fields := []proto.Field{
mysql.NewField("score", consts.FieldTypeLong),
}

var origin VirtualDataset
origin.Columns = fields

var rows [][]proto.Value
for i := 0; i < 10; i++ {
rows = append(rows, []proto.Value{
proto.NewValueInt64(int64(i)),
})
}

for _, it := range rows {
origin.Rows = append(origin.Rows, vrows.NewTextVirtualRow(fields, it))
}

totalFields := []proto.Field{
mysql.NewField("total", consts.FieldTypeLong),
}

// Simulate: SELECT sum(score) AS total FROM xxx WHERE ...
pSum := Pipe(&origin,
Reduce(
map[int]reduce.Reducer{
0: reduce.Sum(),
},
),
)
for {
next, err := pSum.Next()
if err != nil {
break
}
assert.NoError(t, err)
v := make([]proto.Value, len(totalFields))
_ = next.Scan(v)
assert.Equal(t, "45", fmt.Sprint(v[0]))
t.Logf("next: total=%v\n", v[0])
}

maxFields := []proto.Field{
mysql.NewField("max", consts.FieldTypeLong),
}

// Simulate: SELECT max(score) AS max FROM xxx WHERE ...
pMax := Pipe(&origin,
Reduce(
map[int]reduce.Reducer{
0: reduce.Max(),
},
),
)
for {
next, err := pMax.Next()
if err != nil {
break
}
assert.NoError(t, err)
v := make([]proto.Value, len(maxFields))
_ = next.Scan(v)
assert.Equal(t, "9", fmt.Sprint(v[0]))
t.Logf("next: max=%v\n", v[0])
}

minFields := []proto.Field{
mysql.NewField("min", consts.FieldTypeLong),
}

// Simulate: SELECT min(score) AS min FROM xxx WHERE ...
pMin := Pipe(&origin,
Reduce(
map[int]reduce.Reducer{
0: reduce.Min(),
},
),
)
for {
next, err := pMin.Next()
if err != nil {
break
}
assert.NoError(t, err)
v := make([]proto.Value, len(minFields))
_ = next.Scan(v)
assert.Equal(t, "0", fmt.Sprint(v[0]))
t.Logf("next: min=%v\n", v[0])
}
}
39 changes: 39 additions & 0 deletions pkg/dataset/transform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ func TestTransform(t *testing.T) {
mysql.NewField("name", consts.FieldTypeVarChar),
mysql.NewField("level", consts.FieldTypeLong),
}
tsFields := []proto.Field{
mysql.NewField("id", consts.FieldTypeLong),
mysql.NewField("name", consts.FieldTypeVarChar),
mysql.NewField("constlevel", consts.FieldTypeLong),
}

root := &VirtualDataset{
Columns: fields,
Expand Down Expand Up @@ -76,4 +81,38 @@ func TestTransform(t *testing.T) {

assert.Equal(t, "100", fmt.Sprint(dest[2]))
}

for i := int64(0); i < 10; i++ {
root.Rows = append(root.Rows, rows.NewTextVirtualRow(fields, []proto.Value{
proto.NewValueInt64(i),
proto.NewValueString(fmt.Sprintf("fake-name-%d", i)),
proto.NewValueInt64(rand2.Int63n(10)),
}))
}

fdTransformed := Pipe(root, Map(func(fields []proto.Field) []proto.Field {
return tsFields
}, nil))

for {
next, err := fdTransformed.Next()
if err == io.EOF {
break
}

assert.NoError(t, err)

actualFields, _ := fdTransformed.Fields()
actualSet := &VirtualDataset{
Columns: actualFields,
}
actualCell := make([]proto.Value, len(actualFields))
_ = next.Scan(actualCell)
actualSet.Rows = append(actualSet.Rows, rows.NewTextVirtualRow(actualFields, actualCell))

dest := make([]proto.Value, len(actualFields))
_ = actualSet.Rows[0].Scan(dest)

t.Logf("id=%v, myname=%v, constlevel=%v\n", dest[0], dest[1], dest[2])
}
}

0 comments on commit 3ea47a2

Please sign in to comment.