Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

oracle: support txnScope for GetStaleTimestamp #21967

Merged
merged 5 commits into from
Dec 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion store/mockoracle/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (o *MockOracle) GetTimestamp(ctx context.Context, _ *oracle.Option) (uint64
}

// GetStaleTimestamp implements oracle.Oracle interface.
func (o *MockOracle) GetStaleTimestamp(ctx context.Context, prevSecond uint64) (ts uint64, err error) {
func (o *MockOracle) GetStaleTimestamp(ctx context.Context, txnScope string, prevSecond uint64) (ts uint64, err error) {
physical := oracle.GetPhysical(time.Now().Add(-time.Second * time.Duration(prevSecond)))
ts = oracle.ComposeTS(physical, 0)
return ts, nil
Expand Down
2 changes: 1 addition & 1 deletion store/tikv/oracle/oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Oracle interface {
GetTimestampAsync(ctx context.Context, opt *Option) Future
GetLowResolutionTimestamp(ctx context.Context, opt *Option) (uint64, error)
GetLowResolutionTimestampAsync(ctx context.Context, opt *Option) Future
GetStaleTimestamp(ctx context.Context, prevSecond uint64) (uint64, error)
GetStaleTimestamp(ctx context.Context, txnScope string, prevSecond uint64) (uint64, error)
IsExpired(lockTimestamp, TTL uint64, opt *Option) bool
UntilExpired(lockTimeStamp, TTL uint64, opt *Option) int64
Close()
Expand Down
7 changes: 4 additions & 3 deletions store/tikv/oracle/oracles/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,13 @@ func SetEmptyPDOracleLastTs(oc oracle.Oracle, ts uint64) {
lastTSPointer := lastTSInterface.(*uint64)
atomic.StoreUint64(lastTSPointer, ts)
}
setEmptyPDOracleLastArrivalTs(oc, ts)
}

// SetEmptyPDOracleLastTs exports PD oracle's global last ts to test.
func SetEmptyPDOracleLastArrivalTs(oc oracle.Oracle, ts uint64) {
// setEmptyPDOracleLastArrivalTs exports PD oracle's global last ts to test.
func setEmptyPDOracleLastArrivalTs(oc oracle.Oracle, ts uint64) {
switch o := oc.(type) {
case *pdOracle:
o.setLastArrivalTS(ts)
o.setLastArrivalTS(ts, oracle.GlobalTxnScope)
}
}
2 changes: 1 addition & 1 deletion store/tikv/oracle/oracles/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (l *localOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *o
}

// GetStaleTimestamp return physical
func (l *localOracle) GetStaleTimestamp(ctx context.Context, prevSecond uint64) (ts uint64, err error) {
func (l *localOracle) GetStaleTimestamp(ctx context.Context, txnScope string, prevSecond uint64) (ts uint64, err error) {
physical := oracle.GetPhysical(time.Now().Add(-time.Second * time.Duration(prevSecond)))
ts = oracle.ComposeTS(physical, 0)
return ts, nil
Expand Down
69 changes: 44 additions & 25 deletions store/tikv/oracle/oracles/pd.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ const slowDist = 30 * time.Millisecond
type pdOracle struct {
c pd.Client
// txn_scope (string) -> lastTSPointer (*uint64)
lastTSMap sync.Map
lastArrivalTS uint64
quit chan struct{}
lastTSMap sync.Map
// txn_scope (string) -> lastArrivalTSPointer (*uint64)
lastArrivalTSMap sync.Map
quit chan struct{}
}

// NewPdOracle create an Oracle that uses a pd client source.
Expand Down Expand Up @@ -77,9 +78,7 @@ func (o *pdOracle) GetTimestamp(ctx context.Context, opt *oracle.Option) (uint64
if err != nil {
return 0, errors.Trace(err)
}
tsArrival := o.getArrivalTimestamp(ctx)
o.setLastTS(ts, opt.TxnScope)
o.setLastArrivalTS(tsArrival)
return ts, nil
}

Expand Down Expand Up @@ -134,7 +133,7 @@ func (o *pdOracle) getTimestamp(ctx context.Context, txnScope string) (uint64, e
return oracle.ComposeTS(physical, logical), nil
}

func (o *pdOracle) getArrivalTimestamp(ctx context.Context) uint64 {
func (o *pdOracle) getArrivalTimestamp() uint64 {
return oracle.ComposeTS(oracle.GetPhysical(time.Now()), 0)
}

Expand All @@ -147,6 +146,27 @@ func (o *pdOracle) setLastTS(ts uint64, txnScope string) {
lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, new(uint64))
}
lastTSPointer := lastTSInterface.(*uint64)
for {
lastTS := atomic.LoadUint64(lastTSPointer)
if ts <= lastTS {
return
}
if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) {
break
}
}
o.setLastArrivalTS(o.getArrivalTimestamp(), txnScope)
}

func (o *pdOracle) setLastArrivalTS(ts uint64, txnScope string) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
if !ok {
lastTSInterface, _ = o.lastArrivalTSMap.LoadOrStore(txnScope, new(uint64))
}
lastTSPointer := lastTSInterface.(*uint64)
for {
lastTS := atomic.LoadUint64(lastTSPointer)
if ts <= lastTS {
Expand All @@ -169,6 +189,17 @@ func (o *pdOracle) getLastTS(txnScope string) (uint64, bool) {
return atomic.LoadUint64(lastTSInterface.(*uint64)), true
}

func (o *pdOracle) getLastArrivalTS(txnScope string) (uint64, bool) {
if txnScope == "" {
txnScope = oracle.GlobalTxnScope
}
lastArrivalTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
if !ok {
return 0, false
}
return atomic.LoadUint64(lastArrivalTSInterface.(*uint64)), true
}

func (o *pdOracle) updateTS(ctx context.Context, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
Expand Down Expand Up @@ -239,16 +270,16 @@ func (o *pdOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *orac
}
}

func (o *pdOracle) getStaleTimestamp(ctx context.Context, prevSecond uint64) (uint64, error) {
ts, ok := o.getLastTS(oracle.GlobalTxnScope)
func (o *pdOracle) getStaleTimestamp(txnScope string, prevSecond uint64) (uint64, error) {
ts, ok := o.getLastTS(txnScope)
if !ok {
return 0, errors.Errorf("get stale timestamp fail, invalid txnScope = %s", oracle.GlobalTxnScope)
}
tsArrival, ok := o.getLastArrivalTS()
arrivalTS, ok := o.getLastArrivalTS(txnScope)
if !ok {
return 0, errors.Errorf("get last arrival timestamp fail, invalid txnScope = %s", oracle.GlobalTxnScope)
return 0, errors.Errorf("get stale arrival timestamp fail, invalid txnScope = %s", oracle.GlobalTxnScope)
}
arrivalTime := oracle.GetTimeFromTS(tsArrival)
arrivalTime := oracle.GetTimeFromTS(arrivalTS)
physicalTime := oracle.GetTimeFromTS(ts)
if uint64(physicalTime.Unix()) <= prevSecond {
return 0, errors.Errorf("invalid prevSecond %v", prevSecond)
Expand All @@ -260,22 +291,10 @@ func (o *pdOracle) getStaleTimestamp(ctx context.Context, prevSecond uint64) (ui
}

// GetStaleTimestamp generate a TSO which represents for the TSO prevSecond secs ago.
func (o *pdOracle) GetStaleTimestamp(ctx context.Context, prevSecond uint64) (ts uint64, err error) {
ts, err = o.getStaleTimestamp(ctx, prevSecond)
func (o *pdOracle) GetStaleTimestamp(ctx context.Context, txnScope string, prevSecond uint64) (ts uint64, err error) {
ts, err = o.getStaleTimestamp(txnScope, prevSecond)
if err != nil {
return 0, errors.Trace(err)
}
return ts, nil
}

func (o *pdOracle) setLastArrivalTS(ts uint64) {
atomic.StoreUint64(&o.lastArrivalTS, ts)
}

func (o *pdOracle) getLastArrivalTS() (uint64, bool) {
ts := atomic.LoadUint64(&o.lastArrivalTS)
if ts > 0 {
return ts, true
}
return 0, false
}
8 changes: 3 additions & 5 deletions store/tikv/oracle/oracles/pd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ func TestPdOracle_GetStaleTimestamp(t *testing.T) {
o := oracles.NewEmptyPDOracle()
start := time.Now()
oracles.SetEmptyPDOracleLastTs(o, oracle.ComposeTS(oracle.GetPhysical(start), 0))
oracles.SetEmptyPDOracleLastArrivalTs(o, oracle.ComposeTS(oracle.GetPhysical(start), 0))
ts, err := o.GetStaleTimestamp(context.Background(), 10)
ts, err := o.GetStaleTimestamp(context.Background(), oracle.GlobalTxnScope, 10)
if err != nil {
t.Errorf("%v\n", err)
}
Expand All @@ -54,16 +53,15 @@ func TestPdOracle_GetStaleTimestamp(t *testing.T) {
t.Errorf("stable TS have accuracy err, expect: %d +-2, obtain: %d", 10, duration)
}

_, err = o.GetStaleTimestamp(context.Background(), 1e12)
_, err = o.GetStaleTimestamp(context.Background(), oracle.GlobalTxnScope, 1e12)
if err == nil {
t.Errorf("expect exceed err but get nil")
}

for i := uint64(3); i < 1e9; i += i/100 + 1 {
start = time.Now()
oracles.SetEmptyPDOracleLastTs(o, oracle.ComposeTS(oracle.GetPhysical(start), 0))
oracles.SetEmptyPDOracleLastArrivalTs(o, oracle.ComposeTS(oracle.GetPhysical(start), 0))
ts, err = o.GetStaleTimestamp(context.Background(), i)
ts, err = o.GetStaleTimestamp(context.Background(), oracle.GlobalTxnScope, i)
if err != nil {
t.Errorf("%v\n", err)
}
Expand Down