Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lack of Documentation and 400 Bad Request Error in BedrockRuntime ConverseRequest Tools #5568

Open
KaisNeffati opened this issue Sep 6, 2024 · 5 comments
Labels
documentation This is a problem with documentation. p2 This is a standard priority issue service-api This issue is due to a problem in a service API, not the SDK implementation.

Comments

@KaisNeffati
Copy link

KaisNeffati commented Sep 6, 2024

Describe the issue

I am encountering two issues when trying to utilize the Amazon BedrockRuntime ConverseRequest Tools API:

1. Lack of Documentation: There is insufficient documentation on how to properly set up and utilize the ConverseRequest for tool usage with Bedrock models. The official documentation does not provide enough clarity regarding request structure, model interaction, and tool configurations.
2. 400 Bad Request: When executing the code below, I consistently receive a 400 Bad Request error from the Bedrock service. This error is unclear, as the request structure appears to match the expected API, but further debugging is difficult due to the lack of clear examples in the documentation.

Here is the code that I used:

import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.MediaType;
import software.amazon.awssdk.core.SdkNumber;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.*;

import java.util.*;
import java.util.logging.Logger;
import java.util.logging.Level;
import java.util.stream.Collectors;

@Path("/tools")
public class ConverseAPI {

    private static final Logger logger = Logger.getLogger(ConverseAPI.class.getName());

    // Generates text using Amazon Bedrock Model and handles tool use requests.
    public static String generateText(BedrockRuntimeClient bedrockClient, String modelId, String inputText) {
        logger.info("Generating text with model " + modelId);
        try {
            List<Message> messages = createInitialMessage(inputText);
            ToolConfiguration toolConfig = configureTools();

            // Create initial ConverseRequest
            ConverseResponse response = sendConverseRequest(bedrockClient, modelId, messages, toolConfig);

            // Handle tool use if requested
            while (StopReason.TOOL_USE.equals(response.stopReason())) {
                handleToolUse(response, messages, bedrockClient, modelId, toolConfig);
                response = sendConverseRequest(bedrockClient, modelId, messages, toolConfig);
            }

            // Process and return final response
            return processFinalResponse(response);

        } catch (BedrockRuntimeException e) {
            logger.log(Level.SEVERE, "Error calling Bedrock service: ", e);
            return "";
        }
    }

    // Helper to create the initial message list
    private static List<Message> createInitialMessage(String inputText) {
        return Collections.singletonList(
                Message.builder()
                        .role("user")
                        .content(ContentBlock.builder().text(inputText).build())
                        .build()
        );
    }

    // Helper to configure tool settings
    private static ToolConfiguration configureTools() {
        return ToolConfiguration.builder()
                .tools(Tool.builder()
                        .toolSpec(ToolSpecification.builder()
                                .name("sum")
                                .description("Sums two given numbers")
                                .inputSchema(ToolInputSchema.builder()
                                        .json(Document.mapBuilder()
                                                .putNumber("a", "first parameter")
                                                .putNumber("b", "second parameter")
                                                .build())
                                        .build())
                                .build())
                        .build())
                .build();
    }

    // Helper to send ConverseRequest
    private static ConverseResponse sendConverseRequest(BedrockRuntimeClient bedrockClient, String modelId, List<Message> messages, ToolConfiguration toolConfig) {
        ConverseRequest request = ConverseRequest.builder()
                .modelId(modelId)
                .messages(messages)
                .toolConfig(toolConfig)
                .build();
        return bedrockClient.converse(request);
    }

    // Helper to handle tool use and generate the result
    private static void handleToolUse(ConverseResponse response, List<Message> messages, BedrockRuntimeClient bedrockClient, String modelId, ToolConfiguration toolConfig) {
        List<ContentBlock> toolRequests = response.output().message().content();

        for (ContentBlock toolRequest : toolRequests) {
            ToolUseBlock tool = toolRequest.toolUse();
            if (tool != null) {
                logger.info("Requesting tool " + tool.name() + ". Request ID: " + tool.toolUseId());

                // Handle tool processing
                ToolResultBlock toolResult = processToolRequest(tool);

                // Add tool result to messages
                messages.add(Message.builder()
                        .role("user")
                        .content(ContentBlock.builder().toolResult(toolResult).build())
                        .build());
            }
        }
    }

