Skip to content

Commit

Permalink
Merge branch 'master' into etcd3
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-chi-bot authored Mar 20, 2023
2 parents 469ba03 + 220dbed commit afa43ed
Show file tree
Hide file tree
Showing 50 changed files with 867 additions and 567 deletions.
205 changes: 136 additions & 69 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ import (
"google.golang.org/grpc/status"
)

const (
// defaultKeyspaceID is the default key space id.
// Valid keyspace id range is [0, 0xFFFFFF](uint24max, or 16777215)
// ​0 is reserved for default keyspace with the name "DEFAULT", It's initialized when PD bootstrap and reserved for users who haven't been assigned keyspace.
defaultKeyspaceID = uint32(0)
)

// Region contains information of a region's meta and its peers.
type Region struct {
Meta *metapb.Region
Expand Down Expand Up @@ -239,12 +246,20 @@ func WithMaxErrorRetry(count int) ClientOption {

var _ Client = (*client)(nil)

// serviceModeKeeper is for service mode switching
type serviceModeKeeper struct {
svcModeMutex sync.RWMutex
serviceMode pdpb.ServiceMode
tsoClient atomic.Value
tsoSvcDiscovery ServiceDiscovery
}

type client struct {
keyspaceID uint32
// svcDiscovery is for pd service discovery
svcDiscovery ServiceDiscovery
tsoClient *tsoClient
keyspaceID uint32
svrUrls []string
pdSvcDiscovery ServiceDiscovery
tokenDispatcher *tokenDispatcher
serviceModeKeeper

// For internal usage.
updateTokenConnectionCh chan struct{}
Expand All @@ -253,6 +268,7 @@ type client struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
tlsCfg *tlsutil.TLSConfig
option *option
}

Expand All @@ -272,56 +288,15 @@ func NewClient(svrAddrs []string, security SecurityOption, opts ...ClientOption)
return NewClientWithContext(context.Background(), svrAddrs, security, opts...)
}

// NewClientWithContext creates a PD client with context.
// NewClientWithContext creates a PD client with context. This API uses the default keyspace id 0.
func NewClientWithContext(ctx context.Context, svrAddrs []string, security SecurityOption, opts ...ClientOption) (Client, error) {
log.Info("[pd] create pd client with endpoints", zap.Strings("pd-address", svrAddrs))
c, clientCtx, clientCancel, tlsCfg := createClient(ctx, 0, &security)
// Inject the client options.
for _, opt := range opts {
opt(c)
}

c.svcDiscovery = newPDServiceDiscovery(clientCtx, clientCancel, &c.wg, addrsToUrls(svrAddrs), tlsCfg, c.option)
c.tsoClient = newTSOClient(clientCtx, clientCancel, &c.wg, c.option, c.keyspaceID, c.svcDiscovery, c.svcDiscovery.(tsoAllocatorEventSource), &pdTSOStreamBuilderFactory{})
if err := c.setup(); err != nil {
c.cancel()
return nil, err
}
if err := c.tsoClient.setup(); err != nil {
c.cancel()
return nil, err
}
return c, nil
return NewClientWithKeyspace(ctx, defaultKeyspaceID, svrAddrs, security, opts...)
}

// NewTSOClientWithContext creates a TSO client with context.
// TODO:
// Merge NewClientWithContext with this API after we let client detect service mode provided on the server side.
// Before that, internal tools call this function to use mcs service.
func NewTSOClientWithContext(ctx context.Context, keyspaceID uint32, svrAddrs []string, security SecurityOption, opts ...ClientOption) (Client, error) {
log.Info("[tso] create tso client with endpoints", zap.Strings("pd(api)-address", svrAddrs))
c, clientCtx, clientCancel, tlsCfg := createClient(ctx, keyspaceID, &security)
// Inject the client options.
for _, opt := range opts {
opt(c)
}

c.svcDiscovery = newPDServiceDiscovery(clientCtx, clientCancel, &c.wg, addrsToUrls(svrAddrs), tlsCfg, c.option)
if err := c.setup(); err != nil {
c.cancel()
return nil, err
}

tsoSvcDiscovery := newTSOServiceDiscovery(clientCtx, clientCancel, &c.wg, MetaStorageClient(c), c.GetClusterID(c.ctx), keyspaceID, addrsToUrls(svrAddrs), tlsCfg, c.option)
c.tsoClient = newTSOClient(clientCtx, clientCancel, &c.wg, c.option, c.keyspaceID, tsoSvcDiscovery, tsoSvcDiscovery.(tsoAllocatorEventSource), &tsoTSOStreamBuilderFactory{})
if err := c.tsoClient.setup(); err != nil {
c.cancel()
return nil, err
}
return c, nil
}
// NewClientWithKeyspace creates a client with context and the specified keyspace id.
func NewClientWithKeyspace(ctx context.Context, keyspaceID uint32, svrAddrs []string, security SecurityOption, opts ...ClientOption) (Client, error) {
log.Info("[pd] create pd client with endpoints and keyspace", zap.Strings("pd-address", svrAddrs), zap.Uint32("keyspace-id", keyspaceID))

func createClient(ctx context.Context, keyspaceID uint32, security *SecurityOption) (*client, context.Context, context.CancelFunc, *tlsutil.TLSConfig) {
tlsCfg := &tlsutil.TLSConfig{
CAPath: security.CAPath,
CertPath: security.CertPath,
Expand All @@ -338,20 +313,33 @@ func createClient(ctx context.Context, keyspaceID uint32, security *SecurityOpti
ctx: clientCtx,
cancel: clientCancel,
keyspaceID: keyspaceID,
svrUrls: addrsToUrls(svrAddrs),
tlsCfg: tlsCfg,
option: newOption(),
}

return c, clientCtx, clientCancel, tlsCfg
// Inject the client options.
for _, opt := range opts {
opt(c)
}

c.pdSvcDiscovery = newPDServiceDiscovery(clientCtx, clientCancel, &c.wg, c.setServiceMode, c.svrUrls, c.tlsCfg, c.option)
if err := c.setup(); err != nil {
c.cancel()
return nil, err
}

return c, nil
}

func (c *client) setup() error {
// Init the client base.
if err := c.svcDiscovery.Init(); err != nil {
if err := c.pdSvcDiscovery.Init(); err != nil {
return err
}

// Register callbacks
c.svcDiscovery.AddServingAddrSwitchedCallback(c.scheduleUpdateTokenConnection)
c.pdSvcDiscovery.AddServingAddrSwitchedCallback(c.scheduleUpdateTokenConnection)

// Create dispatchers
c.createTokenDispatcher()
Expand All @@ -366,8 +354,13 @@ func (c *client) Close() {
c.cancel()
c.wg.Wait()

c.tsoClient.Close()
c.svcDiscovery.Close()
if tsoClient := c.getTSOClient(); tsoClient != nil {
tsoClient.Close()
}
if c.tsoSvcDiscovery != nil {
c.tsoSvcDiscovery.Close()
}
c.pdSvcDiscovery.Close()

if c.tokenDispatcher != nil {
tokenErr := errors.WithStack(errClosing)
Expand All @@ -376,6 +369,67 @@ func (c *client) Close() {
}
}

func (c *client) setServiceMode(newMode pdpb.ServiceMode) {
c.svcModeMutex.Lock()
defer c.svcModeMutex.Unlock()

if newMode == c.serviceMode {
return
}

log.Info("changing service mode", zap.String("old-mode", pdpb.ServiceMode_name[int32(c.serviceMode)]),
zap.String("new-mode", pdpb.ServiceMode_name[int32(newMode)]))

if newMode == pdpb.ServiceMode_UNKNOWN_SVC_MODE {
log.Warn("intend to switch to unknown service mode. do nothing")
return
}

var newTSOCli *tsoClient
tsoSvcDiscovery := c.tsoSvcDiscovery
ctx, cancel := context.WithCancel(c.ctx)

if newMode == pdpb.ServiceMode_PD_SVC_MODE {
newTSOCli = newTSOClient(ctx, cancel, c.option, c.keyspaceID,
c.pdSvcDiscovery, c.pdSvcDiscovery.(tsoAllocatorEventSource), &pdTSOStreamBuilderFactory{})
newTSOCli.Setup()
} else {
tsoSvcDiscovery = newTSOServiceDiscovery(ctx, cancel, MetaStorageClient(c),
c.GetClusterID(c.ctx), c.keyspaceID, c.svrUrls, c.tlsCfg, c.option)
newTSOCli = newTSOClient(ctx, cancel, c.option, c.keyspaceID,
tsoSvcDiscovery, tsoSvcDiscovery.(tsoAllocatorEventSource), &tsoTSOStreamBuilderFactory{})
if err := tsoSvcDiscovery.Init(); err != nil {
cancel()
log.Error("failed to initialize tso service discovery. keep the current service mode",
zap.Strings("svr-urls", c.svrUrls), zap.String("current-mode", pdpb.ServiceMode_name[int32(c.serviceMode)]), zap.Error(err))
return
}
newTSOCli.Setup()
}

// cleanup the old tso client
if oldTSOCli := c.getTSOClient(); oldTSOCli != nil {
oldTSOCli.Close()
}
if c.serviceMode == pdpb.ServiceMode_API_SVC_MODE {
c.tsoSvcDiscovery.Close()
}

c.tsoSvcDiscovery = tsoSvcDiscovery
c.tsoClient.Store(newTSOCli)

log.Info("service mode changed", zap.String("old-mode", pdpb.ServiceMode_name[int32(c.serviceMode)]),
zap.String("new-mode", pdpb.ServiceMode_name[int32(newMode)]))
c.serviceMode = newMode
}

func (c *client) getTSOClient() *tsoClient {
if tsoCli := c.tsoClient.Load(); tsoCli != nil {
return tsoCli.(*tsoClient)
}
return nil
}

func (c *client) scheduleUpdateTokenConnection() {
select {
case c.updateTokenConnectionCh <- struct{}{}:
Expand All @@ -385,17 +439,17 @@ func (c *client) scheduleUpdateTokenConnection() {

// GetClusterID returns the ClusterID.
func (c *client) GetClusterID(context.Context) uint64 {
return c.svcDiscovery.GetClusterID()
return c.pdSvcDiscovery.GetClusterID()
}

// GetLeaderAddr returns the leader address.
func (c *client) GetLeaderAddr() string {
return c.svcDiscovery.GetServingAddr()
return c.pdSvcDiscovery.GetServingAddr()
}

// GetServiceDiscovery returns the client-side service discovery object
func (c *client) GetServiceDiscovery() ServiceDiscovery {
return c.svcDiscovery
return c.pdSvcDiscovery
}

// UpdateOption updates the client option.
Expand Down Expand Up @@ -443,7 +497,7 @@ func (c *client) leaderCheckLoop() {
func (c *client) checkLeaderHealth(ctx context.Context) {
ctx, cancel := context.WithTimeout(ctx, c.option.timeout)
defer cancel()
if client := c.svcDiscovery.GetServingEndpointClientConn(); client != nil {
if client := c.pdSvcDiscovery.GetServingEndpointClientConn(); client != nil {
healthCli := healthpb.NewHealthClient(client)
resp, err := healthCli.Check(ctx, &healthpb.HealthCheckRequest{Service: ""})
rpcErr, ok := status.FromError(err)
Expand Down Expand Up @@ -481,7 +535,7 @@ func (c *client) GetAllMembers(ctx context.Context) ([]*pdpb.Member, error) {

// leaderClient gets the client of current PD leader.
func (c *client) leaderClient() pdpb.PDClient {
if client := c.svcDiscovery.GetServingEndpointClientConn(); client != nil {
if client := c.pdSvcDiscovery.GetServingEndpointClientConn(); client != nil {
return pdpb.NewPDClient(client)
}
return nil
Expand All @@ -491,7 +545,7 @@ func (c *client) leaderClient() pdpb.PDClient {
// backup service endpoints randomly. Backup service endpoints are followers in a
// quorum-based cluster or secondaries in a primary/secondary configured cluster.
func (c *client) backupClientConn() (*grpc.ClientConn, string) {
addrs := c.svcDiscovery.GetBackupAddrs()
addrs := c.pdSvcDiscovery.GetBackupAddrs()
if len(addrs) < 1 {
return nil, ""
}
Expand All @@ -501,7 +555,7 @@ func (c *client) backupClientConn() (*grpc.ClientConn, string) {
)
for i := 0; i < len(addrs); i++ {
addr := addrs[rand.Intn(len(addrs))]
if cc, err = c.svcDiscovery.GetOrCreateGRPCConn(addr); err != nil {
if cc, err = c.pdSvcDiscovery.GetOrCreateGRPCConn(addr); err != nil {
continue
}
healthCtx, healthCancel := context.WithTimeout(c.ctx, c.option.timeout)
Expand Down Expand Up @@ -538,14 +592,20 @@ func (c *client) GetLocalTSAsync(ctx context.Context, dcLocation string) TSFutur
req := tsoReqPool.Get().(*tsoRequest)
req.requestCtx = ctx
req.clientCtx = c.ctx
tsoClient := c.getTSOClient()
req.start = time.Now()
req.keyspaceID = c.keyspaceID
req.dcLocation = dcLocation

if err := c.tsoClient.dispatchRequest(dcLocation, req); err != nil {
if tsoClient == nil {
req.done <- errs.ErrClientGetTSO
return req
}

if err := tsoClient.dispatchRequest(dcLocation, req); err != nil {
// Wait for a while and try again
time.Sleep(50 * time.Millisecond)
if err = c.tsoClient.dispatchRequest(dcLocation, req); err != nil {
if err = tsoClient.dispatchRequest(dcLocation, req); err != nil {
req.done <- err
}
}
Expand Down Expand Up @@ -626,7 +686,7 @@ func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs

var resp *pdpb.GetRegionResponse
for _, url := range memberURLs {
conn, err := c.svcDiscovery.GetOrCreateGRPCConn(url)
conn, err := c.pdSvcDiscovery.GetOrCreateGRPCConn(url)
if err != nil {
log.Error("[pd] can't get grpc connection", zap.String("member-URL", url), errs.ZapError(err))
continue
Expand All @@ -647,7 +707,7 @@ func (c *client) GetRegionFromMember(ctx context.Context, key []byte, memberURLs

if resp == nil {
cmdFailDurationGetRegion.Observe(time.Since(start).Seconds())
c.svcDiscovery.ScheduleCheckMemberChanged()
c.pdSvcDiscovery.ScheduleCheckMemberChanged()
errorMsg := fmt.Sprintf("[pd] can't get region info from member URLs: %+v", memberURLs)
return nil, errors.WithStack(errors.New(errorMsg))
}
Expand Down Expand Up @@ -1044,7 +1104,7 @@ func (c *client) SplitRegions(ctx context.Context, splitKeys [][]byte, opts ...R

func (c *client) requestHeader() *pdpb.RequestHeader {
return &pdpb.RequestHeader{
ClusterId: c.svcDiscovery.GetClusterID(),
ClusterId: c.pdSvcDiscovery.GetClusterID(),
}
}

Expand Down Expand Up @@ -1096,6 +1156,9 @@ func addrsToUrls(addrs []string) []string {

// IsLeaderChange will determine whether there is a leader change.
func IsLeaderChange(err error) bool {
if err == errs.ErrClientTSOStreamClosed {
return true
}
errMsg := err.Error()
return strings.Contains(errMsg, errs.NotLeaderErr) || strings.Contains(errMsg, errs.MismatchLeaderErr)
}
Expand Down Expand Up @@ -1237,7 +1300,7 @@ func (c *client) respForErr(observer prometheus.Observer, start time.Time, err e
if err != nil || header.GetError() != nil {
observer.Observe(time.Since(start).Seconds())
if err != nil {
c.svcDiscovery.ScheduleCheckMemberChanged()
c.pdSvcDiscovery.ScheduleCheckMemberChanged()
return errors.WithStack(err)
}
return errors.WithStack(errors.New(header.GetError().String()))
Expand All @@ -1248,5 +1311,9 @@ func (c *client) respForErr(observer prometheus.Observer, start time.Time, err e
// GetTSOAllocators returns {dc-location -> TSO allocator leader URL} connection map
// For test only.
func (c *client) GetTSOAllocators() *sync.Map {
return c.tsoClient.GetTSOAllocators()
tsoClient := c.getTSOClient()
if tsoClient == nil {
return nil
}
return tsoClient.GetTSOAllocators()
}
21 changes: 12 additions & 9 deletions client/errs/errno.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,18 @@ const (

// client errors
var (
ErrClientGetProtoClient = errors.Normalize("failed to get proto client, %s", errors.RFCCodeText("PD:client:ErrClientGetProtoClient"))
ErrClientCreateTSOStream = errors.Normalize("create TSO stream failed, %s", errors.RFCCodeText("PD:client:ErrClientCreateTSOStream"))
ErrClientGetTSOTimeout = errors.Normalize("get TSO timeout", errors.RFCCodeText("PD:client:ErrClientGetTSOTimeout"))
ErrClientGetTSO = errors.Normalize("get TSO failed, %v", errors.RFCCodeText("PD:client:ErrClientGetTSO"))
ErrClientGetLeader = errors.Normalize("get leader from %v error", errors.RFCCodeText("PD:client:ErrClientGetLeader"))
ErrClientGetMember = errors.Normalize("get member failed", errors.RFCCodeText("PD:client:ErrClientGetMember"))
ErrClientUpdateMember = errors.Normalize("update member failed, %v", errors.RFCCodeText("PD:client:ErrUpdateMember"))
ErrClientProtoUnmarshal = errors.Normalize("failed to unmarshal proto", errors.RFCCodeText("PD:proto:ErrClientProtoUnmarshal"))
ErrClientGetMultiResponse = errors.Normalize("get invalid value response %v, must only one", errors.RFCCodeText("PD:client:ErrClientGetMultiResponse"))
ErrClientGetProtoClient = errors.Normalize("failed to get proto client", errors.RFCCodeText("PD:client:ErrClientGetProtoClient"))
ErrClientCreateTSOStream = errors.Normalize("create TSO stream failed, %s", errors.RFCCodeText("PD:client:ErrClientCreateTSOStream"))
ErrClientTSOStreamClosed = errors.Normalize("encountered TSO stream being closed unexpectedly", errors.RFCCodeText("PD:client:ErrClientTSOStreamClosed"))
ErrClientGetTSOTimeout = errors.Normalize("get TSO timeout", errors.RFCCodeText("PD:client:ErrClientGetTSOTimeout"))
ErrClientGetTSO = errors.Normalize("get TSO failed, %v", errors.RFCCodeText("PD:client:ErrClientGetTSO"))
ErrClientGetLeader = errors.Normalize("get leader from %v error", errors.RFCCodeText("PD:client:ErrClientGetLeader"))
ErrClientGetMember = errors.Normalize("get member failed", errors.RFCCodeText("PD:client:ErrClientGetMember"))
ErrClientGetClusterInfo = errors.Normalize("get cluster info failed", errors.RFCCodeText("PD:client:ErrClientGetClusterInfo"))
ErrClientUpdateMember = errors.Normalize("update member failed, %v", errors.RFCCodeText("PD:client:ErrUpdateMember"))
ErrClientProtoUnmarshal = errors.Normalize("failed to unmarshal proto", errors.RFCCodeText("PD:proto:ErrClientProtoUnmarshal"))
ErrClientGetMultiResponse = errors.Normalize("get invalid value response %v, must only one", errors.RFCCodeText("PD:client:ErrClientGetMultiResponse"))
ErrClientGetServingEndpoint = errors.Normalize("get serving endpoint failed", errors.RFCCodeText("PD:client:ErrClientGetServingEndpoint"))
)

// grpcutil errors
Expand Down
Loading

0 comments on commit afa43ed

Please sign in to comment.