Skip to content

Commit

Permalink
Fixes NullReferenceException when using long running producer (#7166)
Browse files Browse the repository at this point in the history
* Fixing potential NPE when token managers are created multiple times.

* Fix test assertions.

* Add test for provider.
  • Loading branch information
conniey authored Jan 7, 2020
1 parent c6cc384 commit 0f27612
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
import com.azure.core.amqp.exception.AmqpResponseCode;
import com.azure.core.exception.AzureException;
import com.azure.core.util.logging.ClientLogger;
import reactor.core.Disposable;
import reactor.core.publisher.EmitterProcessor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.ReplayProcessor;

import java.time.Duration;
import java.time.OffsetDateTime;
import java.time.ZoneOffset;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;

/**
* Manages the re-authorization of the client to the token audience against the CBS node.
Expand All @@ -30,27 +31,19 @@ public class ActiveClientTokenManager implements TokenManager {
private final Mono<ClaimsBasedSecurityNode> cbsNode;
private final String tokenAudience;
private final String scopes;
private final Timer timer;
private final Flux<AmqpResponseCode> authorizationResults;
private FluxSink<AmqpResponseCode> sink;
private final ReplayProcessor<AmqpResponseCode> authorizationResults = ReplayProcessor.create(1);
private final FluxSink<AmqpResponseCode> authorizationResultsSink =
authorizationResults.sink(FluxSink.OverflowStrategy.BUFFER);
private final EmitterProcessor<Duration> durationSource = EmitterProcessor.create();
private final FluxSink<Duration> durationSourceSink = durationSource.sink();
private final AtomicReference<Duration> lastRefreshInterval = new AtomicReference<>(Duration.ofMinutes(1));

// last refresh interval in milliseconds.
private AtomicLong lastRefreshInterval = new AtomicLong();
private volatile Disposable subscription;

public ActiveClientTokenManager(Mono<ClaimsBasedSecurityNode> cbsNode, String tokenAudience, String scopes) {
this.timer = new Timer(tokenAudience + "-tokenManager");
this.cbsNode = cbsNode;
this.tokenAudience = tokenAudience;
this.scopes = scopes;
this.authorizationResults = Flux.create(sink -> {
if (hasDisposed.get()) {
sink.complete();
} else {
this.sink = sink;
}
});

lastRefreshInterval.set(Duration.ofMinutes(1).getSeconds() * 1000);
}

/**
Expand Down Expand Up @@ -82,15 +75,18 @@ public Mono<Long> authorize() {

// We want to refresh the token when 90% of the time before expiry has elapsed.
final long refreshSeconds = (long) Math.floor(between.getSeconds() * 0.9);

// This converts it to milliseconds
final long refreshIntervalMS = refreshSeconds * 1000;

lastRefreshInterval.set(refreshIntervalMS);

// If this is the first time authorize is called, the task will not have been scheduled yet.
if (!hasScheduled.getAndSet(true)) {
logger.info("Scheduling refresh token task.");
scheduleRefreshTokenTask(refreshIntervalMS);
logger.info("Scheduling refresh token task");

final Duration firstInterval = Duration.ofMillis(refreshIntervalMS);
lastRefreshInterval.set(firstInterval);
authorizationResultsSink.next(AmqpResponseCode.ACCEPTED);
subscription = scheduleRefreshTokenTask(firstInterval);
}

return refreshIntervalMS;
Expand All @@ -99,52 +95,51 @@ public Mono<Long> authorize() {

@Override
public void close() {
if (!hasDisposed.getAndSet(true)) {
if (this.sink != null) {
this.sink.complete();
}

this.timer.cancel();
if (hasDisposed.getAndSet(true)) {
return;
}
}

private void scheduleRefreshTokenTask(Long refreshIntervalInMS) {
try {
timer.schedule(new RefreshAuthorizationToken(), refreshIntervalInMS);
} catch (IllegalStateException e) {
logger.warning("Unable to schedule RefreshAuthorizationToken task.", e);
hasScheduled.set(false);
authorizationResultsSink.complete();
durationSourceSink.complete();

if (subscription != null) {
subscription.dispose();
}
}

private class RefreshAuthorizationToken extends TimerTask {
@Override
public void run() {
logger.info("Refreshing authorization token.");
authorize().subscribe(
(Long refreshIntervalInMS) -> {

if (hasDisposed.get()) {
logger.info("Token manager has been disposed of. Not rescheduling.");
return;
}

logger.info("Authorization successful. Refreshing token in {} ms.", refreshIntervalInMS);
sink.next(AmqpResponseCode.ACCEPTED);

scheduleRefreshTokenTask(refreshIntervalInMS);
}, error -> {
if ((error instanceof AmqpException) && ((AmqpException) error).isTransient()) {
logger.error("Error is transient. Rescheduling authorization task.", error);
scheduleRefreshTokenTask(lastRefreshInterval.get());
} else {
logger.error("Error occurred while refreshing token that is not retriable. Not scheduling"
+ " refresh task. Use ActiveClientTokenManager.authorize() to schedule task again.", error);
hasScheduled.set(false);
}

sink.error(error);
private Disposable scheduleRefreshTokenTask(Duration initialRefresh) {
// EmitterProcessor can queue up an initial refresh interval before any subscribers are received.
durationSourceSink.next(initialRefresh);

return Flux.switchOnNext(durationSource.map(Flux::interval))
.flatMap(delay -> {
logger.info("Refreshing token.");
return authorize();
})
.onErrorContinue(
error -> (error instanceof AmqpException) && ((AmqpException) error).isTransient(),
(amqpException, interval) -> {
final Duration lastRefresh = lastRefreshInterval.get();

logger.error("Error is transient. Rescheduling authorization task at interval {} ms.",
lastRefresh.toMillis(), amqpException);
durationSourceSink.next(lastRefreshInterval.get());
})
.subscribe(interval -> {
logger.info("Authorization successful. Refreshing token in {} ms.", interval);
authorizationResultsSink.next(AmqpResponseCode.ACCEPTED);

final Duration nextRefresh = Duration.ofMillis(interval);
lastRefreshInterval.set(nextRefresh);
durationSourceSink.next(Duration.ofMillis(interval));
}, error -> {
logger.error("Error occurred while refreshing token that is not retriable. Not scheduling"
+ " refresh task. Use ActiveClientTokenManager.authorize() to schedule task again.", error);
hasScheduled.set(false);
durationSourceSink.complete();
authorizationResultsSink.error(error);
}, () -> {
logger.info("Completed refresh token task.");
});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ public AzureTokenManagerProvider(CbsAuthorizationType authorizationType, String
public TokenManager getTokenManager(Mono<ClaimsBasedSecurityNode> cbsNodeMono, String resource) {
final String scopes = getResourceString(resource);
final String tokenAudience = String.format(Locale.US, TOKEN_AUDIENCE_FORMAT, fullyQualifiedNamespace, resource);

logger.info("Creating new token manager for audience[{}], scopes[{}]", tokenAudience, scopes);
return new ActiveClientTokenManager(cbsNodeMono, tokenAudience, scopes);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.azure.core.amqp.exception.AmqpResponseCode;
import com.azure.core.exception.AzureException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mock;
Expand All @@ -25,7 +26,7 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;

public class ActiveClientTokenManagerTest {
class ActiveClientTokenManagerTest {
private static final String AUDIENCE = "an-audience-test";
private static final String SCOPES = "scopes-test";
private static final Duration TIMEOUT = Duration.ofSeconds(4);
Expand All @@ -34,12 +35,12 @@ public class ActiveClientTokenManagerTest {
private ClaimsBasedSecurityNode cbsNode;

@BeforeEach
public void setup() {
void setup() {
MockitoAnnotations.initMocks(this);
}

@AfterEach
public void teardown() {
void teardown() {
Mockito.framework().clearInlineMocks();
cbsNode = null;
}
Expand All @@ -48,7 +49,7 @@ public void teardown() {
* Verify that we can get successes and errors from CBS node.
*/
@Test
public void getAuthorizationResults() {
void getAuthorizationResults() {
// Arrange
final Mono<ClaimsBasedSecurityNode> cbsNodeMono = Mono.fromCallable(() -> cbsNode);
when(cbsNode.authorize(any(), any())).thenReturn(getNextExpiration(3));
Expand All @@ -60,8 +61,9 @@ public void getAuthorizationResults() {
.then(() -> tokenManager.authorize().block(TIMEOUT))
.expectNext(AmqpResponseCode.ACCEPTED)
.expectNext(AmqpResponseCode.ACCEPTED)
.then(tokenManager::close)
.verifyComplete();
.then(() -> tokenManager.close())
.expectComplete()
.verify();
}

/**
Expand All @@ -70,7 +72,7 @@ public void getAuthorizationResults() {
*/
@SuppressWarnings("unchecked")
@Test
public void getAuthorizationResultsSuccessFailure() {
void getAuthorizationResultsSuccessFailure() {
// Arrange
final Mono<ClaimsBasedSecurityNode> cbsNodeMono = Mono.fromCallable(() -> cbsNode);
final IllegalArgumentException error = new IllegalArgumentException("Some error");
Expand All @@ -83,6 +85,7 @@ public void getAuthorizationResultsSuccessFailure() {
StepVerifier.create(tokenManager.getAuthorizationResults())
.then(() -> tokenManager.authorize().block(TIMEOUT))
.expectNext(AmqpResponseCode.ACCEPTED)
.expectNext(AmqpResponseCode.ACCEPTED)
.expectError(IllegalArgumentException.class)
.verifyThenAssertThat()
.hasNotDroppedElements()
Expand All @@ -95,7 +98,7 @@ public void getAuthorizationResultsSuccessFailure() {
* Verify that we cannot authorize with CBS node when it has already been disposed of.
*/
@Test
public void cannotAuthorizeDisposedInstance() {
void cannotAuthorizeDisposedInstance() {
// Arrange
final Mono<ClaimsBasedSecurityNode> cbsNodeMono = Mono.fromCallable(() -> cbsNode);
when(cbsNode.authorize(any(), any())).thenReturn(getNextExpiration(2));
Expand All @@ -114,31 +117,64 @@ public void cannotAuthorizeDisposedInstance() {
*/
@SuppressWarnings("unchecked")
@Test
public void getAuthorizationResultsRetriableError() {
void getAuthorizationResultsRetriableError() {
// Arrange
final Mono<ClaimsBasedSecurityNode> cbsNodeMono = Mono.fromCallable(() -> cbsNode);
final AmqpException error = new AmqpException(true, AmqpErrorCondition.TIMEOUT_ERROR, "Timed out",
final AmqpException error = new AmqpException(false, AmqpErrorCondition.ARGUMENT_ERROR,
"Non-retryable argument error",
new AmqpErrorContext("Test-context-namespace"));

when(cbsNode.authorize(any(), any())).thenReturn(getNextExpiration(3), Mono.error(error),
getNextExpiration(5), getNextExpiration(10),
getNextExpiration(45));
when(cbsNode.authorize(any(), any())).thenReturn(getNextExpiration(5), Mono.error(error),
getNextExpiration(5));

// Act & Assert
try (ActiveClientTokenManager tokenManager = new ActiveClientTokenManager(cbsNodeMono, AUDIENCE, SCOPES)) {
StepVerifier.create(tokenManager.getAuthorizationResults())
.then(() -> tokenManager.authorize().block(TIMEOUT))
.expectError(AmqpException.class)
.verify();

StepVerifier.create(tokenManager.getAuthorizationResults())
.expectNext(AmqpResponseCode.ACCEPTED)
.expectNext(AmqpResponseCode.ACCEPTED)
.then(tokenManager::close)
.verifyComplete();
.expectErrorSatisfies(exception -> {
Assertions.assertTrue(exception instanceof AmqpException);

AmqpException amqpException = (AmqpException) exception;
Assertions.assertFalse(amqpException.isTransient());
Assertions.assertEquals(error.getErrorCondition(), amqpException.getErrorCondition());
})
.verify(Duration.ofSeconds(30));
}
}


/**
* Verify that the ActiveClientTokenManager does not get more authorization tasks.
*/
@SuppressWarnings("unchecked")
@Test
void getAuthorizationResultsNonRetriableError() {
// Arrange
final Mono<ClaimsBasedSecurityNode> cbsNodeMono = Mono.fromCallable(() -> cbsNode);
final AmqpException error = new AmqpException(true, AmqpErrorCondition.TIMEOUT_ERROR, "Test CBS node error.",
new AmqpErrorContext("Test-context-namespace"));

when(cbsNode.authorize(any(), any())).thenReturn(getNextExpiration(5), Mono.error(error),
getNextExpiration(5), getNextExpiration(10),
getNextExpiration(45));

// Act & Assert
final ActiveClientTokenManager tokenManager = new ActiveClientTokenManager(cbsNodeMono, AUDIENCE, SCOPES);

StepVerifier.create(tokenManager.getAuthorizationResults())
.then(() -> tokenManager.authorize().block(TIMEOUT))
.expectNext(AmqpResponseCode.ACCEPTED)
.expectNext(AmqpResponseCode.ACCEPTED)
.then(() -> {
System.out.println("Closing");
tokenManager.close();
})
.expectComplete()
.verify(Duration.ofSeconds(30));
}


private Mono<OffsetDateTime> getNextExpiration(long secondsToWait) {
return Mono.fromCallable(() -> OffsetDateTime.now(ZoneOffset.UTC).plusSeconds(secondsToWait));
}
Expand Down
Loading

0 comments on commit 0f27612

Please sign in to comment.