Skip to content

Commit

Permalink
Creates S3 Cache Engine (deepjavalibrary#719)
Browse files Browse the repository at this point in the history
* S3 Cache Engine

This creates the S3 Cache engine. It is put into the same cache plugin by
expanding the DDB plugin to handle it as well.

Alongside this, there is some work done to synchronize the efforts on cache
engines. A new BaseCacheEngine class is created to contain some common logic
between the cache engines. The former tests for the DDB cache engine were
generalized a bit and turned into a suite that can be and is run for all of the
three supported cache engines. This ensures (and fixes) some inconsistencies in
behavior among the cache engines.

Also important is that it adds a new test dependency on
localstack (https://localstack.cloud/). This runs a local AWS clone inside a
docker container and is used to verify the running of the S3 cache.

* Fix typo and add skip for failure to start localstack
  • Loading branch information
zachgk authored May 16, 2023
1 parent 3f86ca6 commit dae3750
Show file tree
Hide file tree
Showing 15 changed files with 771 additions and 260 deletions.
7 changes: 5 additions & 2 deletions plugins/cache/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# DJL Serving - DynamoDB Paginator Plugin
# DJL Serving - Cache Paginator Plugin

Allows the model server to use DynamoDB Cache engine.
Allows the model server to use additional cache engine types:

- DynamoDB Cache
- S3 Cache
5 changes: 4 additions & 1 deletion plugins/cache/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ configurations {
dependencies {
api platform("ai.djl:bom:${project.version}")
implementation project(":serving")
api "ai.djl.aws:aws-ai"

api platform("software.amazon.awssdk:bom:${awssdk_version}")
api "software.amazon.awssdk:dynamodb"
api "ai.djl.aws:aws-ai"
api "software.amazon.awssdk:s3"

testImplementation("org.testng:testng:${testng_version}") {
exclude group: "junit", module: "junit"
}
testImplementation "com.amazonaws:DynamoDBLocal:1.21.1"
testImplementation "cloud.localstack:localstack-utils:0.2.15"

exclusion project(":serving")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.ddbcache;
package ai.djl.serving.cache;

import ai.djl.aws.s3.S3RepositoryFactory;
import ai.djl.repository.Repository;
import ai.djl.serving.cache.CacheManager;
import ai.djl.serving.plugins.RequestHandler;
import ai.djl.util.Utils;

Expand All @@ -27,22 +26,49 @@

import software.amazon.awssdk.core.exception.SdkClientException;

/** A plugin handles DynamoDB caching. */
public class DdbCachePlugin implements RequestHandler<Void> {
import java.util.concurrent.CompletionException;

private static final Logger logger = LoggerFactory.getLogger(DdbCachePlugin.class);
/** A plugin handles caching options. */
public class CachePlugin implements RequestHandler<Void> {

/** Constructs a new {@code DdbCachePlugin} instance. */
public DdbCachePlugin() {
private static final Logger logger = LoggerFactory.getLogger(CachePlugin.class);

/** Constructs a new {@code CachePlugin} instance. */
public CachePlugin() {
Repository.registerRepositoryFactory(new S3RepositoryFactory());
boolean multiTenant =
Boolean.parseBoolean(Utils.getEnvOrSystemProperty("SERVING_CACHE_MULTITENANT"));
if (Boolean.parseBoolean(Utils.getEnvOrSystemProperty("SERVING_DDB_CACHE"))) {
try {
DdbCacheEngine engine = DdbCacheEngine.newInstance();
CacheManager.setCacheEngine(engine);
logger.info("DynamoDB cache is enabled.");
} catch (SdkClientException e) {
logger.warn("Failed to create DynamoDB", e);
logger.warn("Failed to create DynamoDB cache", e);
}
} else if (Boolean.parseBoolean(Utils.getEnvOrSystemProperty("SERVING_S3_CACHE"))) {
try {
String bucket = Utils.getEnvOrSystemProperty("SERVING_S3_CACHE_BUCKET");
String keyPrefix = Utils.getEnvOrSystemProperty("SERVING_S3_CACHE_KEY_PREFIX");
S3CacheEngine engine = new S3CacheEngine(multiTenant, bucket, keyPrefix);
if (Boolean.parseBoolean(
Utils.getEnvOrSystemProperty("SERVING_S3_CACHE_AUTOCREATE"))) {
engine.createBucketIfNotExists().join();
}
CacheManager.setCacheEngine(engine);
logger.info("S3 cache is enabled.");
} catch (CompletionException e) {
logger.warn("Failed to create S3 cache ", e);
}
} else {
String capacity = Utils.getEnvOrSystemProperty("SERVING_MEMORY_CACHE_CAPACITY");
MemoryCacheEngine engine;
if (capacity != null) {
engine = new MemoryCacheEngine(multiTenant, Integer.parseInt(capacity));
} else {
engine = new MemoryCacheEngine(multiTenant);
}
CacheManager.setCacheEngine(engine);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,9 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.ddbcache;
package ai.djl.serving.cache;

import ai.djl.inference.streaming.ChunkedBytesSupplier;
import ai.djl.modality.Output;
import ai.djl.ndarray.BytesSupplier;
import ai.djl.serving.cache.CacheEngine;
import ai.djl.util.Utils;

import org.slf4j.Logger;
Expand Down Expand Up @@ -53,12 +50,10 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;

/** A {@link CacheEngine} that stores elements in DynamoDB. */
public final class DdbCacheEngine implements CacheEngine {
public final class DdbCacheEngine extends BaseCacheEngine {

private static final Logger logger = LoggerFactory.getLogger(DdbCacheEngine.class);

Expand All @@ -75,7 +70,6 @@ public final class DdbCacheEngine implements CacheEngine {

private DynamoDbClient ddbClient;
private long cacheTtl;
private int writeBatch;

/**
* Constructs a {@link DdbCacheEngine}.
Expand All @@ -85,7 +79,7 @@ public final class DdbCacheEngine implements CacheEngine {
private DdbCacheEngine(DynamoDbClient ddbClient) {
this.ddbClient = ddbClient;
cacheTtl = Duration.ofMillis(30).toMillis();
writeBatch = Integer.parseInt(Utils.getenv("SERVING_DDB_BATCH", "5"));
writeBatch = Integer.parseInt(Utils.getenv("SERVING_CACHE_BATCH", "5"));
}

/**
Expand Down Expand Up @@ -160,63 +154,26 @@ public boolean isMultiTenant() {
return false;
}

/** {@inheritDoc} */
@Override
public CompletableFuture<Void> put(String key, Output output) {
return CompletableFuture.<Void>supplyAsync(
() -> {
String ttl = String.valueOf(System.currentTimeMillis() + cacheTtl);
BytesSupplier supplier = output.getData();
if (supplier instanceof ChunkedBytesSupplier) {
Output o = new Output();
o.setCode(output.getCode());
o.setMessage(output.getMessage());
o.setProperties(output.getProperties());
ChunkedBytesSupplier cbs = (ChunkedBytesSupplier) supplier;
int index = 0;
writeDdb(key, o, cbs.pollChunk(), index++, ttl, !cbs.hasNext());
List<byte[]> list = new ArrayList<>(writeBatch);
while (cbs.hasNext()) {
try {
list.add(cbs.nextChunk(1, TimeUnit.MINUTES));
} catch (InterruptedException e) {
throw new IllegalStateException(e);
}
if (list.size() >= writeBatch) {
byte[] batch = join(list);
writeDdb(key, null, batch, index++, ttl, !cbs.hasNext());
list.clear();
}
}
if (!list.isEmpty()) {
byte[] batch = join(list);
writeDdb(key, null, batch, index, ttl, true);
}
} else {
boolean last = output.getCode() != 202;
writeDdb(key, output, null, -1, ttl, last);
}
return null;
})
.exceptionally(
t -> {
logger.warn("Failed to write to DynamoDB", t);
return null;
});
protected void putSingle(String key, Output output, boolean last) {
String ttl = String.valueOf(System.currentTimeMillis() + cacheTtl);
writeDdb(key, output, null, -1, ttl, last);
}

/** {@inheritDoc} */
@Override
public Output get(String key, int limit) {
int start = -1;
if (key.length() > 36) {
start = Integer.parseInt(key.substring(36));
key = key.substring(0, 36);
}
protected void putStream(String key, Output output, byte[] buf, int index, boolean last) {
String ttl = String.valueOf(System.currentTimeMillis() + cacheTtl);
writeDdb(key, output, buf, index, ttl, last);
}

/** {@inheritDoc} */
@Override
public Output get(String key, int start, int limit) {
int shiftedStart = start == 0 ? -1 : start;
Map<String, AttributeValue> attrValues = new ConcurrentHashMap<>();
attrValues.put(':' + CACHE_ID, AttributeValue.builder().s(key).build());
attrValues.put(':' + INDEX, AttributeValue.builder().n(String.valueOf(start)).build());
attrValues.put(
':' + INDEX, AttributeValue.builder().n(String.valueOf(shiftedStart)).build());
QueryRequest request =
QueryRequest.builder()
.tableName(TABLE_NAME)
Expand Down Expand Up @@ -258,11 +215,13 @@ public Output get(String key, int limit) {
start = Integer.parseInt(item.get(INDEX).n());
}
if (!list.isEmpty()) {
output.add(join(list));
output.add(joinBytes(list));
}
if (!complete) {
output.addProperty("x-next-token", key + start);
output.addProperty("X-Amzn-SageMaker-Custom-Attributes", "x-next-token=" + key + start);
String startString = start <= 0 ? "" : Integer.toString(start);
output.addProperty("x-next-token", key + startString);
output.addProperty(
"X-Amzn-SageMaker-Custom-Attributes", "x-next-token=" + key + startString);
}
return output;
}
Expand Down Expand Up @@ -305,20 +264,6 @@ private Output decode(AttributeValue header) {
}
}

byte[] join(List<byte[]> list) {
int size = 0;
for (byte[] buf : list) {
size += buf.length;
}
byte[] batch = new byte[size];
size = 0;
for (byte[] buf : list) {
System.arraycopy(buf, 0, batch, size, buf.length);
size += buf.length;
}
return batch;
}

void writeDdb(String key, Output output, byte[] buf, int index, String cacheTtl, boolean last) {
Map<String, AttributeValue> map = new ConcurrentHashMap<>();
try {
Expand Down
Loading

0 comments on commit dae3750

Please sign in to comment.