    // Helper to process tool requests
    private static ToolResultBlock processToolRequest(ToolUseBlock tool) {
        if ("sum".equals(tool.name())) {
            Document input = tool.input();
            try {
                Map<String, Document> inputMap = input.asMap();
                SdkNumber a = Optional.ofNullable(inputMap.get("a")).map(Document::asNumber).orElse(SdkNumber.fromDouble(0));
                SdkNumber b = Optional.ofNullable(inputMap.get("b")).map(Document::asNumber).orElse(SdkNumber.fromDouble(0));

                double result = a.doubleValue() + b.doubleValue();
                return ToolResultBlock.builder()
                        .toolUseId(tool.toolUseId())
                        .content(ToolResultContentBlock.builder().json(Document.fromNumber(result)).build())
                        .status(ToolResultStatus.SUCCESS)
                        .build();

            } catch (Exception ex) {
                logger.log(Level.SEVERE, "Error processing tool request: " + ex.getMessage(), ex);
                return ToolResultBlock.builder()
                        .toolUseId(tool.toolUseId())
                        .content(ToolResultContentBlock.builder().text(ex.getMessage()).build())
                        .status(ToolResultStatus.ERROR)
                        .build();
            }
        }
        return null;
    }

    // Helper to process the final response
    private static String processFinalResponse(ConverseResponse response) {
        Message outputMessage = response.output().message();
        List<ContentBlock> finalContent = outputMessage.content();

        return finalContent.stream()
                .map(ContentBlock::text)
                .collect(Collectors.joining("\n"));
    }

    @GET
    @Produces(MediaType.TEXT_PLAIN)
    public String tools() {

        // Example input
        String modelId = "anthropic.claude-3-sonnet-20240229-v1:0";
        String inputText = "What is the sum of 5 and 6?";

        // Create a Bedrock client
        BedrockRuntimeClient bedrockClient = BedrockRuntimeClient.create();

        try {
            logger.info("Question: " + inputText);
            return generateText(bedrockClient, modelId, inputText);
        } catch (Exception e) {
            logger.log(Level.SEVERE, "A client error occurred: " + e.getMessage());
        } finally {
            logger.info("Finished generating text with model " + modelId);
        }
        return "";
    }
}

Steps to Reproduce:

  1. Run the above code that interacts with the BedrockRuntime ConverseRequest API.
  2. A 400 Bad Request error is returned when invoking the converse() method.

Expected Result:

  • The ConverseRequest should successfully generate text and handle tool requests, summing two numbers as described in the request.

Actual Result:

  • A 400 Bad Request error is returned.
  • No clear documentation or explanation is available to debug or resolve this issue.

Suggested Improvements:

  1. Improved Documentation:
  • Detailed examples on how to configure and use tools within the ConverseRequest.
  • Proper request structure examples, including ToolConfiguration, ToolInputSchema, and ConverseRequest.
  1. Error Handling Clarifications:
  • Clearer error messages when a 400 Bad Request is returned, detailing what part of the request is malformed.
  • Provide a reference to specific sections in the documentation to address common issues like this one.

Links

https://docs.aws.amazon.com/fr_fr/bedrock/latest/userguide/tool-use.html

@KaisNeffati KaisNeffati added documentation This is a problem with documentation. needs-triage This issue or PR still needs to be triaged. labels Sep 6, 2024
@KaisNeffati KaisNeffati changed the title Lack of Documentation and 400 Bad Request Error in BedrockRuntime ConverseRequest Lack of Documentation and 400 Bad Request Error in BedrockRuntime ConverseRequest Tools Sep 6, 2024
@debora-ito
Copy link
Member

@KaisNeffati thank you for you feedback, I'll pass it to the Bedrock team.
Can you provide the full 400 Bad Request error message?

Just in cause you don't know, you can submit documentation feedback directly in the page, just click the "Feedback" button in the top right-side and fill the form with as much detail as you can, the feedback will go straight to the team that owns that particular page.

Screenshot 2024-09-06 at 11 31 08 AM

