Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression, sessionctx: support rand_seed1 and rand_seed2 sysvar #29936

Merged
merged 9 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions expression/builtin_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ import (
"math"
"strconv"
"strings"
"sync"

"github.com/cznic/mathutil"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
utilMath "github.com/pingcap/tidb/util/math"
"github.com/pingcap/tipb/go-tipb"
)

Expand Down Expand Up @@ -1014,7 +1014,7 @@ func (c *randFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
}
bt := bf
if len(args) == 0 {
sig = &builtinRandSig{bt, &sync.Mutex{}, NewWithTime()}
sig = &builtinRandSig{bt, ctx.GetSessionVars().Rng}
sig.setPbCode(tipb.ScalarFuncSig_Rand)
} else if _, isConstant := args[0].(*Constant); isConstant {
// According to MySQL manual:
Expand All @@ -1030,7 +1030,7 @@ func (c *randFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
// The behavior same as MySQL.
seed = 0
}
sig = &builtinRandSig{bt, &sync.Mutex{}, NewWithSeed(seed)}
sig = &builtinRandSig{bt, utilMath.NewWithSeed(seed)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but I'm not sure about removing the mutex that was previously used. Can you explain in the notes why it is safe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sig.setPbCode(tipb.ScalarFuncSig_Rand)
} else {
sig = &builtinRandWithSeedFirstGenSig{bt}
Expand All @@ -1041,22 +1041,19 @@ func (c *randFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio

type builtinRandSig struct {
baseBuiltinFunc
mu *sync.Mutex
mysqlRng *MysqlRng
mysqlRng *utilMath.MysqlRng
}

func (b *builtinRandSig) Clone() builtinFunc {
newSig := &builtinRandSig{mysqlRng: b.mysqlRng, mu: b.mu}
newSig := &builtinRandSig{mysqlRng: b.mysqlRng}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}

// evalReal evals RAND().
// See https://dev.mysql.com/doc/refman/5.7/en/mathematical-functions.html#function_rand
func (b *builtinRandSig) evalReal(row chunk.Row) (float64, bool, error) {
b.mu.Lock()
res := b.mysqlRng.Gen()
b.mu.Unlock()
return res, false, nil
}

Expand All @@ -1080,11 +1077,11 @@ func (b *builtinRandWithSeedFirstGenSig) evalReal(row chunk.Row) (float64, bool,
// b.args[0] is promised to be a non-constant(such as a column name) in
// builtinRandWithSeedFirstGenSig, the seed is initialized with the value for each
// invocation of RAND().
var rng *MysqlRng
var rng *utilMath.MysqlRng
if !isNull {
rng = NewWithSeed(seed)
rng = utilMath.NewWithSeed(seed)
} else {
rng = NewWithSeed(0)
rng = utilMath.NewWithSeed(0)
}
return rng.Gen(), false, nil
}
Expand Down
3 changes: 2 additions & 1 deletion expression/builtin_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/testkit/trequire"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
utilMath "github.com/pingcap/tidb/util/math"
"github.com/pingcap/tipb/go-tipb"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -382,7 +383,7 @@ func TestRand(t *testing.T) {
// issue 3211
f2, err := fc.getFunction(ctx, []Expression{&Constant{Value: types.NewIntDatum(20160101), RetType: types.NewFieldType(mysql.TypeLonglong)}})
require.NoError(t, err)
randGen := NewWithSeed(20160101)
randGen := utilMath.NewWithSeed(20160101)
for i := 0; i < 3; i++ {
v, err = evalBuiltinFunc(f2, chunk.Row{})
require.NoError(t, err)
Expand Down
8 changes: 4 additions & 4 deletions expression/builtin_math_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"math"
"strconv"

utilMath "github.com/pingcap/tidb/util/math"

"github.com/cznic/mathutil"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -709,11 +711,9 @@ func (b *builtinRandSig) vecEvalReal(input *chunk.Chunk, result *chunk.Column) e
n := input.NumRows()
result.ResizeFloat64(n, false)
f64s := result.Float64s()
b.mu.Lock()
for i := range f64s {
f64s[i] = b.mysqlRng.Gen()
}
b.mu.Unlock()
return nil
}

Expand All @@ -738,9 +738,9 @@ func (b *builtinRandWithSeedFirstGenSig) vecEvalReal(input *chunk.Chunk, result
for i := 0; i < n; i++ {
// When the seed is null we need to use 0 as the seed.
// The behavior same as MySQL.
rng := NewWithSeed(0)
rng := utilMath.NewWithSeed(0)
if !buf.IsNull(i) {
rng = NewWithSeed(i64s[i])
rng = utilMath.NewWithSeed(i64s[i])
}
f64s[i] = rng.Gen()
}
Expand Down
4 changes: 4 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,10 @@ func (s *testIntegrationSuite2) TestMathBuiltin(c *C) {
tk.MustQuery("select rand(1) from t").Sort().Check(testkit.Rows("0.1418603212962489", "0.40540353712197724", "0.8716141803857071"))
tk.MustQuery("select rand(a) from t").Check(testkit.Rows("0.40540353712197724", "0.6555866465490187", "0.9057697559760601"))
tk.MustQuery("select rand(1), rand(2), rand(3)").Check(testkit.Rows("0.40540353712197724 0.6555866465490187 0.9057697559760601"))
tk.MustQuery("set @@rand_seed1=10000000,@@rand_seed2=1000000")
tk.MustQuery("select rand()").Check(testkit.Rows("0.028870999839968048"))
tk.MustQuery("select rand(1)").Check(testkit.Rows("0.40540353712197724"))
tk.MustQuery("select rand()").Check(testkit.Rows("0.11641535266900002"))
}

func (s *testIntegrationSuite2) TestStringBuiltin(c *C) {
Expand Down
2 changes: 0 additions & 2 deletions sessionctx/variable/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ var noopSysVars = []*SysVar{
{Scope: ScopeGlobal | ScopeSession, Name: BigTables, Value: Off, Type: TypeBool},
{Scope: ScopeNone, Name: "skip_external_locking", Value: "1"},
{Scope: ScopeNone, Name: "innodb_sync_array_size", Value: "1"},
{Scope: ScopeSession, Name: "rand_seed2", Value: ""},
{Scope: ScopeGlobal, Name: ValidatePasswordCheckUserName, Value: Off, Type: TypeBool},
{Scope: ScopeGlobal, Name: ValidatePasswordNumberCount, Value: "1", Type: TypeUnsigned, MinValue: 0, MaxValue: math.MaxUint64},
{Scope: ScopeSession, Name: "gtid_next", Value: ""},
Expand Down Expand Up @@ -275,7 +274,6 @@ var noopSysVars = []*SysVar{
{Scope: ScopeNone, Name: "binlog_gtid_simple_recovery", Value: "1"},
{Scope: ScopeNone, Name: "performance_schema_digests_size", Value: "10000"},
{Scope: ScopeGlobal | ScopeSession, Name: Profiling, Value: Off, Type: TypeBool},
{Scope: ScopeSession, Name: "rand_seed1", Value: ""},
{Scope: ScopeGlobal, Name: "sha256_password_proxy_users", Value: ""},
{Scope: ScopeGlobal | ScopeSession, Name: SQLQuoteShowCreate, Value: On, Type: TypeBool},
{Scope: ScopeGlobal | ScopeSession, Name: "binlogging_impossible_mode", Value: "IGNORE_ERROR"},
Expand Down
6 changes: 6 additions & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
"sync/atomic"
"time"

utilMath "github.com/pingcap/tidb/util/math"

"github.com/pingcap/errors"
pumpcli "github.com/pingcap/tidb-tools/tidb-binlog/pump_client"
"github.com/pingcap/tidb/config"
Expand Down Expand Up @@ -955,6 +957,9 @@ type SessionVars struct {
curr int8
data [2]stmtctx.StatementContext
}

// Rng stores the rand_seed1 and rand_seed2 for Rand() function
Rng *utilMath.MysqlRng
}

// InitStatementContext initializes a StatementContext, the object is reused to reduce allocation.
Expand Down Expand Up @@ -1188,6 +1193,7 @@ func NewSessionVars() *SessionVars {
MPPStoreLastFailTime: make(map[string]time.Time),
MPPStoreFailTTL: DefTiDBMPPStoreFailTTL,
EnablePlacementChecks: DefEnablePlacementCheck,
Rng: utilMath.NewWithTime(),
}
vars.KVVars = tikvstore.NewVariables(&vars.Killed)
vars.Concurrency = Concurrency{
Expand Down
12 changes: 12 additions & 0 deletions sessionctx/variable/sysvar.go
Original file line number Diff line number Diff line change
Expand Up @@ -1861,6 +1861,18 @@ var defaultSysVars = []*SysVar{
{Scope: ScopeNone, Name: "version_compile_os", Value: runtime.GOOS},
{Scope: ScopeNone, Name: "version_compile_machine", Value: runtime.GOARCH},
{Scope: ScopeNone, Name: TiDBAllowFunctionForExpressionIndex, ReadOnly: true, Value: collectAllowFuncName4ExpressionIndex()},
{Scope: ScopeSession, Name: "rand_seed1", Type: TypeInt, Value: "0", skipInit: true, MaxValue: math.MaxInt32, SetSession: func(s *SessionVars, val string) error {
xhebox marked this conversation as resolved.
Show resolved Hide resolved
s.Rng.SetSeed1(uint32(tidbOptPositiveInt32(val, 0)))
return nil
}, GetSession: func(s *SessionVars) (string, error) {
return "0", nil
}},
{Scope: ScopeSession, Name: "rand_seed2", Type: TypeInt, Value: "0", skipInit: true, MaxValue: math.MaxInt32, SetSession: func(s *SessionVars, val string) error {
s.Rng.SetSeed2(uint32(tidbOptPositiveInt32(val, 0)))
return nil
}, GetSession: func(s *SessionVars) (string, error) {
return "0", nil
}},
}

func collectAllowFuncName4ExpressionIndex() string {
Expand Down
30 changes: 27 additions & 3 deletions expression/rand.go → util/math/rand.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package expression
package math

import "time"
import (
"sync"
"time"
)

const maxRandValue = 0x3FFFFFFF

Expand All @@ -23,13 +26,18 @@ const maxRandValue = 0x3FFFFFFF
type MysqlRng struct {
seed1 uint32
seed2 uint32
mu *sync.Mutex
}

// NewWithSeed create a rng with random seed.
func NewWithSeed(seed int64) *MysqlRng {
seed1 := uint32(seed*0x10001+55555555) % maxRandValue
seed2 := uint32(seed*0x10000001) % maxRandValue
return &MysqlRng{seed1: seed1, seed2: seed2}
return &MysqlRng{
seed1: seed1,
seed2: seed2,
mu: &sync.Mutex{},
}
}

// NewWithTime create a rng with time stamp.
Expand All @@ -39,7 +47,23 @@ func NewWithTime() *MysqlRng {

// Gen will generate random number.
func (rng *MysqlRng) Gen() float64 {
rng.mu.Lock()
defer rng.mu.Unlock()
rng.seed1 = (rng.seed1*3 + rng.seed2) % maxRandValue
rng.seed2 = (rng.seed1 + rng.seed2 + 33) % maxRandValue
return float64(rng.seed1) / float64(maxRandValue)
}

// SetSeed1 is a interface to set seed1
func (rng *MysqlRng) SetSeed1(seed uint32) {
rng.mu.Lock()
defer rng.mu.Unlock()
rng.seed1 = seed
}

// SetSeed2 is a interface to set seed2
func (rng *MysqlRng) SetSeed2(seed uint32) {
rng.mu.Lock()
defer rng.mu.Unlock()
rng.seed2 = seed
}
21 changes: 18 additions & 3 deletions expression/rand_test.go → util/math/rand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package expression
package math

import (
"testing"
Expand Down Expand Up @@ -55,8 +55,23 @@ func TestRandWithSeed(t *testing.T) {
for _, test := range tests {
rng := NewWithSeed(test.seed)
got1 := rng.Gen()
require.True(t, got1 == test.once)
require.Equal(t, got1, test.once)
got2 := rng.Gen()
require.True(t, got2 == test.twice)
require.Equal(t, got2, test.twice)
}
}

func TestRandWithSeed1AndSeed2(t *testing.T) {
t.Parallel()

seed1 := uint32(10000000)
seed2 := uint32(1000000)

rng := NewWithTime()
rng.SetSeed1(seed1)
rng.SetSeed2(seed2)

require.Equal(t, rng.Gen(), 0.028870999839968048)
require.Equal(t, rng.Gen(), 0.11641535266900002)
require.Equal(t, rng.Gen(), 0.49546379455874096)
}