From 5f4ae36574434b4ae11041ae2a7fafe60a57f65f Mon Sep 17 00:00:00 2001
From: pm3310
Date: Thu, 7 Mar 2024 07:34:43 +0000
Subject: [PATCH 1/2] LLM batch inference command
---
sagify/api/llm.py | 74 ++++++++++++++++
sagify/commands/llm.py | 155 ++++++++++++++++++++++++++++++++++
sagify/sagemaker/sagemaker.py | 103 +++++++++++++++++++++-
3 files changed, 331 insertions(+), 1 deletion(-)
create mode 100644 sagify/api/llm.py
diff --git a/sagify/api/llm.py b/sagify/api/llm.py
new file mode 100644
index 0000000..a029ad2
--- /dev/null
+++ b/sagify/api/llm.py
@@ -0,0 +1,74 @@
+from sagify.sagemaker import sagemaker
+
+
+def batch_inference(
+ model,
+ s3_input_location,
+ s3_output_location,
+ aws_profile,
+ aws_region,
+ num_instances,
+ ec2_type,
+ aws_role=None,
+ external_id=None,
+ tags=None,
+ wait=True,
+ job_name=None,
+ model_version='1.*',
+ max_concurrent_transforms=None
+):
+ """
+ Executes a batch inference job given a foundation model on SageMaker
+
+ :param model: [str], model name
+ :param s3_model_location: [str], S3 model location
+ :param s3_input_location: [str], S3 input data location
+ :param s3_output_location: [str], S3 location to save predictions
+ :param aws_profile: [str], AWS profile name
+ :param aws_region: [str], AWS region
+ :param num_instances: [int], number of ec2 instances
+ :param ec2_type: [str], ec2 instance type. Refer to:
+ https://aws.amazon.com/sagemaker/pricing/instance-types/
+ :param aws_role: [str, default=None], the AWS role assumed by SageMaker while deploying
+ :param external_id: [str, default=None], Optional external id used when using an IAM role
+ :param tags: [optional[list[dict], default=None], default: None], List of tags for labeling a training
+ job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Example:
+
+ [
+ {
+ 'Key': 'key_name_1',
+ 'Value': key_value_1,
+ },
+ {
+ 'Key': 'key_name_2',
+ 'Value': key_value_2,
+ },
+ ...
+ ]
+ :param wait: [bool, default=True], wait or not for the batch transform to finish
+ :param job_name: [str, default=None], name for the SageMaker batch transform job
+ :param model_version: [str, default='1.*'], model version to use
+ :param max_concurrent_transforms: [int, default=None], max number of concurrent transforms
+
+ :return: [str], transform job status if wait=True.
+ Valid values: 'InProgress'|'Completed'|'Failed'|'Stopping'|'Stopped'
+ """
+ sage_maker_client = sagemaker.SageMakerClient(
+ aws_profile=aws_profile,
+ aws_region=aws_region,
+ aws_role=aws_role,
+ external_id=external_id
+ )
+
+ return sage_maker_client.foundation_model_batch_transform(
+ model_id=model,
+ s3_input_location=s3_input_location,
+ s3_output_location=s3_output_location,
+ num_instances=num_instances,
+ ec2_type=ec2_type,
+ max_concurrent_transforms=max_concurrent_transforms,
+ tags=tags,
+ wait=wait,
+ job_name=job_name,
+ model_version=model_version
+ )
diff --git a/sagify/commands/llm.py b/sagify/commands/llm.py
index e30734d..3012c66 100644
--- a/sagify/commands/llm.py
+++ b/sagify/commands/llm.py
@@ -10,6 +10,7 @@
import docker
from sagify.api import cloud as api_cloud
+from sagify.api import llm as api_llm
from sagify.commands import ASCII_LOGO
from sagify.commands.custom_validators.validators import validate_tags
from sagify.log import logger
@@ -162,6 +163,20 @@
('ml.p3.2xlarge', f'{VANTAGE_URL}/p3.2xlarge'),
('ml.p3.8xlarge', f'{VANTAGE_URL}/p3.8xlarge'),
('ml.p3.16xlarge', f'{VANTAGE_URL}/p3.16xlarge'),
+ ('ml.g4dn.2xlarge', f'{VANTAGE_URL}/g4dn.2xlarge'),
+ ('ml.g4dn.4xlarge', f'{VANTAGE_URL}/g4dn.4xlarge'),
+ ('ml.g4dn.8xlarge', f'{VANTAGE_URL}/g4dn.8xlarge'),
+ ('ml.g4dn.16xlarge', f'{VANTAGE_URL}/g4dn.16xlarge'),
+]
+
+_VALID_EMBEDDINGS_BATCH_INFERENCE_INSTANCE_TYPES = [
+ ('ml.p3.2xlarge', f'{VANTAGE_URL}/p3.2xlarge'),
+ ('ml.p3.8xlarge', f'{VANTAGE_URL}/p3.8xlarge'),
+ ('ml.p3.16xlarge', f'{VANTAGE_URL}/p3.16xlarge'),
+ ('ml.g4dn.2xlarge', f'{VANTAGE_URL}/g4dn.2xlarge'),
+ ('ml.g4dn.4xlarge', f'{VANTAGE_URL}/g4dn.4xlarge'),
+ ('ml.g4dn.8xlarge', f'{VANTAGE_URL}/g4dn.8xlarge'),
+ ('ml.g4dn.16xlarge', f'{VANTAGE_URL}/g4dn.16xlarge'),
]
@@ -684,8 +699,148 @@ def gateway(image, start_local, platform):
logger.info(f"Access service docs: http://localhost:{PORT}/docs")
+@click.command(name="batch-inference")
+@click.option(
+ u"-m", u"--model",
+ required=True,
+ help="Name of the model to use for batch inference",
+ type=click.Path()
+)
+@click.option(
+ u"-i", u"--s3-input-location",
+ required=True,
+ help="s3 input data location",
+ type=click.Path()
+)
+@click.option(
+ u"-o", u"--s3-output-location",
+ required=True,
+ help="s3 location to save predictions",
+ type=click.Path()
+)
+@click.option(
+ u"--aws-profile",
+ required=True,
+ help="The AWS profile to use for the batch inference job"
+)
+@click.option(
+ u"--aws-region",
+ required=True,
+ help="The AWS region to use for the batch inference job"
+)
+@click.option(u"-n", u"--num-instances", required=True, type=int, help="Number of ec2 instances")
+@click.option(u"-e", u"--ec2-type", required=True, help="ec2 instance type")
+@click.option(
+ u"--max-concurrent-transforms",
+ required=False,
+ default=None,
+ type=int,
+ help=" The maximum number of HTTP requests to be made to each individual inference container at one time"
+)
+@click.option(
+ u"-a", u"--aws-tags",
+ callback=validate_tags,
+ required=False,
+ default=None,
+ help='Tags for labeling an inference job of the form "tag1=value1;tag2=value2". For more, see '
+ 'https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.'
+)
+@click.option(
+ u"-r",
+ u"--iam-role-arn",
+ required=False,
+ help="The AWS role to use for this command"
+)
+@click.option(
+ u"-x",
+ u"--external-id",
+ required=False,
+ help="Optional external id used when using an IAM role"
+)
+@click.option(
+ u"-w",
+ u"--wait",
+ default=False,
+ is_flag=True,
+ help="Wait until Batch Inference is finished. "
+ "Default: don't wait!"
+)
+@click.option(
+ u"--job-name",
+ required=False,
+ default=None,
+ help="Optional name for the SageMaker batch inference job."
+)
+def batch_inference(
+ model,
+ s3_input_location,
+ s3_output_location,
+ aws_profile,
+ aws_region,
+ num_instances,
+ ec2_type,
+ max_concurrent_transforms,
+ aws_tags,
+ iam_role_arn,
+ external_id,
+ wait,
+ job_name
+):
+ """
+ Command to execute a batch inference job
+ """
+ logger.info(ASCII_LOGO)
+
+ if model not in _MAPPING_EMBEDDINGS_MODEL_ID_TO_MODEL_NAME['sagemaker']:
+ raise ValueError(
+ "Invalid embeddings model id. Available model ids: {}".format(
+ list(_MAPPING_EMBEDDINGS_MODEL_ID_TO_MODEL_NAME['sagemaker'].keys())
+ )
+ )
+
+ valid_instance_types = [item[0] for item in _VALID_EMBEDDINGS_BATCH_INFERENCE_INSTANCE_TYPES]
+
+ if ec2_type not in valid_instance_types:
+ raise ValueError(
+ "Invalid instance type for embeddings model. Available instance types: {}".format(
+ _VALID_EMBEDDINGS_INSTANCE_TYPES
+ )
+ )
+
+ logger.info("Starting batch inference job...\n")
+
+ try:
+ status = api_llm.batch_inference(
+ model=_MAPPING_EMBEDDINGS_MODEL_ID_TO_MODEL_NAME['sagemaker'][model][0],
+ s3_input_location=s3_input_location,
+ s3_output_location=s3_output_location,
+ aws_profile=aws_profile,
+ aws_region=aws_region,
+ num_instances=num_instances,
+ ec2_type=ec2_type,
+ max_concurrent_transforms=max_concurrent_transforms,
+ aws_role=iam_role_arn,
+ external_id=external_id,
+ tags=aws_tags,
+ wait=wait,
+ job_name=job_name
+ )
+
+ if wait:
+ logger.info("Batch inference on SageMaker finished with status: {}".format(status))
+ if status == "Failed":
+ sys.exit(1)
+ else:
+ logger.info("Started batch inference on SageMaker successfully")
+
+ except ValueError as e:
+ logger.info("{}".format(e))
+ sys.exit(-1)
+
+
llm.add_command(platforms)
llm.add_command(models)
llm.add_command(start)
llm.add_command(stop)
llm.add_command(gateway)
+llm.add_command(batch_inference)
diff --git a/sagify/sagemaker/sagemaker.py b/sagify/sagemaker/sagemaker.py
index 8e470c9..8ec0d04 100644
--- a/sagify/sagemaker/sagemaker.py
+++ b/sagify/sagemaker/sagemaker.py
@@ -8,7 +8,7 @@
import sagemaker.huggingface
import sagemaker.xgboost
import sagemaker.sklearn.model
-from sagemaker import payloads
+from sagemaker import image_uris, payloads, model_uris
from sagemaker.jumpstart.model import JumpStartModel
from six.moves.urllib.parse import urlparse
@@ -621,6 +621,107 @@ def deploy_foundation_model(
return model_predictor.endpoint_name, self._generate_foundation_model_query_command(model_id, model_version, model_predictor.endpoint_name)
+ def foundation_model_batch_transform(
+ self,
+ model_id,
+ model_version,
+ s3_input_location,
+ s3_output_location,
+ num_instances,
+ ec2_type,
+ max_concurrent_transforms=None,
+ tags=None,
+ wait=True,
+ job_name=None
+ ):
+ """
+ Execute foundation model batch transform on a trained model to SageMaker
+
+ :param model_id: [str], foundation model id
+ :param model_version: [str], foundation model version
+ :param s3_input_location: [str], S3 input data location
+ :param s3_output_location: [str], S3 output data location
+ :param num_instances: [str], number of ec2 instances
+ :param ec2_type: [str], ec2 instance type
+ :param max_concurrent_transforms: [int, default=None], The maximum number of HTTP requests to be made to
+ :param tags: [optional[list[dict]], default: None], List of tags for labeling a training
+ job. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. Example:
+
+ [
+ {
+ 'Key': 'key_name_1',
+ 'Value': key_value_1,
+ },
+ {
+ 'Key': 'key_name_2',
+ 'Value': key_value_2,
+ },
+ ...
+ ]
+ :param wait: [bool, default=True], wait or not for the batch transform to finish
+ :param job_name: [str, default=None], name for the SageMaker batch transform job
+
+ :return: [str], transform job status if wait=True.
+ Valid values: 'InProgress'|'Completed'|'Failed'|'Stopping'|'Stopped'
+ """
+ # Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.
+ deploy_image_uri = image_uris.retrieve(
+ region=self.aws_region,
+ framework=None, # automatically inferred from model_id
+ image_scope="inference",
+ model_id=model_id,
+ model_version=model_version,
+ instance_type=ec2_type,
+ )
+
+ # Retrieve the model uri.
+ model_uri = model_uris.retrieve(
+ model_id=model_id,
+ model_version=model_version,
+ model_scope="inference",
+ region=self.aws_region
+ )
+
+ model = JumpStartModel(
+ model_id=model_id,
+ model_version=model_version,
+ region=self.aws_region,
+ image_uri=deploy_image_uri,
+ model_data=model_uri,
+ sagemaker_session=self.sagemaker_session,
+ tolerate_deprecated_model=True,
+ tolerate_vulnerable_model=True
+ )
+
+ transformer = model.transformer(
+ instance_count=num_instances,
+ instance_type=ec2_type,
+ output_path=s3_output_location,
+ assemble_with="Line",
+ accept="text/csv",
+ max_concurrent_transforms=max_concurrent_transforms,
+ tags=tags,
+ )
+
+ transformer.transform(
+ s3_input_location,
+ content_type="application/jsonlines",
+ split_type="Line",
+ wait=wait
+ )
+
+ if wait:
+ try:
+ transformer.wait()
+ except Exception:
+ # If there is an error, wait() throws an exception and we're not able to return a Failed status
+ pass
+ finally:
+ job_name = transformer.latest_transform_job.job_name
+ job_description = self.sagemaker_client.describe_transform_job(TransformJobName=job_name)
+
+ return job_description['TransformJobStatus']
+
def _generate_foundation_model_query_command(self, model_id, model_version, endpoint_name):
"""
Generate foundation model query command
From 6dab3a7eac8330042916a3890d235860c2f0b6c3 Mon Sep 17 00:00:00 2001
From: pm3310
Date: Thu, 7 Mar 2024 07:35:06 +0000
Subject: [PATCH 2/2] Batch inference docs
---
docs/index.md | 187 +++++++++++++++++++++++++++++++---
sagify/api/llm.py | 10 +-
sagify/commands/llm.py | 2 +-
sagify/sagemaker/sagemaker.py | 19 +++-
4 files changed, 199 insertions(+), 19 deletions(-)
diff --git a/docs/index.md b/docs/index.md
index 417b15f..06d95ba 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -99,9 +99,11 @@ You can change the values for ec2 type (-e), aws region and aws profile with you
Once the Stable Diffusion model is deployed, you can use the generated code snippet to query it. Enjoy!
-### Backend Platforms
+### Restful Inference
-#### OpenAI
+#### Backend Platforms
+
+##### OpenAI
The following models are offered for chat completions:
@@ -129,7 +131,7 @@ And for embeddings:
All these lists of supported models on Openai can be retrieved by running the command `sagify llm models --all --provider openai`. If you want to focus only on chat completions models, then run `sagify llm models --chat-completions --provider openai`. For image creations and embeddings, `sagify llm models --image-creations --provider openai` and `sagify llm models --embeddings --provider openai`, respectively.
-#### Anthropic
+##### Anthropic
The following models are offered for chat completions:
@@ -140,7 +142,7 @@ The following models are offered for chat completions:
|claude-instant-1.2|https://docs.anthropic.com/claude/reference/models|
-#### Open-Source
+##### Open-Source
The following open-source models are offered for chat completions:
@@ -179,7 +181,7 @@ And for embeddings:
All these lists of supported open-source models are supported on AWS Sagemaker and can be retrieved by running the command `sagify llm models --all --provider sagemaker`. If you want to focus only on chat completions models, then run `sagify llm models --chat-completions --provider sagemaker`. For image creations and embeddings, `sagify llm models --image-creations --provider sagemaker` and `sagify llm models --embeddings --provider sagemaker`, respectively.
-### Set up OpenAI
+#### Set up OpenAI
You need to define the following env variables before you start the LLM Gateway server:
@@ -188,14 +190,14 @@ You need to define the following env variables before you start the LLM Gateway
- `OPENAI_EMBEDDINGS_MODEL`: It should have one of values [here](https://platform.openai.com/docs/models/embeddings).
- `OPENAI_IMAGE_CREATION_MODEL`: It should have one of values [here](https://platform.openai.com/docs/models/dall-e).
-### Set up Anthropic
+#### Set up Anthropic
You need to define the following env variables before you start the LLM Gateway server:
- `ANTHROPIC_API_KEY`: Your OpenAI API key. Example: `export ANTHROPIC_API_KEY=...`.
- `ANTHROPIC_CHAT_COMPLETIONS_MODEL`: It should have one of values [here](https://docs.anthropic.com/claude/reference/models). Example `export ANTHROPIC_CHAT_COMPLETIONS_MODEL=claude-2.1`
-### Set up open-source LLMs
+#### Set up open-source LLMs
First step is to deploy the LLM model(s). You can choose to deploy all backend services (chat completions, image creations, embeddings) or some of them.
@@ -227,7 +229,7 @@ It takes 15 to 30 minutes to deploy all the backend services as Sagemaker endpoi
The deployed model names, which are the Sagemaker endpoint names, are printed out and stored in the hidden file `.sagify_llm_infra.json`. You can also access them from the AWS Sagemaker web console.
-### Deploy FastAPI LLM Gateway - Docker
+#### Deploy FastAPI LLM Gateway - Docker
Once you have set up your backend platform, you can deploy the FastAPI LLM Gateway locally.
@@ -273,7 +275,7 @@ sagify llm gateway --image sagify-llm-gateway:v0.1.0 --start-local
If you want to support both platforms (OpenAI and AWS Sagemaker), then pass all the env variables for both platforms.
-### Deploy FastAPI LLM Gateway - AWS Fargate
+#### Deploy FastAPI LLM Gateway - AWS Fargate
In case you want to deploy the LLM Gateway to AWS Fargate, then you can follow these general steps:
@@ -339,11 +341,11 @@ Resources:
-
```
-### LLM Gateway API
+#### LLM Gateway API
Once the LLM Gateway is deployed, you can access it on `HOST_NAME/docs`.
-#### Completions
+##### Completions
Code samples
@@ -488,7 +490,7 @@ print(response.text)
}
```
-#### Embeddings
+##### Embeddings
Code samples
@@ -614,7 +616,7 @@ print(response.text)
}
```
-#### Image Generations
+##### Image Generations
Code samples
@@ -731,7 +733,7 @@ print(response.text)
The above example returns a url to the image. If you want to return a base64 value of the image, then set `response_format` to `base64_json` in the request body params.
-### Upcoming Proprietary & Open-Source LLMs and Cloud Platforms
+#### Upcoming Proprietary & Open-Source LLMs and Cloud Platforms
- [Amazong Bedrock](https://aws.amazon.com/bedrock/)
- [Cohere](https://cohere.com/)
@@ -739,6 +741,97 @@ The above example returns a url to the image. If you want to return a base64 val
- [Gemma](https://blog.google/technology/developers/gemma-open-models/)
- [GCP VertexAI](https://cloud.google.com/vertex-ai)
+### Batch Inference
+
+In the realm of AI/ML, real-time inference via RESTful APIs is undeniably crucial for many applications. However, another equally important, yet often overlooked, aspect of inference lies in batch processing.
+
+While real-time inference caters to immediate, on-the-fly predictions, batch inference empowers users with the ability to process large volumes of data efficiently and cost-effectively.
+
+#### Embeddings
+
+Generating embeddings offline in a batch mode is essential for many real world applications. These embeddings can then be stored in some vector database to serve recommender, search/ranking and other ML powered systems.
+
+You have to use Sagemaker as the backend platform and only the following open-source models are supported:
+
+| Model Name | URL |
+|:------------:|:-----:|
+|bge-large-en|https://huggingface.co/BAAI/bge-large-en|
+|bge-base-en|https://huggingface.co/BAAI/bge-base-en|
+|gte-large|https://huggingface.co/thenlper/gte-large|
+|gte-base|https://huggingface.co/thenlper/gte-base|
+|e5-large-v2|https://huggingface.co/intfloat/e5-large-v2|
+|bge-small-en|https://huggingface.co/BAAI/bge-small-en|
+|e5-base-v2|https://huggingface.co/intfloat/e5-base-v2|
+|multilingual-e5-large|https://huggingface.co/intfloat/multilingual-e5-large|
+|e5-large|https://huggingface.co/intfloat/e5-large|
+|gte-small|https://huggingface.co/thenlper/gte-small|
+|e5-base|https://huggingface.co/intfloat/e5-base|
+|e5-small-v2|https://huggingface.co/intfloat/e5-small-v2|
+|multilingual-e5-base|https://huggingface.co/intfloat/multilingual-e5-base|
+|all-MiniLM-L6-v2|https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2|
+
+Also, the following ec2 instance types support batch inference:
+
+| Instance Type | Details |
+|:------------:|:-----:|
+|ml.p3.2xlarge|https://instances.vantage.sh/aws/ec2/p3.2xlarge|
+|ml.p3.8xlarge|https://instances.vantage.sh/aws/ec2/p3.8xlarge|
+|ml.p3.16xlarge|https://instances.vantage.sh/aws/ec2/p3.16xlarge|
+|ml.g4dn.2xlarge|https://instances.vantage.sh/aws/ec2/g4dn.2xlarge|
+|ml.g4dn.4xlarge|https://instances.vantage.sh/aws/ec2/g4dn.4xlarge|
+|ml.g4dn.8xlarge|https://instances.vantage.sh/aws/ec2/g4dn.8xlarge|
+|ml.g4dn.16xlarge|https://instances.vantage.sh/aws/ec2/g4dn.16xlarge|
+
+##### How does it work?
+
+It's quite simple. To begin, prepare the input JSONL file(s). Consider the following example:
+
+```json
+{"id":1,"text_inputs":"what is the recipe of mayonnaise?"}
+{"id":2,"text_inputs":"what is the recipe of fish and chips?"}
+```
+
+Each line contains a unique identifier (id) and the corresponding text input (text_inputs). This identifier is crucial for linking inputs to their respective outputs, as illustrated in the output format below:
+
+```json
+{'id': 1, 'embedding': [-0.029919596, -0.0011845357, ..., 0.08851079, 0.021398442]}
+{'id': 2, 'embedding': [-0.041918136, 0.007127975, ..., 0.060178414, 0.031050885]}
+```
+
+By ensuring consistency in the id field between input and output files, you empower your ML use cases with seamless data coherence.
+
+Once the input JSONL file(s) are saved in an S3 bucket, you can trigger the batch inference programmatically from your Python codebase or via the Sagify CLI.
+
+##### CLI
+
+The following command does all the magic! Here's an example:
+
+```sh
+sagify llm batch-inference --model gte-small --s3-input-location s3://sagify-llm-playground/batch-input-data-example/embeddings/ --s3-output-location s3://sagify-llm-playground/batch-output-data-example/embeddings/1/ --aws-profile sagemaker-dev --aws-region us-east-1 --num-instances 1 --ec2-type ml.p3.2xlarge --wait
+```
+
+The `--s3-input-location` should be the path where the JSONL file(s) are saved.
+
+##### SDK
+
+Magic can happen with the Sagify SDK, too. Here's a code snippet:
+
+```python
+from sagify.api.llm import batch_inference
+
+batch_inference(
+ model='gte-small',
+ s3_input_location='3://sagify-llm-playground/batch-input-data-example/embeddings/',
+ s3_output_location='s3://sagify-llm-playground/batch-output-data-example/embeddings/1/',
+ aws_profile='sagemaker-dev',
+ aws_region='us-east-1',
+ num_instances=1,
+ ec2_type='ml.p3.2xlarge',
+ aws_access_key_id='YOUR_AWS_ACCESS_KEY_ID'
+ aws_secret_access_key='YOUR_AWS_SECRET_ACCESS_KEY',
+ wait=True
+)
+```
## Machine Learning
@@ -1955,3 +2048,69 @@ It builds gateway docker image and starts the gateway locally.
`--platform PLATFORM`: Operating system. Platform in the format `os[/arch[/variant]]`.
`--start-local`: Flag to indicate if to start the gateway locally.
+
+
+### LLM Batch Inference
+
+#### Name
+
+Command to execute an LLM batch inference job
+
+#### Synopsis
+```sh
+sagify llm batch-inference --model MODEL --s3-input-location S3_INPUT_LOCATION --s3-output-location S3_OUTPUT_LOCATION --aws-profile AWS_PROFILE --aws-region AWS_REGION --num-instances NUMBER_OF_EC2_INSTANCES --ec2-type EC2_TYPE [--aws-tags TAGS] [--iam-role-arn IAM_ROLE] [--external-id EXTERNAL_ID] [--wait] [--job-name JOB_NAME] [--max-concurrent-transforms MAX_CONCURRENT_TRANSFORMS]
+```
+
+#### Description
+
+This command triggers an batch inference job given an LLM model and an batch input.
+
+- The input S3 path should contain a JSONL file or multiple JSONL files. Example of a file:
+```json
+{"id":1,"text_inputs":"what is the recipe of mayonnaise?"}
+{"id":2,"text_inputs":"what is the recipe of fish and chips?"}
+```
+
+Each line contains a unique identifier (id) and the corresponding text input (text_inputs). This identifier is crucial for linking inputs to their respective outputs, as illustrated in the output format below:
+
+```json
+{'id': 1, 'embedding': [-0.029919596, -0.0011845357, ..., 0.08851079, 0.021398442]}
+{'id': 2, 'embedding': [-0.041918136, 0.007127975, ..., 0.060178414, 0.031050885]}
+```
+
+By ensuring consistency in the id field between input and output files, you empower your ML use cases with seamless data coherence.
+
+#### Required Flags
+
+`--model MODEL`: LLM model name
+
+`--s3-input-location S3_INPUT_LOCATION` or `-i S3_INPUT_LOCATION`: s3 input data location
+
+`--s3-output-location S3_OUTPUT_LOCATION` or `-o S3_OUTPUT_LOCATION`: s3 location to save predictions
+
+`--num-instances NUMBER_OF_EC2_INSTANCES` or `n NUMBER_OF_EC2_INSTANCES`: Number of ec2 instances
+
+`--ec2-type EC2_TYPE` or `e EC2_TYPE`: ec2 type. Refer to https://aws.amazon.com/sagemaker/pricing/instance-types/
+
+`--aws-profile AWS_PROFILE`: The AWS profile to use for the lightning deploy command
+
+`--aws-region AWS_REGION`: The AWS region to use for the lightning deploy command
+
+#### Optional Flags
+
+`--aws-tags TAGS` or `-a TAGS`: Tags for labeling an inference job of the form `tag1=value1;tag2=value2`. For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
+
+`--iam-role-arn IAM_ROLE` or `-r IAM_ROLE`: AWS IAM role to use for the inference job with *SageMaker*
+
+`--external-id EXTERNAL_ID` or `-x EXTERNAL_ID`: Optional external id used when using an IAM role
+
+`--wait`: Optional flag to wait until Batch Inference is finished. (default: don't wait)
+
+`--job-name JOB_NAME`: Optional name for the SageMaker batch inference job
+
+`--max-concurrent-transforms MAX_CONCURRENT_TRANSFORMS`: Optional maximum number of HTTP requests to be made to each individual inference container at one time. Default value: 1
+
+#### Example
+```sh
+sagify llm batch-inference --model gte-small --s3-input-location s3://sagify-llm-playground/batch-input-data-example/embeddings/ --s3-output-location s3://sagify-llm-playground/batch-output-data-example/embeddings/1/ --aws-profile sagemaker-dev --aws-region us-east-1 --num-instances 1 --ec2-type ml.p3.2xlarge --wait
+```
diff --git a/sagify/api/llm.py b/sagify/api/llm.py
index a029ad2..1bc185a 100644
--- a/sagify/api/llm.py
+++ b/sagify/api/llm.py
@@ -15,7 +15,9 @@ def batch_inference(
wait=True,
job_name=None,
model_version='1.*',
- max_concurrent_transforms=None
+ max_concurrent_transforms=None,
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
):
"""
Executes a batch inference job given a foundation model on SageMaker
@@ -49,6 +51,8 @@ def batch_inference(
:param job_name: [str, default=None], name for the SageMaker batch transform job
:param model_version: [str, default='1.*'], model version to use
:param max_concurrent_transforms: [int, default=None], max number of concurrent transforms
+ :param aws_access_key_id: [str, default=None], AWS access key id
+ :param aws_secret_access_key: [str, default=None], AWS secret access key
:return: [str], transform job status if wait=True.
Valid values: 'InProgress'|'Completed'|'Failed'|'Stopping'|'Stopped'
@@ -57,7 +61,9 @@ def batch_inference(
aws_profile=aws_profile,
aws_region=aws_region,
aws_role=aws_role,
- external_id=external_id
+ external_id=external_id,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key
)
return sage_maker_client.foundation_model_batch_transform(
diff --git a/sagify/commands/llm.py b/sagify/commands/llm.py
index 3012c66..2846d5a 100644
--- a/sagify/commands/llm.py
+++ b/sagify/commands/llm.py
@@ -735,7 +735,7 @@ def gateway(image, start_local, platform):
required=False,
default=None,
type=int,
- help=" The maximum number of HTTP requests to be made to each individual inference container at one time"
+ help="The maximum number of HTTP requests to be made to each individual inference container at one time"
)
@click.option(
u"-a", u"--aws-tags",
diff --git a/sagify/sagemaker/sagemaker.py b/sagify/sagemaker/sagemaker.py
index 8ec0d04..fa029ff 100644
--- a/sagify/sagemaker/sagemaker.py
+++ b/sagify/sagemaker/sagemaker.py
@@ -23,9 +23,24 @@
class SageMakerClient(object):
- def __init__(self, aws_profile, aws_region, aws_role=None, external_id=None):
+ def __init__(
+ self,
+ aws_profile,
+ aws_region,
+ aws_role=None,
+ external_id=None,
+ aws_access_key_id=None,
+ aws_secret_access_key=None
+ ):
- if aws_role:
+ if aws_access_key_id and aws_secret_access_key:
+ logger.info("AWS access key and secret access key were provided. Using these credentials...")
+ self.boto_session = boto3.Session(
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=aws_region
+ )
+ elif aws_role:
logger.info("An IAM role and corresponding external id were provided. Attempting to assume that role...")
sts_client = boto3.client('sts')