@debora-ito debora-ito added service-api This issue is due to a problem in a service API, not the SDK implementation. p0 This issue is the highest priority p2 This is a standard priority issue and removed needs-triage This issue or PR still needs to be triaged. p0 This issue is the highest priority labels Sep 6, 2024
@KaisNeffati
Copy link
Author

KaisNeffati commented Sep 8, 2024

Hi Debora, thank you for your response. Below is the complete log :

DIS-APP(^-^) 2024-09-08 12:56:31.293GMT ****  INFO  [com.*****.aiservice.poem.ConverseAPI] (executor-thread-1) Question: What is the sum of 5 and 6?
DIS-APP(^-^) 2024-09-08 12:56:31.293GMT ***** ***** INFO  [com.*****.aiservice.poem.ConverseAPI] (executor-thread-1) Generating text with model anthropic.claude-3-sonnet-20240229-v1:0
DIS-APP(^-^) 2024-09-08 12:56:31.939GMT ***** ***** SEVERE [com.*****.aiservice.poem.ConverseAPI] (executor-thread-1) Error calling Bedrock service: : software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeException: Unable to parse request body (Service: BedrockRuntime, Status Code: 400, Request ID: e6550847-35ef-4cb0-8f45-f4bfd6d7ed96)
	at software.amazon.awssdk.core.internal.http.CombinedResponseHandler.handleErrorResponse(CombinedResponseHandler.java:125)
	at software.amazon.awssdk.core.internal.http.CombinedResponseHandler.handleResponse(CombinedResponseHandler.java:82)
	at software.amazon.awssdk.core.internal.http.CombinedResponseHandler.handle(CombinedResponseHandler.java:60)
	at software.amazon.awssdk.core.internal.http.CombinedResponseHandler.handle(CombinedResponseHandler.java:41)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.HandleResponseStage.execute(HandleResponseStage.java:50)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.HandleResponseStage.execute(HandleResponseStage.java:38)
	at software.amazon.awssdk.core.internal.http.pipeline.RequestPipelineBuilder$ComposingRequestPipelineStage.execute(RequestPipelineBuilder.java:206)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallAttemptTimeoutTrackingStage.execute(ApiCallAttemptTimeoutTrackingStage.java:74)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallAttemptTimeoutTrackingStage.execute(ApiCallAttemptTimeoutTrackingStage.java:43)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.TimeoutExceptionHandlingStage.execute(TimeoutExceptionHandlingStage.java:79)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.TimeoutExceptionHandlingStage.execute(TimeoutExceptionHandlingStage.java:41)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallAttemptMetricCollectionStage.execute(ApiCallAttemptMetricCollectionStage.java:55)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallAttemptMetricCollectionStage.execute(ApiCallAttemptMetricCollectionStage.java:39)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.RetryableStage2.executeRequest(RetryableStage2.java:93)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.RetryableStage2.execute(RetryableStage2.java:56)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.RetryableStage2.execute(RetryableStage2.java:36)
	at software.amazon.awssdk.core.internal.http.pipeline.RequestPipelineBuilder$ComposingRequestPipelineStage.execute(RequestPipelineBuilder.java:206)
	at software.amazon.awssdk.core.internal.http.StreamManagingStage.execute(StreamManagingStage.java:53)
	at software.amazon.awssdk.core.internal.http.StreamManagingStage.execute(StreamManagingStage.java:35)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallTimeoutTrackingStage.executeWithTimer(ApiCallTimeoutTrackingStage.java:82)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallTimeoutTrackingStage.execute(ApiCallTimeoutTrackingStage.java:62)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallTimeoutTrackingStage.execute(ApiCallTimeoutTrackingStage.java:43)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallMetricCollectionStage.execute(ApiCallMetricCollectionStage.java:50)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.ApiCallMetricCollectionStage.execute(ApiCallMetricCollectionStage.java:32)
	at software.amazon.awssdk.core.internal.http.pipeline.RequestPipelineBuilder$ComposingRequestPipelineStage.execute(RequestPipelineBuilder.java:206)
	at software.amazon.awssdk.core.internal.http.pipeline.RequestPipelineBuilder$ComposingRequestPipelineStage.execute(RequestPipelineBuilder.java:206)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.ExecutionFailureExceptionReportingStage.execute(ExecutionFailureExceptionReportingStage.java:37)
	at software.amazon.awssdk.core.internal.http.pipeline.stages.ExecutionFailureExceptionReportingStage.execute(ExecutionFailureExceptionReportingStage.java:26)
	at software.amazon.awssdk.core.internal.http.AmazonSyncHttpClient$RequestExecutionBuilderImpl.execute(AmazonSyncHttpClient.java:210)
	at software.amazon.awssdk.core.internal.handler.BaseSyncClientHandler.invoke(BaseSyncClientHandler.java:103)
	at software.amazon.awssdk.core.internal.handler.BaseSyncClientHandler.doExecute(BaseSyncClientHandler.java:173)
	at software.amazon.awssdk.core.internal.handler.BaseSyncClientHandler.lambda$execute$1(BaseSyncClientHandler.java:80)
	at software.amazon.awssdk.core.internal.handler.BaseSyncClientHandler.measureApiCallSuccess(BaseSyncClientHandler.java:182)
	at software.amazon.awssdk.core.internal.handler.BaseSyncClientHandler.execute(BaseSyncClientHandler.java:74)
	at software.amazon.awssdk.core.client.handler.SdkSyncClientHandler.execute(SdkSyncClientHandler.java:45)
	at software.amazon.awssdk.awscore.client.handler.AwsSyncClientHandler.execute(AwsSyncClientHandler.java:53)
	at software.amazon.awssdk.services.bedrockruntime.DefaultBedrockRuntimeClient.converse(DefaultBedrockRuntimeClient.java:244)
	at com.*******.aiservice.poem.ConverseAPI.sendConverseRequest(ConverseAPI.java:84)
	at com.*******.aiservice.poem.ConverseAPI.generateText(ConverseAPI.java:32)
	at com.*******.aiservice.poem.ConverseAPI.tools(ConverseAPI.java:159)
	at com.*******.aiservice.poem.ConverseAPI$quarkusrestinvoker$tools_a1f6adcb3caa16816de4341cf37bdd026d1c4a8a.invoke(Unknown Source)
	at org.jboss.resteasy.reactive.server.handlers.InvocationHandler.handle(InvocationHandler.java:29)
	at io.quarkus.resteasy.reactive.server.runtime.QuarkusResteasyReactiveRequestContext.invokeHandler(QuarkusResteasyReactiveRequestContext.java:141)
	at org.jboss.resteasy.reactive.common.core.AbstractResteasyReactiveContext.run(AbstractResteasyReactiveContext.java:147)
	at io.quarkus.vertx.core.runtime.VertxCoreRecorder$14.runWith(VertxCoreRecorder.java:635)
	at org.jboss.threads.EnhancedQueueExecutor$Task.doRunWith(EnhancedQueueExecutor.java:2516)
	at org.jboss.threads.EnhancedQueueExecutor$Task.run(EnhancedQueueExecutor.java:2495)
	at org.jboss.threads.EnhancedQueueExecutor$ThreadBody.run(EnhancedQueueExecutor.java:1521)
	at org.jboss.threads.DelegatingRunnable.run(DelegatingRunnable.java:11)
	at org.jboss.threads.ThreadLocalResettingRunnable.run(ThreadLocalResettingRunnable.java:11)
	at io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30)
	at java.base/java.lang.Thread.run(Thread.java:840)

