Skip to content

Commit

Permalink
fix: fallback to context classloader when loading class
Browse files Browse the repository at this point in the history
Currently classes are loaded with Class.forName, that uses the classloader from
the calling class. However, this might not work on all environments because of
different class loading setup (e.g in Quarkus testing).
This change uses the current thread context class loader as fallback when
Class.forName is not able to load a class.

Fixes vaadin/quarkus#139
Part of #1655
  • Loading branch information
mcollovati committed Jan 25, 2024
1 parent 82b0761 commit b7a0f6d
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import com.vaadin.testbench.unit.internal.MockVaadin;
import com.vaadin.testbench.unit.internal.Routes;
import com.vaadin.testbench.unit.internal.ShortcutsKt;
import com.vaadin.testbench.unit.internal.UtilsKt;
import com.vaadin.testbench.unit.mocks.MockedUI;

/**
Expand Down Expand Up @@ -93,8 +94,8 @@ private static Map<Class<?>, Class<? extends ComponentTester>> scanForTesters(
.extendsSuperclass(ComponentTester.class))
.forEach(classInfo -> {
try {
final Class<?> tester = Class
.forName(classInfo.getName());
final Class<?> tester = UtilsKt
.findClassOrThrow(classInfo.getName());
final Class<? extends Component>[] annotation = tester
.getAnnotation(Tests.class).value();
for (Class<? extends Component> component : annotation) {
Expand All @@ -108,7 +109,7 @@ private static Map<Class<?>, Class<? extends ComponentTester>> scanForTesters(

Arrays.stream(classes).map(clazz -> {
try {
return Class.forName(clazz);
return UtilsKt.findClassOrThrow(clazz);
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.vaadin.flow.server.VaadinServletService;
import com.vaadin.flow.spring.SpringServlet;
import com.vaadin.testbench.unit.internal.Routes;
import com.vaadin.testbench.unit.internal.UtilsKt;

/**
* Makes sure that the {@link #routes} are properly registered, and that
Expand Down Expand Up @@ -142,19 +143,13 @@ private static Authentication authentication() {
}

private static boolean hasSpringSecurity() {
try {
Class.forName(
"org.springframework.security.core.context.SecurityContextHolder");
return true;
} catch (ClassNotFoundException e) {
// Ignore error
}
return false;
return UtilsKt.findClass(
"org.springframework.security.core.context.SecurityContextHolder") != null;
}

private static UnaryOperator<HttpServletRequest> springSecurityRequestWrapper() {
try {
Constructor<?> constructor = Class.forName(
Constructor<?> constructor = UtilsKt.findClassOrThrow(
"org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestWrapper")
.getConstructor(HttpServletRequest.class, String.class);
return req -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import java.util.stream.Stream
import kotlin.reflect.KClass
import kotlin.reflect.KProperty1
import kotlin.streams.toList
import com.vaadin.testbench.unit.internal.findClassOrThrow

/**
* Returns the item on given row. Fails if the row index is invalid. The data provider is
Expand Down Expand Up @@ -483,8 +484,8 @@ public fun Grid<*>.expectRow(rowIndex: Int, vararg row: String) {
internal val HeaderRow.HeaderCell.column: Any
get() = _AbstractCell_getColumn.invoke(this)

private val abstractCellClass: Class<*> = Class.forName("com.vaadin.flow.component.grid.AbstractRow\$AbstractCell")
private val abstractColumnClass: Class<*> = Class.forName("com.vaadin.flow.component.grid.AbstractColumn")
private val abstractCellClass: Class<*> = findClassOrThrow("com.vaadin.flow.component.grid.AbstractRow\$AbstractCell")
private val abstractColumnClass: Class<*> = findClassOrThrow("com.vaadin.flow.component.grid.AbstractColumn")
private val _AbstractCell_getColumn: Method by lazy(LazyThreadSafetyMode.PUBLICATION) {
val m: Method = abstractCellClass.getDeclaredMethod("getColumn")
m.isAccessible = true
Expand All @@ -495,7 +496,7 @@ internal val <T> ColumnPathRenderer<T>.provider: ValueProvider<T, *>
get() = _ColumnPathRenderer_provider.get(this) as (ValueProvider<T,*>)

private val _ColumnPathRenderer_provider: Field by lazy(LazyThreadSafetyMode.PUBLICATION) {
val f: Field = Class.forName("com.vaadin.flow.component.grid.ColumnPathRenderer").getDeclaredField("provider")
val f: Field = findClassOrThrow("com.vaadin.flow.component.grid.ColumnPathRenderer").getDeclaredField("provider")
f.isAccessible = true
f
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,7 @@ val FormLayout.FormItem.label: String get() {
* The `HasLabel` interface has been introduced in Vaadin 21 but is missing in Vaadin 14.
* Use reflection.
*/
private val _HasLabel: Class<*>? = try {
Class.forName("com.vaadin.flow.component.HasLabel")
} catch (ex: ClassNotFoundException) {
null
}
private val _HasLabel: Class<*>? = findClass("com.vaadin.flow.component.HasLabel")
private val _HasLabel_getLabel: Method? = _HasLabel?.getDeclaredMethod("getLabel")
private val _HasLabel_setLabel: Method? = _HasLabel?.getDeclaredMethod("setLabel", String::class.java)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ data class Routes(
.acceptPackages(*(packageNames.map { it ?: "" }.toTypedArray()))
classGraph.scan().use { scanResult: ScanResult ->
scanResult.getClassesWithAnnotation(Route::class.java.name).mapTo(routes) { info: ClassInfo ->
Class.forName(info.name).asSubclass(Component::class.java)
findClassOrThrow(info.name).asSubclass(Component::class.java)
}
scanResult.getClassesImplementing(HasErrorParameter::class.java.name).mapTo(errorRoutes) { info: ClassInfo ->
Class.forName(info.name).asSubclass(HasErrorParameter::class.java)
findClassOrThrow(info.name).asSubclass(HasErrorParameter::class.java)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,8 @@ val Component.isTemplate: Boolean
*/
var testingLifecycleHook: TestingLifecycleHook = TestingLifecycleHook.default

private val _ConfirmDialog_Class: Class<*>? = try {
Class.forName("com.vaadin.flow.component.confirmdialog.ConfirmDialog")
} catch (e: ClassNotFoundException) {
null
}
private val _ConfirmDialog_Class: Class<*>? = findClass("com.vaadin.flow.component.confirmdialog.ConfirmDialog")

private val _ConfirmDialog_isOpened: Method? =
_ConfirmDialog_Class?.getMethod("isOpened")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ import java.io.ByteArrayOutputStream
import java.io.ObjectInputStream
import java.io.ObjectOutputStream
import java.io.Serializable
import jakarta.servlet.Servlet
import jakarta.servlet.ServletContext
import kotlin.jvm.Throws
import com.vaadin.flow.component.UI
import com.vaadin.flow.internal.ReflectTools
import com.vaadin.flow.router.HasErrorParameter
Expand All @@ -35,11 +34,19 @@ import elemental.json.Json
import elemental.json.JsonArray
import elemental.json.JsonValue
import elemental.json.impl.JreJsonValue
import jakarta.servlet.Servlet
import jakarta.servlet.ServletContext


fun Serializable.serializeToBytes(): ByteArray =
ByteArrayOutputStream().use { ObjectOutputStream(it).writeObject(this); it }
.toByteArray()

inline fun <reified T : Serializable> ByteArray.deserialize(): T =
ObjectInputStream(inputStream()).readObject() as T

fun Serializable.serializeToBytes(): ByteArray = ByteArrayOutputStream().use { ObjectOutputStream(it).writeObject(this); it }.toByteArray()
inline fun <reified T: Serializable> ByteArray.deserialize(): T = ObjectInputStream(inputStream()).readObject() as T
inline fun <reified T: Serializable> T.serializeDeserialize(): T = serializeToBytes().deserialize<T>()
inline fun <reified T : Serializable> T.serializeDeserialize(): T =
serializeToBytes().deserialize<T>()

val IntRange.size: Int get() = (endInclusive + 1 - start).coerceAtLeast(0)

Expand Down Expand Up @@ -88,7 +95,7 @@ fun List<JsonValue>.unwrap(): List<Any?> =
/**
* Unwraps this value into corresponding Java type. Unwraps arrays recursively.
*/
fun JsonValue.unwrap(): Any? = when(this) {
fun JsonValue.unwrap(): Any? = when (this) {
is JsonArray -> this.toList().unwrap()
else -> (this as JreJsonValue).`object`
}
Expand All @@ -109,9 +116,14 @@ internal fun String.parseJvmVersion(): Int {
}

private val regexWhitespace = Regex("\\s+")
internal fun String.splitByWhitespaces(): List<String> = split(regexWhitespace).filterNot { it.isBlank() }
internal fun String.splitByWhitespaces(): List<String> =
split(regexWhitespace).filterNot { it.isBlank() }

internal fun String.containsWhitespace(): Boolean = any { it.isWhitespace() }
internal fun String.ellipsize(maxLength: Int, ellipsize: String = "..."): String {
internal fun String.ellipsize(
maxLength: Int,
ellipsize: String = "..."
): String {
require(maxLength >= ellipsize.length) { "maxLength must be at least the size of ellipsize $ellipsize but it was $maxLength" }
return when {
(length <= maxLength) || (length <= ellipsize.length) -> this
Expand All @@ -125,24 +137,24 @@ internal fun String.ellipsize(maxLength: Int, ellipsize: String = "..."): String
* [HasErrorParameter] interface.
*/
internal fun Class<*>.getErrorParameterType(): Class<*>? =
ReflectTools.getGenericInterfaceType(this, HasErrorParameter::class.java)
ReflectTools.getGenericInterfaceType(this, HasErrorParameter::class.java)

internal val Class<*>.isRouteNotFound: Boolean
get() = getErrorParameterType() == NotFoundException::class.java

val currentRequest: VaadinRequest
get() = VaadinService.getCurrentRequest()
?: throw IllegalStateException("No current request. Have you called MockVaadin.setup()?")
?: throw IllegalStateException("No current request. Have you called MockVaadin.setup()?")
val currentResponse: VaadinResponse
get() = VaadinService.getCurrentResponse()
?: throw IllegalStateException("No current response. Have you called MockVaadin.setup()?")
?: throw IllegalStateException("No current response. Have you called MockVaadin.setup()?")

/**
* Returns the [UI.getCurrent]; fails with informative error message if the UI.getCurrent() is null.
*/
val currentUI: UI
get() = UI.getCurrent()
?: throw IllegalStateException("UI.getCurrent() is null. Have you called MockVaadin.setup()?")
?: throw IllegalStateException("UI.getCurrent() is null. Have you called MockVaadin.setup()?")

/**
* Retrieves the mock request which backs up [VaadinRequest].
Expand Down Expand Up @@ -175,12 +187,33 @@ val Servlet.isInitialized: Boolean get() = servletConfig != null
internal fun Class<*>.hasCustomToString(): Boolean =
getMethod("toString").declaringClass != java.lang.Object::class.java

internal val polymerTemplateClass =
internal val polymerTemplateClass = findClass(
"com.vaadin.flow.component.polymertemplate.PolymerTemplate"
)

internal fun hasPolymerTemplates(): Boolean = polymerTemplateClass != null

internal fun findClass(className: String): Class<*>? {
try {
Class.forName("com.vaadin.flow.component.polymertemplate.PolymerTemplate")
return Class.forName(className)
} catch (ex: ClassNotFoundException) {
null
try {
return Class.forName(
className,
true,
Thread.currentThread().contextClassLoader
)
} catch (ex: ClassNotFoundException) {
return null
}
}
}

internal fun hasPolymerTemplates() : Boolean = polymerTemplateClass != null

@Throws(ClassNotFoundException::class)
internal fun findClassOrThrow(className: String): Class<*>{
val clazz = findClass(className)
if (clazz == null) {
throw ClassNotFoundException(className)
}
return clazz
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import com.vaadin.flow.server.VaadinContext
import com.vaadin.flow.server.VaadinServlet
import com.vaadin.flow.server.VaadinServletContext
import com.vaadin.flow.server.startup.LookupServletContainerInitializer
import com.vaadin.testbench.unit.internal.findClass
import com.vaadin.testbench.unit.internal.findClassOrThrow
import elemental.json.Json
import elemental.json.JsonObject

Expand Down Expand Up @@ -56,9 +58,9 @@ object MockVaadinHelper {
// the same flow-build-info.json that Vaadin reads.

val ctx: VaadinContext = MockVaadinHelper.createMockVaadinContext()
val acf = lookup(ctx, Class.forName("com.vaadin.flow.server.startup.ApplicationConfigurationFactory"))
val acf = lookup(ctx, findClassOrThrow("com.vaadin.flow.server.startup.ApplicationConfigurationFactory"))
checkNotNull(acf) { "ApplicationConfigurationFactory is null" }
val dacfClass = Class.forName("com.vaadin.flow.server.startup.DefaultApplicationConfigurationFactory")
val dacfClass = findClassOrThrow("com.vaadin.flow.server.startup.DefaultApplicationConfigurationFactory")
if (dacfClass.isInstance(acf)) {
val m = dacfClass.getDeclaredMethod("getTokenFileFromClassloader", VaadinContext::class.java)
m.isAccessible = true
Expand Down Expand Up @@ -97,16 +99,13 @@ object MockVaadinHelper {
val loaders = mutableSetOf<Class<*>>(
*lookupServices.toTypedArray(),
LookupInitializer::class.java,
Class.forName("com.vaadin.flow.di.LookupInitializer${'$'}ResourceProviderImpl")
findClassOrThrow("com.vaadin.flow.di.LookupInitializer${'$'}ResourceProviderImpl")
)

fun tryLoad(clazz: String) {
try {
loaders.add(Class.forName(clazz))
} catch (ex: ClassNotFoundException) {
// sometimes customers don't include entire vaadin-core and exclude stuff like fusion on purpose.
// load the class only if it exists.
}
fun tryLoad(className: String) {
// sometimes customers don't include entire vaadin-core and exclude stuff like fusion on purpose.
// load the class only if it exists.
findClass(className)?.let { clazz -> loaders.add(clazz) }
}

tryLoad("com.vaadin.flow.component.polymertemplate.rpc.PolymerPublishedEventRpcHandler")
Expand Down

0 comments on commit b7a0f6d

Please sign in to comment.