From 3f7ea45599caa77d4277a455be1e448fd61653f5 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 21 Jun 2023 17:31:13 +0800 Subject: [PATCH] client: fix the keyspace ID RW race inside tsoServiceDiscovery (#6657) ref tikv/pd#5895 Fix the keyspace ID RW race inside `tsoServiceDiscovery`. Signed-off-by: JmPotato Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com> --- client/tso_service_discovery.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/client/tso_service_discovery.go b/client/tso_service_discovery.go index e3c67a92fa66..cee079634e9a 100644 --- a/client/tso_service_discovery.go +++ b/client/tso_service_discovery.go @@ -20,6 +20,7 @@ import ( "reflect" "strings" "sync" + "sync/atomic" "time" "github.com/gogo/protobuf/proto" @@ -121,7 +122,7 @@ type tsoServiceDiscovery struct { metacli MetaStorageClient apiSvcDiscovery ServiceDiscovery clusterID uint64 - keyspaceID uint32 + keyspaceID atomic.Uint32 // defaultDiscoveryKey is the etcd path used for discovering the serving endpoints of // the default keyspace group @@ -165,12 +166,12 @@ func newTSOServiceDiscovery( cancel: cancel, metacli: metacli, apiSvcDiscovery: apiSvcDiscovery, - keyspaceID: keyspaceID, clusterID: clusterID, tlsCfg: tlsCfg, option: option, checkMembershipCh: make(chan struct{}, 1), } + c.keyspaceID.Store(keyspaceID) c.keyspaceGroupSD = &keyspaceGroupSvcDiscovery{ primaryAddr: "", secondaryAddrs: make([]string, 0), @@ -269,12 +270,12 @@ func (c *tsoServiceDiscovery) GetClusterID() uint64 { // GetKeyspaceID returns the ID of the keyspace func (c *tsoServiceDiscovery) GetKeyspaceID() uint32 { - return c.keyspaceID + return c.keyspaceID.Load() } // SetKeyspaceID sets the ID of the keyspace func (c *tsoServiceDiscovery) SetKeyspaceID(keyspaceID uint32) { - c.keyspaceID = keyspaceID + c.keyspaceID.Store(keyspaceID) } // GetKeyspaceGroupID returns the ID of the keyspace group. If the keyspace group is unknown, @@ -426,7 +427,7 @@ func (c *tsoServiceDiscovery) updateMember() error { var keyspaceGroup *tsopb.KeyspaceGroup if len(tsoServerAddr) > 0 { - keyspaceGroup, err = c.findGroupByKeyspaceID(c.keyspaceID, tsoServerAddr, updateMemberTimeout) + keyspaceGroup, err = c.findGroupByKeyspaceID(c.GetKeyspaceID(), tsoServerAddr, updateMemberTimeout) if err != nil { if c.tsoServerDiscovery.countFailure() { log.Error("[tso] failed to find the keyspace group", errs.ZapError(err))