Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add annotation framework for Wasm interfaces #519

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.dylibso.chicory.function.annotations;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Allocate {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.dylibso.chicory.function.annotations;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Free {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.dylibso.chicory.function.annotations;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface WasmModule {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package com.dylibso.chicory.function.processor;

import static com.github.javaparser.printer.configuration.DefaultPrinterConfiguration.ConfigOption.COLUMN_ALIGN_PARAMETERS;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static javax.tools.Diagnostic.Kind.ERROR;
import static javax.tools.Diagnostic.Kind.NOTE;

import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.FieldAccessExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.StringLiteralExpr;
import com.github.javaparser.printer.DefaultPrettyPrinter;
import com.github.javaparser.printer.configuration.DefaultConfigurationOption;
import com.github.javaparser.printer.configuration.DefaultPrinterConfiguration;
import java.io.IOException;
import java.io.Writer;
import java.lang.annotation.Annotation;
import java.util.Locale;
import java.util.Set;
import javax.annotation.processing.AbstractProcessor;
import javax.annotation.processing.Filer;
import javax.annotation.processing.Generated;
import javax.annotation.processing.RoundEnvironment;
import javax.lang.model.SourceVersion;
import javax.lang.model.element.AnnotationMirror;
import javax.lang.model.element.Element;
import javax.lang.model.element.ElementKind;
import javax.lang.model.element.PackageElement;
import javax.lang.model.element.TypeElement;
import javax.lang.model.type.TypeMirror;
import javax.lang.model.util.Elements;
import javax.tools.Diagnostic;

abstract class AbstractModuleProcessor extends AbstractProcessor {
private final Class<? extends Annotation> annotation;

protected AbstractModuleProcessor(Class<? extends Annotation> annotation) {
this.annotation = requireNonNull(annotation);
}

@Override
public SourceVersion getSupportedSourceVersion() {
return SourceVersion.latestSupported();
}

@Override
public Set<String> getSupportedAnnotationTypes() {
return Set.of(annotation.getName());
}

@Override
public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment round) {
for (Element element : round.getElementsAnnotatedWith(annotation)) {
log(NOTE, "Generating module factory for " + element, null);
try {
processModule((TypeElement) element);
} catch (AbortProcessingException e) {
// skip type
}
}

return false;
}

protected abstract void processModule(TypeElement type);

protected Elements elements() {
return processingEnv.getElementUtils();
}

protected Filer filer() {
return processingEnv.getFiler();
}

protected void log(Diagnostic.Kind kind, String message, Element element) {
processingEnv.getMessager().printMessage(kind, message, element);
}

protected void addGeneratedAnnotation(ClassOrInterfaceDeclaration classDef) {
var processorName = new StringLiteralExpr(getClass().getName());
classDef.addSingleMemberAnnotation(Generated.class, processorName);
}

protected void writeSourceFile(
CompilationUnit cu, PackageElement pkg, TypeElement type, String suffix) {
var prefix = (pkg.isUnnamed()) ? "" : pkg.getQualifiedName().toString() + ".";
var qualifiedName = prefix + type.getSimpleName() + suffix;
try (Writer writer = filer().createSourceFile(qualifiedName, type).openWriter()) {
writer.write(cu.printer(printer()).toString());
} catch (IOException e) {
log(ERROR, format("Failed to create %s file: %s", qualifiedName, e), null);
}
}

protected static DefaultPrettyPrinter printer() {
return new DefaultPrettyPrinter(
new DefaultPrinterConfiguration()
.addOption(new DefaultConfigurationOption(COLUMN_ALIGN_PARAMETERS, true)));
}

protected static PackageElement getPackageName(Element element) {
Element enclosing = element;
while (enclosing.getKind() != ElementKind.PACKAGE) {
enclosing = enclosing.getEnclosingElement();
}
return (PackageElement) enclosing;
}

protected static boolean annotatedWith(
Element element, Class<? extends Annotation> annotation) {
var annotationName = annotation.getName();
return element.getAnnotationMirrors().stream()
.map(AnnotationMirror::getAnnotationType)
.map(TypeMirror::toString)
.anyMatch(annotationName::equals);
}

protected static CompilationUnit createCompilationUnit(PackageElement pkg, TypeElement type) {
var packageName = pkg.getQualifiedName().toString();
var cu = (pkg.isUnnamed()) ? new CompilationUnit() : new CompilationUnit(packageName);
if (!pkg.isUnnamed()) {
cu.setPackageDeclaration(packageName);
cu.addImport(type.getQualifiedName().toString());
}
return cu;
}

protected static Expression valueType(String type) {
return new FieldAccessExpr(new NameExpr("ValueType"), type);
}

protected static String camelCaseToSnakeCase(String name) {
return name.replaceAll("([a-z])([A-Z]+)", "$1_$2").toLowerCase(Locale.ROOT);
}

protected static final class AbortProcessingException extends RuntimeException {}
}
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
package com.dylibso.chicory.function.processor;

import static com.github.javaparser.StaticJavaParser.parseType;
import static com.github.javaparser.printer.configuration.DefaultPrinterConfiguration.ConfigOption.COLUMN_ALIGN_PARAMETERS;
import static java.lang.String.format;
import static javax.tools.Diagnostic.Kind.ERROR;
import static javax.tools.Diagnostic.Kind.NOTE;

import com.dylibso.chicory.function.annotations.Buffer;
import com.dylibso.chicory.function.annotations.CString;
import com.dylibso.chicory.function.annotations.HostModule;
import com.dylibso.chicory.function.annotations.WasmExport;
import com.github.javaparser.ast.ArrayCreationLevel;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.Parameter;
import com.github.javaparser.ast.body.VariableDeclarator;
import com.github.javaparser.ast.expr.ArrayAccessExpr;
import com.github.javaparser.ast.expr.ArrayCreationExpr;
import com.github.javaparser.ast.expr.ArrayInitializerExpr;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.FieldAccessExpr;
import com.github.javaparser.ast.expr.IntegerLiteralExpr;
import com.github.javaparser.ast.expr.LambdaExpr;
import com.github.javaparser.ast.expr.MethodCallExpr;
Expand All @@ -31,85 +26,38 @@
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.stmt.ExpressionStmt;
import com.github.javaparser.ast.stmt.ReturnStmt;
import com.github.javaparser.printer.DefaultPrettyPrinter;
import com.github.javaparser.printer.configuration.DefaultConfigurationOption;
import com.github.javaparser.printer.configuration.DefaultPrinterConfiguration;
import java.io.IOException;
import java.io.Writer;
import java.lang.annotation.Annotation;
import java.util.Locale;
import java.util.Set;
import javax.annotation.processing.AbstractProcessor;
import javax.annotation.processing.Filer;
import javax.annotation.processing.Generated;
import javax.annotation.processing.RoundEnvironment;
import javax.lang.model.SourceVersion;
import javax.lang.model.element.AnnotationMirror;
import javax.lang.model.element.Element;
import javax.lang.model.element.ElementKind;
import javax.lang.model.element.ExecutableElement;
import javax.lang.model.element.PackageElement;
import javax.lang.model.element.TypeElement;
import javax.lang.model.element.VariableElement;
import javax.lang.model.type.TypeMirror;
import javax.lang.model.util.Elements;
import javax.tools.Diagnostic;

public final class FunctionProcessor extends AbstractProcessor {

@Override
public SourceVersion getSupportedSourceVersion() {
return SourceVersion.latestSupported();
}

@Override
public Set<String> getSupportedAnnotationTypes() {
return Set.of(HostModule.class.getName());
public final class HostModuleProcessor extends AbstractModuleProcessor {
public HostModuleProcessor() {
super(HostModule.class);
}

@Override
public boolean process(Set<? extends TypeElement> annotations, RoundEnvironment round) {
for (Element element : round.getElementsAnnotatedWith(HostModule.class)) {
log(NOTE, "Generating module factory for " + element, null);
try {
processModule((TypeElement) element);
} catch (AbortProcessingException e) {
// skip type
}
}

return false;
}

private void processModule(TypeElement type) {
protected void processModule(TypeElement type) {
var moduleName = type.getAnnotation(HostModule.class).value();

var functions = new NodeList<Expression>();
for (Element member : elements().getAllMembers(type)) {
if (member instanceof ExecutableElement && annotatedWith(member, WasmExport.class)) {
functions.add(processMethod(member, (ExecutableElement) member, moduleName));
functions.add(processMethod((ExecutableElement) member, moduleName));
}
}

var pkg = getPackageName(type);
var packageName = pkg.getQualifiedName().toString();
var cu = (pkg.isUnnamed()) ? new CompilationUnit() : new CompilationUnit(packageName);
if (!pkg.isUnnamed()) {
cu.setPackageDeclaration(packageName);
cu.addImport(type.getQualifiedName().toString());
}
var cu = createCompilationUnit(pkg, type);
cu.addImport("com.dylibso.chicory.runtime.HostFunction");
cu.addImport("com.dylibso.chicory.runtime.Instance");
cu.addImport("com.dylibso.chicory.wasm.types.Value");
cu.addImport("com.dylibso.chicory.wasm.types.ValueType");
cu.addImport("java.util.List");

var typeName = type.getSimpleName().toString();
var processorName = new StringLiteralExpr(getClass().getName());
var classDef =
cu.addClass(typeName + "_ModuleFactory")
.setPublic(true)
.setFinal(true)
.addSingleMemberAnnotation(Generated.class, processorName);
var classDef = cu.addClass(typeName + "_ModuleFactory").setPublic(true).setFinal(true);
addGeneratedAnnotation(classDef);

classDef.addConstructor().setPrivate(true);

Expand All @@ -126,17 +74,10 @@ private void processModule(TypeElement type) {
.setType("HostFunction[]")
.setBody(new BlockStmt(new NodeList<>(new ReturnStmt(newHostFunctions))));

String prefix = (pkg.isUnnamed()) ? "" : packageName + ".";
String qualifiedName = prefix + type.getSimpleName() + "_ModuleFactory";
try (Writer writer = filer().createSourceFile(qualifiedName, type).openWriter()) {
writer.write(cu.printer(printer()).toString());
} catch (IOException e) {
log(ERROR, format("Failed to create %s file: %s", qualifiedName, e), null);
}
writeSourceFile(cu, pkg, type, "_ModuleFactory");
}

private Expression processMethod(
Element member, ExecutableElement executable, String moduleName) {
private Expression processMethod(ExecutableElement executable, String moduleName) {
// compute function name
var name = executable.getAnnotation(WasmExport.class).value();
if (name.isEmpty()) {
Expand Down Expand Up @@ -233,7 +174,9 @@ private Expression processMethod(
// function invocation
Expression invocation =
new MethodCallExpr(
new NameExpr("functions"), member.getSimpleName().toString(), arguments);
new NameExpr("functions"),
executable.getSimpleName().toString(),
arguments);

// convert return value
BlockStmt handleBody = new BlockStmt();
Expand Down Expand Up @@ -273,56 +216,11 @@ private Expression processMethod(
.addArgument(new StringLiteralExpr(name))
.addArgument(new MethodCallExpr(new NameExpr("List"), "of", paramTypes))
.addArgument(new MethodCallExpr(new NameExpr("List"), "of", returnType));
// TODO: update javaparser and replace with multiline formatting
function.setLineComment("");
return function;
}

private Elements elements() {
return processingEnv.getElementUtils();
}

private Filer filer() {
return processingEnv.getFiler();
}

private void log(Diagnostic.Kind kind, String message, Element element) {
processingEnv.getMessager().printMessage(kind, message, element);
}

private static PackageElement getPackageName(Element element) {
Element enclosing = element;
while (enclosing.getKind() != ElementKind.PACKAGE) {
enclosing = enclosing.getEnclosingElement();
}
return (PackageElement) enclosing;
}

private static boolean annotatedWith(Element element, Class<? extends Annotation> annotation) {
var annotationName = annotation.getName();
return element.getAnnotationMirrors().stream()
.map(AnnotationMirror::getAnnotationType)
.map(TypeMirror::toString)
.anyMatch(annotationName::equals);
}

private static Expression argExpr(int n) {
return new ArrayAccessExpr(new NameExpr("args"), new IntegerLiteralExpr(String.valueOf(n)));
}

private static Expression valueType(String type) {
return new FieldAccessExpr(new NameExpr("ValueType"), type);
}

private static String camelCaseToSnakeCase(String name) {
return name.replaceAll("([a-z])([A-Z]+)", "$1_$2").toLowerCase(Locale.ROOT);
}

private static DefaultPrettyPrinter printer() {
return new DefaultPrettyPrinter(
new DefaultPrinterConfiguration()
.addOption(new DefaultConfigurationOption(COLUMN_ALIGN_PARAMETERS, true)));
}

private static final class AbortProcessingException extends RuntimeException {}
}
Loading
Loading