diff --git a/tests/system/providers/amazon/aws/example_bedrock.py b/tests/system/providers/amazon/aws/example_bedrock.py index f2c6988a0500f0..3c015496f992e4 100644 --- a/tests/system/providers/amazon/aws/example_bedrock.py +++ b/tests/system/providers/amazon/aws/example_bedrock.py @@ -56,6 +56,12 @@ # the code snippets for docs, and we can manually run the full tests. SKIP_LONG_TASKS = environ.get("SKIP_LONG_SYSTEM_TEST_TASKS", default=True) +# No-commitment Provisioned Throughput is currently restricted to external +# customers only and will fail with a ServiceQuotaExceededException if run +# on the AWS System Test stack. +SKIP_PROVISION_THROUGHPUT = environ.get("SKIP_RESTRICTED_SYSTEM_TEST_TASKS", default=True) + + LLAMA_SHORT_MODEL_ID = "meta.llama2-13b-chat-v1" TITAN_MODEL_ID = "amazon.titan-text-express-v1:0:8k" TITAN_SHORT_MODEL_ID = TITAN_MODEL_ID.split(":")[0] @@ -107,9 +113,43 @@ def run_or_skip(): chain(run_or_skip, customize_model, await_custom_model_job, delete_custom_model(), end_workflow) -@task -def delete_provision_throughput(provisioned_model_id: str): - BedrockHook().conn.delete_provisioned_model_throughput(provisionedModelId=provisioned_model_id) +@task_group +def provision_throughput_workflow(): + # [START howto_operator_provision_throughput] + provision_throughput = BedrockCreateProvisionedModelThroughputOperator( + task_id="provision_throughput", + model_units=1, + provisioned_model_name=provisioned_model_name, + model_id=f"{model_arn_prefix}{TITAN_MODEL_ID}", + ) + # [END howto_operator_provision_throughput] + + # [START howto_sensor_provision_throughput] + await_provision_throughput = BedrockProvisionModelThroughputCompletedSensor( + task_id="await_provision_throughput", + model_id=provision_throughput.output, + ) + # [END howto_sensor_provision_throughput] + + @task + def delete_provision_throughput(provisioned_model_id: str): + BedrockHook().conn.delete_provisioned_model_throughput(provisionedModelId=provisioned_model_id) + + @task.branch + def run_or_skip(): + return end_workflow.task_id if SKIP_PROVISION_THROUGHPUT else provision_throughput.task_id + + run_or_skip = run_or_skip() + end_workflow = EmptyOperator(task_id="end_workflow", trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS) + + chain(run_or_skip, Label("Quota-restricted tasks skipped"), end_workflow) + chain( + run_or_skip, + provision_throughput, + await_provision_throughput, + delete_provision_throughput(provision_throughput.output), + end_workflow, + ) with DAG( @@ -157,23 +197,6 @@ def delete_provision_throughput(provisioned_model_id: str): ) # [END howto_operator_invoke_titan_model] - # [START howto_operator_provision_throughput] - provision_throughput = BedrockCreateProvisionedModelThroughputOperator( - task_id="provision_throughput", - model_units=1, - provisioned_model_name=provisioned_model_name, - model_id=f"{model_arn_prefix}{TITAN_MODEL_ID}", - ) - # [END howto_operator_provision_throughput] - provision_throughput.wait_for_completion = False - - # [START howto_sensor_provision_throughput] - await_provision_throughput = BedrockProvisionModelThroughputCompletedSensor( - task_id="await_provision_throughput", - model_id=provision_throughput.output, - ) - # [END howto_sensor_provision_throughput] - delete_bucket = S3DeleteBucketOperator( task_id="delete_bucket", trigger_rule=TriggerRule.ALL_DONE, @@ -189,10 +212,8 @@ def delete_provision_throughput(provisioned_model_id: str): # TEST BODY [invoke_llama_model, invoke_titan_model], customize_model_workflow(), - provision_throughput, - await_provision_throughput, + provision_throughput_workflow(), # TEST TEARDOWN - delete_provision_throughput(provision_throughput.output), delete_bucket, )