diff --git a/dubbo-cluster/src/main/java/org/apache/dubbo/rpc/cluster/loadbalance/RoundRobinLoadBalance.java b/dubbo-cluster/src/main/java/org/apache/dubbo/rpc/cluster/loadbalance/RoundRobinLoadBalance.java index f3b802ef680a..e8f42cbfa41e 100644 --- a/dubbo-cluster/src/main/java/org/apache/dubbo/rpc/cluster/loadbalance/RoundRobinLoadBalance.java +++ b/dubbo-cluster/src/main/java/org/apache/dubbo/rpc/cluster/loadbalance/RoundRobinLoadBalance.java @@ -17,68 +17,139 @@ package org.apache.dubbo.rpc.cluster.loadbalance; import org.apache.dubbo.common.URL; -import org.apache.dubbo.common.utils.AtomicPositiveInteger; import org.apache.dubbo.rpc.Invocation; import org.apache.dubbo.rpc.Invoker; -import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; import java.util.List; +import java.util.Map; +import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; /** * Round robin load balance. + * + * @author jason */ public class RoundRobinLoadBalance extends AbstractLoadBalance { - public static final String NAME = "roundrobin"; + + private static int RECYCLE_PERIOD = 60000; + + protected static class WeightedRoundRobin { + private int weight; + private AtomicLong current = new AtomicLong(0); + private long lastUpdate; + public int getWeight() { + return weight; + } + public void setWeight(int weight) { + this.weight = weight; + current.set(0); + } + public long increaseCurrent() { + return current.addAndGet(weight); + } + public void sel(int total) { + current.addAndGet(-1 * total); + } + public long getLastUpdate() { + return lastUpdate; + } + public void setLastUpdate(long lastUpdate) { + this.lastUpdate = lastUpdate; + } + } - private final ConcurrentMap sequences = new ConcurrentHashMap(); - - private final ConcurrentMap indexSeqs = new ConcurrentHashMap(); - + private ConcurrentMap> methodWeightMap = new ConcurrentHashMap>(); + private AtomicBoolean updateLock = new AtomicBoolean(); + + /** + * get invoker addr list cached for specified invocation + *

+ * for unit test only + * + * @param invokers + * @param invocation + * @return + */ + protected Collection getInvokerAddrList(List> invokers, Invocation invocation) { + String key = invokers.get(0).getUrl().getServiceKey() + "." + invocation.getMethodName(); + Map map = methodWeightMap.get(key); + if (map != null) { + return map.keySet(); + } + return null; + } + @Override protected Invoker doSelect(List> invokers, URL url, Invocation invocation) { String key = invokers.get(0).getUrl().getServiceKey() + "." + invocation.getMethodName(); - int length = invokers.size(); // Number of invokers - int maxWeight = 0; // The maximum weight - int minWeight = Integer.MAX_VALUE; // The minimum weight - final List> nonZeroWeightedInvokers = new ArrayList<>(); - for (int i = 0; i < length; i++) { - int weight = getWeight(invokers.get(i), invocation); - maxWeight = Math.max(maxWeight, weight); // Choose the maximum weight - minWeight = Math.min(minWeight, weight); // Choose the minimum weight - if (weight > 0) { - nonZeroWeightedInvokers.add(invokers.get(i)); - } - } - AtomicPositiveInteger sequence = sequences.get(key); - if (sequence == null) { - sequences.putIfAbsent(key, new AtomicPositiveInteger()); - sequence = sequences.get(key); + ConcurrentMap map = methodWeightMap.get(key); + if (map == null) { + methodWeightMap.putIfAbsent(key, new ConcurrentHashMap()); + map = methodWeightMap.get(key); } - - if (maxWeight > 0 && minWeight < maxWeight) { - AtomicPositiveInteger indexSeq = indexSeqs.get(key); - if (indexSeq == null) { - indexSeqs.putIfAbsent(key, new AtomicPositiveInteger(-1)); - indexSeq = indexSeqs.get(key); + int totalWeight = 0; + long maxCurrent = Long.MIN_VALUE; + long now = System.currentTimeMillis(); + Invoker selectedInvoker = null; + WeightedRoundRobin selectedWRR = null; + for (Invoker invoker : invokers) { + String identifyString = invoker.getUrl().toIdentityString(); + WeightedRoundRobin weightedRoundRobin = map.get(identifyString); + int weight = getWeight(invoker, invocation); + if (weight < 0) { + weight = 0; } - length = nonZeroWeightedInvokers.size(); - while (true) { - int index = indexSeq.incrementAndGet() % length; - int currentWeight; - if (index == 0) { - currentWeight = sequence.incrementAndGet() % maxWeight; - } else { - currentWeight = sequence.get() % maxWeight; - } - if (getWeight(nonZeroWeightedInvokers.get(index), invocation) > currentWeight) { - return nonZeroWeightedInvokers.get(index); + if (weightedRoundRobin == null) { + weightedRoundRobin = new WeightedRoundRobin(); + weightedRoundRobin.setWeight(weight); + map.putIfAbsent(identifyString, weightedRoundRobin); + weightedRoundRobin = map.get(identifyString); + } + if (weight != weightedRoundRobin.getWeight()) { + //weight changed + weightedRoundRobin.setWeight(weight); + } + long cur = weightedRoundRobin.increaseCurrent(); + weightedRoundRobin.setLastUpdate(now); + if (cur > maxCurrent) { + maxCurrent = cur; + selectedInvoker = invoker; + selectedWRR = weightedRoundRobin; + } + totalWeight += weight; + } + if (!updateLock.get() && invokers.size() != map.size()) { + if (updateLock.compareAndSet(false, true)) { + try { + // copy -> modify -> update reference + ConcurrentMap newMap = new ConcurrentHashMap(); + newMap.putAll(map); + Iterator> it = newMap.entrySet().iterator(); + while (it.hasNext()) { + Entry item = it.next(); + if (now - item.getValue().getLastUpdate() > RECYCLE_PERIOD) { + it.remove(); + } + } + methodWeightMap.put(key, newMap); + } finally { + updateLock.set(false); } } } - // Round robin - return invokers.get(sequence.getAndIncrement() % length); + if (selectedInvoker != null) { + selectedWRR.sel(totalWeight); + return selectedInvoker; + } + // should not happen here + return invokers.get(0); } + } diff --git a/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/StickyTest.java b/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/StickyTest.java index ded3ad6af694..869ac524464c 100644 --- a/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/StickyTest.java +++ b/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/StickyTest.java @@ -114,12 +114,12 @@ public int testSticky(String methodName, boolean check) { given(invoker1.invoke(invocation)).willReturn(result); given(invoker1.isAvailable()).willReturn(true); - given(invoker1.getUrl()).willReturn(url); + given(invoker1.getUrl()).willReturn(url.setPort(1)); given(invoker1.getInterface()).willReturn(StickyTest.class); given(invoker2.invoke(invocation)).willReturn(result); given(invoker2.isAvailable()).willReturn(true); - given(invoker2.getUrl()).willReturn(url); + given(invoker2.getUrl()).willReturn(url.setPort(2)); given(invoker2.getInterface()).willReturn(StickyTest.class); invocation.setMethodName(methodName); diff --git a/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/loadbalance/LoadBalanceBaseTest.java b/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/loadbalance/LoadBalanceBaseTest.java index f9db9aeca091..58cd86d04470 100644 --- a/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/loadbalance/LoadBalanceBaseTest.java +++ b/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/loadbalance/LoadBalanceBaseTest.java @@ -29,6 +29,8 @@ import org.junit.BeforeClass; import org.junit.Test; +import com.alibaba.fastjson.JSON; + import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -114,16 +116,21 @@ public void setUp() throws Exception { public Map getInvokeCounter(int runs, String loadbalanceName) { Map counter = new ConcurrentHashMap(); - LoadBalance lb = ExtensionLoader.getExtensionLoader(LoadBalance.class).getExtension(loadbalanceName); + LoadBalance lb = getLoadBalance(loadbalanceName); for (Invoker invoker : invokers) { counter.put(invoker, new AtomicLong(0)); } + URL url = invokers.get(0).getUrl(); for (int i = 0; i < runs; i++) { - Invoker sinvoker = lb.select(invokers, invokers.get(0).getUrl(), invocation); + Invoker sinvoker = lb.select(invokers, url, invocation); counter.get(sinvoker).incrementAndGet(); } return counter; } + + protected AbstractLoadBalance getLoadBalance(String loadbalanceName) { + return (AbstractLoadBalance) ExtensionLoader.getExtensionLoader(LoadBalance.class).getExtension(loadbalanceName); + } @Test public void testLoadBalanceWarmup() { @@ -153,44 +160,83 @@ private static int calculateDefaultWarmupWeight(int uptime) { } /*------------------------------------test invokers for weight---------------------------------------*/ + + protected static class InvokeResult { + private AtomicLong count = new AtomicLong(); + private int weight = 0; + private int totalWeight = 0; + + public InvokeResult(int weight) { + this.weight = weight; + } + + public AtomicLong getCount() { + return count; + } + + public int getWeight() { + return weight; + } + + public int getTotalWeight() { + return totalWeight; + } + + public void setTotalWeight(int totalWeight) { + this.totalWeight = totalWeight; + } + + public int getExpected(int runCount) { + return getWeight() * runCount / getTotalWeight(); + } + + public float getDeltaPercentage(int runCount) { + int expected = getExpected(runCount); + return Math.abs((expected - getCount().get()) * 100.0f / expected); + } + + @Override + public String toString() { + return JSON.toJSONString(this); + } + } protected List> weightInvokers = new ArrayList>(); protected Invoker weightInvoker1; protected Invoker weightInvoker2; protected Invoker weightInvoker3; + protected Invoker weightInvokerTmp; @Before public void before() throws Exception { weightInvoker1 = mock(Invoker.class); weightInvoker2 = mock(Invoker.class); weightInvoker3 = mock(Invoker.class); + weightInvokerTmp = mock(Invoker.class); weightTestInvocation = new RpcInvocation(); weightTestInvocation.setMethodName("test"); - URL url1 = URL.valueOf("test1://0:1/DemoService"); - url1 = url1.addParameter(Constants.WEIGHT_KEY, 1); - url1 = url1.addParameter(weightTestInvocation.getMethodName() + "." + Constants.WEIGHT_KEY, 1); - url1 = url1.addParameter("active", 0); - - URL url2 = URL.valueOf("test2://0:9/DemoService"); - url2 = url2.addParameter(Constants.WEIGHT_KEY, 9); - url2 = url2.addParameter(weightTestInvocation.getMethodName() + "." + Constants.WEIGHT_KEY, 9); - url2 = url2.addParameter("active", 0); - - URL url3 = URL.valueOf("test3://1:6/DemoService"); - url3 = url3.addParameter(Constants.WEIGHT_KEY, 6); - url3 = url3.addParameter(weightTestInvocation.getMethodName() + "." + Constants.WEIGHT_KEY, 6); - url3 = url3.addParameter("active", 1); + URL url1 = URL.valueOf("test1://127.0.0.1:11/DemoService?weight=1&active=0"); + URL url2 = URL.valueOf("test2://127.0.0.1:12/DemoService?weight=9&active=0"); + URL url3 = URL.valueOf("test3://127.0.0.1:13/DemoService?weight=6&active=1"); + URL urlTmp = URL.valueOf("test4://127.0.0.1:9999/DemoService?weight=11&active=0"); given(weightInvoker1.isAvailable()).willReturn(true); + given(weightInvoker1.getInterface()).willReturn(LoadBalanceBaseTest.class); given(weightInvoker1.getUrl()).willReturn(url1); - + given(weightInvoker2.isAvailable()).willReturn(true); + given(weightInvoker2.getInterface()).willReturn(LoadBalanceBaseTest.class); given(weightInvoker2.getUrl()).willReturn(url2); - + given(weightInvoker3.isAvailable()).willReturn(true); + given(weightInvoker3.getInterface()).willReturn(LoadBalanceBaseTest.class); given(weightInvoker3.getUrl()).willReturn(url3); + + given(weightInvokerTmp.isAvailable()).willReturn(true); + given(weightInvokerTmp.getInterface()).willReturn(LoadBalanceBaseTest.class); + given(weightInvokerTmp.getUrl()).willReturn(urlTmp); weightInvokers.add(weightInvoker1); weightInvokers.add(weightInvoker2); @@ -203,4 +249,25 @@ public void before() throws Exception { // weightTestRpcStatus3 active is 1 RpcStatus.beginCount(weightInvoker3.getUrl(), weightTestInvocation.getMethodName()); } + + protected Map getWeightedInvokeResult(int runs, String loadbalanceName) { + Map counter = new ConcurrentHashMap(); + AbstractLoadBalance lb = getLoadBalance(loadbalanceName); + int totalWeight = 0; + for (int i = 0; i < weightInvokers.size(); i ++) { + InvokeResult invokeResult = new InvokeResult(lb.getWeight(weightInvokers.get(i), weightTestInvocation)); + counter.put(weightInvokers.get(i), invokeResult); + totalWeight += invokeResult.getWeight(); + } + for (InvokeResult invokeResult : counter.values()) { + invokeResult.setTotalWeight(totalWeight); + } + URL url = weightInvokers.get(0).getUrl(); + for (int i = 0; i < runs; i++) { + Invoker sinvoker = lb.select(weightInvokers, url, weightTestInvocation); + counter.get(sinvoker).getCount().incrementAndGet(); + } + return counter; + } + } \ No newline at end of file diff --git a/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/loadbalance/RoundRobinLoadBalanceTest.java b/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/loadbalance/RoundRobinLoadBalanceTest.java index e10f69fea62e..5242f90badaf 100644 --- a/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/loadbalance/RoundRobinLoadBalanceTest.java +++ b/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/loadbalance/RoundRobinLoadBalanceTest.java @@ -20,10 +20,29 @@ import org.junit.Assert; import org.junit.Test; +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; public class RoundRobinLoadBalanceTest extends LoadBalanceBaseTest { + + private void assertStrictWRRResult(int loop, Map resultMap) { + int invokeCount = 0; + for (InvokeResult invokeResult : resultMap.values()) { + int count = (int) invokeResult.getCount().get(); + // Because it's a strictly round robin, so the abs delta should be < 10 too + Assert.assertTrue("delta with expected count should < 10", + Math.abs(invokeResult.getExpected(loop) - count) < 10); + invokeCount += count; + } + Assert.assertEquals("select failed!", invokeCount, loop); + } + @Test public void testRoundRobinLoadBalanceSelect() { int runs = 10000; @@ -36,33 +55,114 @@ public void testRoundRobinLoadBalanceSelect() { @Test public void testSelectByWeight() { - int sumInvoker1 = 0; - int sumInvoker2 = 0; - int sumInvoker3 = 0; - int loop = 10000; - - RoundRobinLoadBalance lb = new RoundRobinLoadBalance(); - for (int i = 0; i < loop; i++) { - Invoker selected = lb.select(weightInvokers, null, weightTestInvocation); - - if (selected.getUrl().getProtocol().equals("test1")) { - sumInvoker1++; - } - - if (selected.getUrl().getProtocol().equals("test2")) { - sumInvoker2++; + final Map totalMap = new HashMap(); + final AtomicBoolean shouldBegin = new AtomicBoolean(false); + final int runs = 10000; + List threads = new ArrayList(); + int threadNum = 10; + for (int i = 0; i < threadNum; i ++) { + threads.add(new Thread() { + @Override + public void run() { + while (!shouldBegin.get()) { + try { + sleep(5); + } catch (InterruptedException e) { + } + } + Map resultMap = getWeightedInvokeResult(runs, RoundRobinLoadBalance.NAME); + synchronized (totalMap) { + for (Entry entry : resultMap.entrySet()) { + if (!totalMap.containsKey(entry.getKey())) { + totalMap.put(entry.getKey(), entry.getValue()); + } else { + totalMap.get(entry.getKey()).getCount().addAndGet(entry.getValue().getCount().get()); + } + } + } + } + }); + } + for (Thread thread : threads) { + thread.start(); + } + // let's rock it! + shouldBegin.set(true); + for (Thread thread : threads) { + try { + thread.join(); + } catch (InterruptedException e) { } - - if (selected.getUrl().getProtocol().equals("test3")) { - sumInvoker3++; + } + assertStrictWRRResult(runs * threadNum, totalMap); + } + + @Test + public void testNodeCacheShouldNotRecycle() { + int loop = 10000; + //tmperately add a new invoker + weightInvokers.add(weightInvokerTmp); + try { + Map resultMap = getWeightedInvokeResult(loop, RoundRobinLoadBalance.NAME); + assertStrictWRRResult(loop, resultMap); + + // inner nodes cache judgement + RoundRobinLoadBalance lb = (RoundRobinLoadBalance)getLoadBalance(RoundRobinLoadBalance.NAME); + Assert.assertEquals(weightInvokers.size(), lb.getInvokerAddrList(weightInvokers, weightTestInvocation).size()); + + weightInvokers.remove(weightInvokerTmp); + + resultMap = getWeightedInvokeResult(loop, RoundRobinLoadBalance.NAME); + assertStrictWRRResult(loop, resultMap); + + Assert.assertNotEquals(weightInvokers.size(), lb.getInvokerAddrList(weightInvokers, weightTestInvocation).size()); + } finally { + //prevent other UT's failure + weightInvokers.remove(weightInvokerTmp); + } + } + + @Test + public void testNodeCacheShouldRecycle() { + { + Field recycleTimeField = null; + try { + //change recycle time to 1 ms + recycleTimeField = RoundRobinLoadBalance.class.getDeclaredField("RECYCLE_PERIOD"); + recycleTimeField.setAccessible(true); + recycleTimeField.setInt(RoundRobinLoadBalance.class, 10); + } catch (NoSuchFieldException e) { + Assert.assertTrue("getField failed", true); + } catch (SecurityException e) { + Assert.assertTrue("getField failed", true); + } catch (IllegalArgumentException e) { + Assert.assertTrue("getField failed", true); + } catch (IllegalAccessException e) { + Assert.assertTrue("getField failed", true); } } - - // 1 : 9 : 6 - System.out.println(sumInvoker1); - System.out.println(sumInvoker2); - System.out.println(sumInvoker3); - Assert.assertEquals("select failed!", sumInvoker1 + sumInvoker2 + sumInvoker3, loop); + + int loop = 10000; + //tmperately add a new invoker + weightInvokers.add(weightInvokerTmp); + try { + Map resultMap = getWeightedInvokeResult(loop, RoundRobinLoadBalance.NAME); + assertStrictWRRResult(loop, resultMap); + + // inner nodes cache judgement + RoundRobinLoadBalance lb = (RoundRobinLoadBalance)getLoadBalance(RoundRobinLoadBalance.NAME); + Assert.assertEquals(weightInvokers.size(), lb.getInvokerAddrList(weightInvokers, weightTestInvocation).size()); + + weightInvokers.remove(weightInvokerTmp); + + resultMap = getWeightedInvokeResult(loop, RoundRobinLoadBalance.NAME); + assertStrictWRRResult(loop, resultMap); + + Assert.assertEquals(weightInvokers.size(), lb.getInvokerAddrList(weightInvokers, weightTestInvocation).size()); + } finally { + //prevent other UT's failure + weightInvokers.remove(weightInvokerTmp); + } } - + } diff --git a/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/support/AbstractClusterInvokerTest.java b/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/support/AbstractClusterInvokerTest.java index 247a719ad84e..b07e123235b4 100644 --- a/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/support/AbstractClusterInvokerTest.java +++ b/dubbo-cluster/src/test/java/org/apache/dubbo/rpc/cluster/support/AbstractClusterInvokerTest.java @@ -90,27 +90,27 @@ public void setUp() throws Exception { given(invoker1.isAvailable()).willReturn(false); given(invoker1.getInterface()).willReturn(IHelloService.class); - given(invoker1.getUrl()).willReturn(turl.addParameter("name", "invoker1")); + given(invoker1.getUrl()).willReturn(turl.setPort(1).addParameter("name", "invoker1")); given(invoker2.isAvailable()).willReturn(true); given(invoker2.getInterface()).willReturn(IHelloService.class); - given(invoker2.getUrl()).willReturn(turl.addParameter("name", "invoker2")); + given(invoker2.getUrl()).willReturn(turl.setPort(2).addParameter("name", "invoker2")); given(invoker3.isAvailable()).willReturn(false); given(invoker3.getInterface()).willReturn(IHelloService.class); - given(invoker3.getUrl()).willReturn(turl.addParameter("name", "invoker3")); + given(invoker3.getUrl()).willReturn(turl.setPort(3).addParameter("name", "invoker3")); given(invoker4.isAvailable()).willReturn(true); given(invoker4.getInterface()).willReturn(IHelloService.class); - given(invoker4.getUrl()).willReturn(turl.addParameter("name", "invoker4")); + given(invoker4.getUrl()).willReturn(turl.setPort(4).addParameter("name", "invoker4")); given(invoker5.isAvailable()).willReturn(false); given(invoker5.getInterface()).willReturn(IHelloService.class); - given(invoker5.getUrl()).willReturn(turl.addParameter("name", "invoker5")); + given(invoker5.getUrl()).willReturn(turl.setPort(5).addParameter("name", "invoker5")); given(mockedInvoker1.isAvailable()).willReturn(false); given(mockedInvoker1.getInterface()).willReturn(IHelloService.class); - given(mockedInvoker1.getUrl()).willReturn(turl.setProtocol("mock")); + given(mockedInvoker1.getUrl()).willReturn(turl.setPort(999).setProtocol("mock")); invokers.add(invoker1); dic = new StaticDirectory(url, invokers, null);