DIS-APP(^-^) 2024-09-08 12:56:31.940GMT ***** ***** INFO  [com.*****.aiservice.poem.ConverseAPI] (executor-thread-1) Finished generating text with model anthropic.claude-3-sonnet-20240229-v1:0

@herbert-beckman
Copy link

herbert-beckman commented Sep 9, 2024

You didn't call the api with correct parameter in tool configuration. Here is an example of code that works:

package com.hb;

import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.Tool;
import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema;
import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification;

public class Main {

    public static String converse() {
        var client = createClient();
        var modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0";
        var toolConfig = createToolConfig();
        var message = createMessage("What is the temperature in Paris?");

        try {
            ConverseResponse response = client.converse(request -> request
                    .modelId(modelId)
                    .toolConfig(toolConfig)
                    .messages(message)
                    .inferenceConfig(config -> config
                            .maxTokens(512)
                            .temperature(0.5F)
                            .topP(0.9F))
            );

            System.out.println("Response:" + response);

            var responseText = response.output().message().content().get(0).text();
            System.out.println("Text Response:" + responseText);

            return responseText;

        } catch (SdkClientException e) {
            System.err.printf("ERROR: Can't invoke '%s'. Reason: %s", modelId, e.getMessage());
            throw new RuntimeException(e);
        }
    }

