Skip to content

Commit

Permalink
Override ml endpoint with PredictEndpoint if set
Browse files Browse the repository at this point in the history
  • Loading branch information
JordonPhillips committed Sep 30, 2020
1 parent acfe56d commit 827e3de
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public final class AwsCustomGoDependency extends AwsGoDependency {
"service/glacier/internal/customizations", "glaciercust");
public static final GoDependency S3_SHARED_CUSTOMIZATION = awsModuleDep(
"service/internal/s3shared", null, "s3shared");
public static final GoDependency MACHINE_LEARNING_CUSTOMIZATION = aws(
"service/machinelearning/internal/customizations", "mlcust");

private AwsCustomGoDependency() {
super();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package software.amazon.smithy.aws.go.codegen.customization;

import java.util.List;
import software.amazon.smithy.aws.traits.ServiceTrait;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.go.codegen.GoDelegator;
import software.amazon.smithy.go.codegen.GoSettings;
import software.amazon.smithy.go.codegen.GoWriter;
import software.amazon.smithy.go.codegen.SmithyGoDependency;
import software.amazon.smithy.go.codegen.SymbolUtils;
import software.amazon.smithy.go.codegen.integration.GoIntegration;
import software.amazon.smithy.go.codegen.integration.MiddlewareRegistrar;
import software.amazon.smithy.go.codegen.integration.ProtocolUtils;
import software.amazon.smithy.go.codegen.integration.RuntimeClientPlugin;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.model.shapes.Shape;
import software.amazon.smithy.model.shapes.StructureShape;
import software.amazon.smithy.utils.ListUtils;

public class MachineLearningCustomizations implements GoIntegration {
private static final String ADD_PREDICT_ENDPOINT = "AddPredictEndpointMiddleware";
private static final String ENDPOINT_ACCESSOR = "getPredictEndpoint";

@Override
public byte getOrder() {
// This needs to be run after the generic endpoint resolver gets added
return 50;
}

@Override
public List<RuntimeClientPlugin> getClientPlugins() {
return ListUtils.of(
RuntimeClientPlugin.builder()
.operationPredicate(MachineLearningCustomizations::isPredict)
.registerMiddleware(MiddlewareRegistrar.builder()
.resolvedFunction(SymbolUtils.createValueSymbolBuilder(ADD_PREDICT_ENDPOINT,
AwsCustomGoDependency.MACHINE_LEARNING_CUSTOMIZATION).build())
.functionArguments(ListUtils.of(
SymbolUtils.createValueSymbolBuilder(ENDPOINT_ACCESSOR).build()
))
.build())
.build()
);
}

@Override
public void writeAdditionalFiles(
GoSettings settings,
Model model,
SymbolProvider symbolProvider,
GoDelegator goDelegator
) {
ServiceShape service = settings.getService(model);
if (!isMachineLearning(model, service)) {
return;
}

service.getAllOperations().stream()
.filter(shapeId -> shapeId.getName().equalsIgnoreCase("Predict"))
.findAny()
.map(model::expectShape)
.flatMap(Shape::asOperationShape)
.ifPresent(operation -> {
goDelegator.useShapeWriter(operation, writer -> writeEndpointAccessor(
writer, model, symbolProvider, operation));
});
}

private void writeEndpointAccessor(
GoWriter writer,
Model model,
SymbolProvider symbolProvider,
OperationShape operation
) {
StructureShape input = ProtocolUtils.expectInput(model, operation);
writer.openBlock("func $L(input interface{}) (*string, error) {", "}", ENDPOINT_ACCESSOR, () -> {
writer.write("in, ok := input.($P)", symbolProvider.toSymbol(input));
writer.openBlock("if !ok {", "}", () -> {
writer.addUseImports(SmithyGoDependency.SMITHY);
writer.addUseImports(SmithyGoDependency.FMT);
writer.write("return nil, &smithy.SerializationError{Err: fmt.Errorf("
+ "\"expected $P, but was %T\", input)}", symbolProvider.toSymbol(input));
});
writer.write("return in.PredictEndpoint, nil");
});
}

private static boolean isPredict(Model model, ServiceShape service, OperationShape operation) {
return isMachineLearning(model, service) && operation.getId().getName().equalsIgnoreCase("Predict");
}

private static boolean isMachineLearning(Model model, ServiceShape service) {
return service.expectTrait(ServiceTrait.class).getSdkId().equalsIgnoreCase("Machine Learning");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ software.amazon.smithy.aws.go.codegen.customization.GlacierCustomizations
software.amazon.smithy.aws.go.codegen.customization.S3ResponseErrorWrapper
software.amazon.smithy.aws.go.codegen.customization.S3MetadataRetriever
software.amazon.smithy.aws.go.codegen.customization.S3ContentSHA256Header
software.amazon.smithy.aws.go.codegen.customization.BackfillS3ObjectSizeMemberShapeType
software.amazon.smithy.aws.go.codegen.customization.BackfillS3ObjectSizeMemberShapeType
software.amazon.smithy.aws.go.codegen.customization.MachineLearningCustomizations
11 changes: 11 additions & 0 deletions service/machinelearning/api_op_Predict.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

61 changes: 61 additions & 0 deletions service/machinelearning/internal/customizations/predictendpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package customizations

import (
"context"
"fmt"
"github.com/awslabs/smithy-go"
"github.com/awslabs/smithy-go/middleware"
smithyhttp "github.com/awslabs/smithy-go/transport/http"
"net/url"
)

type fetchPredictEndpointFunc func(interface{}) (*string, error)

// AddPredictEndpointMiddleware adds the middleware required to set the endpoint
// based on Predict's PredictEndpoint input member.
func AddPredictEndpointMiddleware(stack *middleware.Stack, endpoint fetchPredictEndpointFunc) {
stack.Serialize.Insert(&predictEndpointMiddleware{}, "ResolveEndpoint", middleware.After)
}

// predictEndpointMiddleware rewrites the endpoint with whatever is specified in the
// operation input if it is non-nil and non-empty.
type predictEndpointMiddleware struct{
fetchPredictEndpoint fetchPredictEndpointFunc
}

// ID returns the id for the middleware.
func (*predictEndpointMiddleware) ID() string { return "MachineLearning:PredictEndpoint" }

// HandleSerialize implements the SerializeMiddleware interface.
func (m *predictEndpointMiddleware) HandleSerialize(
ctx context.Context, in middleware.SerializeInput, next middleware.SerializeHandler,
) (
out middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {
req, ok := in.Request.(*smithyhttp.Request)
if !ok {
return out, metadata, &smithy.SerializationError{
Err: fmt.Errorf("unknown request type %T", in.Request),
}
}

endpoint, err := m.fetchPredictEndpoint(in.Parameters)
if err != nil {
return out, metadata, &smithy.SerializationError{
Err: fmt.Errorf("failed to fetch PredictEndpoint value, %v", err),
}
}

if endpoint != nil && len(*endpoint) != 0 {
uri, err := url.Parse(*endpoint)
if err != nil {
return out, metadata, &smithy.SerializationError{
Err: fmt.Errorf("unable to parse predict endpoint, %v", err),
}
}
req.URL = uri
in.Request = req
}

return next.HandleSerialize(ctx, in)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package customizations

import (
"context"
"github.com/awslabs/smithy-go/middleware"
"github.com/awslabs/smithy-go/ptr"
smithyhttp "github.com/awslabs/smithy-go/transport/http"
"strings"
"testing"
)

func TestPredictEndpointMiddleware(t *testing.T) {
cases := map[string]struct {
PredictEndpoint *string
ExpectedEndpoint string
ExpectedErr string
}{
"nil endpoint": {},
"empty endpoint": {
PredictEndpoint: ptr.String(""),
},
"invalid endpoint": {
PredictEndpoint: ptr.String("::::::::"),
ExpectedErr: "unable to parse",
},
"valid endpoint": {
PredictEndpoint: ptr.String("https://example.amazonaws.com/"),
ExpectedEndpoint: "https://example.amazonaws.com/",
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
m := &predictEndpointMiddleware{
fetchPredictEndpoint: func(i interface{}) (*string, error) {
return c.PredictEndpoint, nil
},
}
_, _, err := m.HandleSerialize(context.Background(),
middleware.SerializeInput{
Request: smithyhttp.NewStackRequest(),
},
middleware.SerializeHandlerFunc(
func(ctx context.Context, input middleware.SerializeInput) (
output middleware.SerializeOutput, metadata middleware.Metadata, err error,
) {

req, ok := input.Request.(*smithyhttp.Request)
if !ok || req == nil {
t.Fatalf("expect smithy request, got %T", input.Request)
}

if c.ExpectedEndpoint != req.URL.String() {
t.Errorf("expected url to be `%v`, but was `%v`", c.ExpectedEndpoint, req.URL.String())
}

return output, metadata, err
}),
)
if len(c.ExpectedErr) != 0 {
if err == nil {
t.Fatalf("expect error, got none")
}
if e, a := c.ExpectedErr, err.Error(); !strings.Contains(a, e) {
t.Fatalf("expect error to contain %v, got %v", e, a)
}
} else {
if err != nil {
t.Fatalf("expect no error, got %v", err)
}
}
})
}

}

0 comments on commit 827e3de

Please sign in to comment.