From 7e80483664d4892bd3a10140579cf1bdecd84eb8 Mon Sep 17 00:00:00 2001 From: David Phillips Date: Sun, 8 Sep 2024 12:15:59 -0700 Subject: [PATCH 1/5] Remove TODO comment --- .../dylibso/chicory/function/processor/FunctionProcessor.java | 1 - 1 file changed, 1 deletion(-) diff --git a/function-processor/src/main/java/com/dylibso/chicory/function/processor/FunctionProcessor.java b/function-processor/src/main/java/com/dylibso/chicory/function/processor/FunctionProcessor.java index 088fa2741..aaf34130b 100644 --- a/function-processor/src/main/java/com/dylibso/chicory/function/processor/FunctionProcessor.java +++ b/function-processor/src/main/java/com/dylibso/chicory/function/processor/FunctionProcessor.java @@ -273,7 +273,6 @@ private Expression processMethod( .addArgument(new StringLiteralExpr(name)) .addArgument(new MethodCallExpr(new NameExpr("List"), "of", paramTypes)) .addArgument(new MethodCallExpr(new NameExpr("List"), "of", returnType)); - // TODO: update javaparser and replace with multiline formatting function.setLineComment(""); return function; } From 0b2ef27690d8b2bfdf220ab95529c910d3bebd6c Mon Sep 17 00:00:00 2001 From: David Phillips Date: Sat, 7 Sep 2024 17:15:01 -0700 Subject: [PATCH 2/5] Rename FunctionProcessor to HostFunctionProcessor --- ...rocessor.java => HostModuleProcessor.java} | 2 +- .../javax.annotation.processing.Processor | 2 +- ...Test.java => HostModuleProcessorTest.java} | 43 +++++++++++-------- .../test/resources/{ => host}/BasicMath.java | 0 .../{ => host}/BasicMathGenerated.java | 2 +- .../src/test/resources/{ => host}/Box.java | 0 .../{ => host}/InvalidParameterString.java | 0 .../InvalidParameterUnsupported.java | 0 .../resources/{ => host}/InvalidReturn.java | 0 .../resources/{ => host}/NestedGenerated.java | 2 +- .../test/resources/{ => host}/NoPackage.java | 0 .../{ => host}/NoPackageGenerated.java | 2 +- .../src/test/resources/{ => host}/Simple.java | 0 .../resources/{ => host}/SimpleGenerated.java | 2 +- 14 files changed, 30 insertions(+), 25 deletions(-) rename function-processor/src/main/java/com/dylibso/chicory/function/processor/{FunctionProcessor.java => HostModuleProcessor.java} (99%) rename function-processor/src/test/java/com/dylibso/chicory/function/processor/{FunctionProcessorTest.java => HostModuleProcessorTest.java} (55%) rename function-processor/src/test/resources/{ => host}/BasicMath.java (100%) rename function-processor/src/test/resources/{ => host}/BasicMathGenerated.java (96%) rename function-processor/src/test/resources/{ => host}/Box.java (100%) rename function-processor/src/test/resources/{ => host}/InvalidParameterString.java (100%) rename function-processor/src/test/resources/{ => host}/InvalidParameterUnsupported.java (100%) rename function-processor/src/test/resources/{ => host}/InvalidReturn.java (100%) rename function-processor/src/test/resources/{ => host}/NestedGenerated.java (94%) rename function-processor/src/test/resources/{ => host}/NoPackage.java (100%) rename function-processor/src/test/resources/{ => host}/NoPackageGenerated.java (94%) rename function-processor/src/test/resources/{ => host}/Simple.java (100%) rename function-processor/src/test/resources/{ => host}/SimpleGenerated.java (96%) diff --git a/function-processor/src/main/java/com/dylibso/chicory/function/processor/FunctionProcessor.java b/function-processor/src/main/java/com/dylibso/chicory/function/processor/HostModuleProcessor.java similarity index 99% rename from function-processor/src/main/java/com/dylibso/chicory/function/processor/FunctionProcessor.java rename to function-processor/src/main/java/com/dylibso/chicory/function/processor/HostModuleProcessor.java index aaf34130b..9d3a268d7 100644 --- a/function-processor/src/main/java/com/dylibso/chicory/function/processor/FunctionProcessor.java +++ b/function-processor/src/main/java/com/dylibso/chicory/function/processor/HostModuleProcessor.java @@ -55,7 +55,7 @@ import javax.lang.model.util.Elements; import javax.tools.Diagnostic; -public final class FunctionProcessor extends AbstractProcessor { +public final class HostModuleProcessor extends AbstractProcessor { @Override public SourceVersion getSupportedSourceVersion() { diff --git a/function-processor/src/main/resources/META-INF/services/javax.annotation.processing.Processor b/function-processor/src/main/resources/META-INF/services/javax.annotation.processing.Processor index 21917dfa0..e57d518bf 100644 --- a/function-processor/src/main/resources/META-INF/services/javax.annotation.processing.Processor +++ b/function-processor/src/main/resources/META-INF/services/javax.annotation.processing.Processor @@ -1 +1 @@ -com.dylibso.chicory.function.processor.FunctionProcessor +com.dylibso.chicory.function.processor.HostModuleProcessor diff --git a/function-processor/src/test/java/com/dylibso/chicory/function/processor/FunctionProcessorTest.java b/function-processor/src/test/java/com/dylibso/chicory/function/processor/HostModuleProcessorTest.java similarity index 55% rename from function-processor/src/test/java/com/dylibso/chicory/function/processor/FunctionProcessorTest.java rename to function-processor/src/test/java/com/dylibso/chicory/function/processor/HostModuleProcessorTest.java index efc4e7cb1..a58c7f5f5 100644 --- a/function-processor/src/test/java/com/dylibso/chicory/function/processor/FunctionProcessorTest.java +++ b/function-processor/src/test/java/com/dylibso/chicory/function/processor/HostModuleProcessorTest.java @@ -5,78 +5,83 @@ import com.google.testing.compile.Compilation; import com.google.testing.compile.JavaFileObjects; +import javax.tools.JavaFileObject; import org.junit.jupiter.api.Test; -class FunctionProcessorTest { +class HostModuleProcessorTest { @Test void generateModules() { Compilation compilation = - javac().withProcessors(new FunctionProcessor()) + javac().withProcessors(new HostModuleProcessor()) .compile( - JavaFileObjects.forResource("BasicMath.java"), - JavaFileObjects.forResource("Box.java"), - JavaFileObjects.forResource("NoPackage.java"), - JavaFileObjects.forResource("Simple.java")); + resource("host/BasicMath.java"), + resource("host/Box.java"), + resource("host/NoPackage.java"), + resource("host/Simple.java")); assertThat(compilation).succeededWithoutWarnings(); assertThat(compilation) .generatedSourceFile("chicory.testing.BasicMath_ModuleFactory") - .hasSourceEquivalentTo(JavaFileObjects.forResource("BasicMathGenerated.java")); + .hasSourceEquivalentTo(resource("host/BasicMathGenerated.java")); assertThat(compilation) .generatedSourceFile("chicory.testing.Simple_ModuleFactory") - .hasSourceEquivalentTo(JavaFileObjects.forResource("SimpleGenerated.java")); + .hasSourceEquivalentTo(resource("host/SimpleGenerated.java")); assertThat(compilation) .generatedSourceFile("chicory.testing.Nested_ModuleFactory") - .hasSourceEquivalentTo(JavaFileObjects.forResource("NestedGenerated.java")); + .hasSourceEquivalentTo(resource("host/NestedGenerated.java")); assertThat(compilation) .generatedSourceFile("NoPackage_ModuleFactory") - .hasSourceEquivalentTo(JavaFileObjects.forResource("NoPackageGenerated.java")); + .hasSourceEquivalentTo(resource("host/NoPackageGenerated.java")); } @Test void invalidParameterTypeUnsupported() { Compilation compilation = - javac().withProcessors(new FunctionProcessor()) - .compile(JavaFileObjects.forResource("InvalidParameterUnsupported.java")); + javac().withProcessors(new HostModuleProcessor()) + .compile(resource("host/InvalidParameterUnsupported.java")); assertThat(compilation).failed(); assertThat(compilation) .hadErrorContaining("Unsupported WASM type: java.math.BigDecimal") - .inFile(JavaFileObjects.forResource("InvalidParameterUnsupported.java")) + .inFile(resource("host/InvalidParameterUnsupported.java")) .onLineContaining("public long square(BigDecimal x) {"); } @Test void invalidParameterTypeString() { Compilation compilation = - javac().withProcessors(new FunctionProcessor()) - .compile(JavaFileObjects.forResource("InvalidParameterString.java")); + javac().withProcessors(new HostModuleProcessor()) + .compile(resource("host/InvalidParameterString.java")); assertThat(compilation).failed(); assertThat(compilation) .hadErrorContaining("Missing annotation for WASM type: java.lang.String") - .inFile(JavaFileObjects.forResource("InvalidParameterString.java")) + .inFile(resource("host/InvalidParameterString.java")) .onLineContaining("public long concat(int a, String s) {"); } @Test void invalidReturnType() { Compilation compilation = - javac().withProcessors(new FunctionProcessor()) - .compile(JavaFileObjects.forResource("InvalidReturn.java")); + javac().withProcessors(new HostModuleProcessor()) + .compile(resource("host/InvalidReturn.java")); assertThat(compilation).failed(); assertThat(compilation) .hadErrorContaining("Unsupported WASM type: java.lang.String") - .inFile(JavaFileObjects.forResource("InvalidReturn.java")) + .inFile(resource("host/InvalidReturn.java")) .onLineContaining("public String toString(int x) {"); } + + private static JavaFileObject resource(String resource) { + return JavaFileObjects.forResource(resource); + } } diff --git a/function-processor/src/test/resources/BasicMath.java b/function-processor/src/test/resources/host/BasicMath.java similarity index 100% rename from function-processor/src/test/resources/BasicMath.java rename to function-processor/src/test/resources/host/BasicMath.java diff --git a/function-processor/src/test/resources/BasicMathGenerated.java b/function-processor/src/test/resources/host/BasicMathGenerated.java similarity index 96% rename from function-processor/src/test/resources/BasicMathGenerated.java rename to function-processor/src/test/resources/host/BasicMathGenerated.java index 86a8cb775..5e16bbf89 100644 --- a/function-processor/src/test/resources/BasicMathGenerated.java +++ b/function-processor/src/test/resources/host/BasicMathGenerated.java @@ -7,7 +7,7 @@ import java.util.List; import javax.annotation.processing.Generated; -@Generated("com.dylibso.chicory.function.processor.FunctionProcessor") +@Generated("com.dylibso.chicory.function.processor.HostModuleProcessor") public final class BasicMath_ModuleFactory { private BasicMath_ModuleFactory() {} diff --git a/function-processor/src/test/resources/Box.java b/function-processor/src/test/resources/host/Box.java similarity index 100% rename from function-processor/src/test/resources/Box.java rename to function-processor/src/test/resources/host/Box.java diff --git a/function-processor/src/test/resources/InvalidParameterString.java b/function-processor/src/test/resources/host/InvalidParameterString.java similarity index 100% rename from function-processor/src/test/resources/InvalidParameterString.java rename to function-processor/src/test/resources/host/InvalidParameterString.java diff --git a/function-processor/src/test/resources/InvalidParameterUnsupported.java b/function-processor/src/test/resources/host/InvalidParameterUnsupported.java similarity index 100% rename from function-processor/src/test/resources/InvalidParameterUnsupported.java rename to function-processor/src/test/resources/host/InvalidParameterUnsupported.java diff --git a/function-processor/src/test/resources/InvalidReturn.java b/function-processor/src/test/resources/host/InvalidReturn.java similarity index 100% rename from function-processor/src/test/resources/InvalidReturn.java rename to function-processor/src/test/resources/host/InvalidReturn.java diff --git a/function-processor/src/test/resources/NestedGenerated.java b/function-processor/src/test/resources/host/NestedGenerated.java similarity index 94% rename from function-processor/src/test/resources/NestedGenerated.java rename to function-processor/src/test/resources/host/NestedGenerated.java index bb0804441..dd9b87f7e 100644 --- a/function-processor/src/test/resources/NestedGenerated.java +++ b/function-processor/src/test/resources/host/NestedGenerated.java @@ -8,7 +8,7 @@ import java.util.List; import javax.annotation.processing.Generated; -@Generated("com.dylibso.chicory.function.processor.FunctionProcessor") +@Generated("com.dylibso.chicory.function.processor.HostModuleProcessor") public final class Nested_ModuleFactory { private Nested_ModuleFactory() {} diff --git a/function-processor/src/test/resources/NoPackage.java b/function-processor/src/test/resources/host/NoPackage.java similarity index 100% rename from function-processor/src/test/resources/NoPackage.java rename to function-processor/src/test/resources/host/NoPackage.java diff --git a/function-processor/src/test/resources/NoPackageGenerated.java b/function-processor/src/test/resources/host/NoPackageGenerated.java similarity index 94% rename from function-processor/src/test/resources/NoPackageGenerated.java rename to function-processor/src/test/resources/host/NoPackageGenerated.java index 1539a9f8d..45cfdb84d 100644 --- a/function-processor/src/test/resources/NoPackageGenerated.java +++ b/function-processor/src/test/resources/host/NoPackageGenerated.java @@ -5,7 +5,7 @@ import java.util.List; import javax.annotation.processing.Generated; -@Generated("com.dylibso.chicory.function.processor.FunctionProcessor") +@Generated("com.dylibso.chicory.function.processor.HostModuleProcessor") public final class NoPackage_ModuleFactory { private Nested_ModuleFactory() {} diff --git a/function-processor/src/test/resources/Simple.java b/function-processor/src/test/resources/host/Simple.java similarity index 100% rename from function-processor/src/test/resources/Simple.java rename to function-processor/src/test/resources/host/Simple.java diff --git a/function-processor/src/test/resources/SimpleGenerated.java b/function-processor/src/test/resources/host/SimpleGenerated.java similarity index 96% rename from function-processor/src/test/resources/SimpleGenerated.java rename to function-processor/src/test/resources/host/SimpleGenerated.java index c80bd824f..7ceddb34f 100644 --- a/function-processor/src/test/resources/SimpleGenerated.java +++ b/function-processor/src/test/resources/host/SimpleGenerated.java @@ -7,7 +7,7 @@ import java.util.List; import javax.annotation.processing.Generated; -@Generated("com.dylibso.chicory.function.processor.FunctionProcessor") +@Generated("com.dylibso.chicory.function.processor.HostModuleProcessor") public final class Simple_ModuleFactory { private Simple_ModuleFactory() {} From 16bbdfd8305cf05bc0d390e431c5b51c953623bc Mon Sep 17 00:00:00 2001 From: David Phillips Date: Sun, 8 Sep 2024 11:46:40 -0700 Subject: [PATCH 3/5] Extract AbstractModuleProcessor --- .../processor/AbstractModuleProcessor.java | 140 ++++++++++++++++++ .../processor/HostModuleProcessor.java | 120 ++------------- 2 files changed, 149 insertions(+), 111 deletions(-) create mode 100644 function-processor/src/main/java/com/dylibso/chicory/function/processor/AbstractModuleProcessor.java diff --git a/function-processor/src/main/java/com/dylibso/chicory/function/processor/AbstractModuleProcessor.java b/function-processor/src/main/java/com/dylibso/chicory/function/processor/AbstractModuleProcessor.java new file mode 100644 index 000000000..6d804acfb --- /dev/null +++ b/function-processor/src/main/java/com/dylibso/chicory/function/processor/AbstractModuleProcessor.java @@ -0,0 +1,140 @@ +package com.dylibso.chicory.function.processor; + +import static com.github.javaparser.printer.configuration.DefaultPrinterConfiguration.ConfigOption.COLUMN_ALIGN_PARAMETERS; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static javax.tools.Diagnostic.Kind.ERROR; +import static javax.tools.Diagnostic.Kind.NOTE; + +import com.github.javaparser.ast.CompilationUnit; +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration; +import com.github.javaparser.ast.expr.Expression; +import com.github.javaparser.ast.expr.FieldAccessExpr; +import com.github.javaparser.ast.expr.NameExpr; +import com.github.javaparser.ast.expr.StringLiteralExpr; +import com.github.javaparser.printer.DefaultPrettyPrinter; +import com.github.javaparser.printer.configuration.DefaultConfigurationOption; +import com.github.javaparser.printer.configuration.DefaultPrinterConfiguration; +import java.io.IOException; +import java.io.Writer; +import java.lang.annotation.Annotation; +import java.util.Locale; +import java.util.Set; +import javax.annotation.processing.AbstractProcessor; +import javax.annotation.processing.Filer; +import javax.annotation.processing.Generated; +import javax.annotation.processing.RoundEnvironment; +import javax.lang.model.SourceVersion; +import javax.lang.model.element.AnnotationMirror; +import javax.lang.model.element.Element; +import javax.lang.model.element.ElementKind; +import javax.lang.model.element.PackageElement; +import javax.lang.model.element.TypeElement; +import javax.lang.model.type.TypeMirror; +import javax.lang.model.util.Elements; +import javax.tools.Diagnostic; + +abstract class AbstractModuleProcessor extends AbstractProcessor { + private final Class annotation; + + protected AbstractModuleProcessor(Class annotation) { + this.annotation = requireNonNull(annotation); + } + + @Override + public SourceVersion getSupportedSourceVersion() { + return SourceVersion.latestSupported(); + } + + @Override + public Set getSupportedAnnotationTypes() { + return Set.of(annotation.getName()); + } + + @Override + public boolean process(Set annotations, RoundEnvironment round) { + for (Element element : round.getElementsAnnotatedWith(annotation)) { + log(NOTE, "Generating module factory for " + element, null); + try { + processModule((TypeElement) element); + } catch (AbortProcessingException e) { + // skip type + } + } + + return false; + } + + protected abstract void processModule(TypeElement type); + + protected Elements elements() { + return processingEnv.getElementUtils(); + } + + protected Filer filer() { + return processingEnv.getFiler(); + } + + protected void log(Diagnostic.Kind kind, String message, Element element) { + processingEnv.getMessager().printMessage(kind, message, element); + } + + protected void addGeneratedAnnotation(ClassOrInterfaceDeclaration classDef) { + var processorName = new StringLiteralExpr(getClass().getName()); + classDef.addSingleMemberAnnotation(Generated.class, processorName); + } + + protected void writeSourceFile( + CompilationUnit cu, PackageElement pkg, TypeElement type, String suffix) { + var prefix = (pkg.isUnnamed()) ? "" : pkg.getQualifiedName().toString() + "."; + var qualifiedName = prefix + type.getSimpleName() + suffix; + try (Writer writer = filer().createSourceFile(qualifiedName, type).openWriter()) { + writer.write(cu.printer(printer()).toString()); + } catch (IOException e) { + log(ERROR, format("Failed to create %s file: %s", qualifiedName, e), null); + } + } + + protected static DefaultPrettyPrinter printer() { + return new DefaultPrettyPrinter( + new DefaultPrinterConfiguration() + .addOption(new DefaultConfigurationOption(COLUMN_ALIGN_PARAMETERS, true))); + } + + protected static PackageElement getPackageName(Element element) { + Element enclosing = element; + while (enclosing.getKind() != ElementKind.PACKAGE) { + enclosing = enclosing.getEnclosingElement(); + } + return (PackageElement) enclosing; + } + + protected static boolean annotatedWith( + Element element, Class annotation) { + var annotationName = annotation.getName(); + return element.getAnnotationMirrors().stream() + .map(AnnotationMirror::getAnnotationType) + .map(TypeMirror::toString) + .anyMatch(annotationName::equals); + } + + protected static CompilationUnit createCompilationUnit(PackageElement pkg, TypeElement type) { + var packageName = pkg.getQualifiedName().toString(); + var cu = (pkg.isUnnamed()) ? new CompilationUnit() : new CompilationUnit(packageName); + if (!pkg.isUnnamed()) { + cu.setPackageDeclaration(packageName); + cu.addImport(type.getQualifiedName().toString()); + } + return cu; + } + + protected static Expression valueType(String type) { + return new FieldAccessExpr(new NameExpr("ValueType"), type); + } + + protected static String camelCaseToSnakeCase(String name) { + return name.replaceAll("([a-z])([A-Z]+)", "$1_$2").toLowerCase(Locale.ROOT); + } + + protected static final class AbortProcessingException extends RuntimeException {} +} diff --git a/function-processor/src/main/java/com/dylibso/chicory/function/processor/HostModuleProcessor.java b/function-processor/src/main/java/com/dylibso/chicory/function/processor/HostModuleProcessor.java index 9d3a268d7..5980602ba 100644 --- a/function-processor/src/main/java/com/dylibso/chicory/function/processor/HostModuleProcessor.java +++ b/function-processor/src/main/java/com/dylibso/chicory/function/processor/HostModuleProcessor.java @@ -1,17 +1,13 @@ package com.dylibso.chicory.function.processor; import static com.github.javaparser.StaticJavaParser.parseType; -import static com.github.javaparser.printer.configuration.DefaultPrinterConfiguration.ConfigOption.COLUMN_ALIGN_PARAMETERS; -import static java.lang.String.format; import static javax.tools.Diagnostic.Kind.ERROR; -import static javax.tools.Diagnostic.Kind.NOTE; import com.dylibso.chicory.function.annotations.Buffer; import com.dylibso.chicory.function.annotations.CString; import com.dylibso.chicory.function.annotations.HostModule; import com.dylibso.chicory.function.annotations.WasmExport; import com.github.javaparser.ast.ArrayCreationLevel; -import com.github.javaparser.ast.CompilationUnit; import com.github.javaparser.ast.NodeList; import com.github.javaparser.ast.body.Parameter; import com.github.javaparser.ast.body.VariableDeclarator; @@ -19,7 +15,6 @@ import com.github.javaparser.ast.expr.ArrayCreationExpr; import com.github.javaparser.ast.expr.ArrayInitializerExpr; import com.github.javaparser.ast.expr.Expression; -import com.github.javaparser.ast.expr.FieldAccessExpr; import com.github.javaparser.ast.expr.IntegerLiteralExpr; import com.github.javaparser.ast.expr.LambdaExpr; import com.github.javaparser.ast.expr.MethodCallExpr; @@ -31,57 +26,18 @@ import com.github.javaparser.ast.stmt.BlockStmt; import com.github.javaparser.ast.stmt.ExpressionStmt; import com.github.javaparser.ast.stmt.ReturnStmt; -import com.github.javaparser.printer.DefaultPrettyPrinter; -import com.github.javaparser.printer.configuration.DefaultConfigurationOption; -import com.github.javaparser.printer.configuration.DefaultPrinterConfiguration; -import java.io.IOException; -import java.io.Writer; -import java.lang.annotation.Annotation; -import java.util.Locale; -import java.util.Set; -import javax.annotation.processing.AbstractProcessor; -import javax.annotation.processing.Filer; -import javax.annotation.processing.Generated; -import javax.annotation.processing.RoundEnvironment; -import javax.lang.model.SourceVersion; -import javax.lang.model.element.AnnotationMirror; import javax.lang.model.element.Element; -import javax.lang.model.element.ElementKind; import javax.lang.model.element.ExecutableElement; -import javax.lang.model.element.PackageElement; import javax.lang.model.element.TypeElement; import javax.lang.model.element.VariableElement; -import javax.lang.model.type.TypeMirror; -import javax.lang.model.util.Elements; -import javax.tools.Diagnostic; -public final class HostModuleProcessor extends AbstractProcessor { - - @Override - public SourceVersion getSupportedSourceVersion() { - return SourceVersion.latestSupported(); - } - - @Override - public Set getSupportedAnnotationTypes() { - return Set.of(HostModule.class.getName()); +public final class HostModuleProcessor extends AbstractModuleProcessor { + public HostModuleProcessor() { + super(HostModule.class); } @Override - public boolean process(Set annotations, RoundEnvironment round) { - for (Element element : round.getElementsAnnotatedWith(HostModule.class)) { - log(NOTE, "Generating module factory for " + element, null); - try { - processModule((TypeElement) element); - } catch (AbortProcessingException e) { - // skip type - } - } - - return false; - } - - private void processModule(TypeElement type) { + protected void processModule(TypeElement type) { var moduleName = type.getAnnotation(HostModule.class).value(); var functions = new NodeList(); @@ -90,13 +46,9 @@ private void processModule(TypeElement type) { functions.add(processMethod(member, (ExecutableElement) member, moduleName)); } } + var pkg = getPackageName(type); - var packageName = pkg.getQualifiedName().toString(); - var cu = (pkg.isUnnamed()) ? new CompilationUnit() : new CompilationUnit(packageName); - if (!pkg.isUnnamed()) { - cu.setPackageDeclaration(packageName); - cu.addImport(type.getQualifiedName().toString()); - } + var cu = createCompilationUnit(pkg, type); cu.addImport("com.dylibso.chicory.runtime.HostFunction"); cu.addImport("com.dylibso.chicory.runtime.Instance"); cu.addImport("com.dylibso.chicory.wasm.types.Value"); @@ -104,12 +56,8 @@ private void processModule(TypeElement type) { cu.addImport("java.util.List"); var typeName = type.getSimpleName().toString(); - var processorName = new StringLiteralExpr(getClass().getName()); - var classDef = - cu.addClass(typeName + "_ModuleFactory") - .setPublic(true) - .setFinal(true) - .addSingleMemberAnnotation(Generated.class, processorName); + var classDef = cu.addClass(typeName + "_ModuleFactory").setPublic(true).setFinal(true); + addGeneratedAnnotation(classDef); classDef.addConstructor().setPrivate(true); @@ -126,13 +74,7 @@ private void processModule(TypeElement type) { .setType("HostFunction[]") .setBody(new BlockStmt(new NodeList<>(new ReturnStmt(newHostFunctions)))); - String prefix = (pkg.isUnnamed()) ? "" : packageName + "."; - String qualifiedName = prefix + type.getSimpleName() + "_ModuleFactory"; - try (Writer writer = filer().createSourceFile(qualifiedName, type).openWriter()) { - writer.write(cu.printer(printer()).toString()); - } catch (IOException e) { - log(ERROR, format("Failed to create %s file: %s", qualifiedName, e), null); - } + writeSourceFile(cu, pkg, type, "_ModuleFactory"); } private Expression processMethod( @@ -277,51 +219,7 @@ private Expression processMethod( return function; } - private Elements elements() { - return processingEnv.getElementUtils(); - } - - private Filer filer() { - return processingEnv.getFiler(); - } - - private void log(Diagnostic.Kind kind, String message, Element element) { - processingEnv.getMessager().printMessage(kind, message, element); - } - - private static PackageElement getPackageName(Element element) { - Element enclosing = element; - while (enclosing.getKind() != ElementKind.PACKAGE) { - enclosing = enclosing.getEnclosingElement(); - } - return (PackageElement) enclosing; - } - - private static boolean annotatedWith(Element element, Class annotation) { - var annotationName = annotation.getName(); - return element.getAnnotationMirrors().stream() - .map(AnnotationMirror::getAnnotationType) - .map(TypeMirror::toString) - .anyMatch(annotationName::equals); - } - private static Expression argExpr(int n) { return new ArrayAccessExpr(new NameExpr("args"), new IntegerLiteralExpr(String.valueOf(n))); } - - private static Expression valueType(String type) { - return new FieldAccessExpr(new NameExpr("ValueType"), type); - } - - private static String camelCaseToSnakeCase(String name) { - return name.replaceAll("([a-z])([A-Z]+)", "$1_$2").toLowerCase(Locale.ROOT); - } - - private static DefaultPrettyPrinter printer() { - return new DefaultPrettyPrinter( - new DefaultPrinterConfiguration() - .addOption(new DefaultConfigurationOption(COLUMN_ALIGN_PARAMETERS, true))); - } - - private static final class AbortProcessingException extends RuntimeException {} } From a555b0dc84a06b4227899f3c0fa55a5fb25382ec Mon Sep 17 00:00:00 2001 From: David Phillips Date: Sun, 8 Sep 2024 13:10:19 -0700 Subject: [PATCH 4/5] Remove unnecessary parameter in HostModuleProcessor --- .../chicory/function/processor/HostModuleProcessor.java | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/function-processor/src/main/java/com/dylibso/chicory/function/processor/HostModuleProcessor.java b/function-processor/src/main/java/com/dylibso/chicory/function/processor/HostModuleProcessor.java index 5980602ba..4fe473093 100644 --- a/function-processor/src/main/java/com/dylibso/chicory/function/processor/HostModuleProcessor.java +++ b/function-processor/src/main/java/com/dylibso/chicory/function/processor/HostModuleProcessor.java @@ -43,7 +43,7 @@ protected void processModule(TypeElement type) { var functions = new NodeList(); for (Element member : elements().getAllMembers(type)) { if (member instanceof ExecutableElement && annotatedWith(member, WasmExport.class)) { - functions.add(processMethod(member, (ExecutableElement) member, moduleName)); + functions.add(processMethod((ExecutableElement) member, moduleName)); } } @@ -77,8 +77,7 @@ protected void processModule(TypeElement type) { writeSourceFile(cu, pkg, type, "_ModuleFactory"); } - private Expression processMethod( - Element member, ExecutableElement executable, String moduleName) { + private Expression processMethod(ExecutableElement executable, String moduleName) { // compute function name var name = executable.getAnnotation(WasmExport.class).value(); if (name.isEmpty()) { @@ -175,7 +174,9 @@ private Expression processMethod( // function invocation Expression invocation = new MethodCallExpr( - new NameExpr("functions"), member.getSimpleName().toString(), arguments); + new NameExpr("functions"), + executable.getSimpleName().toString(), + arguments); // convert return value BlockStmt handleBody = new BlockStmt(); From dd122b22230fabb2fea0cc0c3f550a061b114d4b Mon Sep 17 00:00:00 2001 From: David Phillips Date: Sun, 8 Sep 2024 22:22:10 -0700 Subject: [PATCH 5/5] Add annotation framework for Wasm interfaces --- .../function/annotations/Allocate.java | 10 + .../chicory/function/annotations/Free.java | 10 + .../function/annotations/WasmModule.java | 10 + .../processor/WasmModuleProcessor.java | 452 ++++++++++++++++++ .../javax.annotation.processing.Processor | 1 + .../processor/WasmModuleProcessorTest.java | 29 ++ .../src/test/resources/wasm/Demo.java | 35 ++ .../test/resources/wasm/DemoGenerated.java | 107 +++++ 8 files changed, 654 insertions(+) create mode 100644 function-annotations/src/main/java/com/dylibso/chicory/function/annotations/Allocate.java create mode 100644 function-annotations/src/main/java/com/dylibso/chicory/function/annotations/Free.java create mode 100644 function-annotations/src/main/java/com/dylibso/chicory/function/annotations/WasmModule.java create mode 100644 function-processor/src/main/java/com/dylibso/chicory/function/processor/WasmModuleProcessor.java create mode 100644 function-processor/src/test/java/com/dylibso/chicory/function/processor/WasmModuleProcessorTest.java create mode 100644 function-processor/src/test/resources/wasm/Demo.java create mode 100644 function-processor/src/test/resources/wasm/DemoGenerated.java diff --git a/function-annotations/src/main/java/com/dylibso/chicory/function/annotations/Allocate.java b/function-annotations/src/main/java/com/dylibso/chicory/function/annotations/Allocate.java new file mode 100644 index 000000000..289801ecf --- /dev/null +++ b/function-annotations/src/main/java/com/dylibso/chicory/function/annotations/Allocate.java @@ -0,0 +1,10 @@ +package com.dylibso.chicory.function.annotations; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface Allocate {} diff --git a/function-annotations/src/main/java/com/dylibso/chicory/function/annotations/Free.java b/function-annotations/src/main/java/com/dylibso/chicory/function/annotations/Free.java new file mode 100644 index 000000000..6682fadb3 --- /dev/null +++ b/function-annotations/src/main/java/com/dylibso/chicory/function/annotations/Free.java @@ -0,0 +1,10 @@ +package com.dylibso.chicory.function.annotations; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface Free {} diff --git a/function-annotations/src/main/java/com/dylibso/chicory/function/annotations/WasmModule.java b/function-annotations/src/main/java/com/dylibso/chicory/function/annotations/WasmModule.java new file mode 100644 index 000000000..f18b23781 --- /dev/null +++ b/function-annotations/src/main/java/com/dylibso/chicory/function/annotations/WasmModule.java @@ -0,0 +1,10 @@ +package com.dylibso.chicory.function.annotations; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface WasmModule {} diff --git a/function-processor/src/main/java/com/dylibso/chicory/function/processor/WasmModuleProcessor.java b/function-processor/src/main/java/com/dylibso/chicory/function/processor/WasmModuleProcessor.java new file mode 100644 index 000000000..f8a99d064 --- /dev/null +++ b/function-processor/src/main/java/com/dylibso/chicory/function/processor/WasmModuleProcessor.java @@ -0,0 +1,452 @@ +package com.dylibso.chicory.function.processor; + +import static com.github.javaparser.StaticJavaParser.parseMethodDeclaration; +import static com.github.javaparser.StaticJavaParser.parseType; +import static java.util.stream.Collectors.toCollection; +import static java.util.stream.Collectors.toList; +import static javax.tools.Diagnostic.Kind.ERROR; + +import com.dylibso.chicory.function.annotations.Allocate; +import com.dylibso.chicory.function.annotations.Buffer; +import com.dylibso.chicory.function.annotations.CString; +import com.dylibso.chicory.function.annotations.Free; +import com.dylibso.chicory.function.annotations.WasmExport; +import com.dylibso.chicory.function.annotations.WasmModule; +import com.github.javaparser.ast.NodeList; +import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration; +import com.github.javaparser.ast.body.MethodDeclaration; +import com.github.javaparser.ast.body.Parameter; +import com.github.javaparser.ast.body.VariableDeclarator; +import com.github.javaparser.ast.expr.ArrayAccessExpr; +import com.github.javaparser.ast.expr.ArrayCreationExpr; +import com.github.javaparser.ast.expr.ArrayInitializerExpr; +import com.github.javaparser.ast.expr.AssignExpr; +import com.github.javaparser.ast.expr.BinaryExpr; +import com.github.javaparser.ast.expr.CastExpr; +import com.github.javaparser.ast.expr.Expression; +import com.github.javaparser.ast.expr.FieldAccessExpr; +import com.github.javaparser.ast.expr.IntegerLiteralExpr; +import com.github.javaparser.ast.expr.MethodCallExpr; +import com.github.javaparser.ast.expr.NameExpr; +import com.github.javaparser.ast.expr.ObjectCreationExpr; +import com.github.javaparser.ast.expr.StringLiteralExpr; +import com.github.javaparser.ast.expr.VariableDeclarationExpr; +import com.github.javaparser.ast.stmt.BlockStmt; +import com.github.javaparser.ast.stmt.ExpressionStmt; +import com.github.javaparser.ast.stmt.ReturnStmt; +import com.github.javaparser.ast.stmt.Statement; +import com.github.javaparser.ast.type.Type; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import javax.lang.model.element.ExecutableElement; +import javax.lang.model.element.TypeElement; +import javax.lang.model.element.VariableElement; + +public final class WasmModuleProcessor extends AbstractModuleProcessor { + public WasmModuleProcessor() { + super(WasmModule.class); + } + + @Override + protected void processModule(TypeElement type) { + var pkg = getPackageName(type); + var cu = createCompilationUnit(pkg, type); + + cu.addImport("com.dylibso.chicory.runtime.ExportFunction"); + cu.addImport("com.dylibso.chicory.runtime.Instance"); + cu.addImport("com.dylibso.chicory.runtime.Memory"); + cu.addImport("com.dylibso.chicory.wasm.types.FunctionType"); + cu.addImport("com.dylibso.chicory.wasm.types.Value"); + cu.addImport("com.dylibso.chicory.wasm.types.ValueType"); + cu.addImport("java.nio.charset.StandardCharsets"); + cu.addImport("java.util.List"); + + var typeName = type.getSimpleName().toString(); + var moduleFactory = cu.addClass(typeName + "_ModuleFactory").setPublic(true).setFinal(true); + addGeneratedAnnotation(moduleFactory); + + moduleFactory.addConstructor().setPrivate(true); + + // instance factory method + moduleFactory + .addMethod("create") + .setPublic(true) + .setStatic(true) + .addParameter("Instance", "instance") + .setType(typeName) + .setBody( + new BlockStmt() + .addStatement( + new ReturnStmt( + new ObjectCreationExpr() + .setType(typeName + "_Instance") + .addArgument(new NameExpr("instance"))))); + + // nested instance class + var instance = + new ClassOrInterfaceDeclaration() + .setName(typeName + "_Instance") + .setPrivate(true) + .setStatic(true) + .setFinal(true) + .addImplementedType(typeName); + moduleFactory.addMember(instance); + + // declare memory field + instance.addField("Memory", "memory").setPrivate(true).setFinal(true); + + // assign memory field in constructor + var constructorBody = new BlockStmt(); + constructorBody.addStatement( + new ExpressionStmt( + new AssignExpr( + new NameExpr("memory"), + methodCall("instance", "memory"), + AssignExpr.Operator.ASSIGN))); + + // find annotated methods + var methods = + elements().getAllMembers(type).stream() + .filter(member -> member instanceof ExecutableElement) + .filter(member -> annotatedWith(member, WasmExport.class)) + .map(member -> (ExecutableElement) member) + .collect(toList()); + + // find allocate and free functions + Optional allocate = findAllocate(methods); + Optional free = findFree(methods); + + // declare fields and collect methods + var functions = + methods.stream() + .map( + member -> + processMethod( + member, instance, constructorBody, allocate, free)) + .collect(toCollection(NodeList::new)); + + // declare constructor + instance.addConstructor() + .setPublic(true) + .addParameter("Instance", "instance") + .setBody(constructorBody); + + // declare methods + for (MethodDeclaration function : functions) { + instance.addMember(function); + } + + declareCheckTypeMethods(instance); + + writeSourceFile(cu, pkg, type, "_ModuleFactory"); + } + + private MethodDeclaration processMethod( + ExecutableElement executable, + ClassOrInterfaceDeclaration classDef, + BlockStmt constructor, + Optional allocate, + Optional free) { + + // compute function name + var methodName = executable.getSimpleName().toString(); + var wasmName = executable.getAnnotation(WasmExport.class).value(); + if (wasmName.isEmpty()) { + wasmName = camelCaseToSnakeCase(methodName); + } + + // declare field + String fieldName = "_" + methodName; + classDef.addField("ExportFunction", fieldName).setPrivate(true).setFinal(true); + + // assign field in constructor + constructor.addStatement( + new ExpressionStmt( + new AssignExpr( + new NameExpr(fieldName), + methodCall("instance", "export", new StringLiteralExpr(wasmName)), + AssignExpr.Operator.ASSIGN))); + + // compute parameter types and argument conversions + BlockStmt method = new BlockStmt(); + List cleanup = new ArrayList<>(); + var parameters = new NodeList(); + var paramTypes = new NodeList(); + var arguments = new NodeList(); + for (VariableElement parameter : executable.getParameters()) { + var name = "_" + parameter.getSimpleName().toString(); + var nameExpr = new NameExpr(name); + var paramTypeName = parameter.asType().toString(); + switch (parameter.asType().toString()) { + case "int": + paramTypes.add(valueType("I32")); + arguments.add(methodCall("Value", "i32", nameExpr)); + break; + case "long": + paramTypes.add(valueType("I64")); + arguments.add(methodCall("Value", "i64", nameExpr)); + break; + case "float": + paramTypes.add(valueType("F32")); + arguments.add(methodCall("Value", "fromFloat", nameExpr)); + break; + case "double": + paramTypes.add(valueType("F64")); + arguments.add(methodCall("Value", "fromDouble", nameExpr)); + break; + case "java.lang.String": + paramTypeName = "String"; + // validation + boolean buffer = annotatedWith(parameter, Buffer.class); + boolean cstring = annotatedWith(parameter, CString.class); + if (!buffer && !cstring) { + log(ERROR, "Missing annotation for WASM type: java.lang.String", parameter); + throw new AbortProcessingException(); + } + if (allocate.isEmpty()) { + log(ERROR, "No method is annotated with @Allocate", parameter); + throw new AbortProcessingException(); + } + if (free.isEmpty()) { + log(ERROR, "No method is annotated with @Free", parameter); + throw new AbortProcessingException(); + } + // byte[] bytes$name = $name.getBytes(StandardCharsets.UTF_8); + var bytesName = "bytes" + name; + var utf8 = new FieldAccessExpr(new NameExpr("StandardCharsets"), "UTF_8"); + var getBytes = methodCall(name, "getBytes", utf8); + var bytesLength = new FieldAccessExpr(new NameExpr(bytesName), "length"); + method.addStatement(declareVariable(parseType("byte[]"), bytesName, getBytes)); + // int ptr$name = malloc(...); + var ptrName = "ptr" + name; + var ptrNameExpr = new NameExpr(ptrName); + Expression allocateLength = bytesLength; + if (cstring) { + allocateLength = + new BinaryExpr( + bytesLength, + new IntegerLiteralExpr("1"), + BinaryExpr.Operator.PLUS); + } + method.addStatement( + declareVariable( + parseType("int"), + ptrName, + new MethodCallExpr(allocate.get(), allocateLength))); + // memory.write(ptr$name, bytes$name); + method.addStatement( + new ExpressionStmt( + methodCall( + "memory", + "write", + ptrNameExpr, + new NameExpr(bytesName)))); + if (cstring) { + // memory.writeByte(ptr$name + bytes$name.length, (byte) 0); + method.addStatement( + new ExpressionStmt( + methodCall( + "memory", + "writeByte", + new BinaryExpr( + ptrNameExpr, + bytesLength, + BinaryExpr.Operator.PLUS), + new CastExpr( + parseType("byte"), + new IntegerLiteralExpr("0"))))); + } + // free(ptr$name); + cleanup.add(new ExpressionStmt(new MethodCallExpr(free.get(), ptrNameExpr))); + // arguments + paramTypes.add(valueType("I32")); + arguments.add(methodCall("Value", "i32", ptrNameExpr)); + if (buffer) { + paramTypes.add(valueType("I32")); + arguments.add(methodCall("Value", "i32", bytesLength)); + } + break; + default: + log(ERROR, "Unsupported WASM type: " + parameter.asType(), parameter); + throw new AbortProcessingException(); + } + parameters.add(new Parameter(parseType(paramTypeName), name)); + } + + // wrap arguments + var wrappedArgs = + new ArrayCreationExpr(parseType("Value")) + .setInitializer(new ArrayInitializerExpr(arguments)); + method.addStatement(declareVariable(parseType("Value[]"), "args", wrappedArgs)); + + // function invocation + var invoke = methodCall(fieldName, "apply", new NameExpr("args")); + + // compute return type and conversion + String returnTypeName = executable.getReturnType().toString(); + String resultType = returnTypeName; + Expression result; + NodeList returnType; + switch (returnTypeName) { + case "void": + returnType = new NodeList<>(); + result = null; + break; + case "int": + returnType = new NodeList<>(valueType("I32")); + result = new MethodCallExpr(arrayZero(invoke), "asInt"); + break; + case "long": + returnType = new NodeList<>(valueType("I64")); + result = new MethodCallExpr(arrayZero(invoke), "asLong"); + break; + case "float": + returnType = new NodeList<>(valueType("F32")); + result = new MethodCallExpr(arrayZero(invoke), "asFloat"); + break; + case "double": + returnType = new NodeList<>(valueType("F64")); + result = new MethodCallExpr(arrayZero(invoke), "asDouble"); + break; + case "com.dylibso.chicory.wasm.types.Value[]": + returnType = null; + resultType = "Value[]"; + result = invoke; + break; + default: + log(ERROR, "Unsupported WASM type: " + returnTypeName, executable); + throw new AbortProcessingException(); + } + + // assign return value + if (result != null) { + method.addStatement(declareVariable(parseType(resultType), "result", result)); + } else { + method.addStatement(new ExpressionStmt(invoke)); + } + + // cleanup + cleanup.forEach(method::addStatement); + + // return result + if (result != null) { + method.addStatement(new ReturnStmt(new NameExpr("result"))); + } + + // check type in constructor + Expression checkExpected; + if (returnType != null) { + checkExpected = + methodCall( + "FunctionType", + "of", + new MethodCallExpr(new NameExpr("List"), "of", paramTypes), + new MethodCallExpr(new NameExpr("List"), "of", returnType)); + } else { + checkExpected = new MethodCallExpr(new NameExpr("List"), "of", paramTypes); + } + constructor.addStatement( + new ExpressionStmt( + new MethodCallExpr( + "checkType", + new NameExpr("instance"), + new StringLiteralExpr(wasmName), + checkExpected))); + + return new MethodDeclaration() + .setPublic(true) + .setName(methodName) + .addMarkerAnnotation(Override.class) + .setParameters(parameters) + .setType(resultType) + .setBody(method); + } + + private Optional findAllocate(List methods) { + var list = + methods.stream() + .filter(member -> annotatedWith(member, Allocate.class)) + .collect(toList()); + if (list.isEmpty()) { + return Optional.empty(); + } + if (list.size() > 1) { + log(ERROR, "Method with @Allocate previously declared", list.get(1)); + throw new AbortProcessingException(); + } + ExecutableElement method = list.get(0); + if (method.getParameters().toString().equals("[int]")) { + log(ERROR, "Method with @Allocate must have a single 'int' parameter", method); + throw new AbortProcessingException(); + } + if (!method.getReturnType().toString().equals("int")) { + log(ERROR, "Method with @Allocate must return 'int'", method); + throw new AbortProcessingException(); + } + return Optional.of(method.getSimpleName().toString()); + } + + private Optional findFree(List methods) { + var list = + methods.stream() + .filter(member -> annotatedWith(member, Free.class)) + .collect(toList()); + if (list.isEmpty()) { + return Optional.empty(); + } + if (list.size() > 1) { + log(ERROR, "Method with @Free previously declared", list.get(1)); + throw new AbortProcessingException(); + } + ExecutableElement method = list.get(0); + if (method.getParameters().toString().equals("[int]")) { + log(ERROR, "Method with @Free must have a single 'int' parameter", method); + throw new AbortProcessingException(); + } + if (!method.getReturnType().toString().equals("void")) { + log(ERROR, "Method with @Free must return 'void'", method); + throw new AbortProcessingException(); + } + return Optional.of("free"); + } + + private static void declareCheckTypeMethods(ClassOrInterfaceDeclaration classDef) { + classDef.addMember( + parseMethodDeclaration( + "private static void checkType(Instance instance, String name, FunctionType" + + " expected) {\n" + + " checkType(name, expected, instance.exportType(name));\n" + + "}\n")); + + classDef.addMember( + parseMethodDeclaration( + "private static void checkType(Instance instance, String name," + + " List expected) {\n" + + " checkType(name, expected, instance.exportType(name).params());\n" + + "}\n")); + + classDef.addMember( + parseMethodDeclaration( + "private static void checkType(String name, T expected, T actual) {\n" + + " if (!expected.equals(actual)) {\n" + + " throw new IllegalArgumentException(String.format(\n" + + " \"Function type mismatch for '%s': expected %s <=>" + + " actual %s\", name, expected, actual));\n" + + " }\n" + + "}\n")); + } + + private static Expression methodCall(String scope, String name, Expression... arguments) { + return new MethodCallExpr(new NameExpr(scope), name, new NodeList<>(arguments)); + } + + private static Statement declareVariable(Type type, String name, Expression result) { + return new ExpressionStmt( + new VariableDeclarationExpr(new VariableDeclarator(type, name, result))); + } + + private static Expression arrayZero(Expression array) { + return new ArrayAccessExpr(array, new IntegerLiteralExpr("0")); + } +} diff --git a/function-processor/src/main/resources/META-INF/services/javax.annotation.processing.Processor b/function-processor/src/main/resources/META-INF/services/javax.annotation.processing.Processor index e57d518bf..550c7df25 100644 --- a/function-processor/src/main/resources/META-INF/services/javax.annotation.processing.Processor +++ b/function-processor/src/main/resources/META-INF/services/javax.annotation.processing.Processor @@ -1 +1,2 @@ com.dylibso.chicory.function.processor.HostModuleProcessor +com.dylibso.chicory.function.processor.WasmModuleProcessor diff --git a/function-processor/src/test/java/com/dylibso/chicory/function/processor/WasmModuleProcessorTest.java b/function-processor/src/test/java/com/dylibso/chicory/function/processor/WasmModuleProcessorTest.java new file mode 100644 index 000000000..762028337 --- /dev/null +++ b/function-processor/src/test/java/com/dylibso/chicory/function/processor/WasmModuleProcessorTest.java @@ -0,0 +1,29 @@ +package com.dylibso.chicory.function.processor; + +import static com.google.testing.compile.CompilationSubject.assertThat; +import static com.google.testing.compile.Compiler.javac; + +import com.google.testing.compile.Compilation; +import com.google.testing.compile.JavaFileObjects; +import javax.tools.JavaFileObject; +import org.junit.jupiter.api.Test; + +class WasmModuleProcessorTest { + + @Test + void generateModules() { + Compilation compilation = + javac().withProcessors(new WasmModuleProcessor()) + .compile(resource("wasm/Demo.java")); + + assertThat(compilation).succeededWithoutWarnings(); + + assertThat(compilation) + .generatedSourceFile("chicory.testing.Demo_ModuleFactory") + .hasSourceEquivalentTo(resource("wasm/DemoGenerated.java")); + } + + private static JavaFileObject resource(String resource) { + return JavaFileObjects.forResource(resource); + } +} diff --git a/function-processor/src/test/resources/wasm/Demo.java b/function-processor/src/test/resources/wasm/Demo.java new file mode 100644 index 000000000..936d74566 --- /dev/null +++ b/function-processor/src/test/resources/wasm/Demo.java @@ -0,0 +1,35 @@ +package chicory.testing; + +import com.dylibso.chicory.function.annotations.Allocate; +import com.dylibso.chicory.function.annotations.Buffer; +import com.dylibso.chicory.function.annotations.CString; +import com.dylibso.chicory.function.annotations.Free; +import com.dylibso.chicory.function.annotations.WasmExport; +import com.dylibso.chicory.function.annotations.WasmModule; +import com.dylibso.chicory.runtime.Instance; +import com.dylibso.chicory.wasm.types.Value; + +@WasmModule +public interface Demo { + + @Allocate + @WasmExport + int malloc(int size); + + @Free + @WasmExport + void free(int ptr); + + @WasmExport + void print(@CString String data); + + @WasmExport + int length(@Buffer String data); + + @WasmExport + Value[] multiReturn(float x); + + static Demo create(Instance instance) { + return Demo_ModuleFactory.create(instance); + } +} diff --git a/function-processor/src/test/resources/wasm/DemoGenerated.java b/function-processor/src/test/resources/wasm/DemoGenerated.java new file mode 100644 index 000000000..40cc2e3f1 --- /dev/null +++ b/function-processor/src/test/resources/wasm/DemoGenerated.java @@ -0,0 +1,107 @@ +package chicory.testing; + +import com.dylibso.chicory.runtime.ExportFunction; +import com.dylibso.chicory.runtime.Instance; +import com.dylibso.chicory.runtime.Memory; +import com.dylibso.chicory.wasm.types.FunctionType; +import com.dylibso.chicory.wasm.types.Value; +import com.dylibso.chicory.wasm.types.ValueType; +import java.nio.charset.StandardCharsets; +import java.util.List; +import javax.annotation.processing.Generated; + +@Generated("com.dylibso.chicory.function.processor.WasmModuleProcessor") +public final class Demo_ModuleFactory { + + private Demo_ModuleFactory() {} + + public static Demo create(Instance instance) { + return new Demo_Instance(instance); + } + + private static final class Demo_Instance implements Demo { + + private final Memory memory; + private final ExportFunction _malloc; + private final ExportFunction _free; + private final ExportFunction _print; + private final ExportFunction _length; + private final ExportFunction _multiReturn; + + public Demo_Instance(Instance instance) { + memory = instance.memory(); + + _malloc = instance.export("malloc"); + checkType(instance, "malloc", FunctionType.of(List.of(ValueType.I32), List.of(ValueType.I32))); + + _free = instance.export("free"); + checkType(instance, "free", FunctionType.of(List.of(ValueType.I32), List.of())); + + _print = instance.export("print"); + checkType(instance, "print", FunctionType.of(List.of(ValueType.I32), List.of())); + + _length = instance.export("length"); + checkType(instance, "length", FunctionType.of(List.of(ValueType.I32, ValueType.I32), List.of(ValueType.I32))); + + _multiReturn = instance.export("multi_return"); + checkType(instance, "multi_return", List.of(ValueType.F32)); + } + + @Override + public int malloc(int _size) { + Value[] args = new Value[] { Value.i32(_size) }; + int result = _malloc.apply(args)[0].asInt(); + return result; + } + + @Override + public void free(int _ptr) { + Value[] args = new Value[] { Value.i32(_ptr) }; + _free.apply(args); + } + + @Override + public void print(String _data) { + byte[] bytes_data = _data.getBytes(StandardCharsets.UTF_8); + int ptr_data = malloc(bytes_data.length + 1); + memory.write(ptr_data, bytes_data); + memory.writeByte(ptr_data + bytes_data.length, (byte) 0); + Value[] args = new Value[] { Value.i32(ptr_data) }; + _print.apply(args); + free(ptr_data); + } + + @Override + public int length(String _data) { + byte[] bytes_data = _data.getBytes(StandardCharsets.UTF_8); + int ptr_data = malloc(bytes_data.length); + memory.write(ptr_data, bytes_data); + Value[] args = new Value[] { Value.i32(ptr_data), Value.i32(bytes_data.length) }; + int result = _length.apply(args)[0].asInt(); + free(ptr_data); + return result; + } + + @Override + public Value[] multiReturn(float _x) { + Value[] args = new Value[] { Value.fromFloat(_x) }; + Value[] result = _multiReturn.apply(args); + return result; + } + + private static void checkType(Instance instance, String name, FunctionType expected) { + checkType(name, expected, instance.exportType(name)); + } + + private static void checkType(Instance instance, String name, List expected) { + checkType(name, expected, instance.exportType(name).params()); + } + + private static void checkType(String name, T expected, T actual) { + if (!expected.equals(actual)) { + throw new IllegalArgumentException(String.format( + "Function type mismatch for '%s': expected %s <=> actual %s", name, expected, actual)); + } + } + } +}