    private static Message createMessage(String inputText) {
        return Message.builder()
                .content(ContentBlock.fromText(inputText))
                .role(ConversationRole.USER)
                .build();
    }

    private static ToolConfiguration createToolConfig() {
        return ToolConfiguration.builder()
                .tools(Tool.builder()
                        .toolSpec(ToolSpecification.builder()
                                .name("currentTemperature")
                                .description("Returns the current temperature of a city")
                                .inputSchema(ToolInputSchema.builder()
                                        .json(createDocument())
                                        .build())
                                .build())
                        .build())
                .build();
    }

    private static BedrockRuntimeClient createClient() {
        return BedrockRuntimeClient.builder()
                .credentialsProvider(DefaultCredentialsProvider.create())
                .region(Region.US_EAST_1)
                .build();
    }

    private static Document createDocument() {
        var cityParameter = Document.mapBuilder()
                .putString("type", "string")
                .putString("description", "City name")
                .build();

        var properties = Document.mapBuilder()
                .putDocument("city", cityParameter)
                .build();

        var required = Document.listBuilder()
                .addString("city")
                .build();

        return Document.mapBuilder()
                .putString("type", "object")
                .putDocument("properties", properties)
                .putDocument("required", required)
                .build();
    }

    public static void main(String[] args) {
        converse();
    }

}

Output:

Response:ConverseResponse(Output=ConverseOutput(Message=Message(Role=assistant, Content=[ContentBlock(Text=To answer your question about the current temperature in Paris, I can use the currentTemperature function to retrieve that information for you. Let me do that now.), ContentBlock(ToolUse=ToolUseBlock(ToolUseId=tooluse_BVAdWPn8Tg-Hm3y9rRKlwg, Name=currentTemperature, Input={"city": "Paris"}))])), StopReason=tool_use, Usage=TokenUsage(InputTokens=371, OutputTokens=88, TotalTokens=459), Metrics=ConverseMetrics(LatencyMs=2064))

@herbert-beckman
Copy link

Another example with your code logic:

package com.hb;

import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;
import software.amazon.awssdk.services.bedrockruntime.model.Tool;
import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultStatus;
import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification;
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;

public class Main {

    private static final Logger LOG = Logger.getLogger(Main.class.getName());

    public static String converse() {
        var client = createClient();
        var modelId = "anthropic.claude-3-5-sonnet-20240620-v1:0";
        var toolConfig = createToolConfig();
        var messages = new ArrayList<>(List.of(createUserMessage("What is the temperature in Paris?")));

        try {
            var response = sendConverse(client, modelId, toolConfig, messages);
            messages.add(response.output().message());

            while (StopReason.TOOL_USE.equals(response.stopReason())) {
                handleToolUse(response, messages);
                response = sendConverse(client, modelId, toolConfig, messages);
            }

            LOG.info("messages=%s".formatted(messages));
            LOG.info("response=%s".formatted(response));

            var responseText = response.output().message().content().get(0).text();
            LOG.info("textResponse=%s".formatted(responseText));

            return responseText;

        } catch (SdkClientException e) {
            System.err.printf("ERROR: Can't invoke '%s'. Reason: %s", modelId, e.getMessage());
            throw new RuntimeException(e);
        }
    }

    private static Message createUserMessage(String inputText) {
        return Message.builder()
                .role(ConversationRole.USER)
                .content(ContentBlock.fromText(inputText))
                .build();
    }

    private static ToolConfiguration createToolConfig() {
        return ToolConfiguration.builder()
                .tools(Tool.builder()
                        .toolSpec(ToolSpecification.builder()
                                .name("currentTemperature")
                                .description("Returns the current temperature of a city")
                                .inputSchema(ToolInputSchema.builder()
                                        .json(createToolSpecDocument())
                                        .build())
                                .build())
                        .build())
                .build();
    }

