Skip to content

Commit

Permalink
fix(client-transcribe-streaming): add plugin only for operations with…
Browse files Browse the repository at this point in the history
… streaming trait (#6349)
  • Loading branch information
trivikr authored Aug 2, 2024
1 parent 9836a09 commit 6043d79
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import {
} from "@aws-sdk/middleware-host-header";
import { getLoggerPlugin } from "@aws-sdk/middleware-logger";
import { getRecursionDetectionPlugin } from "@aws-sdk/middleware-recursion-detection";
import { getTranscribeStreamingPlugin } from "@aws-sdk/middleware-sdk-transcribe-streaming";
import {
getUserAgentPlugin,
resolveUserAgentConfig,
Expand Down Expand Up @@ -361,7 +360,6 @@ export class TranscribeStreamingClient extends __Client<
}),
})
);
this.middlewareStack.use(getTranscribeStreamingPlugin(this.config));
this.middlewareStack.use(getHttpSigningPlugin(this.config));
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// smithy-typescript generated code
import { getEventStreamPlugin } from "@aws-sdk/middleware-eventstream";
import { getTranscribeStreamingPlugin } from "@aws-sdk/middleware-sdk-transcribe-streaming";
import { getWebSocketPlugin } from "@aws-sdk/middleware-websocket";
import { getEndpointPlugin } from "@smithy/middleware-endpoint";
import { getSerdePlugin } from "@smithy/middleware-serde";
Expand Down Expand Up @@ -248,6 +249,7 @@ export class StartCallAnalyticsStreamTranscriptionCommand extends $Command
getWebSocketPlugin(config, {
headerPrefix: "x-amzn-transcribe-",
}),
getTranscribeStreamingPlugin(config),
];
})
.s("Transcribe", "StartCallAnalyticsStreamTranscription", {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// smithy-typescript generated code
import { getEventStreamPlugin } from "@aws-sdk/middleware-eventstream";
import { getTranscribeStreamingPlugin } from "@aws-sdk/middleware-sdk-transcribe-streaming";
import { getWebSocketPlugin } from "@aws-sdk/middleware-websocket";
import { getEndpointPlugin } from "@smithy/middleware-endpoint";
import { getSerdePlugin } from "@smithy/middleware-serde";
Expand Down Expand Up @@ -229,6 +230,7 @@ export class StartMedicalStreamTranscriptionCommand extends $Command
getWebSocketPlugin(config, {
headerPrefix: "x-amzn-transcribe-",
}),
getTranscribeStreamingPlugin(config),
];
})
.s("Transcribe", "StartMedicalStreamTranscription", {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// smithy-typescript generated code
import { getEventStreamPlugin } from "@aws-sdk/middleware-eventstream";
import { getTranscribeStreamingPlugin } from "@aws-sdk/middleware-sdk-transcribe-streaming";
import { getWebSocketPlugin } from "@aws-sdk/middleware-websocket";
import { getEndpointPlugin } from "@smithy/middleware-endpoint";
import { getSerdePlugin } from "@smithy/middleware-serde";
Expand Down Expand Up @@ -253,6 +254,7 @@ export class StartStreamTranscriptionCommand extends $Command
getWebSocketPlugin(config, {
headerPrefix: "x-amzn-transcribe-",
}),
getTranscribeStreamingPlugin(config),
];
})
.s("Transcribe", "StartStreamTranscription", {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import software.amazon.smithy.aws.traits.ServiceTrait;
import software.amazon.smithy.codegen.core.SymbolProvider;
import software.amazon.smithy.model.Model;
import software.amazon.smithy.model.knowledge.EventStreamIndex;
import software.amazon.smithy.model.shapes.OperationShape;
import software.amazon.smithy.model.shapes.ServiceShape;
import software.amazon.smithy.typescript.codegen.LanguageTarget;
import software.amazon.smithy.typescript.codegen.TypeScriptSettings;
Expand Down Expand Up @@ -54,7 +56,7 @@ public List<RuntimeClientPlugin> getClientPlugins() {
RuntimeClientPlugin.builder()
.withConventions(AwsDependency.TRANSCRIBE_STREAMING_MIDDLEWARE.dependency,
"TranscribeStreaming", RuntimeClientPlugin.Convention.HAS_MIDDLEWARE)
.servicePredicate((m, s) -> isTranscribeStreaming(s))
.operationPredicate((m, s, o) -> isTranscribeStreaming(s) && hasEventStreamInput(m, s, o))
.build()
);
}
Expand Down Expand Up @@ -92,6 +94,11 @@ private static boolean isTranscribeStreaming(ServiceShape service) {
String serviceId = service.getTrait(ServiceTrait.class).map(ServiceTrait::getSdkId).orElse("");
return serviceId.equals("Transcribe Streaming");
}

private static boolean hasEventStreamInput(Model model, ServiceShape service, OperationShape operation) {
EventStreamIndex eventStreamIndex = EventStreamIndex.of(model);
return eventStreamIndex.getInputInfo(operation).isPresent();
}
}


0 comments on commit 6043d79

Please sign in to comment.