diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 92ef92b94..efb184704 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -33,7 +33,7 @@ jobs: - name: 'Set up Temp AWS Credentials' run: | creds=($(aws sts get-session-token \ - --duration-seconds 3600 \ + --duration-seconds 7200 \ --query 'Credentials.[AccessKeyId, SecretAccessKey, SessionToken]' \ --output text \ | xargs)); diff --git a/src/main/core-api/java/com/mysql/cj/util/CacheMap.java b/src/main/core-api/java/com/mysql/cj/util/CacheMap.java new file mode 100644 index 000000000..43e6797f4 --- /dev/null +++ b/src/main/core-api/java/com/mysql/cj/util/CacheMap.java @@ -0,0 +1,138 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License, version 2.0 + * (GPLv2), as published by the Free Software Foundation, with the + * following additional permissions: + * + * This program is distributed with certain software that is licensed + * under separate terms, as designated in a particular file or component + * or in the license documentation. Without limiting your rights under + * the GPLv2, the authors of this program hereby grant you an additional + * permission to link the program and your derivative works with the + * separately licensed software that they have included with the program. + * + * Without limiting the foregoing grant of rights under the GPLv2 and + * additional permission as to separately licensed software, this + * program is also subject to the Universal FOSS Exception, version 1.0, + * a copy of which can be found along with its FAQ at + * http://oss.oracle.com/licenses/universal-foss-exception. + * + * This program is distributed in the hope that it will be useful, but + * WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * See the GNU General Public License, version 2.0, for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see + * http://www.gnu.org/licenses/gpl-2.0.html. + */ + +package com.mysql.cj.util; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + +public class CacheMap { + + private final Map> cache = new ConcurrentHashMap<>(); + private final long cleanupIntervalNanos = TimeUnit.MINUTES.toNanos(10); + private final AtomicLong cleanupTimeNanos = new AtomicLong(System.nanoTime() + cleanupIntervalNanos); + + public CacheMap() { + } + + public V get(final K key) { + CacheItem cacheItem = cache.computeIfPresent(key, (kk, vv) -> vv.isExpired() ? null : vv); + return cacheItem == null ? null : cacheItem.item; + } + + public V get(final K key, final V defaultItemValue, long itemExpirationNano) { + CacheItem cacheItem = cache.compute(key, + (kk, vv) -> (vv == null || vv.isExpired()) + ? new CacheItem<>(defaultItemValue, System.nanoTime() + itemExpirationNano) + : vv); + return cacheItem.item; + } + + public void put(final K key, final V item, long itemExpirationNano) { + cache.put(key, new CacheItem<>(item, System.nanoTime() + itemExpirationNano)); + cleanUp(); + } + + public void putIfAbsent(final K key, final V item, long itemExpirationNano) { + cache.putIfAbsent(key, new CacheItem<>(item, System.nanoTime() + itemExpirationNano)); + cleanUp(); + } + + public void remove(final K key) { + cache.remove(key); + cleanUp(); + } + + public void clear() { + cache.clear(); + } + + public int size() { return this.cache.size(); } + + private void cleanUp() { + if (this.cleanupTimeNanos.get() < System.nanoTime()) { + this.cleanupTimeNanos.set(System.nanoTime() + cleanupIntervalNanos); + cache.forEach((key, value) -> { + if (value == null || value.isExpired()) { + cache.remove(key); + } + }); + } + } + + private static class CacheItem { + final V item; + final long expirationTime; + + public CacheItem(V item, long expirationTime) { + this.item = item; + this.expirationTime = expirationTime; + } + + boolean isExpired() { + return System.nanoTime() > expirationTime; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((item == null) ? 0 : item.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + CacheItem other = (CacheItem) obj; + if (item == null) { + return other.item == null; + } else { + return item.equals(other.item); + } + } + + @Override + public String toString() { + return "CacheItem [item=" + item + ", expirationTime=" + expirationTime + "]"; + } + } +} \ No newline at end of file diff --git a/src/main/resources/com/mysql/cj/LocalizedErrorMessages.properties b/src/main/resources/com/mysql/cj/LocalizedErrorMessages.properties index 148ddcc43..5965393f1 100644 --- a/src/main/resources/com/mysql/cj/LocalizedErrorMessages.properties +++ b/src/main/resources/com/mysql/cj/LocalizedErrorMessages.properties @@ -710,7 +710,6 @@ ConnectionFeatureNotAvailableException.0=Feature not available in this distribut IllegalArgumentException.NullParameter=Parameter ''{0}'' must not be null. InvalidLoadBalanceStrategy=Invalid load balancing strategy ''{0}''. DefaultMonitorService.EmptyNodeKeys=Empty NodeKey set passed into DefaultMonitorService. Set should not be empty. -DefaultMonitorService.NoMonitorForContext=Can't find monitor for context passed into DefaultMonitorService. DefaultMonitorService.InvalidContext=Invalid context passed into DefaultMonitorService. Could not find any NodeKey from context. DefaultMonitorService.InvalidNodeKey=Invalid node key passed into DefaultMonitorService. No existing monitor for the given set of node keys. diff --git a/src/main/user-impl/java/com/mysql/cj/jdbc/ha/ConnectionProxy.java b/src/main/user-impl/java/com/mysql/cj/jdbc/ha/ConnectionProxy.java index abfd1e7a7..2dd980bf7 100644 --- a/src/main/user-impl/java/com/mysql/cj/jdbc/ha/ConnectionProxy.java +++ b/src/main/user-impl/java/com/mysql/cj/jdbc/ha/ConnectionProxy.java @@ -74,6 +74,7 @@ public class ConnectionProxy implements ICurrentConnectionProvider, InvocationHa protected ConnectionPluginManager pluginManager = null; private HostInfo currentHostInfo; private JdbcConnection currentConnection; + private Class currentConnectionClass; public ConnectionProxy(ConnectionUrl connectionUrl) throws SQLException { this(connectionUrl, null); @@ -98,6 +99,7 @@ public ConnectionProxy(ConnectionUrl connectionUrl, JdbcConnection connection) t throws SQLException { this.currentHostInfo = connectionUrl.getMainHost(); this.currentConnection = connection; + this.currentConnectionClass = connection == null ? null : connection.getClass(); initLogger(connectionUrl); initSettings(connectionUrl); @@ -175,11 +177,12 @@ public void setCurrentConnection(JdbcConnection connection, HostInfo info) { } this.currentConnection = connection; + this.currentConnectionClass = connection == null ? null : connection.getClass(); this.currentHostInfo = info; } @Override - public synchronized Object invoke(Object proxy, Method method, Object[] args) + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { final String methodName = method.getName(); @@ -187,24 +190,24 @@ public synchronized Object invoke(Object proxy, Method method, Object[] args) return executeMethodDirectly(methodName, args); } - Object[] argsCopy = args == null ? null : Arrays.copyOf(args, args.length); - - try { - Object result = this.pluginManager.execute( - this.currentConnection.getClass(), - methodName, - () -> method.invoke(currentConnection, args), - argsCopy); - return proxyIfReturnTypeIsJdbcInterface(method.getReturnType(), result); - } catch (Exception e) { - // Check if the captured exception must be wrapped by an unchecked exception. - Class[] declaredExceptions = method.getExceptionTypes(); - for (Class declaredException : declaredExceptions) { - if (declaredException.isAssignableFrom(e.getClass())) { - throw e; + synchronized (currentConnection) { + try { + Object result = this.pluginManager.execute( + this.currentConnectionClass, + methodName, + () -> method.invoke(currentConnection, args), + args); + return proxyIfReturnTypeIsJdbcInterface(method.getReturnType(), result); + } catch (Exception e) { + // Check if the captured exception must be wrapped by an unchecked exception. + Class[] declaredExceptions = method.getExceptionTypes(); + for (Class declaredException : declaredExceptions) { + if (declaredException.isAssignableFrom(e.getClass())) { + throw e; + } } + throw new IllegalStateException(e.getMessage(), e); } - throw new IllegalStateException(e.getMessage(), e); } } @@ -301,10 +304,12 @@ private boolean isDirectExecute(String methodName) { * Proxy class to intercept and deal with errors that may occur in any object bound to the current connection. */ class JdbcInterfaceProxy implements InvocationHandler { - Object invokeOn; + private final Object invokeOn; + private final Class invokeOnClass; JdbcInterfaceProxy(Object toInvokeOn) { this.invokeOn = toInvokeOn; + this.invokeOnClass = toInvokeOn == null ? null : toInvokeOn.getClass(); } /** @@ -329,21 +334,19 @@ private Object executeMethodDirectly(String methodName, Object[] args) { return null; } - public synchronized Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { final String methodName = method.getName(); if (isDirectExecute(methodName)) { return executeMethodDirectly(methodName, args); } - Object[] argsCopy = args == null ? null : Arrays.copyOf(args, args.length); - - synchronized(ConnectionProxy.this) { + synchronized(this.invokeOn) { Object result = ConnectionProxy.this.pluginManager.execute( - this.invokeOn.getClass(), + this.invokeOnClass, methodName, () -> method.invoke(this.invokeOn, args), - argsCopy); + args); return proxyIfReturnTypeIsJdbcInterface(method.getReturnType(), result); } } diff --git a/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/DefaultMonitorService.java b/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/DefaultMonitorService.java index 86baa209d..c916176bc 100644 --- a/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/DefaultMonitorService.java +++ b/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/DefaultMonitorService.java @@ -38,7 +38,7 @@ import com.mysql.cj.jdbc.JdbcConnection; import com.mysql.cj.log.Log; -import java.util.Iterator; +import java.util.Collections; import java.util.Set; import java.util.concurrent.Executors; @@ -51,6 +51,8 @@ public class DefaultMonitorService implements IMonitorService { private final Log logger; final IMonitorInitializer monitorInitializer; + private Set cachedMonitorNodeKeys = null; + private IMonitor cachedMonitor = null; public DefaultMonitorService(Log logger) { this( @@ -96,11 +98,21 @@ public MonitorConnectionContext startMonitoring( throw new IllegalArgumentException(warning); } - final IMonitor monitor = getMonitor(nodeKeys, hostInfo, propertySet); + IMonitor monitor; + if (this.cachedMonitor == null + || this.cachedMonitorNodeKeys == null + || !this.cachedMonitorNodeKeys.equals(nodeKeys)) { + + monitor = getMonitor(nodeKeys, hostInfo, propertySet); + this.cachedMonitor = monitor; + this.cachedMonitorNodeKeys = Collections.unmodifiableSet(nodeKeys); + } else { + monitor = this.cachedMonitor; + } final MonitorConnectionContext context = new MonitorConnectionContext( + monitor, connectionToAbort, - nodeKeys, logger, failureDetectionTimeMillis, failureDetectionIntervalMillis, @@ -118,20 +130,8 @@ public void stopMonitoring(MonitorConnectionContext context) { return; } - context.invalidate(); - - // Any 1 node is enough to find the monitor containing the context - // All nodes will map to the same monitor - IMonitor monitor; - for (Iterator it = context.getNodeKeys().iterator(); it.hasNext();) { - String nodeKey = it.next(); - monitor = this.threadContainer.getMonitor(nodeKey); - if (monitor != null) { - monitor.stopMonitoring(context); - return; - } - } - logger.logTrace(Messages.getString("DefaultMonitorService.NoMonitorForContext")); + IMonitor monitor = context.getMonitor(); + monitor.stopMonitoring(context); } @Override diff --git a/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/Monitor.java b/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/Monitor.java index d43138511..c6e06225e 100644 --- a/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/Monitor.java +++ b/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/Monitor.java @@ -44,8 +44,6 @@ import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; /** * This class uses a background thread to monitor a particular server with one or more @@ -62,23 +60,22 @@ static class ConnectionStatus { } } - static final long DEFAULT_CONNECTION_CHECK_INTERVAL_MILLIS = 100; - static final long DEFAULT_CONNECTION_CHECK_TIMEOUT_MILLIS = 3000; private static final int THREAD_SLEEP_WHEN_INACTIVE_MILLIS = 100; + private static final int MIN_CONNECTION_CHECK_TIMEOUT_MILLIS = 3000; private static final String MONITORING_PROPERTY_PREFIX = "monitoring-"; - private final Queue contexts = new ConcurrentLinkedQueue<>(); + private final Queue activeContexts = new ConcurrentLinkedQueue<>(); + private final Queue newContexts = new ConcurrentLinkedQueue<>(); private final IConnectionProvider connectionProvider; private final Log logger; private final PropertySet propertySet; private final HostInfo hostInfo; private final IMonitorService monitorService; - private final AtomicLong connectionCheckIntervalMillis = new AtomicLong(DEFAULT_CONNECTION_CHECK_INTERVAL_MILLIS); - private final AtomicLong contextLastUsedTimestampNano = new AtomicLong(); - private final AtomicBoolean stopped = new AtomicBoolean(true); - private final AtomicBoolean isConnectionCheckIntervalInitialized = new AtomicBoolean(false); + private volatile long contextLastUsedTimestampNano; + private volatile boolean stopped = false; private final long monitorDisposalTimeMillis; private Connection monitoringConn = null; + private long nodeCheckTimeoutMillis = MIN_CONNECTION_CHECK_TIMEOUT_MILLIS; /** * Store the monitoring configuration for a connection. @@ -108,7 +105,7 @@ public Monitor( this.monitorDisposalTimeMillis = monitorDisposalTimeMillis; this.monitorService = monitorService; - this.contextLastUsedTimestampNano.set(this.getCurrentTimeNano()); + this.contextLastUsedTimestampNano = this.getCurrentTimeNano(); } long getCurrentTimeNano() { @@ -117,19 +114,10 @@ long getCurrentTimeNano() { @Override public void startMonitoring(MonitorConnectionContext context) { - if (!this.isConnectionCheckIntervalInitialized.get()) { - this.connectionCheckIntervalMillis.set(context.getFailureDetectionIntervalMillis()); - this.isConnectionCheckIntervalInitialized.set(true); - } else { - this.connectionCheckIntervalMillis.set(Math.min( - this.connectionCheckIntervalMillis.get(), - context.getFailureDetectionIntervalMillis())); - } - final long currentTimeNano = this.getCurrentTimeNano(); context.setStartMonitorTimeNano(currentTimeNano); - this.contextLastUsedTimestampNano.set(currentTimeNano); - this.contexts.add(context); + this.contextLastUsedTimestampNano = currentTimeNano; + this.newContexts.add(context); } @Override @@ -139,42 +127,113 @@ public void stopMonitoring(MonitorConnectionContext context) { return; } - context.invalidate(); - this.contexts.remove(context); + context.setInactive(); - this.connectionCheckIntervalMillis.set(findShortestIntervalMillis()); - this.isConnectionCheckIntervalInitialized.set(true); + this.contextLastUsedTimestampNano = this.getCurrentTimeNano(); } public synchronized void clearContexts() { - this.contexts.clear(); - this.connectionCheckIntervalMillis.set(findShortestIntervalMillis()); - this.isConnectionCheckIntervalInitialized.set(true); + this.newContexts.clear(); + this.activeContexts.clear(); } @Override public void run() { try { - this.stopped.set(false); + this.stopped = false; while (true) { - if (!this.contexts.isEmpty()) { + + // process new contexts + MonitorConnectionContext newMonitorContext; + MonitorConnectionContext firstAddedNewMonitorContext = null; + final long currentTimeNano = this.getCurrentTimeNano(); + while ((newMonitorContext = this.newContexts.poll()) != null) { + if (firstAddedNewMonitorContext == newMonitorContext) { + // This context has already been processed. + // Add it back to the queue and process it in the next round. + this.newContexts.add(newMonitorContext); + break; + } + if (newMonitorContext.isActiveContext()) { + if (newMonitorContext.getExpectedActiveMonitoringStartTimeNano() > currentTimeNano) { + // The context active monitoring time hasn't come. + // Add the context to the queue and check it later. + this.newContexts.add(newMonitorContext); + if (firstAddedNewMonitorContext == null) { + firstAddedNewMonitorContext = newMonitorContext; + } + } else { + // It's time to start actively monitor this context. + this.activeContexts.add(newMonitorContext); + } + } + } + + if (!this.activeContexts.isEmpty()) { + final long statusCheckStartTimeNano = this.getCurrentTimeNano(); - this.contextLastUsedTimestampNano.set(statusCheckStartTimeNano); + this.contextLastUsedTimestampNano = statusCheckStartTimeNano; final ConnectionStatus status = - checkConnectionStatus(this.getConnectionCheckTimeoutMillis()); + checkConnectionStatus(this.nodeCheckTimeoutMillis); + + long delayMillis = -1; + MonitorConnectionContext monitorContext; + MonitorConnectionContext firstAddedMonitorContext = null; + + while ((monitorContext = this.activeContexts.poll()) != null) { + + synchronized (monitorContext) { + // If context is already invalid, just skip it + if (!monitorContext.isActiveContext()) { + continue; + } + + if (firstAddedMonitorContext == monitorContext) { + // this context has already been processed by this loop + // add it to the queue and exit this loop + this.activeContexts.add(monitorContext); + break; + } - for (MonitorConnectionContext monitorContext : this.contexts) { - monitorContext.updateConnectionStatus( - statusCheckStartTimeNano, - statusCheckStartTimeNano + status.elapsedTimeNano, - status.isValid); + // otherwise, process this context + monitorContext.updateConnectionStatus( + this.hostInfo.getHostPortPair(), + statusCheckStartTimeNano, + statusCheckStartTimeNano + status.elapsedTimeNano, + status.isValid); + + // If context is still valid and node is still healthy, it needs to continue updating this context + if (monitorContext.isActiveContext() && !monitorContext.isNodeUnhealthy()) { + this.activeContexts.add(monitorContext); + if (firstAddedMonitorContext == null) { + firstAddedMonitorContext = monitorContext; + } + + if (delayMillis == -1 || delayMillis > monitorContext.getFailureDetectionIntervalMillis()) { + delayMillis = monitorContext.getFailureDetectionIntervalMillis(); + } + } + } + } + + if (delayMillis == -1) { + // No active contexts + delayMillis = THREAD_SLEEP_WHEN_INACTIVE_MILLIS; + } else { + delayMillis -= status.elapsedTimeNano; + // Check for min delay between node health check + if (delayMillis < MIN_CONNECTION_CHECK_TIMEOUT_MILLIS) { + delayMillis = MIN_CONNECTION_CHECK_TIMEOUT_MILLIS; + } + // Use this delay as node checkout timeout since it corresponds to min interval for all active contexts + this.nodeCheckTimeoutMillis = delayMillis; } - TimeUnit.MILLISECONDS.sleep( - Math.max(0, this.getConnectionCheckIntervalMillis() - TimeUnit.NANOSECONDS.toMillis(status.elapsedTimeNano))); + TimeUnit.MILLISECONDS.sleep(delayMillis); + } else { - if ((this.getCurrentTimeNano() - this.contextLastUsedTimestampNano.get()) + if ((this.getCurrentTimeNano() - this.contextLastUsedTimestampNano) >= TimeUnit.MILLISECONDS.toNanos(this.monitorDisposalTimeMillis)) { monitorService.notifyUnused(this); break; @@ -192,7 +251,7 @@ public void run() { // ignore } } - this.stopped.set(true); + this.stopped = true; } } @@ -241,17 +300,9 @@ ConnectionStatus checkConnectionStatus(final long shortestFailureDetectionInterv } } - long getConnectionCheckTimeoutMillis() { - return this.connectionCheckIntervalMillis.get() == 0 ? DEFAULT_CONNECTION_CHECK_TIMEOUT_MILLIS : this.connectionCheckIntervalMillis.get(); - } - - long getConnectionCheckIntervalMillis() { - return this.connectionCheckIntervalMillis.get() == 0 ? DEFAULT_CONNECTION_CHECK_INTERVAL_MILLIS : this.connectionCheckIntervalMillis.get(); - } - @Override public boolean isStopped() { - return this.stopped.get(); + return this.stopped; } private HostInfo copy(HostInfo src, Map props) { @@ -263,12 +314,4 @@ private HostInfo copy(HostInfo src, Map props) { src.getPassword(), props); } - - private long findShortestIntervalMillis() { - long currentMin = Long.MAX_VALUE; - for (MonitorConnectionContext context : this.contexts) { - currentMin = Math.min(currentMin, context.getFailureDetectionIntervalMillis()); - } - return currentMin == Long.MAX_VALUE ? DEFAULT_CONNECTION_CHECK_INTERVAL_MILLIS : currentMin; - } } diff --git a/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/MonitorConnectionContext.java b/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/MonitorConnectionContext.java index 5f4319864..e12d4fe45 100644 --- a/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/MonitorConnectionContext.java +++ b/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/MonitorConnectionContext.java @@ -35,12 +35,8 @@ import com.mysql.cj.log.Log; import java.sql.SQLException; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; /** * Monitoring context for each connection. This contains each connection's criteria for whether a server should be @@ -51,22 +47,23 @@ public class MonitorConnectionContext { private final int failureDetectionTimeMillis; private final int failureDetectionCount; - private final Set nodeKeys; // Variable is never written, so it does not need to be thread-safe private final Log log; private final JdbcConnection connectionToAbort; + private final IMonitor monitor; - private final AtomicBoolean activeContext = new AtomicBoolean(true); - private final AtomicBoolean nodeUnhealthy = new AtomicBoolean(); - private final AtomicLong startMonitorTimeNano = new AtomicLong(); + private volatile boolean activeContext = true; + private volatile boolean nodeUnhealthy = false; + private long startMonitorTimeNano; + private long expectedActiveMonitoringStartTimeNano; private long invalidNodeStartTimeNano; // Only accessed by monitor thread private int failureCount; // Only accessed by monitor thread /** * Constructor. * + * @param monitor A reference to a monitor object. * @param connectionToAbort A reference to the connection associated with this context * that will be aborted in case of server failure. - * @param nodeKeys All valid references to the server. * @param log A {@link Log} implementation. * @param failureDetectionTimeMillis Grace period after which node monitoring starts. * @param failureDetectionIntervalMillis Interval between each failed connection check. @@ -74,14 +71,14 @@ public class MonitorConnectionContext { * database node as unhealthy. */ public MonitorConnectionContext( + IMonitor monitor, JdbcConnection connectionToAbort, - Set nodeKeys, Log log, int failureDetectionTimeMillis, int failureDetectionIntervalMillis, int failureDetectionCount) { + this.monitor = monitor; this.connectionToAbort = connectionToAbort; - this.nodeKeys = new HashSet<>(nodeKeys); // Variable is never written, so it does not need to be thread-safe this.log = log; this.failureDetectionTimeMillis = failureDetectionTimeMillis; this.failureDetectionIntervalMillis = failureDetectionIntervalMillis; @@ -89,11 +86,9 @@ public MonitorConnectionContext( } void setStartMonitorTimeNano(long startMonitorTimeNano) { - this.startMonitorTimeNano.set(startMonitorTimeNano); - } - - Set getNodeKeys() { - return Collections.unmodifiableSet(this.nodeKeys); + this.startMonitorTimeNano = startMonitorTimeNano; + this.expectedActiveMonitoringStartTimeNano = startMonitorTimeNano + + TimeUnit.MILLISECONDS.toNanos(this.failureDetectionTimeMillis); } public int getFailureDetectionTimeMillis() { @@ -108,7 +103,11 @@ public int getFailureDetectionCount() { return failureDetectionCount; } - public int getFailureCount() { + public long getExpectedActiveMonitoringStartTimeNano() { return this.expectedActiveMonitoringStartTimeNano; } + + public IMonitor getMonitor() { return this.monitor; } + + int getFailureCount() { return this.failureCount; } @@ -128,28 +127,29 @@ boolean isInvalidNodeStartTimeDefined() { return this.invalidNodeStartTimeNano > 0; } - public long getInvalidNodeStartTimeNano() { + long getInvalidNodeStartTimeNano() { return this.invalidNodeStartTimeNano; } - public boolean isNodeUnhealthy() { - return this.nodeUnhealthy.get(); + public boolean isNodeUnhealthy() + { + return this.nodeUnhealthy; } void setNodeUnhealthy(boolean nodeUnhealthy) { - this.nodeUnhealthy.set(nodeUnhealthy); + this.nodeUnhealthy = nodeUnhealthy; } public boolean isActiveContext() { - return this.activeContext.get(); + return this.activeContext; } - public void invalidate() { - this.activeContext.set(false); + public void setInactive() { + this.activeContext = false; } - synchronized void abortConnection() { - if (this.connectionToAbort == null || !this.activeContext.get()) { + void abortConnection() { + if (this.connectionToAbort == null || !this.activeContext) { return; } @@ -167,22 +167,25 @@ synchronized void abortConnection() { * Update whether the connection is still valid if the total elapsed time has passed the * grace period. * + * @param nodeName A node name for logging purposes. * @param statusCheckStartNano The time when connection status check started in nanoseconds. * @param statusCheckEndNano The time when connection status check ended in nanoseconds. * @param isValid Whether the connection is valid. */ public void updateConnectionStatus( + String nodeName, long statusCheckStartNano, long statusCheckEndNano, boolean isValid) { - if (!this.activeContext.get()) { + + if (!this.activeContext) { return; } - final long totalElapsedTimeNano = statusCheckEndNano - this.startMonitorTimeNano.get(); + final long totalElapsedTimeNano = statusCheckEndNano - this.startMonitorTimeNano; if (totalElapsedTimeNano > TimeUnit.MILLISECONDS.toNanos(this.failureDetectionTimeMillis)) { - this.setConnectionValid(isValid, statusCheckStartNano, statusCheckEndNano); + this.setConnectionValid(nodeName, isValid, statusCheckStartNano, statusCheckEndNano); } } @@ -197,14 +200,17 @@ public void updateConnectionStatus( *
  • {@code failureDetectionCount}
  • * * + * @param nodeName A node name for logging purposes. * @param connectionValid Boolean indicating whether the server is still responsive. * @param statusCheckStartNano The time when connection status check started in nanoseconds. * @param statusCheckEndNano The time when connection status check ended in nanoseconds. */ void setConnectionValid( + String nodeName, boolean connectionValid, long statusCheckStartNano, long statusCheckEndNano) { + if (!connectionValid) { this.failureCount++; @@ -220,7 +226,7 @@ void setConnectionValid( this.log.logTrace( String.format( "[MonitorConnectionContext] node '%s' is *dead*.", - nodeKeys)); + nodeName)); this.setNodeUnhealthy(true); this.abortConnection(); return; @@ -228,7 +234,7 @@ void setConnectionValid( this.log.logTrace(String.format( "[MonitorConnectionContext] node '%s' is not *responding* (%d).", - nodeKeys, + nodeName, this.getFailureCount())); return; } @@ -240,6 +246,6 @@ void setConnectionValid( this.log.logTrace( String.format( "[MonitorConnectionContext] node '%s' is *alive*.", - nodeKeys)); + nodeName)); } } diff --git a/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/NodeMonitoringConnectionPlugin.java b/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/NodeMonitoringConnectionPlugin.java index 9813d9ba6..da6febfe6 100644 --- a/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/NodeMonitoringConnectionPlugin.java +++ b/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/NodeMonitoringConnectionPlugin.java @@ -45,7 +45,7 @@ import java.sql.SQLException; import java.sql.Statement; import java.util.Arrays; -import java.util.List; +import java.util.HashSet; import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; @@ -59,8 +59,77 @@ public class NodeMonitoringConnectionPlugin implements IConnectionPlugin { private static final String RETRIEVE_HOST_PORT_SQL = "SELECT CONCAT(@@hostname, ':', @@port)"; - private static final List METHODS_STARTING_WITH = Arrays.asList("get", "abort"); - private static final List METHODS_EQUAL_TO = Arrays.asList("close", "next"); + private static final Set SKIP_MONITORING_METHODS = new HashSet<>(Arrays.asList( + "close", + "next", + "abort", + "closeOnCompletion", + "getName", + "getVendor", + "getVendorTypeNumber", + "getBaseTypeName", + "getBaseType", + "getBinaryStream", + "getBytes", + "getArray", + "getBigDecimal", + "getSubString", + "getCharacterStream", + "getAsciiStream", + "getURL", + "getUserName", + "getDatabaseProductName", + "getParameterCount", + "getPrecision", + "getScale", + "getParameterType", + "getParameterTypeName", + "getParameterClassName", + "getConnection", + "getFetchDirection", + "getFetchSize", + "getColumnCount", + "getColumnDisplaySize", + "getColumnLabel", + "getColumnName", + "getSchemaName", + "getSQLTypeName", + "getSavepointId", + "getSavepointName", + "getMaxFieldSize", + "getMaxRows", + "getQueryTimeout", + "getAttributes", + "getString", + "getTime", + "getTimestamp", + "getType", + "getUnicodeStream", + "getWarnings", + "getBinaryStream", + "getBlob", + "getBoolean", + "getByte", + "getBytes", + "getClob", + "getConcurrency", + "getDate", + "getDouble", + "getFloat", + "getHoldability", + "getInt", + "getLong", + "getMetaData", + "getNCharacterStream", + "getNClob", + "getNString", + "getObject", + "getRef", + "getRow", + "getRowId", + "getSQLXML", + "getShort", + "getStatement")); protected IConnectionPlugin nextPlugin; protected Log logger; @@ -181,27 +250,30 @@ public Object execute( } finally { if (monitorContext != null) { - this.monitorService.stopMonitoring(monitorContext); - - final boolean isConnectionClosed; - try { - isConnectionClosed = this.currentConnectionProvider.getCurrentConnection().isClosed(); - } catch (final SQLException e) { - throw new CJCommunicationsException("Node is unavailable."); - } - - if (monitorContext.isNodeUnhealthy()) { - if (!isConnectionClosed) { - abortConnection(); - throw new CJCommunicationsException("Node is unavailable."); + synchronized (monitorContext) { + this.monitorService.stopMonitoring(monitorContext); + + if (monitorContext.isNodeUnhealthy()) { + + final boolean isConnectionClosed; + try { + isConnectionClosed = this.currentConnectionProvider.getCurrentConnection().isClosed(); + } catch (final SQLException e) { + throw new CJCommunicationsException("Node is unavailable."); + } + + if (!isConnectionClosed) { + abortConnection(); + throw new CJCommunicationsException("Node is unavailable."); + } } } if (this.logger.isTraceEnabled()) { this.logger.logTrace(String.format( - "[NodeMonitoringConnectionPlugin.execute]: method=%s.%s, monitoring is deactivated", - methodInvokeOn.getName(), - methodName)); + "[NodeMonitoringConnectionPlugin.execute]: method=%s.%s, monitoring is deactivated", + methodInvokeOn.getName(), + methodName)); } } } @@ -231,20 +303,7 @@ protected boolean doesNeedMonitoring(Class methodInvokeOn, String methodName) // boolean isJdbcStatement = Statement.class.isAssignableFrom(methodInvokeOn); // boolean isJdbcResultSet = ResultSet.class.isAssignableFrom(methodInvokeOn); - for (final String method : METHODS_STARTING_WITH) { - if (methodName.startsWith(method)) { - return false; - } - } - - for (final String method : METHODS_EQUAL_TO) { - if (method.equals(methodName)) { - return false; - } - } - - // Monitor all the other methods - return true; + return !SKIP_MONITORING_METHODS.contains(methodName); } private void initMonitorService() { diff --git a/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/failover/AuroraTopologyService.java b/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/failover/AuroraTopologyService.java index 307c6cdca..f74e8ed12 100644 --- a/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/failover/AuroraTopologyService.java +++ b/src/main/user-impl/java/com/mysql/cj/jdbc/ha/plugins/failover/AuroraTopologyService.java @@ -39,7 +39,7 @@ import com.mysql.cj.jdbc.JdbcConnection; import com.mysql.cj.log.Log; import com.mysql.cj.log.NullLogger; -import com.mysql.cj.util.ExpiringCache; +import com.mysql.cj.util.CacheMap; import com.mysql.cj.util.Util; import java.sql.ResultSet; @@ -47,16 +47,15 @@ import java.sql.SQLSyntaxErrorException; import java.sql.Statement; import java.sql.Timestamp; -import java.time.Duration; -import java.time.Instant; import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.Set; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; import java.util.function.Supplier; /** @@ -70,7 +69,7 @@ public class AuroraTopologyService implements ITopologyService { static final int DEFAULT_REFRESH_RATE_IN_MILLISECONDS = 30000; static final int DEFAULT_CACHE_EXPIRE_MS = 5 * 60 * 1000; // 5 min - private int refreshRateInMilliseconds; + private long refreshRateNanos; static final String RETRIEVE_TOPOLOGY_SQL = "SELECT SERVER_ID, SESSION_ID, LAST_UPDATE_TIMESTAMP, REPLICA_LAG_IN_MILLISECONDS " + "FROM information_schema.replica_host_status " @@ -85,9 +84,10 @@ public class AuroraTopologyService implements ITopologyService { static final String FIELD_LAST_UPDATED = "LAST_UPDATE_TIMESTAMP"; static final String FIELD_REPLICA_LAG = "REPLICA_LAG_IN_MILLISECONDS"; - public static final ExpiringCache topologyCache = - new ExpiringCache<>(DEFAULT_CACHE_EXPIRE_MS); - private static final Object cacheLock = new Object(); + public static final CacheMap> topologyCache = new CacheMap<>(); + public static final CacheMap> downHostCache = new CacheMap<>(); + public static final CacheMap lastUsedReaderCache = new CacheMap<>(); + public static final CacheMap multiWriterClusterCache = new CacheMap<>(); protected String clusterId; protected HostInfo clusterInstanceTemplate; @@ -112,7 +112,7 @@ public AuroraTopologyService(Log log) { */ public AuroraTopologyService(int refreshRateInMilliseconds, Log log, Supplier metricsContainerSupplier) { - this.refreshRateInMilliseconds = refreshRateInMilliseconds; + this.refreshRateNanos = TimeUnit.MILLISECONDS.toNanos(refreshRateInMilliseconds); this.clusterId = UUID.randomUUID().toString(); this.clusterInstanceTemplate = new HostInfo(null, "?", HostInfo.NO_PORT, null, null); this.metricsContainer = metricsContainerSupplier.get(); @@ -122,16 +122,6 @@ public AuroraTopologyService(int refreshRateInMilliseconds, Log log, } } - /** - * Service instances with the same cluster Id share cluster topology. Shared topology is cached - * for a specified period of time. This method sets cache expiration time in millis. - * - * @param expireTimeMs Topology cache expiration time in millis - */ - public static void setExpireTime(int expireTimeMs) { - topologyCache.setExpireTime(expireTimeMs); - } - /** * Sets cluster Id for a service instance. Different service instances with the same cluster Id * share topology cache. @@ -179,7 +169,7 @@ public void setClusterInstanceTemplate(HostInfo clusterInstanceTemplate) { /** * Get cluster topology. It may require an extra call to database to fetch the latest topology. A * cached copy of topology is returned if it's not yet outdated (controlled by {@link - * #refreshRateInMilliseconds }). + * #refreshRateNanos }). * * @param conn A connection to database to fetch the latest topology, if needed. * @param forceUpdate If true, it forces a service to ignore cached copy of topology and to fetch @@ -190,30 +180,27 @@ public void setClusterInstanceTemplate(HostInfo clusterInstanceTemplate) { @Override public List getTopology(JdbcConnection conn, boolean forceUpdate) throws SQLException { - ClusterTopologyInfo clusterTopologyInfo = topologyCache.get(this.clusterId); - if (clusterTopologyInfo == null - || Util.isNullOrEmpty(clusterTopologyInfo.hosts) - || forceUpdate - || refreshNeeded(clusterTopologyInfo)) { + List hosts = topologyCache.get(this.clusterId); + + if (hosts == null || forceUpdate) { ClusterTopologyInfo latestTopologyInfo = queryForTopology(conn); - if (!Util.isNullOrEmpty(latestTopologyInfo.hosts)) { - clusterTopologyInfo = updateCache(clusterTopologyInfo, latestTopologyInfo); - } else if (clusterTopologyInfo == null - || clusterTopologyInfo.hosts == null - || forceUpdate) { - return new ArrayList<>(); + if (latestTopologyInfo != null) { + multiWriterClusterCache.put( + this.clusterId, latestTopologyInfo.getMultiWriterCluster(), this.refreshRateNanos); + + downHostCache.get(this.clusterId, ConcurrentHashMap.newKeySet(), this.refreshRateNanos).clear(); + + if (!Util.isNullOrEmpty(latestTopologyInfo.getHosts())) { + topologyCache.put(this.clusterId, latestTopologyInfo.getHosts(), this.refreshRateNanos); + return latestTopologyInfo.getHosts(); + } } } - return clusterTopologyInfo.hosts; - } - - private boolean refreshNeeded(ClusterTopologyInfo info) { - Instant lastUpdateTime = info.lastUpdated; - return lastUpdateTime == null || Duration.between(lastUpdateTime, Instant.now()).toMillis() > refreshRateInMilliseconds; + return forceUpdate ? new ArrayList<>() : hosts; } /** @@ -242,13 +229,9 @@ protected ClusterTopologyInfo queryForTopology(JdbcConnection conn) throws SQLEx } } - return topologyInfo != null ? topologyInfo - : new ClusterTopologyInfo( - new ArrayList<>(), - new HashSet<>(), - null, - Instant.now(), - false); + return topologyInfo != null + ? topologyInfo + : new ClusterTopologyInfo(new ArrayList<>(),false); } /** @@ -298,12 +281,7 @@ private ClusterTopologyInfo processQueryResults(ResultSet resultSet) hosts.clear(); } - return new ClusterTopologyInfo( - hosts, - new HashSet<>(), - null, - Instant.now(), - writerCount > 1); + return new ClusterTopologyInfo(hosts, writerCount > 1); } private HostInfo createHost(ResultSet resultSet) throws SQLException { @@ -366,34 +344,6 @@ private String convertTimestampToString(Timestamp timestamp) { return timestamp == null ? null : timestamp.toString(); } - /** - * Store the information for the topology in the cache, creating the information object if it did not previously exist - * in the cache. - * - * @param clusterTopologyInfo The cluster topology info that existed in the cache before the topology query. This parameter - * will be null if no topology info for the cluster has been created in the cache yet. - * @param latestTopologyInfo The results of the current topology query - * @return The {@link ClusterTopologyInfo} stored in the cache by this method, representing the most up-to-date - * information we have about the topology. - */ - private ClusterTopologyInfo updateCache( - ClusterTopologyInfo clusterTopologyInfo, - ClusterTopologyInfo latestTopologyInfo) { - if (clusterTopologyInfo == null) { - clusterTopologyInfo = latestTopologyInfo; - } else { - clusterTopologyInfo.hosts = latestTopologyInfo.hosts; - clusterTopologyInfo.downHosts = latestTopologyInfo.downHosts; - clusterTopologyInfo.isMultiWriterCluster = latestTopologyInfo.isMultiWriterCluster; - } - clusterTopologyInfo.lastUpdated = Instant.now(); - - synchronized (cacheLock) { - topologyCache.put(this.clusterId, clusterTopologyInfo); - } - return clusterTopologyInfo; - } - /** * Get cached topology. * @@ -402,8 +352,7 @@ private ClusterTopologyInfo updateCache( */ @Override public List getCachedTopology() { - ClusterTopologyInfo info = topologyCache.get(this.clusterId); - return info == null || refreshNeeded(info) ? null : info.hosts; + return topologyCache.get(this.clusterId); } /** @@ -414,8 +363,7 @@ public List getCachedTopology() { */ @Override public HostInfo getLastUsedReaderHost() { - ClusterTopologyInfo info = topologyCache.get(this.clusterId); - return info == null || refreshNeeded(info) ? null : info.lastUsedReader; + return lastUsedReaderCache.get(this.clusterId); } /** @@ -425,14 +373,10 @@ public HostInfo getLastUsedReaderHost() { */ @Override public void setLastUsedReaderHost(HostInfo reader) { - if (reader != null) { - synchronized (cacheLock) { - ClusterTopologyInfo info = topologyCache.get(this.clusterId); - if (info != null) { - info.lastUsedReader = reader; - } - } + if (reader == null) { + return; } + lastUsedReaderCache.put(this.clusterId, reader, this.refreshRateNanos); } /** @@ -450,9 +394,8 @@ public HostInfo getHostByName(JdbcConnection conn) { if (resultSet.next()) { instanceName = resultSet.getString(GET_INSTANCE_NAME_COL); } - ClusterTopologyInfo clusterTopologyInfo = topologyCache.get(this.clusterId); - return instanceNameToHost( - instanceName, clusterTopologyInfo == null ? null : clusterTopologyInfo.hosts); + List hosts = topologyCache.get(this.clusterId); + return instanceNameToHost(instanceName, hosts); } } catch (SQLException e) { return null; @@ -480,12 +423,7 @@ private HostInfo instanceNameToHost(String name, List hosts) { */ @Override public Set getDownHosts() { - synchronized (cacheLock) { - ClusterTopologyInfo clusterTopologyInfo = topologyCache.get(this.clusterId); - return clusterTopologyInfo != null && clusterTopologyInfo.downHosts != null - ? clusterTopologyInfo.downHosts - : new HashSet<>(); - } + return downHostCache.get(this.clusterId, ConcurrentHashMap.newKeySet(), this.refreshRateNanos); } /** @@ -498,21 +436,8 @@ public void addToDownHostList(HostInfo downHost) { if (downHost == null) { return; } - synchronized (cacheLock) { - ClusterTopologyInfo clusterTopologyInfo = topologyCache.get(this.clusterId); - if (clusterTopologyInfo == null) { - clusterTopologyInfo = new ClusterTopologyInfo( - new ArrayList<>(), - new HashSet<>(), - null, - Instant.now(), - false); - topologyCache.put(this.clusterId, clusterTopologyInfo); - } else if (clusterTopologyInfo.downHosts == null) { - clusterTopologyInfo.downHosts = new HashSet<>(); - } - clusterTopologyInfo.downHosts.add(downHost.getHostPortPair()); - } + downHostCache.get(this.clusterId, ConcurrentHashMap.newKeySet(), this.refreshRateNanos) + .add(downHost.getHostPortPair()); } /** @@ -525,12 +450,8 @@ public void removeFromDownHostList(HostInfo host) { if (host == null) { return; } - synchronized (cacheLock) { - ClusterTopologyInfo clusterTopologyInfo = topologyCache.get(this.clusterId); - if (clusterTopologyInfo != null && clusterTopologyInfo.downHosts != null) { - clusterTopologyInfo.downHosts.remove(host.getHostPortPair()); - } - } + downHostCache.get(this.clusterId, ConcurrentHashMap.newKeySet(), this.refreshRateNanos) + .remove(host.getHostPortPair()); } /** @@ -540,63 +461,48 @@ public void removeFromDownHostList(HostInfo host) { */ @Override public boolean isMultiWriterCluster() { - synchronized (cacheLock) { - ClusterTopologyInfo clusterTopologyInfo = topologyCache.get(this.clusterId); - return (clusterTopologyInfo != null - && clusterTopologyInfo.downHosts != null - && clusterTopologyInfo.isMultiWriterCluster); - } + return multiWriterClusterCache.get(this.clusterId, false, this.refreshRateNanos); } /** * Set new topology refresh rate. Different service instances may have different topology refresh * rate while sharing the same topology cache. * - * @param refreshRate Topology refresh rate in millis. + * @param refreshRateMillis Topology refresh rate in millis. */ @Override - public void setRefreshRate(int refreshRate) { - this.refreshRateInMilliseconds = refreshRate; - if (topologyCache.getExpireTime() < this.refreshRateInMilliseconds) { - synchronized (cacheLock) { - if (topologyCache.getExpireTime() < this.refreshRateInMilliseconds) { - topologyCache.setExpireTime(this.refreshRateInMilliseconds); - } - } - } + public void setRefreshRate(int refreshRateMillis) { + this.refreshRateNanos = TimeUnit.MILLISECONDS.toNanos(refreshRateMillis); } /** Clear topology cache for all clusters. */ @Override public void clearAll() { - synchronized (cacheLock) { - topologyCache.clear(); - } + topologyCache.clear(); + downHostCache.clear(); + multiWriterClusterCache.clear(); + lastUsedReaderCache.clear(); } /** Clear topology cache for the current cluster. */ @Override public void clear() { - synchronized (cacheLock) { - topologyCache.remove(this.clusterId); - } + topologyCache.remove(this.clusterId); + downHostCache.remove(this.clusterId); + multiWriterClusterCache.remove(this.clusterId); + lastUsedReaderCache.remove(this.clusterId); } private static class ClusterTopologyInfo { - public Instant lastUpdated; - public Set downHosts; - public List hosts; - public HostInfo lastUsedReader; - public boolean isMultiWriterCluster; - - ClusterTopologyInfo( - List hosts, Set downHosts, HostInfo lastUsedReader, - Instant lastUpdated, boolean isMultiWriterCluster) { + private List hosts; + private boolean isMultiWriterCluster; + + ClusterTopologyInfo(List hosts, boolean isMultiWriterCluster) { this.hosts = hosts; - this.downHosts = downHosts; - this.lastUsedReader = lastUsedReader; - this.lastUpdated = lastUpdated; this.isMultiWriterCluster = isMultiWriterCluster; } + + List getHosts() { return this.hosts; } + boolean getMultiWriterCluster() { return this.isMultiWriterCluster; } } } diff --git a/src/test/java/com/mysql/cj/jdbc/ha/plugins/DefaultMonitorServiceTest.java b/src/test/java/com/mysql/cj/jdbc/ha/plugins/DefaultMonitorServiceTest.java index d3a6cba05..3b9572ccc 100644 --- a/src/test/java/com/mysql/cj/jdbc/ha/plugins/DefaultMonitorServiceTest.java +++ b/src/test/java/com/mysql/cj/jdbc/ha/plugins/DefaultMonitorServiceTest.java @@ -31,11 +31,14 @@ package com.mysql.cj.jdbc.ha.plugins; +import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.any; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.eq; @@ -176,31 +179,10 @@ void test_3_stopMonitoringWithInterruptedThread() { FAILURE_DETECTION_INTERVAL_MILLIS, FAILURE_DETECTION_COUNT); - monitorService.stopMonitoring(context); - - assertEquals(context, contextCaptor.getValue()); - verify(monitorA).stopMonitoring(any()); - } - - @Test - void test_4_stopMonitoringCalledTwice() { - doNothing().when(monitorA).stopMonitoring(contextCaptor.capture()); - - final MonitorConnectionContext context = monitorService.startMonitoring( - connection, - NODE_KEYS, - info, - propertySet, - FAILURE_DETECTION_TIME_MILLIS, - FAILURE_DETECTION_INTERVAL_MILLIS, - FAILURE_DETECTION_COUNT); - - monitorService.stopMonitoring(context); - - assertEquals(context, contextCaptor.getValue()); + assertEquals(monitorA, context.getMonitor()); monitorService.stopMonitoring(context); - verify(monitorA, times(2)).stopMonitoring(any()); + verify(monitorA, atLeastOnce()).stopMonitoring(eq(context)); } @Test diff --git a/src/test/java/com/mysql/cj/jdbc/ha/plugins/MonitorConnectionContextTest.java b/src/test/java/com/mysql/cj/jdbc/ha/plugins/MonitorConnectionContextTest.java index e3e7ef30d..d7e99c69d 100644 --- a/src/test/java/com/mysql/cj/jdbc/ha/plugins/MonitorConnectionContextTest.java +++ b/src/test/java/com/mysql/cj/jdbc/ha/plugins/MonitorConnectionContextTest.java @@ -36,6 +36,8 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import java.util.Collections; @@ -58,8 +60,8 @@ class MonitorConnectionContextTest { void init() { closeable = MockitoAnnotations.openMocks(this); context = new MonitorConnectionContext( + Mockito.mock(IMonitor.class), null, - NODE_KEYS, new NullLogger(MonitorConnectionContextTest.class.getName()), FAILURE_DETECTION_TIME_MILLIS, FAILURE_DETECTION_INTERVAL_MILLIS, @@ -74,7 +76,7 @@ void cleanUp() throws Exception { @Test public void test_1_isNodeUnhealthyWithConnection_returnFalse() { long currentTimeNano = System.nanoTime(); - context.setConnectionValid(true, currentTimeNano, currentTimeNano); + context.setConnectionValid("test-node", true, currentTimeNano, currentTimeNano); Assertions.assertFalse(context.isNodeUnhealthy()); Assertions.assertEquals(0, this.context.getFailureCount()); } @@ -82,7 +84,7 @@ public void test_1_isNodeUnhealthyWithConnection_returnFalse() { @Test public void test_2_isNodeUnhealthyWithInvalidConnection_returnFalse() { long currentTimeNano = System.nanoTime(); - context.setConnectionValid(false, currentTimeNano, currentTimeNano); + context.setConnectionValid("test-node", false, currentTimeNano, currentTimeNano); Assertions.assertFalse(context.isNodeUnhealthy()); Assertions.assertEquals(1, this.context.getFailureCount()); } @@ -94,7 +96,7 @@ public void test_3_isNodeUnhealthyExceedsFailureDetectionCount_returnTrue() { context.resetInvalidNodeStartTime(); long currentTimeNano = System.nanoTime(); - context.setConnectionValid(false, currentTimeNano, currentTimeNano); + context.setConnectionValid("test-node", false, currentTimeNano, currentTimeNano); Assertions.assertFalse(context.isNodeUnhealthy()); Assertions.assertEquals(expectedFailureCount, context.getFailureCount()); @@ -112,7 +114,7 @@ public void test_4_isNodeUnhealthyExceedsFailureDetectionCount() { long statusCheckStartTime = currentTimeNano; long statusCheckEndTime = currentTimeNano + TimeUnit.MILLISECONDS.toNanos(VALIDATION_INTERVAL_MILLIS); - context.setConnectionValid(false, statusCheckStartTime, statusCheckEndTime); + context.setConnectionValid("test-node", false, statusCheckStartTime, statusCheckEndTime); Assertions.assertFalse(context.isNodeUnhealthy()); currentTimeNano += TimeUnit.MILLISECONDS.toNanos(VALIDATION_INTERVAL_MILLIS); @@ -125,7 +127,7 @@ public void test_4_isNodeUnhealthyExceedsFailureDetectionCount() { long statusCheckStartTime = currentTimeNano; long statusCheckEndTime = currentTimeNano + TimeUnit.MILLISECONDS.toNanos(VALIDATION_INTERVAL_MILLIS); - context.setConnectionValid(false, statusCheckStartTime, statusCheckEndTime); + context.setConnectionValid("test-node", false, statusCheckStartTime, statusCheckEndTime); Assertions.assertTrue(context.isNodeUnhealthy()); } } diff --git a/src/test/java/com/mysql/cj/jdbc/ha/plugins/MonitorTest.java b/src/test/java/com/mysql/cj/jdbc/ha/plugins/MonitorTest.java index faff73ae8..c40739a0b 100644 --- a/src/test/java/com/mysql/cj/jdbc/ha/plugins/MonitorTest.java +++ b/src/test/java/com/mysql/cj/jdbc/ha/plugins/MonitorTest.java @@ -141,56 +141,6 @@ void cleanUp() throws Exception { closeable.close(); } - @Test - void test_1_startMonitoringWithDifferentContexts() { - monitor.startMonitoring(contextWithShortInterval); - monitor.startMonitoring(contextWithLongInterval); - - assertEquals( - SHORT_INTERVAL_MILLIS, - monitor.getConnectionCheckIntervalMillis()); - verify(contextWithShortInterval) - .setStartMonitorTimeNano(anyLong()); - verify(contextWithLongInterval) - .setStartMonitorTimeNano(anyLong()); - } - - @Test - void test_2_stopMonitoringWithContextRemaining() { - monitor.startMonitoring(contextWithShortInterval); - monitor.startMonitoring(contextWithLongInterval); - - monitor.stopMonitoring(contextWithShortInterval); - assertEquals( - LONG_INTERVAL_MILLIS, - monitor.getConnectionCheckIntervalMillis()); - } - - @Test - void test_3_stopMonitoringWithNoMatchingContexts() { - assertDoesNotThrow(() -> monitor.stopMonitoring(contextWithLongInterval)); - assertEquals(Monitor.DEFAULT_CONNECTION_CHECK_INTERVAL_MILLIS, - monitor.getConnectionCheckIntervalMillis()); - - monitor.startMonitoring(contextWithShortInterval); - assertDoesNotThrow(() -> monitor.stopMonitoring(contextWithLongInterval)); - assertEquals( - SHORT_INTERVAL_MILLIS, - monitor.getConnectionCheckIntervalMillis()); - } - - @Test - void test_4_stopMonitoringTwiceWithSameContext() { - monitor.startMonitoring(contextWithLongInterval); - assertDoesNotThrow(() -> { - monitor.stopMonitoring(contextWithLongInterval); - monitor.stopMonitoring(contextWithLongInterval); - }); - assertEquals( - Monitor.DEFAULT_CONNECTION_CHECK_INTERVAL_MILLIS, - monitor.getConnectionCheckIntervalMillis()); - } - @Test void test_5_isConnectionHealthyWithNoExistingConnection() throws SQLException { final Monitor.ConnectionStatus status = monitor.checkConnectionStatus(SHORT_INTERVAL_MILLIS); diff --git a/src/test/java/com/mysql/cj/jdbc/ha/plugins/MultiThreadedDefaultMonitorServiceTest.java b/src/test/java/com/mysql/cj/jdbc/ha/plugins/MultiThreadedDefaultMonitorServiceTest.java index e37301d07..587bd0aa5 100644 --- a/src/test/java/com/mysql/cj/jdbc/ha/plugins/MultiThreadedDefaultMonitorServiceTest.java +++ b/src/test/java/com/mysql/cj/jdbc/ha/plugins/MultiThreadedDefaultMonitorServiceTest.java @@ -390,8 +390,8 @@ private List generateContexts(final int numContexts, f nodeKeysList.forEach(nodeKeys -> { monitorThreadContainer.getOrCreateMonitor(nodeKeys, () -> monitor); contexts.add(new MonitorConnectionContext( + monitor, null, - nodeKeys, logger, FAILURE_DETECTION_TIME, FAILURE_DETECTION_INTERVAL, diff --git a/src/test/java/com/mysql/cj/jdbc/ha/plugins/failover/AuroraTopologyServiceTest.java b/src/test/java/com/mysql/cj/jdbc/ha/plugins/failover/AuroraTopologyServiceTest.java index 1f64fd969..1d7d7205e 100644 --- a/src/test/java/com/mysql/cj/jdbc/ha/plugins/failover/AuroraTopologyServiceTest.java +++ b/src/test/java/com/mysql/cj/jdbc/ha/plugins/failover/AuroraTopologyServiceTest.java @@ -36,6 +36,7 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -73,7 +74,6 @@ void resetProvider() { spyProvider.setClusterInstanceTemplate(new HostInfo(null, "?", HostInfo.NO_PORT, null, null)); spyProvider.setRefreshRate(AuroraTopologyService.DEFAULT_REFRESH_RATE_IN_MILLISECONDS); spyProvider.clearAll(); - AuroraTopologyService.setExpireTime(AuroraTopologyService.DEFAULT_CACHE_EXPIRE_MS); } @Test @@ -101,7 +101,7 @@ public void testTopologyQuery() throws SQLException { final List topology = spyProvider.getTopology(mockConn, false); final HostInfo master = topology.get(FailoverConnectionPlugin.WRITER_CONNECTION_INDEX); - final List slaves = + final List replicas = topology.subList(FailoverConnectionPlugin.WRITER_CONNECTION_INDEX + 1, topology.size()); assertEquals("writer-instance.XYZ.us-east-2.rds.amazonaws.com", master.getHost()); @@ -118,7 +118,7 @@ public void testTopologyQuery() throws SQLException { assertFalse(spyProvider.isMultiWriterCluster()); assertEquals(3, topology.size()); - assertEquals(2, slaves.size()); + assertEquals(2, replicas.size()); } @Test @@ -305,37 +305,6 @@ public void testForceUpdateQueryFailureWithSQLException() throws SQLException { assertThrows(SQLException.class, () -> spyProvider.getTopology(mockConn, true)); } - @Test - public void testQueryFailureReturnsStaleTopology() throws SQLException, InterruptedException { - final JdbcConnection mockConn = Mockito.mock(ConnectionImpl.class); - final Statement mockStatement = Mockito.mock(StatementImpl.class); - final ResultSet mockResultSet = Mockito.mock(ResultSetImpl.class); - stubTopologyQuery(mockConn, mockStatement, mockResultSet); - final String url = - "jdbc:mysql:aws://my-cluster-name.cluster-XYZ.us-east-2.rds.amazonaws.com:1234/test"; - final ConnectionUrl conStr = ConnectionUrl.getConnectionUrlInstance(url, new Properties()); - final HostInfo mainHost = conStr.getMainHost(); - final HostInfo clusterInstanceInfo = - new HostInfo( - conStr, - "?.XYZ.us-east-2.rds.amazonaws.com", - mainHost.getPort(), - mainHost.getUser(), - mainHost.getPassword(), - mainHost.getHostProperties()); - spyProvider.setClusterInstanceTemplate(clusterInstanceInfo); - spyProvider.setRefreshRate(1); - - final List hosts = spyProvider.getTopology(mockConn, false); - when(mockConn.createStatement()).thenThrow(SQLSyntaxErrorException.class); - Thread.sleep(5); - final List staleHosts = spyProvider.getTopology(mockConn, false); - - verify(spyProvider, times(2)).queryForTopology(mockConn); - assertEquals(3, staleHosts.size()); - assertEquals(hosts, staleHosts); - } - @Test public void testGetHostByName_success() throws SQLException { final JdbcConnection mockConn = Mockito.mock(ConnectionImpl.class); @@ -439,27 +408,11 @@ public void testProviderTopologyExpires() throws SQLException, InterruptedExcept mainHost.getHostProperties()); spyProvider.setClusterInstanceTemplate(clusterInstanceInfo); - AuroraTopologyService.setExpireTime(1000); // 1 sec - spyProvider.setRefreshRate( - 10000); // 10 sec; and cache expiration time is also (indirectly) changed to 10 sec - - spyProvider.getTopology(mockConn, false); - verify(spyProvider, times(1)).queryForTopology(mockConn); - - Thread.sleep(3000); + spyProvider.setRefreshRate(1000); // 1 sec - spyProvider.getTopology(mockConn, false); + spyProvider.getTopology(mockConn, false); // this call should be filling cache + spyProvider.getTopology(mockConn, false); // this call should use data in cache verify(spyProvider, times(1)).queryForTopology(mockConn); - - Thread.sleep(3000); - // internal cache has NOT expired yet - spyProvider.getTopology(mockConn, false); - verify(spyProvider, times(1)).queryForTopology(mockConn); - - Thread.sleep(5000); - // internal cache has expired by now - spyProvider.getTopology(mockConn, false); - verify(spyProvider, times(2)).queryForTopology(mockConn); } @Test @@ -483,7 +436,6 @@ public void testProviderTopologyNotExpired() throws SQLException, InterruptedExc mainHost.getHostProperties()); spyProvider.setClusterInstanceTemplate(clusterInstanceInfo); - AuroraTopologyService.setExpireTime(10000); // 10 sec spyProvider.setRefreshRate(1000); // 1 sec spyProvider.getTopology(mockConn, false); @@ -523,9 +475,9 @@ public void testClearProviderCache() throws SQLException { spyProvider.getTopology(mockConn, false); spyProvider.addToDownHostList(clusterInstanceInfo); - assertEquals(1, AuroraTopologyService.topologyCache.size()); + assertEquals(1, AuroraTopologyService.downHostCache.size()); spyProvider.clearAll(); - assertEquals(0, AuroraTopologyService.topologyCache.size()); + assertEquals(0, AuroraTopologyService.downHostCache.size()); } } diff --git a/src/test/java/testsuite/integration/container/AuroraMysqlFailoverIntegrationTest.java b/src/test/java/testsuite/integration/container/AuroraMysqlFailoverIntegrationTest.java index 1f20a8772..63d27bc67 100644 --- a/src/test/java/testsuite/integration/container/AuroraMysqlFailoverIntegrationTest.java +++ b/src/test/java/testsuite/integration/container/AuroraMysqlFailoverIntegrationTest.java @@ -33,7 +33,11 @@ import com.mysql.cj.conf.PropertyKey; import com.mysql.cj.exceptions.MysqlErrorNumbers; +import com.mysql.cj.log.Log; +import com.mysql.cj.log.StandardLogger; import eu.rekawek.toxiproxy.Proxy; +import software.amazon.awssdk.services.rds.model.FailoverDbClusterResponse; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import java.io.IOException; @@ -53,6 +57,13 @@ /** Integration testing with Aurora MySQL failover logic. */ public class AuroraMysqlFailoverIntegrationTest extends AuroraMysqlIntegrationBaseTest { + @Override + @BeforeEach + public void setUpEach() throws InterruptedException, SQLException { + waitUntilClusterHasRightState(); + super.setUpEach(); + } + /* Writer connection failover tests. */ /** @@ -214,7 +225,8 @@ public void test_writerFailWithinTransaction_setAutoCommitFalse() final String initialWriterId = instanceIDs[0]; - try (final Connection conn = connectToInstance(initialWriterId + DB_CONN_STR_SUFFIX, MYSQL_PORT, initDefaultProps())) { + Properties props = initDefaultProps(); + try (final Connection conn = connectToInstance(initialWriterId + DB_CONN_STR_SUFFIX, MYSQL_PORT, props)) { final Statement testStmt1 = conn.createStatement(); testStmt1.executeUpdate("DROP TABLE IF EXISTS test3_2"); testStmt1.executeUpdate( @@ -430,10 +442,22 @@ public void test_failoverTimeoutMs() throws SQLException, IOException { } // Helpers + private void failoverClusterAndWaitUntilWriterChanged(String clusterWriterId) throws InterruptedException { + // Trigger failover failoverCluster(); - waitUntilWriterInstanceChanged(clusterWriterId); + + int remainingAttempts = 3; + // let cluster to start and complete failover + while (!waitUntilWriterInstanceChanged(clusterWriterId, TimeUnit.MINUTES.toNanos(3))) { + // if writer is not changed, try to trigger failover again + remainingAttempts--; + if (remainingAttempts == 0) { + throw new RuntimeException("Cluster writer has not changed."); + } + failoverCluster(); + } } private void failoverCluster() throws InterruptedException { @@ -472,18 +496,50 @@ private void failoverClusterWithATargetInstance(String targetInstanceId) private void waitUntilWriterInstanceChanged(String initialWriterInstanceId) throws InterruptedException { + + // wait for cluster recover for up to 10 min + final long timeoutNanos = System.nanoTime() + TimeUnit.MINUTES.toNanos(10); + String nextClusterWriterId = getDBClusterWriterInstanceId(); + while (initialWriterInstanceId.equals(nextClusterWriterId)) { + if (timeoutNanos < System.nanoTime()) { + throw new RuntimeException("Cluster writer has not changed."); + } TimeUnit.MILLISECONDS.sleep(3000); // Calling the RDS API to get writer Id. nextClusterWriterId = getDBClusterWriterInstanceId(); } } + private boolean waitUntilWriterInstanceChanged(String initialWriterInstanceId, long timeoutNanos) + throws InterruptedException { + + // wait for cluster recover for up to 10 min + final long waitUntil = System.nanoTime() + timeoutNanos; + + String nextClusterWriterId = getDBClusterWriterInstanceId(); + + while (initialWriterInstanceId.equals(nextClusterWriterId)) { + if (waitUntil < System.nanoTime()) { + return false; + } + TimeUnit.MILLISECONDS.sleep(3000); + // Calling the RDS API to get writer Id. + nextClusterWriterId = getDBClusterWriterInstanceId(); + } + return true; + } + private void waitUntilClusterHasRightState() throws InterruptedException { + final long timeoutNanos = System.nanoTime() + TimeUnit.MINUTES.toNanos(10); String status = getDBCluster().status(); + while (!"available".equalsIgnoreCase(status)) { - TimeUnit.MILLISECONDS.sleep(1000); + if (timeoutNanos < System.nanoTime()) { + throw new RuntimeException("Cluster is still unavailable."); + } + TimeUnit.MILLISECONDS.sleep(5000); status = getDBCluster().status(); } }