    private static BedrockRuntimeClient createClient() {
        return BedrockRuntimeClient.builder()
                .credentialsProvider(DefaultCredentialsProvider.create())
                .region(Region.US_EAST_1)
                .build();
    }

    private static Document createToolSpecDocument() {
        var cityParameter = Document.mapBuilder()
                .putString("type", "string")
                .putString("description", "City name")
                .build();

        var properties = Document.mapBuilder()
                .putDocument("city", cityParameter)
                .build();

        var required = Document.listBuilder()
                .addString("city")
                .build();

        return Document.mapBuilder()
                .putString("type", "object")
                .putDocument("properties", properties)
                .putDocument("required", required)
                .build();
    }

    private static ConverseResponse sendConverse(BedrockRuntimeClient client, String modelId, ToolConfiguration toolConfig, List<Message> messages) {
        var converse = ConverseRequest.builder()
                .modelId(modelId)
                .toolConfig(toolConfig)
                .messages(messages)
                .inferenceConfig(config -> config
                        .maxTokens(512)
                        .temperature(0.5F)
                        .topP(0.9F))
                .build();

        LOG.info("converseRequest=%s".formatted(converse));
        return client.converse(converse);
    }

    private static void handleToolUse(ConverseResponse response, List<Message> messages) {
        var toolRequests = response.output()
                .message()
                .content()
                .stream()
                .filter(contentBlock -> Objects.nonNull(contentBlock.toolUse()))
                .toList();

        for (var toolRequest : toolRequests) {
            var tool = toolRequest.toolUse();
            LOG.info("tool=%s id=%s".formatted(tool.name(), tool.toolUseId()));

            ToolResultBlock toolResult = processToolRequest(tool);

            messages.add(Message.builder()
                    .role(ConversationRole.USER)
                    .content(ContentBlock.builder().toolResult(toolResult).build())
                    .build());
        }
    }

    private static ToolResultBlock processToolRequest(ToolUseBlock tool) {
        if ("currentTemperature".equals(tool.name())) {
            try {
                var input = tool.input();
                var inputMap = input.asMap();
                var cityName = Optional.ofNullable(inputMap.get("city"))
                        .map(Document::asString)
                        .orElse("");

                double result = currentTemperature(cityName);

                return ToolResultBlock.builder()
                        .toolUseId(tool.toolUseId())
                        .content(ToolResultContentBlock.builder().json(createToolResultDocument(result)).build())
                        .status(ToolResultStatus.SUCCESS)
                        .build();

            } catch (Exception ex) {
                LOG.log(Level.SEVERE, "Error processing tool request: " + ex.getMessage(), ex);
                return ToolResultBlock.builder()
                        .toolUseId(tool.toolUseId())
                        .content(ToolResultContentBlock.builder().text(ex.getMessage()).build())
                        .status(ToolResultStatus.ERROR)
                        .build();
            }
        }
        return null;
    }

    private static double currentTemperature(String cityName) {
        if ("Paris".equalsIgnoreCase(cityName)) {
            return 25.0;
        }
        return 40.0;
    }

    private static Document createToolResultDocument(double temperature) {
        return Document.mapBuilder()
                .putString("type", "object")
                .putDocument("content", Document.listBuilder()
                        .addNumber(temperature)
                        .build())
                .build();
    }

    public static void main(String[] args) {
        converse();
    }

}

Console output:

22:59:21: Executing ':com.hb.Main.main()'...

> Task :compileJava
> Task :processResources NO-SOURCE
> Task :classes

