Skip to content

Commit

Permalink
Optimize CPU cycles for persistent sessions (keycloak#31702)
Browse files Browse the repository at this point in the history
Closes keycloak#31701

Signed-off-by: Alexander Schwartz <[email protected]>
  • Loading branch information
ahus1 authored Jul 29, 2024
1 parent 9478418 commit 00d8e06
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,17 @@ abstract public class PersistentSessionsChangelogBasedTransaction<K, V extends S
protected final KeycloakSession kcSession;
protected final Map<K, SessionUpdatesList<V>> updates = new HashMap<>();
protected final Map<K, SessionUpdatesList<V>> offlineUpdates = new HashMap<>();
private final List<SessionChangesPerformer<K, V>> changesPerformers;
private final String cacheName;
private final Cache<K, SessionEntityWrapper<V>> cache;
private final Cache<K, SessionEntityWrapper<V>> offlineCache;
private final RemoteCacheInvoker remoteCacheInvoker;
private final SessionFunction<V> lifespanMsLoader;
private final SessionFunction<V> maxIdleTimeMsLoader;
private final SessionFunction<V> offlineLifespanMsLoader;
private final SessionFunction<V> offlineMaxIdleTimeMsLoader;
private final ArrayBlockingQueue<PersistentUpdate> batchingQueue;
private final SerializeExecutionsByKey<K> serializerOnline;
private final SerializeExecutionsByKey<K> serializerOffline;

public PersistentSessionsChangelogBasedTransaction(KeycloakSession session,
String cacheName,
Expand All @@ -62,57 +66,17 @@ public PersistentSessionsChangelogBasedTransaction(KeycloakSession session,
SerializeExecutionsByKey<K> serializerOnline,
SerializeExecutionsByKey<K> serializerOffline) {
kcSession = session;

changesPerformers = new LinkedList<>();

if (batchingQueue != null) {
changesPerformers.add(new JpaChangesPerformer<>(cacheName, batchingQueue));
} else {
changesPerformers.add(new JpaChangesPerformer<>(cacheName, null) {
@Override
public void applyChanges() {
KeycloakModelUtils.runJobInTransaction(session.getKeycloakSessionFactory(),
super::applyChangesSynchronously);
}
});
}

if (cache != null) {
changesPerformers.add(new EmbeddedCachesChangesPerformer<>(cache, serializerOnline) {
@Override
public boolean shouldConsumeChange(V entity) {
return !entity.isOffline();
}
});
changesPerformers.add(new RemoteCachesChangesPerformer<>(session, cache, remoteCacheInvoker) {
@Override
public boolean shouldConsumeChange(V entity) {
return !entity.isOffline();
}
});
}

if (offlineCache != null) {
changesPerformers.add(new EmbeddedCachesChangesPerformer<>(offlineCache, serializerOffline){
@Override
public boolean shouldConsumeChange(V entity) {
return entity.isOffline();
}
});
changesPerformers.add(new RemoteCachesChangesPerformer<>(session, offlineCache, remoteCacheInvoker) {
@Override
public boolean shouldConsumeChange(V entity) {
return entity.isOffline();
}
});
}

this.cacheName = cacheName;
this.cache = cache;
this.offlineCache = offlineCache;
this.remoteCacheInvoker = remoteCacheInvoker;
this.lifespanMsLoader = lifespanMsLoader;
this.maxIdleTimeMsLoader = maxIdleTimeMsLoader;
this.offlineLifespanMsLoader = offlineLifespanMsLoader;
this.offlineMaxIdleTimeMsLoader = offlineMaxIdleTimeMsLoader;
this.batchingQueue = batchingQueue;
this.serializerOnline = serializerOnline;
this.serializerOffline = serializerOffline;
}

protected Cache<K, SessionEntityWrapper<V>> getCache(boolean offline) {
Expand Down Expand Up @@ -174,8 +138,57 @@ public SessionEntityWrapper<V> get(K key, boolean offline){
}
}

List<SessionChangesPerformer<K, V>> prepareChangesPerformers() {
List<SessionChangesPerformer<K, V>> changesPerformers = new LinkedList<>();

if (batchingQueue != null) {
changesPerformers.add(new JpaChangesPerformer<>(cacheName, batchingQueue));
} else {
changesPerformers.add(new JpaChangesPerformer<>(cacheName, null) {
@Override
public void applyChanges() {
KeycloakModelUtils.runJobInTransaction(kcSession.getKeycloakSessionFactory(),
super::applyChangesSynchronously);
}
});
}

if (cache != null) {
changesPerformers.add(new EmbeddedCachesChangesPerformer<>(cache, serializerOnline) {
@Override
public boolean shouldConsumeChange(V entity) {
return !entity.isOffline();
}
});
changesPerformers.add(new RemoteCachesChangesPerformer<>(kcSession, cache, remoteCacheInvoker) {
@Override
public boolean shouldConsumeChange(V entity) {
return !entity.isOffline();
}
});
}

if (offlineCache != null) {
changesPerformers.add(new EmbeddedCachesChangesPerformer<>(offlineCache, serializerOffline){
@Override
public boolean shouldConsumeChange(V entity) {
return entity.isOffline();
}
});
changesPerformers.add(new RemoteCachesChangesPerformer<>(kcSession, offlineCache, remoteCacheInvoker) {
@Override
public boolean shouldConsumeChange(V entity) {
return entity.isOffline();
}
});
}

return changesPerformers;
}

@Override
protected void commitImpl() {
List<SessionChangesPerformer<K, V>> changesPerformers = null;
for (Map.Entry<K, SessionUpdatesList<V>> entry : Stream.concat(updates.entrySet().stream(), offlineUpdates.entrySet().stream()).toList()) {
SessionUpdatesList<V> sessionUpdates = entry.getValue();
SessionEntityWrapper<V> sessionWrapper = sessionUpdates.getEntityWrapper();
Expand All @@ -193,13 +206,18 @@ protected void commitImpl() {
MergedUpdate<V> merged = MergedUpdate.computeUpdate(sessionUpdates.getUpdateTasks(), sessionWrapper, lifespanMs, maxIdleTimeMs);

if (merged != null) {
if (changesPerformers == null) {
changesPerformers = prepareChangesPerformers();
}
changesPerformers.stream()
.filter(performer -> performer.shouldConsumeChange(entity))
.forEach(p -> p.registerChange(entry, merged));
}
}

changesPerformers.forEach(SessionChangesPerformer::applyChanges);
if (changesPerformers != null) {
changesPerformers.forEach(SessionChangesPerformer::applyChanges);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ public void run() {

private void process(ArrayBlockingQueue<PersistentUpdate> queue) throws InterruptedException {
ArrayList<PersistentUpdate> batch = new ArrayList<>();
PersistentUpdate polled = queue.poll(100, TimeUnit.MILLISECONDS);
// Timeout is only a backup if interrupting the worker task in the stop() method didn't work as expected because someone else swallowed the interrupted flag.
PersistentUpdate polled = queue.poll(1, TimeUnit.SECONDS);
if (polled != null) {
batch.add(polled);
queue.drainTo(batch, maxBatchSize - 1);
Expand Down

0 comments on commit 00d8e06

Please sign in to comment.