diff --git a/pkg/client/config.go b/pkg/client/config.go index bd84852d..4c56be27 100644 --- a/pkg/client/config.go +++ b/pkg/client/config.go @@ -31,6 +31,7 @@ import ( "github.com/knadh/koanf/parsers/toml" "github.com/knadh/koanf/parsers/yaml" "github.com/knadh/koanf/providers/rawbytes" + "github.com/seata/seata-go/pkg/discovery" "github.com/seata/seata-go/pkg/datasource/sql" diff --git a/pkg/remoting/getty/rpc_client.go b/pkg/remoting/getty/rpc_client.go index 601064ef..7311876b 100644 --- a/pkg/remoting/getty/rpc_client.go +++ b/pkg/remoting/getty/rpc_client.go @@ -25,6 +25,7 @@ import ( getty "github.com/apache/dubbo-getty" gxsync "github.com/dubbogo/gost/sync" + "github.com/seata/seata-go/pkg/discovery" "github.com/seata/seata-go/pkg/protocol/codec" "github.com/seata/seata-go/pkg/remoting/config" diff --git a/pkg/remoting/loadbalance/consistent_hash_loadbalance.go b/pkg/remoting/loadbalance/consistent_hash_loadbalance.go new file mode 100644 index 00000000..626a0104 --- /dev/null +++ b/pkg/remoting/loadbalance/consistent_hash_loadbalance.go @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import ( + "crypto/md5" + "fmt" + "sort" + "sync" + + getty "github.com/apache/dubbo-getty" +) + +var ( + once sync.Once + defaultVirtualNodeNumber = 10 + consistentInstance *Consistent +) + +type Consistent struct { + sync.RWMutex + virtualNodeCount int + // consistent hashCircle + hashCircle map[int64]getty.Session + sortedHashNodes []int64 +} + +func (c *Consistent) put(key int64, session getty.Session) { + c.Lock() + defer c.Unlock() + c.hashCircle[key] = session +} + +func (c *Consistent) hash(key string) int64 { + hashByte := md5.Sum([]byte(key)) + var res int64 + for i := 0; i < 4; i++ { + res <<= 8 + res |= int64(hashByte[i]) & 0xff + } + + return res +} + +// pick get a node +func (c *Consistent) pick(sessions *sync.Map, key string) getty.Session { + hashKey := c.hash(key) + index := sort.Search(len(c.sortedHashNodes), func(i int) bool { + return c.sortedHashNodes[i] >= hashKey + }) + + if index == len(c.sortedHashNodes) { + return RandomLoadBalance(sessions, key) + } + + c.RLock() + session, ok := c.hashCircle[c.sortedHashNodes[index]] + if !ok { + c.RUnlock() + return RandomLoadBalance(sessions, key) + } + c.RUnlock() + + if session.IsClosed() { + go c.refreshHashCircle(sessions) + return c.firstKey() + } + + return session +} + +// refreshHashCircle refresh hashCircle +func (c *Consistent) refreshHashCircle(sessions *sync.Map) { + var sortedHashNodes []int64 + hashCircle := make(map[int64]getty.Session) + var session getty.Session + sessions.Range(func(key, value interface{}) bool { + session = key.(getty.Session) + for i := 0; i < defaultVirtualNodeNumber; i++ { + if !session.IsClosed() { + position := c.hash(fmt.Sprintf("%s%d", session.RemoteAddr(), i)) + hashCircle[position] = session + sortedHashNodes = append(sortedHashNodes, position) + } else { + sessions.Delete(key) + } + } + return true + }) + + // virtual node sort + sort.Slice(sortedHashNodes, func(i, j int) bool { + return sortedHashNodes[i] < sortedHashNodes[j] + }) + + c.sortedHashNodes = sortedHashNodes + c.hashCircle = hashCircle +} + +func (c *Consistent) firstKey() getty.Session { + c.RLock() + defer c.RUnlock() + + if len(c.sortedHashNodes) > 0 { + return c.hashCircle[c.sortedHashNodes[0]] + } + + return nil +} + +func newConsistenceInstance(sessions *sync.Map) *Consistent { + once.Do(func() { + consistentInstance = &Consistent{ + hashCircle: make(map[int64]getty.Session), + } + // construct hash circle + sessions.Range(func(key, value interface{}) bool { + session := key.(getty.Session) + for i := 0; i < defaultVirtualNodeNumber; i++ { + if !session.IsClosed() { + position := consistentInstance.hash(fmt.Sprintf("%s%d", session.RemoteAddr(), i)) + consistentInstance.put(position, session) + consistentInstance.sortedHashNodes = append(consistentInstance.sortedHashNodes, position) + } else { + sessions.Delete(key) + } + } + return true + }) + + // virtual node sort + sort.Slice(consistentInstance.sortedHashNodes, func(i, j int) bool { + return consistentInstance.sortedHashNodes[i] < consistentInstance.sortedHashNodes[j] + }) + }) + + return consistentInstance +} + +func ConsistentHashLoadBalance(sessions *sync.Map, xid string) getty.Session { + if consistentInstance == nil { + newConsistenceInstance(sessions) + } + + // pick a node + return consistentInstance.pick(sessions, xid) +} diff --git a/pkg/remoting/loadbalance/consistent_hash_loadbalance_test.go b/pkg/remoting/loadbalance/consistent_hash_loadbalance_test.go new file mode 100644 index 00000000..3fc5b509 --- /dev/null +++ b/pkg/remoting/loadbalance/consistent_hash_loadbalance_test.go @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package loadbalance + +import ( + "fmt" + "sync" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + + "github.com/seata/seata-go/pkg/remoting/mock" +) + +func TestConsistentHashLoadBalance(t *testing.T) { + ctrl := gomock.NewController(t) + sessions := &sync.Map{} + + for i := 0; i < 3; i++ { + session := mock.NewMockTestSession(ctrl) + session.EXPECT().IsClosed().Return(false).AnyTimes() + session.EXPECT().RemoteAddr().AnyTimes().DoAndReturn(func() string { + return "127.0.0.1:8000" + }) + sessions.Store(session, fmt.Sprintf("session-%d", i)) + } + + result := ConsistentHashLoadBalance(sessions, "test_xid") + assert.NotNil(t, result) + assert.False(t, result.IsClosed()) + + sessions.Range(func(key, value interface{}) bool { + t.Logf("key: %v, value: %v", key, value) + return true + }) +} diff --git a/pkg/remoting/loadbalance/loadbalance.go b/pkg/remoting/loadbalance/loadbalance.go index f867793b..5704eb39 100644 --- a/pkg/remoting/loadbalance/loadbalance.go +++ b/pkg/remoting/loadbalance/loadbalance.go @@ -37,6 +37,8 @@ func Select(loadBalanceType string, sessions *sync.Map, xid string) getty.Sessio return RandomLoadBalance(sessions, xid) case xidLoadBalance: return XidLoadBalance(sessions, xid) + case consistentHashLoadBalance: + return ConsistentHashLoadBalance(sessions, xid) case leastActiveLoadBalance: return LeastActiveLoadBalance(sessions, xid) case roundRobinLoadBalance: diff --git a/pkg/remoting/loadbalance/random_loadbalance_test.go b/pkg/remoting/loadbalance/random_loadbalance_test.go index 5db9c882..e63a74cb 100644 --- a/pkg/remoting/loadbalance/random_loadbalance_test.go +++ b/pkg/remoting/loadbalance/random_loadbalance_test.go @@ -23,8 +23,9 @@ import ( "testing" "github.com/golang/mock/gomock" - "github.com/seata/seata-go/pkg/remoting/mock" "github.com/stretchr/testify/assert" + + "github.com/seata/seata-go/pkg/remoting/mock" ) func TestRandomLoadBalance_Normal(t *testing.T) { diff --git a/pkg/remoting/loadbalance/xid_loadbalance_test.go b/pkg/remoting/loadbalance/xid_loadbalance_test.go index d361f338..cd47cdd8 100644 --- a/pkg/remoting/loadbalance/xid_loadbalance_test.go +++ b/pkg/remoting/loadbalance/xid_loadbalance_test.go @@ -22,8 +22,9 @@ import ( "testing" "github.com/golang/mock/gomock" - "github.com/seata/seata-go/pkg/remoting/mock" "github.com/stretchr/testify/assert" + + "github.com/seata/seata-go/pkg/remoting/mock" ) func TestXidLoadBalance(t *testing.T) {