Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve LLM's tool awareness #1657

Closed
drbh opened this issue Mar 20, 2024 · 2 comments · Fixed by #1693
Closed

Improve LLM's tool awareness #1657

drbh opened this issue Mar 20, 2024 · 2 comments · Fixed by #1693
Assignees

Comments

@drbh
Copy link
Collaborator

drbh commented Mar 20, 2024

Feature request

Currently the tools feature does not return the name of the chosen function. This is due to how tools/functions are implemented in TGI (by constrained generation vs fine tuned model).

It's been proposed that the internal structure is updated to force the generation to include names #1650 and overall the function mechanism may be improved.

Opening this issue as a place for others thoughts/idea/uses on how to improve tools to be as useful as possible.

@puppetm4st3r
Copy link

puppetm4st3r commented Mar 21, 2024

I have solved the name issue with a const with tool name in the original json grammar, i'm not familiar with PR 's in github but will try it in the next days, otherwise i Can leave you the changes there are small and a little few. With that changes the output is 100% aligned with the open ai specs.

regards

@puppetm4st3r
Copy link

puppetm4st3r commented Mar 21, 2024

in the server.rs the extraction of the name and the tool parser looks like this:

let (tool_calls, output) = if tool_grammar.is_some() {
            // gen_text should be valid json
            let gen_text_value: Value = serde_json::from_str(&generation.generated_text).map_err(|e| {
                (
                    StatusCode::UNPROCESSABLE_ENTITY,
                    Json(ErrorResponse {
                        error: e.to_string(),
                        error_type: "Input validation error".to_string(),
                    }),
                )
            })?;
        
            let tool_call_id = generate_random_id();
            let tool_call = ToolCall {
                id: tool_call_id,
                r#type: "function".to_string(),
                function: FunctionDefinitionResponse {
                    // Extract the function name from "function" -> "name", not as a constant
                    name: gen_text_value
                        .get("function") // Access the JSON value of "function"
                        .and_then(|f| f.get("name")) // Directly access "name" inside "function"
                        .and_then(|name| name.as_str()) // Ensure "name" is a string
                        .unwrap_or("default_function_name") // Provide a default name if none is found
                        .to_string(),
                    // Serialize the JSON object obtained from "function" to an escaped JSON string
                    arguments: gen_text_value
                        .get("function") // Access the JSON value of "function"
                        .map_or_else(
                            || Ok("{}".to_string()), // Use an empty JSON object if "function" does not exist
                            |f| {
                                // Remove the "name" key from properties before serialization
                                let mut f_cloned = f.clone();
                                if let Value::Object(ref mut props) = f_cloned {
                                    props.remove("name"); // Remove the "name" key
                                }
                                serde_json::to_string(&f_cloned) // Attempt to serialize the modified object to String
                                    .map_err(|e| { // Handle serialization error, if any
                                        (
                                            StatusCode::UNPROCESSABLE_ENTITY,
                                            Json(ErrorResponse {
                                                error: e.to_string(),
                                                error_type: "Input validation error".to_string(),
                                            }),
                                        )
                                    })
                            },
                        )?,
                },
            };                  
            (Some(vec![tool_call]), None)
        } else {
            (None, Some(generation.generated_text))
        };

and the input grammar generator looks like:

// First, generate `tools_str` and `tool_grammar` without depending on `tool_prompt`.
    let (tools_str, tool_grammar) = if let Some((req_tools, tool_choice)) = req.tools.zip(req.tool_choice) {
        // Determine the tools to use based on `tool_choice`.
        let tools_to_use = match tool_choice {
            ToolType::FunctionName(name) => {
                vec![req_tools
                    .iter()
                    .find(|tool| tool.function.name == *name)
                    .ok_or_else(|| {
                        (
                            StatusCode::UNPROCESSABLE_ENTITY,
                            Json(ErrorResponse {
                                error: "Tool choice not found in tool names".to_string(),
                                error_type: "Tool not found".to_string(),
                            }),
                        )
                    })?
                    .clone()]
            },
            ToolType::OneOf => req_tools.to_owned(),
        };

        // Map each tool to its function and parameters.
        let mut functions: HashMap<String, Value> = tools_to_use
            .iter()
            .map(|tool| {
                let func = tool.function.clone();
                
                // Clone the existing parameters, which are expected to be a JSON object
                let mut params = if let Value::Object(params) = &func.parameters {
                    params.clone()
                } else {
                    Map::new()
                };

                // Insert the function's description at the top level, outside of properties
                params.insert("description".to_string(), Value::String(func.description.clone().unwrap_or_default()));
                
                // Ensure 'properties' exists and is an object
                let properties = params.entry("properties".to_string()).or_insert_with(|| json!({})).as_object_mut().unwrap();

                // Insert the constant for the function name inside 'properties'
                properties.insert("name".to_string(), json!({
                    "type": "string",
                    "const": func.name.clone(),
                    "description": "The name of the function"
                }));

                // Check if 'required' exists, and it is an array. If not, create an empty array.
                let required = params.entry("required".to_string()).or_insert_with(|| json!([])).as_array_mut().unwrap();

                // Add 'name' to the 'required' array if it is not already present
                if !required.iter().any(|r| r == "name") {
                    required.push(json!("name"));
                }

                // Return the function name and its parameters (including the constant) as a JSON object
                (func.name.clone(), Value::Object(params))
            })
            .collect();

        // adds the error notification function for LLM feedback if required
        let mut text_response_properties = Map::new();
        text_response_properties.insert("error".to_string(), serde_json::json!({
            "type": "string",
            "description": "The error or issue to notify"
        }));
        text_response_properties.insert("name".to_string(), serde_json::json!({
            "type": "string",
            "description": "The name of the function",
            "const": "notify_error"
        }));
        let text_response_object = serde_json::json!({
            "description": "Useful to notify when a tool can not be called.",
            "properties": text_response_properties,
            "required": ["error", "name"],
            "type": "object"
        });
        functions.insert("notify_error".to_string(), text_response_object);

        // Collect function references from `tools_to_use`
        let mut function_refs: Vec<FunctionRef> = tools_to_use
        .iter()
        .map(|tool| FunctionRef {
            ref_path: format!("#/$functions/{}", tool.function.name.clone()),
        })
        .collect();

        // Manually add the reference to `text_response` !
        function_refs.push(FunctionRef {
            ref_path: "#/$functions/notify_error".to_string(),
        });

        // Now `function_refs` includes all selected functions plus `text_response`
        let tools = Tools {
            functions_map: FunctionsMap { functions },
            properties: Properties {
                function: function_refs, 
            },
            required: vec!["function".to_string()],
        };

        // Serialize the `tools` object to a string.
        let tools_str = serde_json::to_string(&tools).map_err(|e| {
            (
                StatusCode::UNPROCESSABLE_ENTITY,
                Json(ErrorResponse {
                    error: e.to_string(),
                    error_type: "Input validation error".to_string(),
                }),
            )
        })?;
        (tools_str, Some(GrammarType::Json(serde_json::json!(tools))))
    } else {
        (String::new(), None)
    };
    
    // Proceed only if tool_prompt is not None
    if let Some(tool_prompt) = &req.tool_prompt {
        // Find the last message with role 'user'
        if let Some(last_user_message) = req.messages.iter_mut().rev().find(|msg| msg.role == "user") {
            // Generate the additional content combining tool_prompt and tools_str
            let additional_content = format!("\n\n---------------------------\n{}{}\n---------------------------", tool_prompt, tools_str);
            
            // If the last user message has existing content, append the additional content to it
            if let Some(content) = &mut last_user_message.content {
                content.push_str(&additional_content); // Append to the existing content
            } else {
                // If, for some reason, the content is None, replace it with the additional_content
                last_user_message.content = Some(additional_content);
            }
        }
        // Note: If there is no message with role 'user', this block does nothing
    }

    // Apply the chat template to flatten the request into a single input.
    let inputs = match infer.apply_chat_template(req.messages) {
        Ok(inputs) => inputs,
        Err(err) => {
            metrics::increment_counter!("tgi_request_failure", "err" => "validation");
            tracing::error!("{err}");
            return Err((
                StatusCode::UNPROCESSABLE_ENTITY,
                Json(ErrorResponse {
                    error: err.to_string(),
                    error_type: err.error_type().to_string(),
                }),
            ));
        }
    };

Added a notify function by default for fallback cases, with that the llm can notify errores, lack of conext information to select a tool, or other kind of errors to avoid selectings wrong tools.

Also updated the tool prompt to:
You will be presented with a JSON schema representing a set of tools.\nIf the user request lacks of sufficient information to make a precise tool selection: Do not invent any tool's properties, instead notify with an error message.\n\nJSON Schema:\n

Also added a random generator for the call_id number:

fn generate_random_id() -> String {
    let now = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .expect("Time went backwards")
        .as_millis();
    
    let seed = now as u64;
    let mut rng = StdRng::seed_from_u64(seed);
    let id: String = (0..6)
        .map(|_| rng.sample(rand::distributions::Alphanumeric))
        .map(char::from)
        .collect();
    
    format!("call_{}", id)
}

its my first time in rust, hope to do it acceptable

@drbh drbh self-assigned this Apr 1, 2024
drbh added a commit that referenced this issue Apr 16, 2024
This PR makes tool calling aware of the name of the function selected. 

Fixes:
#1657

Thank you @puppetm4st3r for the helpful snippets, large parts of this PR
are simply refactors of the code shared 🙏

**opening draft PR because small tweaks are needed before merging
cr313 added a commit to cr313/text-generation-inference-load-test that referenced this issue Apr 19, 2024
This PR makes tool calling aware of the name of the function selected. 

Fixes:
huggingface/text-generation-inference#1657

Thank you @puppetm4st3r for the helpful snippets, large parts of this PR
are simply refactors of the code shared 🙏

**opening draft PR because small tweaks are needed before merging
Nilabhra pushed a commit to TII-AI-Research-Center/text-generation-inference that referenced this issue May 14, 2024
This PR makes tool calling aware of the name of the function selected. 

Fixes:
huggingface#1657

Thank you @puppetm4st3r for the helpful snippets, large parts of this PR
are simply refactors of the code shared 🙏

**opening draft PR because small tweaks are needed before merging
kdamaszk pushed a commit to kdamaszk/tgi-gaudi that referenced this issue May 27, 2024
This PR makes tool calling aware of the name of the function selected. 

Fixes:
huggingface#1657

Thank you @puppetm4st3r for the helpful snippets, large parts of this PR
are simply refactors of the code shared 🙏

**opening draft PR because small tweaks are needed before merging
kdamaszk pushed a commit to kdamaszk/tgi-gaudi that referenced this issue Jun 3, 2024
This PR makes tool calling aware of the name of the function selected. 

Fixes:
huggingface#1657

Thank you @puppetm4st3r for the helpful snippets, large parts of this PR
are simply refactors of the code shared 🙏

**opening draft PR because small tweaks are needed before merging
alfredgui2 pushed a commit to mlsys-io/kv.run that referenced this issue Jul 6, 2024
This PR makes tool calling aware of the name of the function selected. 

Fixes:
huggingface/text-generation-inference#1657

Thank you @puppetm4st3r for the helpful snippets, large parts of this PR
are simply refactors of the code shared 🙏

**opening draft PR because small tweaks are needed before merging
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants