diff --git a/core/deployment/src/main/java/io/quarkus/deployment/recording/BytecodeRecorderImpl.java b/core/deployment/src/main/java/io/quarkus/deployment/recording/BytecodeRecorderImpl.java index 6fa666ed0a059..23ff1bda6ba40 100644 --- a/core/deployment/src/main/java/io/quarkus/deployment/recording/BytecodeRecorderImpl.java +++ b/core/deployment/src/main/java/io/quarkus/deployment/recording/BytecodeRecorderImpl.java @@ -8,12 +8,14 @@ import java.lang.reflect.Array; import java.lang.reflect.Constructor; import java.lang.reflect.Field; +import java.lang.reflect.GenericArrayType; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.lang.reflect.Parameter; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Proxy; +import java.lang.reflect.WildcardType; import java.net.MalformedURLException; import java.net.URL; import java.time.Duration; @@ -71,6 +73,9 @@ import io.quarkus.runtime.StartupTask; import io.quarkus.runtime.annotations.IgnoreProperty; import io.quarkus.runtime.annotations.RelaxedValidation; +import io.quarkus.runtime.types.GenericArrayTypeImpl; +import io.quarkus.runtime.types.ParameterizedTypeImpl; +import io.quarkus.runtime.types.WildcardTypeImpl; /** * A class that can be used to record invocations to bytecode so they can be replayed later. This is done through the @@ -769,6 +774,68 @@ ResultHandle doLoad(MethodContext context, MethodCreator method, ResultHandle ar } }; } + } else if (param instanceof ParameterizedType parameterized) { + DeferredParameter raw = loadObjectInstance(parameterized.getRawType(), existing, + java.lang.reflect.Type.class, relaxedValidation); + DeferredParameter args = loadObjectInstance(parameterized.getActualTypeArguments(), existing, + java.lang.reflect.Type[].class, relaxedValidation); + DeferredParameter owner = loadObjectInstance(parameterized.getOwnerType(), existing, + java.lang.reflect.Type.class, relaxedValidation); + return new DeferredParameter() { + @Override + ResultHandle doLoad(MethodContext context, MethodCreator method, ResultHandle array) { + return method.newInstance(ofConstructor(ParameterizedTypeImpl.class, java.lang.reflect.Type.class, + java.lang.reflect.Type[].class, java.lang.reflect.Type.class), + context.loadDeferred(raw), context.loadDeferred(args), context.loadDeferred(owner)); + } + }; + } else if (param instanceof GenericArrayType array) { + DeferredParameter res = loadObjectInstance(array.getGenericComponentType(), existing, + java.lang.reflect.Type.class, relaxedValidation); + return new DeferredParameter() { + @Override + ResultHandle doLoad(MethodContext context, MethodCreator method, ResultHandle array) { + return method.newInstance(ofConstructor(GenericArrayTypeImpl.class, java.lang.reflect.Type.class), + context.loadDeferred(res)); + } + }; + } else if (param instanceof WildcardType wildcard) { + java.lang.reflect.Type[] upperBound = wildcard.getUpperBounds(); + java.lang.reflect.Type[] lowerBound = wildcard.getLowerBounds(); + if (lowerBound.length == 0 && upperBound.length == 1 && Object.class.equals(upperBound[0])) { + // unbounded + return new DeferredParameter() { + @Override + ResultHandle doLoad(MethodContext context, MethodCreator method, ResultHandle array) { + return method.invokeStaticMethod(ofMethod(WildcardTypeImpl.class, "defaultInstance", + WildcardType.class)); + } + }; + } else if (lowerBound.length == 0 && upperBound.length == 1) { + // upper bound + DeferredParameter res = loadObjectInstance(upperBound[0], existing, + java.lang.reflect.Type.class, relaxedValidation); + return new DeferredParameter() { + @Override + ResultHandle doLoad(MethodContext context, MethodCreator method, ResultHandle array) { + return method.invokeStaticMethod(ofMethod(WildcardTypeImpl.class, "withUpperBound", + WildcardType.class, java.lang.reflect.Type.class), context.loadDeferred(res)); + } + }; + } else if (lowerBound.length == 1) { + // lower bound + DeferredParameter res = loadObjectInstance(lowerBound[0], existing, + java.lang.reflect.Type.class, relaxedValidation); + return new DeferredParameter() { + @Override + ResultHandle doLoad(MethodContext context, MethodCreator method, ResultHandle array) { + return method.invokeStaticMethod(ofMethod(WildcardTypeImpl.class, "withLowerBound", + WildcardType.class, java.lang.reflect.Type.class), context.loadDeferred(res)); + } + }; + } else { + throw new UnsupportedOperationException("Unsupported wildcard type: " + wildcard); + } } else if (expectedType == boolean.class || expectedType == Boolean.class || param instanceof Boolean) { return new DeferredParameter() { @Override diff --git a/core/deployment/src/main/java/io/quarkus/deployment/types/TypeParser.java b/core/deployment/src/main/java/io/quarkus/deployment/types/TypeParser.java new file mode 100644 index 0000000000000..3e79d19e6f38d --- /dev/null +++ b/core/deployment/src/main/java/io/quarkus/deployment/types/TypeParser.java @@ -0,0 +1,233 @@ +package io.quarkus.deployment.types; + +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import io.quarkus.runtime.types.GenericArrayTypeImpl; +import io.quarkus.runtime.types.ParameterizedTypeImpl; +import io.quarkus.runtime.types.WildcardTypeImpl; + +/** + * Creates a {@link Type} by parsing the given string according to the following grammar: + * + *
+ * Type -> VoidType | PrimitiveType | ReferenceType
+ * VoidType -> 'void'
+ * PrimitiveType -> 'boolean' | 'byte' | 'short' | 'int'
+ *                | 'long' | 'float' | 'double' | 'char'
+ * ReferenceType -> PrimitiveType ('[' ']')+
+ *                | ClassType ('<' TypeArgument (',' TypeArgument)* '>')? ('[' ']')*
+ * ClassType -> FULLY_QUALIFIED_NAME
+ * TypeArgument -> ReferenceType | WildcardType
+ * WildcardType -> '?' | '?' ('extends' | 'super') ReferenceType
+ * 
+ * + * Notice that the resulting type never contains type variables, only "proper" types. + * Also notice that the grammar above does not support all kinds of nested types; + * it should be possible to add that later, if there's an actual need. + *

+ * Types produced by this parser can be transferred from build time to runtime + * via the recorder mechanism. + */ +public class TypeParser { + public static Type parse(String str) { + return new TypeParser(str).parse(); + } + + private final String str; + + private int pos = 0; + + private TypeParser(String str) { + this.str = Objects.requireNonNull(str); + } + + private Type parse() { + Type result; + + String token = nextToken(); + if (token.isEmpty()) { + throw unexpected(token); + } else if (token.equals("void")) { + result = void.class; + } else if (isPrimitiveType(token) && peekToken().isEmpty()) { + result = parsePrimitiveType(token); + } else { + result = parseReferenceType(token); + } + + expect(""); + return result; + } + + private Type parseReferenceType(String token) { + if (isPrimitiveType(token)) { + Type primitive = parsePrimitiveType(token); + return parseArrayType(primitive); + } else if (isClassType(token)) { + Type result = parseClassType(token); + if (peekToken().equals("<")) { + expect("<"); + List typeArguments = new ArrayList<>(); + typeArguments.add(parseTypeArgument()); + while (peekToken().equals(",")) { + expect(","); + typeArguments.add(parseTypeArgument()); + } + expect(">"); + result = new ParameterizedTypeImpl(result, typeArguments.toArray(Type[]::new)); + } + if (peekToken().equals("[")) { + return parseArrayType(result); + } + return result; + } else { + throw unexpected(token); + } + } + + private Type parseArrayType(Type elementType) { + expect("["); + expect("]"); + int dimensions = 1; + while (peekToken().equals("[")) { + expect("["); + expect("]"); + dimensions++; + } + + if (elementType instanceof Class clazz) { + return parseClassType("[".repeat(dimensions) + + (clazz.isPrimitive() ? clazz.descriptorString() : "L" + clazz.getName() + ";")); + } else { + Type result = elementType; + for (int i = 0; i < dimensions; i++) { + result = new GenericArrayTypeImpl(result); + } + return result; + } + } + + private Type parseTypeArgument() { + String token = nextToken(); + if (token.equals("?")) { + if (peekToken().equals("extends")) { + expect("extends"); + Type bound = parseReferenceType(nextToken()); + return WildcardTypeImpl.withUpperBound(bound); + } else if (peekToken().equals("super")) { + expect("super"); + Type bound = parseReferenceType(nextToken()); + return WildcardTypeImpl.withLowerBound(bound); + } else { + return WildcardTypeImpl.defaultInstance(); + } + } else { + return parseReferenceType(token); + } + } + + private boolean isPrimitiveType(String token) { + return token.equals("boolean") + || token.equals("byte") + || token.equals("short") + || token.equals("int") + || token.equals("long") + || token.equals("float") + || token.equals("double") + || token.equals("char"); + } + + private Type parsePrimitiveType(String token) { + return switch (token) { + case "boolean" -> boolean.class; + case "byte" -> byte.class; + case "short" -> short.class; + case "int" -> int.class; + case "long" -> long.class; + case "float" -> float.class; + case "double" -> double.class; + case "char" -> char.class; + default -> throw unexpected(token); + }; + } + + private boolean isClassType(String token) { + return !token.isEmpty() && Character.isJavaIdentifierStart(token.charAt(0)); + } + + private Type parseClassType(String token) { + try { + return Class.forName(token, true, Thread.currentThread().getContextClassLoader()); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException("Unknown class: " + token, e); + } + } + + // --- + + private void expect(String expected) { + String token = nextToken(); + if (!expected.equals(token)) { + throw unexpected(token); + } + } + + private IllegalArgumentException unexpected(String token) { + if (token.isEmpty()) { + throw new IllegalArgumentException("Unexpected end of input: " + str); + } + return new IllegalArgumentException("Unexpected token '" + token + "' at position " + (pos - token.length()) + + ": " + str); + } + + private String peekToken() { + // skip whitespace + while (pos < str.length() && Character.isWhitespace(str.charAt(pos))) { + pos++; + } + + // end of input + if (pos == str.length()) { + return ""; + } + + int pos = this.pos; + + // current char is a token on its own + if (isSpecial(str.charAt(pos))) { + return str.substring(pos, pos + 1); + } + + // token is a keyword or fully qualified name + int begin = pos; + while (pos < str.length() && Character.isJavaIdentifierStart(str.charAt(pos))) { + do { + pos++; + } while (pos < str.length() && Character.isJavaIdentifierPart(str.charAt(pos))); + + if (pos < str.length() && str.charAt(pos) == '.') { + pos++; + } else { + return str.substring(begin, pos); + } + } + + if (pos == str.length()) { + throw new IllegalArgumentException("Unexpected end of input: " + str); + } + throw new IllegalArgumentException("Unexpected character '" + str.charAt(pos) + "' at position " + pos + ": " + str); + } + + private String nextToken() { + String result = peekToken(); + pos += result.length(); + return result; + } + + private boolean isSpecial(char c) { + return c == ',' || c == '?' || c == '<' || c == '>' || c == '[' || c == ']'; + } +} diff --git a/core/deployment/src/test/java/io/quarkus/deployment/types/TypeParserTest.java b/core/deployment/src/test/java/io/quarkus/deployment/types/TypeParserTest.java new file mode 100644 index 0000000000000..b335c674d8889 --- /dev/null +++ b/core/deployment/src/test/java/io/quarkus/deployment/types/TypeParserTest.java @@ -0,0 +1,157 @@ +package io.quarkus.deployment.types; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.lang.reflect.Type; +import java.util.List; +import java.util.Map; + +import jakarta.enterprise.util.TypeLiteral; + +import org.junit.jupiter.api.Test; + +public class TypeParserTest { + @Test + public void testVoid() { + assertCorrect("void", void.class); + assertCorrect(" void", void.class); + assertCorrect("void ", void.class); + assertCorrect(" void ", void.class); + } + + @Test + public void testPrimitive() { + assertCorrect("boolean", boolean.class); + assertCorrect(" byte", byte.class); + assertCorrect("short ", short.class); + assertCorrect(" int ", int.class); + assertCorrect("\tlong", long.class); + assertCorrect("float\t", float.class); + assertCorrect("\tdouble\t", double.class); + assertCorrect(" \n char \n ", char.class); + } + + @Test + public void testPrimitiveArray() { + assertCorrect("boolean[]", boolean[].class); + assertCorrect("byte [][]", byte[][].class); + assertCorrect("short [] [] []", short[][][].class); + assertCorrect("int [ ] [ ] [ ] [ ]", int[][][][].class); + assertCorrect("long [][][]", long[][][].class); + assertCorrect(" float[][]", float[][].class); + assertCorrect(" double [] ", double[].class); + assertCorrect(" char [ ][ ] ", char[][].class); + } + + @Test + public void testClass() { + assertCorrect("java.lang.Object", Object.class); + assertCorrect("java.lang.String", String.class); + + assertCorrect(" java.lang.Boolean", Boolean.class); + assertCorrect("java.lang.Byte ", Byte.class); + assertCorrect(" java.lang.Short ", Short.class); + assertCorrect("\tjava.lang.Integer", Integer.class); + assertCorrect("java.lang.Long\t", Long.class); + assertCorrect("\tjava.lang.Float\t", Float.class); + assertCorrect(" java.lang.Double", Double.class); + assertCorrect("java.lang.Character ", Character.class); + } + + @Test + public void testClassArray() { + assertCorrect("java.lang.Object[]", Object[].class); + assertCorrect("java.lang.String[][]", String[][].class); + + assertCorrect("java.lang.Boolean[][][]", Boolean[][][].class); + assertCorrect("java.lang.Byte[][][][]", Byte[][][][].class); + assertCorrect("java.lang.Short[][][]", Short[][][].class); + assertCorrect("java.lang.Integer[][]", Integer[][].class); + assertCorrect("java.lang.Long[]", Long[].class); + assertCorrect("java.lang.Float[][]", Float[][].class); + assertCorrect("java.lang.Double[][][]", Double[][][].class); + assertCorrect("java.lang.Character[][][][]", Character[][][][].class); + } + + @Test + public void testParameterizedType() { + assertCorrect("java.util.List", new TypeLiteral>() { + }.getType()); + assertCorrect("java.util.Map", new TypeLiteral>() { + }.getType()); + + assertCorrect("java.util.List", new TypeLiteral>() { + }.getType()); + assertCorrect("java.util.Map>", new TypeLiteral>>() { + }.getType()); + } + + @Test + public void testParameterizedTypeArray() { + assertCorrect("java.util.List[]", new TypeLiteral[]>() { + }.getType()); + assertCorrect("java.util.Map[][]", new TypeLiteral[][]>() { + }.getType()); + } + + @Test + public void testIncorrect() { + assertIncorrect(""); + assertIncorrect(" "); + assertIncorrect("\t"); + assertIncorrect(" "); + assertIncorrect(" \n "); + + assertIncorrect("."); + assertIncorrect(","); + assertIncorrect("["); + assertIncorrect("]"); + assertIncorrect("<"); + assertIncorrect(">"); + + assertIncorrect("int."); + assertIncorrect("int,"); + assertIncorrect("int["); + assertIncorrect("int]"); + assertIncorrect("int[[]"); + assertIncorrect("int[]["); + assertIncorrect("int[]]"); + assertIncorrect("int[0]"); + assertIncorrect("int<"); + assertIncorrect("int>"); + assertIncorrect("int<>"); + + assertIncorrect("java.util.List<"); + assertIncorrect("java.util.List<>"); + assertIncorrect("java.util.List>"); + assertIncorrect("java.util.List"); + assertIncorrect("java.util.List>>"); + + assertIncorrect("java.util.List"); + assertIncorrect("java.util.Map"); + + assertIncorrect("java.lang.Integer."); + assertIncorrect("java .lang.Integer"); + assertIncorrect("java. lang.Integer"); + assertIncorrect("java . lang.Integer"); + assertIncorrect(".java.lang.Integer"); + assertIncorrect(".java.lang.Integer."); + + assertIncorrect("java.lang.Integer["); + assertIncorrect("java.lang.Integer[[]"); + assertIncorrect("java.lang.Integer[]["); + assertIncorrect("java.lang.Integer[]]"); + assertIncorrect("java.lang.Integer[0]"); + } + + private void assertCorrect(String str, Type expectedType) { + assertEquals(expectedType, TypeParser.parse(str)); + } + + private void assertIncorrect(String str) { + assertThrows(IllegalArgumentException.class, () -> TypeParser.parse(str)); + } +} diff --git a/core/runtime/src/main/java/io/quarkus/runtime/types/GenericArrayTypeImpl.java b/core/runtime/src/main/java/io/quarkus/runtime/types/GenericArrayTypeImpl.java new file mode 100644 index 0000000000000..be64b08c69622 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkus/runtime/types/GenericArrayTypeImpl.java @@ -0,0 +1,53 @@ +package io.quarkus.runtime.types; + +import java.lang.reflect.GenericArrayType; +import java.lang.reflect.Type; + +/** + * @author Marko Luksa + * @author Jozef Hartinger + */ +public class GenericArrayTypeImpl implements GenericArrayType { + + private Type genericComponentType; + + public GenericArrayTypeImpl(Type genericComponentType) { + this.genericComponentType = genericComponentType; + } + + public GenericArrayTypeImpl(Class rawType, Type... actualTypeArguments) { + this.genericComponentType = new ParameterizedTypeImpl(rawType, actualTypeArguments); + } + + @Override + public Type getGenericComponentType() { + return genericComponentType; + } + + @Override + public int hashCode() { + return ((genericComponentType == null) ? 0 : genericComponentType.hashCode()); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof GenericArrayType) { + GenericArrayType that = (GenericArrayType) obj; + if (genericComponentType == null) { + return that.getGenericComponentType() == null; + } else { + return genericComponentType.equals(that.getGenericComponentType()); + } + } else { + return false; + } + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(genericComponentType.toString()); + sb.append("[]"); + return sb.toString(); + } +} diff --git a/core/runtime/src/main/java/io/quarkus/runtime/types/ParameterizedTypeImpl.java b/core/runtime/src/main/java/io/quarkus/runtime/types/ParameterizedTypeImpl.java new file mode 100644 index 0000000000000..587d961b9c9b1 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkus/runtime/types/ParameterizedTypeImpl.java @@ -0,0 +1,86 @@ +package io.quarkus.runtime.types; + +import java.io.Serializable; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.Arrays; + +public class ParameterizedTypeImpl implements ParameterizedType, Serializable { + + private static final long serialVersionUID = -3005183010706452884L; + + private final Type[] actualTypeArguments; + private final Type rawType; + private final Type ownerType; + + public ParameterizedTypeImpl(Type rawType, Type... actualTypeArguments) { + this(rawType, actualTypeArguments, null); + } + + public ParameterizedTypeImpl(Type rawType, Type[] actualTypeArguments, Type ownerType) { + this.actualTypeArguments = actualTypeArguments; + this.rawType = rawType; + this.ownerType = ownerType; + } + + @Override + public Type[] getActualTypeArguments() { + return Arrays.copyOf(actualTypeArguments, actualTypeArguments.length); + } + + @Override + public Type getOwnerType() { + return ownerType; + } + + @Override + public Type getRawType() { + return rawType; + } + + @Override + public int hashCode() { + return Arrays.hashCode(actualTypeArguments) ^ (ownerType == null ? 0 : ownerType.hashCode()) + ^ (rawType == null ? 0 : rawType.hashCode()); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } else if (obj instanceof ParameterizedType that) { + Type thatOwnerType = that.getOwnerType(); + Type thatRawType = that.getRawType(); + return (ownerType == null ? thatOwnerType == null : ownerType.equals(thatOwnerType)) + && (rawType == null ? thatRawType == null : rawType.equals(thatRawType)) + && Arrays.equals(actualTypeArguments, that.getActualTypeArguments()); + } else { + return false; + } + + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + if (rawType instanceof Class) { + sb.append(((Class) rawType).getName()); + } else { + sb.append(rawType); + } + if (actualTypeArguments.length > 0) { + sb.append("<"); + for (Type actualType : actualTypeArguments) { + if (actualType instanceof Class) { + sb.append(((Class) actualType).getName()); + } else { + sb.append(actualType); + } + sb.append(", "); + } + sb.delete(sb.length() - 2, sb.length()); + sb.append(">"); + } + return sb.toString(); + } +} diff --git a/core/runtime/src/main/java/io/quarkus/runtime/types/WildcardTypeImpl.java b/core/runtime/src/main/java/io/quarkus/runtime/types/WildcardTypeImpl.java new file mode 100644 index 0000000000000..1a4ac54beec99 --- /dev/null +++ b/core/runtime/src/main/java/io/quarkus/runtime/types/WildcardTypeImpl.java @@ -0,0 +1,73 @@ +package io.quarkus.runtime.types; + +import java.lang.reflect.Type; +import java.lang.reflect.WildcardType; +import java.util.Arrays; + +/** + * This code was mainly copied from Weld codebase. + * + * Implementation of {@link WildcardType}. + * + * Note that per JLS a wildcard may define either the upper bound or the lower bound. A wildcard may not have multiple bounds. + * + * @author Jozef Hartinger + * + */ +public class WildcardTypeImpl implements WildcardType { + + public static WildcardType defaultInstance() { + return DEFAULT_INSTANCE; + } + + public static WildcardType withUpperBound(Type type) { + return new WildcardTypeImpl(new Type[] { type }, DEFAULT_LOWER_BOUND); + } + + public static WildcardType withLowerBound(Type type) { + return new WildcardTypeImpl(DEFAULT_UPPER_BOUND, new Type[] { type }); + } + + private static final Type[] DEFAULT_UPPER_BOUND = new Type[] { Object.class }; + private static final Type[] DEFAULT_LOWER_BOUND = new Type[0]; + private static final WildcardType DEFAULT_INSTANCE = new WildcardTypeImpl(DEFAULT_UPPER_BOUND, DEFAULT_LOWER_BOUND); + + private final Type[] upperBound; + private final Type[] lowerBound; + + private WildcardTypeImpl(Type[] upperBound, Type[] lowerBound) { + this.upperBound = upperBound; + this.lowerBound = lowerBound; + } + + @Override + public Type[] getUpperBounds() { + return upperBound; + } + + @Override + public Type[] getLowerBounds() { + return lowerBound; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (!(obj instanceof WildcardType)) { + return false; + } + WildcardType other = (WildcardType) obj; + return Arrays.equals(lowerBound, other.getLowerBounds()) && Arrays.equals(upperBound, other.getUpperBounds()); + } + + @Override + public int hashCode() { + // We deliberately use the logic from JDK/guava + return Arrays.hashCode(lowerBound) ^ Arrays.hashCode(upperBound); + } +} diff --git a/extensions/redis-cache/deployment/src/main/java/io/quarkus/cache/redis/deployment/RedisCacheProcessor.java b/extensions/redis-cache/deployment/src/main/java/io/quarkus/cache/redis/deployment/RedisCacheProcessor.java index 0dbce4bf6bfd4..a37e01324d1b1 100644 --- a/extensions/redis-cache/deployment/src/main/java/io/quarkus/cache/redis/deployment/RedisCacheProcessor.java +++ b/extensions/redis-cache/deployment/src/main/java/io/quarkus/cache/redis/deployment/RedisCacheProcessor.java @@ -16,6 +16,7 @@ import org.jboss.jandex.AnnotationInstance; import org.jboss.jandex.AnnotationValue; +import org.jboss.jandex.ClassType; import org.jboss.jandex.DotName; import org.jboss.jandex.ParameterizedType; import org.jboss.jandex.Type; @@ -34,6 +35,7 @@ import io.quarkus.deployment.annotations.Record; import io.quarkus.deployment.builditem.CombinedIndexBuildItem; import io.quarkus.deployment.builditem.nativeimage.ReflectiveClassBuildItem; +import io.quarkus.deployment.types.TypeParser; import io.quarkus.redis.deployment.client.RequestedRedisClientBuildItem; import io.quarkus.redis.runtime.client.config.RedisConfig; import io.smallrye.mutiny.Uni; @@ -69,11 +71,25 @@ void nativeImage(BuildProducer producer) { @BuildStep @Record(STATIC_INIT) - void determineValueTypes(RedisCacheBuildRecorder recorder, CombinedIndexBuildItem combinedIndex, + void determineKeyValueTypes(RedisCacheBuildRecorder recorder, CombinedIndexBuildItem combinedIndex, CacheNamesBuildItem cacheNamesBuildItem, RedisCachesBuildTimeConfig buildConfig) { - Map resolvedValuesTypesFromAnnotations = valueTypesFromCacheResultAnnotation(combinedIndex); - Map valueTypes = new HashMap<>(); + Map keyTypes = new HashMap<>(); + RedisCacheBuildTimeConfig defaultBuildTimeConfig = buildConfig.defaultConfig; + for (String cacheName : cacheNamesBuildItem.getNames()) { + RedisCacheBuildTimeConfig namedBuildTimeConfig = buildConfig.cachesConfig.get(cacheName); + + if (namedBuildTimeConfig != null && namedBuildTimeConfig.keyType.isPresent()) { + keyTypes.put(cacheName, TypeParser.parse(namedBuildTimeConfig.keyType.get())); + } else if (defaultBuildTimeConfig.keyType.isPresent()) { + keyTypes.put(cacheName, TypeParser.parse(defaultBuildTimeConfig.keyType.get())); + } + } + recorder.setCacheKeyTypes(keyTypes); + + Map resolvedValuesTypesFromAnnotations = valueTypesFromCacheResultAnnotation(combinedIndex); + + Map valueTypes = new HashMap<>(); Optional defaultValueType = buildConfig.defaultConfig.valueType; Set cacheNames = cacheNamesBuildItem.getNames(); for (String cacheName : cacheNames) { @@ -89,12 +105,13 @@ void determineValueTypes(RedisCacheBuildRecorder recorder, CombinedIndexBuildIte } } - if (valueType == null) { // TODO: does it make sense to use the return type of method annotated with @CacheResult as the last resort or should it override the default cache config? - valueType = resolvedValuesTypesFromAnnotations.get(cacheName); + if (valueType == null && resolvedValuesTypesFromAnnotations.containsKey(cacheName)) { + // TODO: does it make sense to use the return type of method annotated with @CacheResult as the last resort or should it override the default cache config? + valueType = typeToString(resolvedValuesTypesFromAnnotations.get(cacheName)); } if (valueType != null) { - valueTypes.put(cacheName, valueType); + valueTypes.put(cacheName, TypeParser.parse(valueType)); } else { throw new DeploymentException("Unable to determine the value type for '" + cacheName + "' Redis cache. An appropriate configuration value for 'quarkus.cache.redis." + cacheName @@ -104,7 +121,7 @@ void determineValueTypes(RedisCacheBuildRecorder recorder, CombinedIndexBuildIte recorder.setCacheValueTypes(valueTypes); } - private static Map valueTypesFromCacheResultAnnotation(CombinedIndexBuildItem combinedIndex) { + private static Map valueTypesFromCacheResultAnnotation(CombinedIndexBuildItem combinedIndex) { Map> valueTypesFromAnnotations = new HashMap<>(); // first go through @CacheResult instances and simply record the return types @@ -133,7 +150,7 @@ private static Map valueTypesFromCacheResultAnnotation(CombinedI return Collections.emptyMap(); } - Map result = new HashMap<>(); + Map result = new HashMap<>(); // now apply our resolution logic on the obtained types for (var entry : valueTypesFromAnnotations.entrySet()) { @@ -146,17 +163,15 @@ private static Map valueTypesFromCacheResultAnnotation(CombinedI } Type type = typeSet.iterator().next(); - String resolvedType = null; - if (type.kind() == Type.Kind.CLASS) { - resolvedType = type.asClassType().name().toString(); - } else if (type.kind() == Type.Kind.PRIMITIVE) { - resolvedType = type.asPrimitiveType().name().toString(); - } else if ((type.kind() == Type.Kind.PARAMETERIZED_TYPE) && UNI.equals(type.name())) { + Type resolvedType = null; + if (type.kind() == Type.Kind.PARAMETERIZED_TYPE && UNI.equals(type.name())) { ParameterizedType parameterizedType = type.asParameterizedType(); List arguments = parameterizedType.arguments(); if (arguments.size() == 1) { - resolvedType = arguments.get(0).name().toString(); + resolvedType = arguments.get(0); } + } else { + resolvedType = type; } if (resolvedType != null) { @@ -170,4 +185,48 @@ private static Map valueTypesFromCacheResultAnnotation(CombinedI return result; } + + private static String typeToString(Type type) { + StringBuilder result = new StringBuilder(); + typeToString(type, result); + return result.toString(); + } + + private static void typeToString(Type type, StringBuilder result) { + switch (type.kind()) { + case VOID, PRIMITIVE, CLASS -> result.append(type.name().toString()); + case ARRAY -> { + typeToString(type.asArrayType().elementType(), result); + result.append("[]".repeat(type.asArrayType().deepDimensions())); + } + case PARAMETERIZED_TYPE -> { + if (type.asParameterizedType().owner() != null) { + throw new IllegalArgumentException("Unsupported type: " + type); + } + + result.append(type.name().toString()); + result.append('<'); + boolean first = true; + for (Type typeArgument : type.asParameterizedType().arguments()) { + if (!first) { + result.append(", "); + } + typeToString(typeArgument, result); + first = false; + } + result.append('>'); + } + case WILDCARD_TYPE -> { + result.append('?'); + if (type.asWildcardType().superBound() != null) { + result.append(" super "); + typeToString(type.asWildcardType().superBound(), result); + } else if (type.asWildcardType().extendsBound() != ClassType.OBJECT_TYPE) { + result.append(" extends "); + typeToString(type.asWildcardType().extendsBound(), result); + } + } + default -> throw new IllegalArgumentException("Unsupported type: " + type); + } + } } diff --git a/extensions/redis-cache/deployment/src/test/java/io/quarkus/cache/redis/deployment/ComplexCachedService.java b/extensions/redis-cache/deployment/src/test/java/io/quarkus/cache/redis/deployment/ComplexCachedService.java new file mode 100644 index 0000000000000..ce7e73022906a --- /dev/null +++ b/extensions/redis-cache/deployment/src/test/java/io/quarkus/cache/redis/deployment/ComplexCachedService.java @@ -0,0 +1,37 @@ +package io.quarkus.cache.redis.deployment; + +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; + +import jakarta.enterprise.context.ApplicationScoped; + +import io.quarkus.cache.CacheResult; + +@ApplicationScoped +public class ComplexCachedService { + static final String CACHE_NAME_GENERIC = "test-cache-generic"; + static final String CACHE_NAME_ARRAY = "test-cache-array"; + static final String CACHE_NAME_GENERIC_ARRAY = "test-cache-generic-array"; + + @CacheResult(cacheName = CACHE_NAME_GENERIC) + public List genericReturnType(String key) { + return List.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + } + + @CacheResult(cacheName = CACHE_NAME_ARRAY) + public int[] arrayReturnType(String key) { + int[] result = new int[2]; + result[0] = ThreadLocalRandom.current().nextInt(); + result[1] = ThreadLocalRandom.current().nextInt(); + return result; + } + + @CacheResult(cacheName = CACHE_NAME_GENERIC_ARRAY) + public List[] genericArrayReturnType(String key) { + return new List[] { + List.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()), + List.of(UUID.randomUUID().toString(), UUID.randomUUID().toString()) + }; + } +} diff --git a/extensions/redis-cache/deployment/src/test/java/io/quarkus/cache/redis/deployment/ComplexTypesRedisCacheTest.java b/extensions/redis-cache/deployment/src/test/java/io/quarkus/cache/redis/deployment/ComplexTypesRedisCacheTest.java new file mode 100644 index 0000000000000..38c09f22623b1 --- /dev/null +++ b/extensions/redis-cache/deployment/src/test/java/io/quarkus/cache/redis/deployment/ComplexTypesRedisCacheTest.java @@ -0,0 +1,102 @@ +package io.quarkus.cache.redis.deployment; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.List; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.arc.Arc; +import io.quarkus.redis.datasource.RedisDataSource; +import io.quarkus.test.QuarkusUnitTest; + +public class ComplexTypesRedisCacheTest { + private static final String KEY_1 = "1"; + private static final String KEY_2 = "2"; + private static final String KEY_3 = "3"; + + @RegisterExtension + static final QuarkusUnitTest TEST = new QuarkusUnitTest() + .withApplicationRoot(jar -> jar.addClasses(ComplexCachedService.class, TestUtil.class)); + + @Inject + ComplexCachedService cachedService; + + @Test + public void testGeneric() { + RedisDataSource redisDataSource = Arc.container().select(RedisDataSource.class).get(); + List allKeysAtStart = TestUtil.allRedisKeys(redisDataSource); + + // STEP 1 + // Action: @CacheResult-annotated method call. + // Expected effect: method invoked and result cached. + // Verified by: STEP 2. + List value1 = cachedService.genericReturnType(KEY_1); + List newKeys = TestUtil.allRedisKeys(redisDataSource); + assertEquals(allKeysAtStart.size() + 1, newKeys.size()); + assertThat(newKeys).contains(expectedCacheKey(ComplexCachedService.CACHE_NAME_GENERIC, KEY_1)); + + // STEP 2 + // Action: same call as STEP 1. + // Expected effect: method not invoked and result coming from the cache. + // Verified by: same object reference between STEPS 1 and 2 results. + List value2 = cachedService.genericReturnType(KEY_1); + assertEquals(value1, value2); + assertEquals(allKeysAtStart.size() + 1, TestUtil.allRedisKeys(redisDataSource).size()); + } + + @Test + public void testArray() { + RedisDataSource redisDataSource = Arc.container().select(RedisDataSource.class).get(); + List allKeysAtStart = TestUtil.allRedisKeys(redisDataSource); + + // STEP 1 + // Action: @CacheResult-annotated method call. + // Expected effect: method invoked and result cached. + // Verified by: STEP 2. + int[] value1 = cachedService.arrayReturnType(KEY_2); + List newKeys = TestUtil.allRedisKeys(redisDataSource); + assertEquals(allKeysAtStart.size() + 1, newKeys.size()); + assertThat(newKeys).contains(expectedCacheKey(ComplexCachedService.CACHE_NAME_ARRAY, KEY_2)); + + // STEP 2 + // Action: same call as STEP 1. + // Expected effect: method not invoked and result coming from the cache. + // Verified by: same object reference between STEPS 1 and 2 results. + int[] value2 = cachedService.arrayReturnType(KEY_2); + assertArrayEquals(value1, value2); + assertEquals(allKeysAtStart.size() + 1, TestUtil.allRedisKeys(redisDataSource).size()); + } + + @Test + public void testGenericArray() { + RedisDataSource redisDataSource = Arc.container().select(RedisDataSource.class).get(); + List allKeysAtStart = TestUtil.allRedisKeys(redisDataSource); + + // STEP 1 + // Action: @CacheResult-annotated method call. + // Expected effect: method invoked and result cached. + // Verified by: STEP 2. + List[] value1 = cachedService.genericArrayReturnType(KEY_3); + List newKeys = TestUtil.allRedisKeys(redisDataSource); + assertEquals(allKeysAtStart.size() + 1, newKeys.size()); + assertThat(newKeys).contains(expectedCacheKey(ComplexCachedService.CACHE_NAME_GENERIC_ARRAY, KEY_3)); + + // STEP 2 + // Action: same call as STEP 1. + // Expected effect: method not invoked and result coming from the cache. + // Verified by: same object reference between STEPS 1 and 2 results. + List[] value2 = cachedService.genericArrayReturnType(KEY_3); + assertArrayEquals(value1, value2); + assertEquals(allKeysAtStart.size() + 1, TestUtil.allRedisKeys(redisDataSource).size()); + } + + private static String expectedCacheKey(String cacheName, String key) { + return "cache:" + cacheName + ":" + key; + } +} diff --git a/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCache.java b/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCache.java index 8127fde03e743..6b9faf1a9f102 100644 --- a/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCache.java +++ b/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCache.java @@ -3,6 +3,8 @@ import java.util.function.Function; import java.util.function.Supplier; +import jakarta.enterprise.util.TypeLiteral; + import io.quarkus.cache.Cache; import io.smallrye.mutiny.Uni; @@ -12,32 +14,17 @@ public interface RedisCache extends Cache { * When configured, gets the default type of the value stored in the cache. * The configured type is used when no type is passed into the {@link #get(Object, Class, Function)}. * - * @return the type, {@code null} if not configured. + * @deprecated should have never been exposed publicly + * @return the type, {@code null} if not configured or if not a {@code Class}. */ + @Deprecated Class getDefaultValueType(); @Override - default Uni get(K key, Function valueLoader) { - Class type = (Class) getDefaultValueType(); - if (type == null) { - throw new UnsupportedOperationException("Cannot use `get` method without a default type configured. " + - "Consider using the `get` method accepting the type or configure the default type for the cache " + - getName()); - } - return get(key, type, valueLoader); - } + Uni get(K key, Function valueLoader); - @SuppressWarnings("unchecked") @Override - default Uni getAsync(K key, Function> valueLoader) { - Class type = (Class) getDefaultValueType(); - if (type == null) { - throw new UnsupportedOperationException("Cannot use `getAsync` method without a default type configured. " + - "Consider using the `getAsync` method accepting the type or configure the default type for the cache " + - getName()); - } - return getAsync(key, type, valueLoader); - } + Uni getAsync(K key, Function> valueLoader); /** * Allows retrieving a value from the Redis cache. @@ -51,6 +38,18 @@ default Uni getAsync(K key, Function> valueLoader) { */ Uni get(K key, Class clazz, Function valueLoader); + /** + * Allows retrieving a value from the Redis cache. + * + * @param key the key + * @param type the type of the value + * @param valueLoader the value loader called when there is no value stored in the cache + * @param the type of key + * @param the type of value + * @return the Uni emitting the cached value. + */ + Uni get(K key, TypeLiteral type, Function valueLoader); + /** * Allows retrieving a value from the Redis cache. * @@ -63,6 +62,18 @@ default Uni getAsync(K key, Function> valueLoader) { */ Uni getAsync(K key, Class clazz, Function> valueLoader); + /** + * Allows retrieving a value from the Redis cache. + * + * @param key the key + * @param type the type of the value + * @param valueLoader the value loader called when there is no value stored in the cache + * @param the type of key + * @param the type of value + * @return the Uni emitting the cached value. + */ + Uni getAsync(K key, TypeLiteral type, Function> valueLoader); + /** * Put a value in the cache. * @@ -86,4 +97,6 @@ public V get() { Uni getOrDefault(K key, V defaultValue); Uni getOrNull(K key, Class clazz); + + Uni getOrNull(K key, TypeLiteral type); } diff --git a/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheBuildRecorder.java b/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheBuildRecorder.java index a55f128de6fbf..1ec0696eb0e9f 100644 --- a/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheBuildRecorder.java +++ b/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheBuildRecorder.java @@ -1,5 +1,6 @@ package io.quarkus.cache.redis.runtime; +import java.lang.reflect.Type; import java.util.*; import java.util.function.Supplier; @@ -20,7 +21,8 @@ public class RedisCacheBuildRecorder { private final RedisCachesBuildTimeConfig buildConfig; private final RuntimeValue redisCacheConfigRV; - private static Map valueTypes; + private static Map keyTypes; + private static Map valueTypes; public RedisCacheBuildRecorder(RedisCachesBuildTimeConfig buildConfig, RuntimeValue redisCacheConfigRV) { this.buildConfig = buildConfig; @@ -35,13 +37,12 @@ public boolean supports(Context context) { } @Override - @SuppressWarnings({ "rawtypes", "unchecked" }) public Supplier get(Context context) { return new Supplier() { @Override public CacheManager get() { - Set cacheInfos = RedisCacheInfoBuilder.build(context.cacheNames(), buildConfig, - redisCacheConfigRV.getValue(), valueTypes); + Set cacheInfos = RedisCacheInfoBuilder.build(context.cacheNames(), + redisCacheConfigRV.getValue(), keyTypes, valueTypes); if (cacheInfos.isEmpty()) { return new CacheManagerImpl(Collections.emptyMap()); } else { @@ -66,7 +67,11 @@ public CacheManager get() { }; } - public void setCacheValueTypes(Map valueTypes) { + public void setCacheKeyTypes(Map keyTypes) { + RedisCacheBuildRecorder.keyTypes = keyTypes; + } + + public void setCacheValueTypes(Map valueTypes) { RedisCacheBuildRecorder.valueTypes = valueTypes; } } diff --git a/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheImpl.java b/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheImpl.java index d0be182e72ad0..70d7764160c6d 100644 --- a/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheImpl.java +++ b/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheImpl.java @@ -1,16 +1,18 @@ package io.quarkus.cache.redis.runtime; +import java.lang.reflect.Type; import java.net.ConnectException; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.List; -import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; +import jakarta.enterprise.util.TypeLiteral; + import org.jboss.logging.Logger; import io.quarkus.arc.Arc; @@ -41,22 +43,12 @@ public class RedisCacheImpl extends AbstractCache implements RedisCache { private static final Logger log = Logger.getLogger(RedisCacheImpl.class); - private static final Map> PRIMITIVE_TO_CLASS_MAPPING = Map.of( - "int", Integer.class, - "byte", Byte.class, - "char", Character.class, - "short", Short.class, - "long", Long.class, - "float", Float.class, - "double", Double.class, - "boolean", Boolean.class); - private final Vertx vertx; private final Redis redis; private final RedisCacheInfo cacheInfo; - private final Class classOfValue; - private final Class classOfKey; + private final Type classOfValue; + private final Type classOfKey; private final Marshaller marshaller; @@ -82,18 +74,10 @@ public RedisCacheImpl(RedisCacheInfo cacheInfo, Vertx vertx, Redis redis, Suppli this.cacheInfo = cacheInfo; this.blockingAllowedSupplier = blockingAllowedSupplier; - try { - this.classOfKey = loadClass(this.cacheInfo.keyType); - } catch (ClassNotFoundException e) { - throw new IllegalArgumentException("Unable to load the class " + this.cacheInfo.keyType, e); - } + this.classOfKey = this.cacheInfo.keyType; if (this.cacheInfo.valueType != null) { - try { - this.classOfValue = loadClass(this.cacheInfo.valueType); - } catch (ClassNotFoundException e) { - throw new IllegalArgumentException("Unable to load the class " + this.cacheInfo.valueType, e); - } + this.classOfValue = this.cacheInfo.valueType; this.marshaller = new Marshaller(this.classOfValue, this.classOfKey); } else { this.classOfValue = null; @@ -108,13 +92,6 @@ private static boolean isRecomputableError(Throwable error) { || error instanceof ConnectionPoolTooBusyException; } - private Class loadClass(String type) throws ClassNotFoundException { - if (PRIMITIVE_TO_CLASS_MAPPING.containsKey(type)) { - return PRIMITIVE_TO_CLASS_MAPPING.get(type); - } - return Thread.currentThread().getContextClassLoader().loadClass(type); - } - @Override public String getName() { return Objects.requireNonNullElse(cacheInfo.name, "default-redis-cache"); @@ -127,7 +104,7 @@ public Object getDefaultKey() { @Override public Class getDefaultValueType() { - return classOfValue; + return classOfValue instanceof Class ? (Class) classOfValue : null; } private String encodeKey(K key) { @@ -147,8 +124,27 @@ public V get() { } } + @Override + public Uni get(K key, Function valueLoader) { + if (classOfValue == null) { + throw new UnsupportedOperationException("Cannot use `get` method without a default type configured. " + + "Consider using the `get` method accepting the type or configure the default type for the cache " + + getName()); + } + return get(key, classOfValue, valueLoader); + } + @Override public Uni get(K key, Class clazz, Function valueLoader) { + return get(key, (Type) clazz, valueLoader); + } + + @Override + public Uni get(K key, TypeLiteral type, Function valueLoader) { + return get(key, type.getType(), valueLoader); + } + + private Uni get(K key, Type type, Function valueLoader) { // With optimistic locking: // WATCH K // val = deserialize(GET K) @@ -171,9 +167,9 @@ public Uni apply(RedisConnection connection) { Uni startingPoint; if (cacheInfo.useOptimisticLocking) { startingPoint = watch(connection, encodedKey) - .chain(new GetFromConnectionSupplier<>(connection, clazz, encodedKey, marshaller)); + .chain(new GetFromConnectionSupplier<>(connection, type, encodedKey, marshaller)); } else { - startingPoint = new GetFromConnectionSupplier<>(connection, clazz, encodedKey, marshaller).get(); + startingPoint = new GetFromConnectionSupplier(connection, type, encodedKey, marshaller).get(); } return startingPoint @@ -226,8 +222,27 @@ public Uni apply(Throwable e) { }); } + @Override + public Uni getAsync(K key, Function> valueLoader) { + if (classOfValue == null) { + throw new UnsupportedOperationException("Cannot use `getAsync` method without a default type configured. " + + "Consider using the `getAsync` method accepting the type or configure the default type for the cache " + + getName()); + } + return getAsync(key, classOfValue, valueLoader); + } + @Override public Uni getAsync(K key, Class clazz, Function> valueLoader) { + return getAsync(key, (Type) clazz, valueLoader); + } + + @Override + public Uni getAsync(K key, TypeLiteral type, Function> valueLoader) { + return getAsync(key, type.getType(), valueLoader); + } + + private Uni getAsync(K key, Type type, Function> valueLoader) { byte[] encodedKey = marshaller.encode(computeActualKey(encodeKey(key))); return withConnection(new Function>() { @Override @@ -235,9 +250,9 @@ public Uni apply(RedisConnection connection) { Uni startingPoint; if (cacheInfo.useOptimisticLocking) { startingPoint = watch(connection, encodedKey) - .chain(new GetFromConnectionSupplier<>(connection, clazz, encodedKey, marshaller)); + .chain(new GetFromConnectionSupplier<>(connection, type, encodedKey, marshaller)); } else { - startingPoint = new GetFromConnectionSupplier<>(connection, clazz, encodedKey, marshaller).get(); + startingPoint = new GetFromConnectionSupplier(connection, type, encodedKey, marshaller).get(); } return startingPoint @@ -296,7 +311,6 @@ private void enforceDefaultType() { } } - @SuppressWarnings("unchecked") @Override public Uni getOrDefault(K key, V defaultValue) { enforceDefaultType(); @@ -304,20 +318,29 @@ public Uni getOrDefault(K key, V defaultValue) { return withConnection(new Function>() { @Override public Uni apply(RedisConnection redisConnection) { - return (Uni) doGet(redisConnection, encodedKey, classOfValue, marshaller); + return doGet(redisConnection, encodedKey, classOfValue, marshaller); } }).onItem().ifNull().continueWith(new StaticSupplier<>(defaultValue)); } @Override - @SuppressWarnings("unchecked") public Uni getOrNull(K key, Class clazz) { + return getOrNull(key, (Type) clazz); + } + + @Override + public Uni getOrNull(K key, TypeLiteral type) { + return getOrNull(key, type.getType()); + } + + private Uni getOrNull(K key, Type type) { enforceDefaultType(); byte[] encodedKey = marshaller.encode(computeActualKey(encodeKey(key))); return withConnection(new Function>() { @Override public Uni apply(RedisConnection redisConnection) { - return (Uni) doGet(redisConnection, encodedKey, classOfValue, marshaller); + // TODO maybe use `type` (if non-null?) instead of `classOfValue`? + return doGet(redisConnection, encodedKey, classOfValue, marshaller); } }); } @@ -408,7 +431,7 @@ private Uni watch(RedisConnection connection, byte[] keyToWatch) { .replaceWithVoid(); } - private Uni doGet(RedisConnection connection, byte[] encoded, Class clazz, + private Uni doGet(RedisConnection connection, byte[] encoded, Type clazz, Marshaller marshaller) { if (cacheInfo.expireAfterAccess.isPresent()) { Duration duration = cacheInfo.expireAfterAccess.get(); @@ -470,11 +493,11 @@ public V get() { private class GetFromConnectionSupplier implements Supplier> { private final RedisConnection connection; - private final Class clazz; + private final Type clazz; private final byte[] encodedKey; private final Marshaller marshaller; - public GetFromConnectionSupplier(RedisConnection connection, Class clazz, byte[] encodedKey, Marshaller marshaller) { + public GetFromConnectionSupplier(RedisConnection connection, Type clazz, byte[] encodedKey, Marshaller marshaller) { this.connection = connection; this.clazz = clazz; this.encodedKey = encodedKey; diff --git a/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheInfo.java b/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheInfo.java index b6408ab061e4e..e47448ccd2fbf 100644 --- a/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheInfo.java +++ b/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheInfo.java @@ -1,5 +1,6 @@ package io.quarkus.cache.redis.runtime; +import java.lang.reflect.Type; import java.time.Duration; import java.util.Optional; @@ -29,12 +30,12 @@ public class RedisCacheInfo { /** * The default type of the value stored in the cache. */ - public String valueType; + public Type valueType; /** * The key type, {@code String} by default. */ - public String keyType = String.class.getName(); + public Type keyType = String.class; /** * Whether the access to the cache should be using optimistic locking diff --git a/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheInfoBuilder.java b/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheInfoBuilder.java index 7fd99591c1ba9..e457d348bcdec 100644 --- a/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheInfoBuilder.java +++ b/extensions/redis-cache/runtime/src/main/java/io/quarkus/cache/redis/runtime/RedisCacheInfoBuilder.java @@ -1,5 +1,6 @@ package io.quarkus.cache.redis.runtime; +import java.lang.reflect.Type; import java.util.Collections; import java.util.Map; import java.util.Set; @@ -8,8 +9,8 @@ public class RedisCacheInfoBuilder { - public static Set build(Set cacheNames, RedisCachesBuildTimeConfig buildTimeConfig, - RedisCachesConfig runtimeConfig, Map valueTypes) { + public static Set build(Set cacheNames, RedisCachesConfig runtimeConfig, + Map keyTypes, Map valueTypes) { if (cacheNames.isEmpty()) { return Collections.emptySet(); } else { @@ -50,14 +51,8 @@ public static Set build(Set cacheNames, RedisCachesBuild cacheInfo.valueType = valueTypes.get(cacheName); - RedisCacheBuildTimeConfig defaultBuildTimeConfig = buildTimeConfig.defaultConfig; - RedisCacheBuildTimeConfig namedBuildTimeConfig = buildTimeConfig.cachesConfig - .get(cacheInfo.name); - - if (namedBuildTimeConfig != null && namedBuildTimeConfig.keyType.isPresent()) { - cacheInfo.keyType = namedBuildTimeConfig.keyType.get(); - } else if (defaultBuildTimeConfig.keyType.isPresent()) { - cacheInfo.keyType = defaultBuildTimeConfig.keyType.get(); + if (keyTypes.containsKey(cacheName)) { + cacheInfo.keyType = keyTypes.get(cacheName); } if (namedRuntimeConfig != null && namedRuntimeConfig.useOptimisticLocking.isPresent()) { diff --git a/extensions/redis-cache/runtime/src/test/java/io/quarkus/cache/redis/runtime/RedisCacheImplTest.java b/extensions/redis-cache/runtime/src/test/java/io/quarkus/cache/redis/runtime/RedisCacheImplTest.java index adf1bf53a1d0b..3362bf8e7bc9b 100644 --- a/extensions/redis-cache/runtime/src/test/java/io/quarkus/cache/redis/runtime/RedisCacheImplTest.java +++ b/extensions/redis-cache/runtime/src/test/java/io/quarkus/cache/redis/runtime/RedisCacheImplTest.java @@ -46,7 +46,7 @@ public void testPutInTheCache() { String k = UUID.randomUUID().toString(); RedisCacheInfo info = new RedisCacheInfo(); info.name = "foo"; - info.valueType = String.class.getName(); + info.valueType = String.class; info.expireAfterWrite = Optional.of(Duration.ofSeconds(2)); RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); assertThat(cache.get(k, s -> "hello").await().indefinitely()).isEqualTo("hello"); @@ -59,7 +59,7 @@ public void testExhaustConnectionPool() { String k = UUID.randomUUID().toString(); RedisCacheInfo info = new RedisCacheInfo(); info.name = "foo"; - info.valueType = String.class.getName(); + info.valueType = String.class; info.expireAfterWrite = Optional.of(Duration.ofSeconds(2)); Redis redis = Redis.createClient(vertx, new RedisOptions() @@ -86,7 +86,7 @@ public void testPutInTheCacheWithoutRedis() { String k = UUID.randomUUID().toString(); RedisCacheInfo info = new RedisCacheInfo(); info.name = "foo"; - info.valueType = String.class.getName(); + info.valueType = String.class; info.expireAfterWrite = Optional.of(Duration.ofSeconds(2)); RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); server.close(); @@ -98,7 +98,7 @@ public void testPutInTheCacheWithOptimisticLocking() { String k = UUID.randomUUID().toString(); RedisCacheInfo info = new RedisCacheInfo(); info.name = "foo"; - info.valueType = String.class.getName(); + info.valueType = String.class; info.expireAfterWrite = Optional.of(Duration.ofSeconds(2)); info.useOptimisticLocking = true; RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); @@ -111,7 +111,7 @@ public void testPutInTheCacheWithOptimisticLocking() { public void testPutAndWaitForInvalidation() { String k = UUID.randomUUID().toString(); RedisCacheInfo info = new RedisCacheInfo(); - info.valueType = String.class.getName(); + info.valueType = String.class; info.expireAfterWrite = Optional.of(Duration.ofSeconds(1)); RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); assertThat(cache.get(k, s -> "hello").await().indefinitely()).isEqualTo("hello"); @@ -125,7 +125,7 @@ public void testExpireAfterReadAndWrite() throws InterruptedException { String k = UUID.randomUUID().toString(); RedisCacheInfo info = new RedisCacheInfo(); info.name = "foo"; - info.valueType = String.class.getName(); + info.valueType = String.class; info.expireAfterWrite = Optional.of(Duration.ofSeconds(1)); info.expireAfterAccess = Optional.of(Duration.ofSeconds(1)); RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); @@ -144,7 +144,7 @@ public void testExpireAfterReadAndWrite() throws InterruptedException { @Test public void testManualInvalidation() { RedisCacheInfo info = new RedisCacheInfo(); - info.valueType = String.class.getName(); + info.valueType = String.class; info.expireAfterWrite = Optional.of(Duration.ofSeconds(10)); RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); cache.get("foo", s -> "hello").await().indefinitely(); @@ -188,7 +188,7 @@ public int hashCode() { public void testGetOrNull() { RedisCacheInfo info = new RedisCacheInfo(); info.expireAfterWrite = Optional.of(Duration.ofSeconds(10)); - info.valueType = Person.class.getName(); + info.valueType = Person.class; RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); Person person = cache.getOrNull("foo", Person.class).await().indefinitely(); assertThat(person).isNull(); @@ -208,7 +208,7 @@ public void testGetOrNull() { public void testGetOrDefault() { RedisCacheInfo info = new RedisCacheInfo(); info.expireAfterWrite = Optional.of(Duration.ofSeconds(10)); - info.valueType = Person.class.getName(); + info.valueType = Person.class; RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); Person person = cache.getOrDefault("foo", new Person("bar", "BAR")).await().indefinitely(); assertThat(person).isNotNull() @@ -234,7 +234,7 @@ public void testGetOrDefault() { public void testCacheNullValue() { RedisCacheInfo info = new RedisCacheInfo(); info.expireAfterWrite = Optional.of(Duration.ofSeconds(10)); - info.valueType = Person.class.getName(); + info.valueType = Person.class; RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); // with custom key @@ -248,8 +248,8 @@ public void testCacheNullValue() { public void testExceptionInValueLoader() { RedisCacheInfo info = new RedisCacheInfo(); info.expireAfterWrite = Optional.of(Duration.ofSeconds(10)); - info.valueType = Person.class.getName(); - info.keyType = Double.class.getName(); + info.valueType = Person.class; + info.keyType = Double.class; info.name = "foo"; RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); @@ -271,8 +271,8 @@ public void testExceptionInValueLoader() { public void testPutShouldPopulateCache() { RedisCacheInfo info = new RedisCacheInfo(); info.expireAfterWrite = Optional.of(Duration.ofSeconds(10)); - info.valueType = Person.class.getName(); - info.keyType = Integer.class.getName(); + info.valueType = Person.class; + info.keyType = Integer.class; RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); cache.put(1, new Person("luke", "skywalker")).await().indefinitely(); @@ -287,8 +287,8 @@ public void testPutShouldPopulateCache() { public void testPutShouldPopulateCacheWithOptimisticLocking() { RedisCacheInfo info = new RedisCacheInfo(); info.expireAfterWrite = Optional.of(Duration.ofSeconds(10)); - info.valueType = Person.class.getName(); - info.keyType = Integer.class.getName(); + info.valueType = Person.class; + info.keyType = Integer.class; info.useOptimisticLocking = true; RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); @@ -305,7 +305,7 @@ public void testThatConnectionsAreRecycled() { String k = UUID.randomUUID().toString(); RedisCacheInfo info = new RedisCacheInfo(); info.name = "foo"; - info.valueType = String.class.getName(); + info.valueType = String.class; info.expireAfterWrite = Optional.of(Duration.ofSeconds(1)); RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); @@ -330,7 +330,7 @@ public void testThatConnectionsAreRecycledWithOptimisticLocking() { String k = UUID.randomUUID().toString(); RedisCacheInfo info = new RedisCacheInfo(); info.name = "foo"; - info.valueType = String.class.getName(); + info.valueType = String.class; info.expireAfterWrite = Optional.of(Duration.ofSeconds(1)); info.useOptimisticLocking = true; RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); @@ -378,7 +378,7 @@ void testAsyncGetWithDefaultType() { RedisCacheInfo info = new RedisCacheInfo(); info.name = "star-wars"; info.expireAfterWrite = Optional.of(Duration.ofSeconds(2)); - info.valueType = Person.class.getName(); + info.valueType = Person.class; RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); assertThat(cache @@ -408,7 +408,7 @@ void testAsyncGetWithDefaultTypeWithoutRedis() { RedisCacheInfo info = new RedisCacheInfo(); info.name = "star-wars"; info.expireAfterWrite = Optional.of(Duration.ofSeconds(2)); - info.valueType = Person.class.getName(); + info.valueType = Person.class; RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); server.close(); @@ -434,7 +434,7 @@ void testAsyncGetWithDefaultTypeWithOptimisticLocking() { RedisCacheInfo info = new RedisCacheInfo(); info.name = "star-wars"; info.expireAfterWrite = Optional.of(Duration.ofSeconds(2)); - info.valueType = Person.class.getName(); + info.valueType = Person.class; info.useOptimisticLocking = true; RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); @@ -465,7 +465,7 @@ void testPut() { RedisCacheInfo info = new RedisCacheInfo(); info.name = "put"; info.expireAfterWrite = Optional.of(Duration.ofSeconds(2)); - info.valueType = Person.class.getName(); + info.valueType = Person.class; RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); Person luke = new Person("luke", "skywalker"); @@ -484,7 +484,7 @@ void testPutWithSupplier() { RedisCacheInfo info = new RedisCacheInfo(); info.name = "put"; info.expireAfterWrite = Optional.of(Duration.ofSeconds(2)); - info.valueType = Person.class.getName(); + info.valueType = Person.class; RedisCacheImpl cache = new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED); Person luke = new Person("luke", "skywalker"); @@ -498,17 +498,6 @@ void testPutWithSupplier() { .await().indefinitely()).isEqualTo(leia)); } - @Test - void testInitializationWithAnUnknownClass() { - RedisCacheInfo info = new RedisCacheInfo(); - info.name = "put"; - info.expireAfterWrite = Optional.of(Duration.ofSeconds(2)); - info.valueType = Person.class.getPackage().getName() + ".Missing"; - - assertThatThrownBy(() -> new RedisCacheImpl(info, vertx, redis, BLOCKING_ALLOWED)) - .isInstanceOf(IllegalArgumentException.class); - } - @Test void testGetDefaultKey() { RedisCacheInfo info = new RedisCacheInfo();