> Task :com.hb.Main.main()
SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder".
SLF4J: Defaulting to no-operation (NOP) logger implementation
SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details.
Sep 08, 2024 10:59:22 PM com.hb.Main sendConverse
INFO: converseRequest=ConverseRequest(ModelId=anthropic.claude-3-5-sonnet-20240620-v1:0, Messages=[Message(Role=user, Content=[ContentBlock(Text=What is the temperature in Paris?)])], InferenceConfig=InferenceConfiguration(MaxTokens=512, Temperature=0.5, TopP=0.9), ToolConfig=ToolConfiguration(Tools=[Tool(ToolSpec=ToolSpecification(Name=currentTemperature, Description=Returns the current temperature of a city, InputSchema=ToolInputSchema(Json={"type": "object","properties": {"city": {"type": "string","description": "City name"}},"required": ["city"]})))]))
Sep 08, 2024 10:59:25 PM com.hb.Main handleToolUse
INFO: tool=currentTemperature id=tooluse_ZsZKoRYnSmaMUhWdYF-Rlg
Sep 08, 2024 10:59:25 PM com.hb.Main sendConverse
INFO: converseRequest=ConverseRequest(ModelId=anthropic.claude-3-5-sonnet-20240620-v1:0, Messages=[Message(Role=user, Content=[ContentBlock(Text=What is the temperature in Paris?)]), Message(Role=assistant, Content=[ContentBlock(Text=To answer your question about the current temperature in Paris, I can use the currentTemperature function to retrieve that information for you. Let me do that now.), ContentBlock(ToolUse=ToolUseBlock(ToolUseId=tooluse_ZsZKoRYnSmaMUhWdYF-Rlg, Name=currentTemperature, Input={"city": "Paris"}))]), Message(Role=user, Content=[ContentBlock(ToolResult=ToolResultBlock(ToolUseId=tooluse_ZsZKoRYnSmaMUhWdYF-Rlg, Content=[ToolResultContentBlock(Json={"type": "object","content": [25.0]})], Status=success))])], InferenceConfig=InferenceConfiguration(MaxTokens=512, Temperature=0.5, TopP=0.9), ToolConfig=ToolConfiguration(Tools=[Tool(ToolSpec=ToolSpecification(Name=currentTemperature, Description=Returns the current temperature of a city, InputSchema=ToolInputSchema(Json={"type": "object","properties": {"city": {"type": "string","description": "City name"}},"required": ["city"]})))]))
Sep 08, 2024 10:59:27 PM com.hb.Main converse
INFO: messages=[Message(Role=user, Content=[ContentBlock(Text=What is the temperature in Paris?)]), Message(Role=assistant, Content=[ContentBlock(Text=To answer your question about the current temperature in Paris, I can use the currentTemperature function to retrieve that information for you. Let me do that now.), ContentBlock(ToolUse=ToolUseBlock(ToolUseId=tooluse_ZsZKoRYnSmaMUhWdYF-Rlg, Name=currentTemperature, Input={"city": "Paris"}))]), Message(Role=user, Content=[ContentBlock(ToolResult=ToolResultBlock(ToolUseId=tooluse_ZsZKoRYnSmaMUhWdYF-Rlg, Content=[ToolResultContentBlock(Json={"type": "object","content": [25.0]})], Status=success))])]
Sep 08, 2024 10:59:27 PM com.hb.Main converse
INFO: response=ConverseResponse(Output=ConverseOutput(Message=Message(Role=assistant, Content=[ContentBlock(Text=Based on the information I've retrieved, the current temperature in Paris is 25.0 degrees Celsius (which is equivalent to 77 degrees Fahrenheit).

Is there anything else you'd like to know about the weather in Paris or any other city?)])), StopReason=end_turn, Usage=TokenUsage(InputTokens=482, OutputTokens=59, TotalTokens=541), Metrics=ConverseMetrics(LatencyMs=1789))
Sep 08, 2024 10:59:27 PM com.hb.Main converse
INFO: textResponse=Based on the information I've retrieved, the current temperature in Paris is 25.0 degrees Celsius (which is equivalent to 77 degrees Fahrenheit).

Is there anything else you'd like to know about the weather in Paris or any other city?

BUILD SUCCESSFUL in 6s
2 actionable tasks: 2 executed
22:59:27: Execution finished ':com.hb.Main.main()'.

@KaisNeffati
Copy link
Author

Many thanks @herbert-beckman! My code is working perfectly now! I hope this solution helps others implement it until the official documentation becomes available to everyone

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation This is a problem with documentation. p2 This is a standard priority issue service-api This issue is due to a problem in a service API, not the SDK implementation.
Projects
None yet
Development

No branches or pull requests

3 participants