From ef2c43c1028eef95a65aa963bda17ea2cd83d977 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 27 Mar 2024 19:39:37 -0700 Subject: [PATCH 1/2] Introduce Amazon Bedrock service Includes the basic doc page, hook, operator, unit tests, and system test. --- airflow/providers/amazon/aws/hooks/bedrock.py | 39 ++++++++ .../providers/amazon/aws/operators/bedrock.py | 89 ++++++++++++++++++ airflow/providers/amazon/provider.yaml | 12 +++ .../operators/bedrock.rst | 72 ++++++++++++++ .../aws/Amazon-Bedrock_light-bg@4x.png | Bin 0 -> 12621 bytes .../amazon/aws/hooks/test_bedrock.py | 27 ++++++ .../amazon/aws/operators/test_bedrock.py | 59 ++++++++++++ .../providers/amazon/aws/example_bedrock.py | 76 +++++++++++++++ 8 files changed, 374 insertions(+) create mode 100644 airflow/providers/amazon/aws/hooks/bedrock.py create mode 100644 airflow/providers/amazon/aws/operators/bedrock.py create mode 100644 docs/apache-airflow-providers-amazon/operators/bedrock.rst create mode 100644 docs/integration-logos/aws/Amazon-Bedrock_light-bg@4x.png create mode 100644 tests/providers/amazon/aws/hooks/test_bedrock.py create mode 100644 tests/providers/amazon/aws/operators/test_bedrock.py create mode 100644 tests/system/providers/amazon/aws/example_bedrock.py diff --git a/airflow/providers/amazon/aws/hooks/bedrock.py b/airflow/providers/amazon/aws/hooks/bedrock.py new file mode 100644 index 0000000000000..11bacd9414598 --- /dev/null +++ b/airflow/providers/amazon/aws/hooks/bedrock.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class BedrockRuntimeHook(AwsBaseHook): + """ + Interact with the Amazon Bedrock Runtime. + + Provide thin wrapper around :external+boto3:py:class:`boto3.client("bedrock-runtime") `. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + client_type = "bedrock-runtime" + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = self.client_type + super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py new file mode 100644 index 0000000000000..3bb5fe4c89072 --- /dev/null +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.utils.helpers import prune_dict + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +DEFAULT_CONN_ID = "aws_default" + + +class BedrockInvokeModelOperator(AwsBaseOperator[BedrockRuntimeHook]): + """ + Invoke the specified Bedrock model to run inference using the input provided. + + Use InvokeModel to run inference for text models, image models, and embedding models. + To see the format and content of the input_data field for different models, refer to + `Inference parameters docs `_. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BedrockInvokeModelOperator` + + :param model_id: The ID of the Bedrock model. (templated) + :param input_data: Input data in the format specified in the content-type request header. (templated) + :param content_type: The MIME type of the input data in the request. (templated) Default: application/json + :param accept: The desired MIME type of the inference body in the response. + (templated) Default: application/json + + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + """ + + aws_hook_class = BedrockRuntimeHook + template_fields: Sequence[str] = ("model_id", "input_data", "content_type", "accept_type") + + def __init__( + self, + model_id: str, + input_data: dict[str, Any], + content_type: str | None = None, + accept_type: str | None = None, + **kwargs, + ): + self.model_id = model_id + self.input_data = input_data + self.content_type = content_type + self.accept_type = accept_type + super().__init__(**kwargs) + + def execute(self, context: Context) -> dict[str, str | int]: + # These are optional values which the API defaults to "application/json" if not provided here. + invoke_kwargs = prune_dict({"contentType": self.content_type, "accept": self.accept_type}) + + response = self.hook.conn.invoke_model( + body=json.dumps(self.input_data), + modelId=self.model_id, + **invoke_kwargs, + ) + + response_body = json.loads(response["body"].read()) + self.log.info("Bedrock %s prompt: %s", self.model_id, self.input_data) + self.log.info("Bedrock model response: %s", response_body) + return response_body diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index e2b0df930eff1..4c4f7cf5970ab 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -142,6 +142,12 @@ integrations: - /docs/apache-airflow-providers-amazon/operators/athena/athena_boto.rst - /docs/apache-airflow-providers-amazon/operators/athena/athena_sql.rst tags: [aws] + - integration-name: Amazon Bedrock + external-doc-url: https://aws.amazon.com/bedrock/ + logo: /integration-logos/aws/Amazon-Bedrock_light-bg@4x.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/bedrock.rst + tags: [aws] - integration-name: Amazon Chime external-doc-url: https://aws.amazon.com/chime/ logo: /integration-logos/aws/Amazon-Chime-light-bg.png @@ -363,6 +369,9 @@ operators: - integration-name: AWS Batch python-modules: - airflow.providers.amazon.aws.operators.batch + - integration-name: Amazon Bedrock + python-modules: + - airflow.providers.amazon.aws.operators.bedrock - integration-name: Amazon CloudFormation python-modules: - airflow.providers.amazon.aws.operators.cloud_formation @@ -514,6 +523,9 @@ hooks: python-modules: - airflow.providers.amazon.aws.hooks.athena - airflow.providers.amazon.aws.hooks.athena_sql + - integration-name: Amazon Bedrock + python-modules: + - airflow.providers.amazon.aws.hooks.bedrock - integration-name: Amazon Chime python-modules: - airflow.providers.amazon.aws.hooks.chime diff --git a/docs/apache-airflow-providers-amazon/operators/bedrock.rst b/docs/apache-airflow-providers-amazon/operators/bedrock.rst new file mode 100644 index 0000000000000..3e84cbc445357 --- /dev/null +++ b/docs/apache-airflow-providers-amazon/operators/bedrock.rst @@ -0,0 +1,72 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +============== +Amazon Bedrock +============== + +`Amazon Bedrock `__ is a fully managed service that +offers a choice of high-performing foundation models (FMs) from leading AI companies +like AI21 Labs, Anthropic, Cohere, Meta, Mistral AI, Stability AI, and Amazon via a +single API, along with a broad set of capabilities you need to build generative AI +applications with security, privacy, and responsible AI. + +Prerequisite Tasks +------------------ + +.. include:: ../_partials/prerequisite_tasks.rst + +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + +Operators +--------- + +.. _howto/operator:BedrockInvokeModelOperator: + +Invoke an existing Amazon Bedrock Model +======================================= + +To invoke an existing Amazon Bedrock model, you can use +:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockInvokeModelOperator`. + +Note that every model family has different input and output formats. +For example, to invoke a Meta Llama model you would use: + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_invoke_llama_model] + :end-before: [END howto_operator_invoke_llama_model] + +To invoke an Amazon Titan model you would use: + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_invoke_titan_model] + :end-before: [END howto_operator_invoke_titan_model] + +For details on the different formats, see `Inference parameters for foundation models `__ + + +Reference +--------- + +* `AWS boto3 library documentation for Amazon Bedrock `__ diff --git a/docs/integration-logos/aws/Amazon-Bedrock_light-bg@4x.png b/docs/integration-logos/aws/Amazon-Bedrock_light-bg@4x.png new file mode 100644 index 0000000000000000000000000000000000000000..e6af4b727609de237151a9d8cf3e94c7b059f599 GIT binary patch literal 12621 zcmeHug;!k7)8-6r!Gi?15Zn@6LvTxQcL~8A1{f^3zXS`OV8Pu74Z+>r2KV5;m+$xP z_wAps=jPohzpsed;Qwhc8{vP~28(AS{*U~uCp=u-0$k7?Wp!LYAR+AM z4|uRHCkO=MG?n`xq3#Jj%s_2WpPG7{*&LS;k6Sm|l|4p7pruy*uwH=I{5pxmEc`pJ z-sG(6Bd=`Kg+fU)o(S` zVq)W%soiCtsXJm%yl_d8=E6!~pI}DU^YGNXBebQiy z$H9&Ck*=WLkv_Lx-M$scqM5R2Nn`qQ#mjC^0wS@V?;zySimi+NX>RH59q8Gzx=tyq z%}1JRP$p!hCmOgJM+Drga@>G4Ixm~2RvnpXy;F{)A=0jJi|I1gx+dXY2VMl819h7{ z+iZuE%N&;1=Nq57yV5p;*kF9g655ptT~~JuBaDv%chaB(PS!FWw{osyptzMy!|SRI(-PlsTX0Y z=pd30FW07>teq@h3FXyD&XEp0dzq$D%=C^$U4FA#ZbV(b(k)?dhVq(T#L#weo5iBI4v!6j zG>gdUoJQDLn%AqLXmPB8H6X&rjM2Sfgn~wKshc&?5Us9(=W4p3exGSrOMiii(J32= z4NO@4Jc2#j)s8h=mv81Y{<{FjKpr)$*crobiBfscMAziO(nIgQFwG_5W&{(h!NRNU z-cn4fe&q*67EwSuV4=NSPr=>kCoS&Wrt|IHGLN)2O7#1O?_bIt4#pVe3xf7_lBRaM zZj4G3)biVyS?0)GBUo!3#e|7PYD!W#7j=^|9Rrqo##qjO7Rd^cAW(dvq~$@$Bu%-_ zQO8xy4X?58W)KIng8Uk46$LI`NtfT4h9@4{@)@r73IH zft?l7OKH!B<1VDUg(cWC@kxWfG+2`yA?2*Af)DFdVFt~b4hJs9^8Dy%i@>yaMGpV% zsKDI(hbY5&x!p36P1)w=;fF`=g-xs;@Z+j7);`svetc2h^; z6pFLYwMZBqdqm`+(T?CXMaK(Cm-M1OMY;HkrrbK&bXc?D>^hyxM0vAbWc@z=q#;_= zQAmnZ0UNRpMTJ1!*OTMTnOWDhN&kLWuIa?j20>PrS|Dsj_H6v6c?(|)Hvb9`ND377 zOD2+7P=OB)YqI44kcg(zf8d+pOTl&~jMv7~00dEfH0HM?F5 z^CJ+7C$rk@9$_7nRL=+iylZ-n{FA@$P){gVHmf+IN2ONbUy1O z{d@8kxs`(kqv5mSFgQR7f^k8USOkB8&S-Ug?NuoaLHeFZjzEvt+UA-46 z10D9d?%ZPew64fjH7+ozlK%blLkygsj!!DM4-UI8dFw;?-R}aI__dmcvesM(vO!y0 zoel)@t8S>oBqk2W{*V_G1n;$fS9041Nd9M%h?4kRBKZT@m_Mla*34FY<8J|Zdpp5X zryNN}aR1m-adJy$_NHlSBj_VBu{vy7f&zq+29wUx>JsGhZl1iwk&gFy#YemkQs&pV zOoo_6ux(t5+_ew814Ab^D%|&OVtW%bJrQp78cr-sV(4+v5sKA&?Z?sd72 zFxr87TddMtSu(GaH#MWU3$GQyHB!GpatsyjJ-noEH>D~Uy7q%{#y`yW25AZ-q$9FY z-pRskJGLR7=dPUV$s+?ubmHJB!?cPID7gOF(OTP9LN54!*k6d^#Dca#U1kH{nUE<1txSg~wsufh0X{J<&on=GuG4vdc$3!nnwI>l5ZS?YfcA5knI3j>DwWR}|k zpE@9nAe#!PUbyJ;lh*J(aCeZQoie-8uIC>Bsb15IgU`|9#(>@iffMy@{h_uqJ~cQC zuwkNLZb9@o|3*OEs-sX)Mxb9!0Cp~;xHKaepEDT{7wgRW07W0*6u3p|-Jusp*aF1w z3+}7}fT9Yp$Fe}A8A!72`he)N5W@f{p2+!o2lyhGpF5(@F6oimA1R6)0Y=Jotn-3d zqxQuieUu$7GRDWU++1-1E;9VE*IfxIzVy=Yc z9|K1Bn>9HuTk{N5eyR6waS(cJ5hx+tQ^ge+2u>HV`0b7x-R^JB7?z}BY|V;3(W#xQ z9vF?J(Vp=^*BDa8H?n80JY#9bVakXyyY}<76QB4|`juAqmztf$R_o(w#KD7KBAOD5 zdZ`4-*~Dz7B?43Eg8YX`eaeKJekj^Ar?DFyEM_K}UAI9FFJO(+kL$-L3J&(zaN>u$ zpP(a~pAnn#n?*?m@IQ4j=pmI7uaI`OhlKh`qjmeJh@Z5M40u)C50?nciM%G>mmi1! z4C#%p&A7N7NB0G=q<)*R?yc6NtDG_2^^o6!1uF^%0P3>wdxTd+1RX+6Y*J86VAX0y zH^C8W?*zrJRm0AL4845|pAKHwz!yCX6E^jDL&E8jtR0qNYOl)-GVt4QdO7~urGEw= zy==Kd~ zV|31eQPPkSjHK<$+`D(KqJ{6Jv^eBP&w`-dzQlvKv8>B29U2rQDM*U1KdN#=%-=pv z68ryH@e)0cpwz46tWBx3->yTH$3p$gnyEiq0xReyz{SV^GKo&nX+j}o8jPcRp#ch$F+gPDt()8@{Z5^++QG!4@XX1pG_M2-BIDTpC?(JQ) zV{mNW>eHC5k`RO~;qni9iIyKr#GM7PxgyU+8W|C4d3Gd78kak~+(4F~)*()e}ssu*9Hu4I|Rgu$@gxMtBWG{H(cW~fTJF0BW9Y6iYE zqm$wbGasCjm>6IENTd0z<3Cl_QUd894$Mo071+6fm-D8n(PPjY2c@^~+=qW|goygf z#$1>j>{bRYDy=j$-c))#+~*&fPu{t}r-OMctO72}h?{Doh{auZmq_k!pm{xNY+>!E z>TmeSi#VVjfx;(i-mP_;G;C@R-#095!Uz{ET8Ub>>GUFd5K2Z zOcO?liGqqTX;OEyrl>VG0_3-*naAI(pX6_Lq@?b-i<2{{K9Ej$NAYCo`AU}`qlh%- z&HT2aYaBJva2Lf)5+}%D%yeH&ZSCf$R?omLV+8Og6M#eixp89^CFAgSFZK6o-QQ|s zuru=Gb}r!AW)@P8lh*lU`E6-$T?KjyjGqO)Y0jvev66X1I>NQh*EGXvt$mR_ z%8Z3a?tLDvDi`|4sd5(6m`m9hdvDWG^|Fv$@1=T1Y78F&ovwQVn0>no2LHOoNu;@M z`iCQs@F&gcXbWUvCQvQGn*O6Lf+M&&#|}rPTl*QH{$P9+ZFKf&k3u?EdftY-1e1bz zS5EqneDEy#^OVwKYYQt8J9p(VSP;TOj?$S_*G3`m#;phSa&26e8Uv|pLhtc_2+^aM zY~ZEKYq=i2#~dy0K_#}Njyck&3Q4q;HPTM%l0@Hn!Nn!!Fs_j|{@C9!lm}*N9*+!H zBftB3=#1yF3CisUeGv*9qNc~C1TTs`@EkVkkRsXbOja^!u@rZSonjqnquDS9ZojqP zn5~>Ivh9XSgKMn82ONgN=Pi>4l&Q%Ki8I6w=E+dr%9|9VP6o%4WkGEjEx%IR-z~qm z)*O=OVyH5+mRlZ9^7E#RXer7t0!Ih>q79NY5g%`_^%x?Z3^vL~A?`s!hI> zZd>zm=~!MK&~;TlJn1ES(-WkS=23MH{y?}$W!ltj2-?OH)k<%wG3T}9I>wYs{uyGZ zYi5O3?RCZ8Z0tj+hLGSk^Cu)@fz2Z*O4VqJA&rlJ`31MQM>2Ixq6rQcy9)P`=`qTN z=DjL*He_6?as*FU>q@uPuM@fMn7EUgl`R#?{dkEbcAlSg<0NqAE`v7NMrE|W`Q->M ziV0p%f99A%!n&i2{(-l@W^~E}E5=pu>g?%V^NTpMJ<3djmh`rbxPHHL@}l7d3V1T) zADL8p-yaN-?HLMI1KsnFnv9*|qel~07z+#IJtn>_BQm(5g-rt3n5Yb^e;W8nkU#2i zVLbYf^sk4Nvc&Ar=0QdS^@5Ctr`8^QsRk4D&WXy)+tw;(*(BAvHuU?&*fJSB+>UJY zkO$Ep8x{N(95unxdrD4D+F)slr6-kYTC>^e!Mpc{HOf|!c_L3x(r&Bj~3B3qy!6n zx+gi2BsB|Md=2{pYPPK@N&Zr+=(eDj%uKC96EYJvA;`9)D5qq?7K#T)D9TYF=}p7c z^SvVtw&o_vqq4bHjgygR6YZ#$2Hk^@?~Nv6^G!OpisWohODA5&iR~5jS7|@Z8rzAy z!Gr1BT^Y^Y+*ndv^xYAq3u{M|Zf@-vFG>x)g)eZjPBX}!os=$9zR5wXX21au(3Lu= zKU+!4@NPZTM=!ZtZ{5vw5G8BQGtXiD;w$xDVOVNH} zSq7T2TfnG^14muGNvI5YdU{4&4plaRJZ0y-%$3!AMt(etPCbk#Y=$L^{c3VR5hv2` z2P}w2tq0DG-4=-+z&2m(EpdF@E;fIKWGaSw?W=)#XZ`1PMBgfG&7RmG=kcP2sVpSl z5h6|gTk4hlkb=(AR>#TnnMBwgr;D7>!&1*L#PLFMEUWdbsfkngDX3ADue(c_W?WoX zSu66B^tVK51UbNO!1^_{#tPY7luP<$Rd$3TfD>9%VR~R*OJWg~U{=iFC`uo|3be^e zE%6x92iJ0gvL|~{9Bs+7%y=k=bidIfk9pg#;RlsLPzJoFjE{fuwu-zGEM!7BG(UD3 z6iD^+u}Y_YEp~1T!Waz_q%dj2N|h) ze<+b*#OXb8YqOOM~C$IPwVbm>Cf=ds=_6l?!UwlV#|FZyW>R{Zj3Lk02FVg7Qi?Gk!IodW^%%5p!%avakPoExqHL1Zkx zYP8&#fz<@lpsRND`B&kt5qh%y;G$z-!LFTy5CCeq6L|@XT{tO=n^OS z4g$&$;kXwsBa1VZ8r=4{uN$U0nhh5&PQoWFTbLKl%`I6re6n9=tY%5+y+kNb0jnNX=IAzZ@cR zYM)dKL8Z_Rj-57QGsjp&`<^@4a*Rh~&%!BpW6J#vm0;UUo1x^)S-_@3e;)hWSL4}E zd%o6X$*(9%j(r7~|12m%W!w=NM+TYwQp$aYrGxw`Zm_xBYiv@$Q6y32bM1cg)U}B<$`&-R>#l9 z*dNO&Da2IVcf;9&IhzfQ%D7iwm-ATO3(fqL^_!76W^Y~DSnA%5=>7!y#K$ixiNv){ zoU~?o_UU*eUXOSfUW@8QO(urgpdU-PC;GTMIm)X-aR!aP=@7b5Uh0+{!Ssc4ip=nN zoHTlcw%lOMf8eP911rlX32F8_8aLYdY!Lg)yMd{j{j)&EFzPLnda!eza|Yz)Dm*oR z>dUPhbDdN4G1#(hFxD*|kvQJ0btJL_l_ly%f&(V{&99)0e7SCiqS0p8wxePRNeDH$ z!@H$WMC)OF9#L^-D%{Ulup}$X#=|zTL+y`j>N2_PU8Z;m67Ez{q2C2`s;YQ#v@SF| zbbFfmZR}onUmHsr&Az+63)zGJ$<~>fZSGf%3N_ zpVqt+v3S8*K;vw~12*mK-OyN-b8+>)?oz&k>GdqSaTDGu=phtaW*~TI9`kE&=}iWe zqI@|{*OK1C)j!WBd{gqIHmVu44rBf8V|@FN1(?{~o2}|}f<@Ha8u%nleg4|--R)_S z=FO=Q&0ZGBg_%u;qy4%wHqvl!(BvkFC1p0L?tGtvVl80Sba`{muq)T3_EpS!E@I9M z-+5`}-@gS^5$K49SRPMqgMCcO_D{Xpe9pVg1Zo9z1^Ze1+jfzbTepImD<;8I$5u1* z7n0>#LCjOCX$XLGP607|ZH~J@Y0P{7eh}w%-c0BkdwR{?qzEfn_2F>?#qlG~Nfhr&T4337ixv#fw|lvST7tsP7yHBgA)m z1R`0%OybtHytFWXm1G2Qwftjo0vsaNL+gyQjSI}taDf(~{TxTSb!`Taq+J#;HO zpyRc?A)&rP_J^8et3#vPt4&Ne@o}qlq6KXxgFI-0?G^mq7}LP1Zhk0!nCGs*nxzG* z>A|%K+k(|2c@S9XjU(gweG?PgB~F5t{2TASP<1p#bS>!He^9O@l;BsG@!K?+`%=v} zCkK5U$^~UfC%)ERf?Y46|Mh=-7u@!$?sa2j%@cK`{+wmH(9C6MIE z(q;9SwhPIA=9IdkM#n!Mp+z)W_u!^(_T8jjN@Mq*ZMnr|8=^+LHc(EXp@S8r zzx!uAiw4!c#rpNDb~lVT7lTk_=f~yh;T(ZB6~A!gSRWVYE@vaOGwY=jS7)D9#+Rv~ zt2bOLYM`rk@BVewO*d(8%tA#oA5g1T$P>11fu~-satyIu9-=h5^o~2e>T*M#>93gI zpL$)QV?kP*sxZSj@Egll76^7Va7;1|jNBVrI_IC?t`7Pvh+1u_;tUT?v|g)gc`!Ps zX;^Lu-Rt*ky{WG$j_iVCqt!tcwh*Bsk9ql_8LD&EDUU+PoGO_4+o(=qdm`*t#A%cJ zLd4!4wa^PUuAWkp2JgfB#o}iBiGJc``$T@Dmh9FiUGHj#G6RO05kEcA6HXOdtr%62 z?N}fTOejuvIh%O;J@Q6EP1Ez?U}?zRvQFaj`9_c{*G^9;6-`f3LHj3yOmuxDzTJ3R zH-)nmJw!&Ua?u~r=Y%vYlFM>@k2g*yD7o9Hu_RY1b~}aTC;72Oot65AF-4tw_Q^

_!Oi!JTkV`mR3w^4P<}3Cz7PcOdOc%ybsk%zj2Q1`uK7-3sd+&1Ed|0AF4kkgY?7thThYr)_AwWl zj3bl|P5ysb13A7+kl8JYvD?gu;}zpwNHbxZqY*9|9y@0hH@U4p3^og~end4sb-bj}Am3UK}(LOSQ zQ%>mRkuHVC%hNACMV<{!e?8s2c3mA|ANPalPUDFWg@B0(gcRiUPzk{UT0XBlQ|{NVkt&zk;-^cMfd=)yW%>hyN(*12Sal2xVy*eT?J+_B~rCa>9-Fl`H&0WyJ>m(Ebms#H+d|z6oX3w%wutiKY28UnXFPGhO zG2aDBez^t8HTnc{aS0!+CqI(TXTs3MMVofGI6rCQ;Fy?NJA#PVkjd2l$%#8T4$7j^ zH)xo*S+XS*w-J@{eRc7(C<23#CRMWC(R(56DnI||q@C-MzmGIDHGkx+>lZs(C@@va zArzFZ&hD3i0L*q`_AGolC(CSdT(tWmCs@?aMwZYeIqDejX`F)#qv-{czXZXHH!#wD zF&8%f#-pVdh5U*kNT$zGd%F1&57tx$mU>%fb-|LO*AIovGT9_yNL2uQulE12sihcb zB;fInh10l-MnflXT@t@n(qD2_*!aiEyVuo553Y^rp4w+rgrQwaRMcz4^sCQ& z|GTPA>4q=M@JW5+&GF;!5yOMQn}oZ3&8N)ZbjVKwB1%dp;iEYqZ!f?$H;aII6-3$KYx6~6`~}f@XLIZIf`%k zwtaQzJz%0yM4ZP87huO4T``ytOdJ&(G07>!qHU;*g3kGm9zp z`#3(f`R6g%X9Xr5t#46S?4oZEq(0U@*DV~NA5#&^$* zG}*`UoGCNWNDx58wMTzk*sLq^ufF=G-~8RpeBQ~bze@(YW{Jq%Xa-y-!mJXE#-us} zWIW_YA>$OAIJEYL1vO|+U##A0#@?_Z0NsQ6I;#lS%r{v<%_ z)KC`onA0^D(v=gB0uu>fCmaHIt^&94y5%N!HI003} z0*l7t!mD1}Km*$qU?v1NpfKPPogZ%8;h!xz~oJxoh zQOZ1ipbMbVQOAlTSJ~EvJF1FKKH$T0n^$tTEk^1MQDT&x(RmfQB6S8^g7QdKCg1q8 zFSxJIPnDUb&sXHJl0u;UcznL)p~XlDSm}mvewGp1Pz}wWT5O5DyH$K$KtXHd{UQ&E#(t!dPO z_p&cUu#s(E|1yIg0321VP@RQX%>D5|?Y$;%g9JfwPtpj=ukzf)(%HQOL#yJq*R0Q+ zW3taFRZ-7ZD|o3}RznoixNNKfL^~<_o{$YAH}K~Vv$1gZ((B%=?!Kkic^a$6FV#b2Ui(&yn9_m5pD~lLs6o9$GYcixkA1Hfo}t*Z8bt)PHu! z`JCuNZK&&hXDSm7`Oms2Y-m%+vXG8)edta?^e-vo3 zRbpx8=>J;9L3za0c@cVwWr9oORB}K%MBp{?Ln%Ag{9Kmjp7h6%li%*@orgp6BLW`@ z2MXT9khj_uRX3bI`cjba&S;51FCX0AhV0RviPAnl#H!lsw%JJn)@3xOy zI7fP^-@G$jXEU;PbhYwQKP-mQQnZV<=Q(lDLD$XYJWGg&;+~P>=8>GDwQU1(otYZ5 zG7^j{{_XT>ztV7Kw~sftLdflOr(V!cKJqzfDa`IGGPW(*omVw>pnzC?2xafoJedzo z&*&>50m?nz6Mq7|^n5WjVgYuj$9wv%8w>{jMAzQi=b%u+(6Vs>)>&}czs_Ocz=8D` zf%iaier#G-jJ3D)Gwqo{w{}_!9lU&*RZ^-$@ZW@Ek*+kDu{~bPNSgMsyTG~MHI1}b zq^W`xmhf|lOohlUn9pdbg3}d%cz<& zP~BtcfhN4xuwzXip*^6N+sb=g=q4pPIjVbH|3^R-H6wijKKR%Nz?o5~_r)szBJbtF zUS8{>W!>iyI7{gE%Mgu2>JJib&Cj$|ZV;2$ zdVzhV-t-G%GneSXNNlSrDmyk%a&iV~yiVomv4D1u(Br-VR#i5CkkVlJ!r%%JkUZ|s z$Gpc2s;dB+ALcke=vdw>|Ll=zBq;3nk3TVRUU5ufkQ7ANdvwu1)Rs0!CKKdw(TAXk zmM{vK6(06<5{Z0!!BWiO*CgylNZdW5V#8HN0BA9Oqs7gyq-=w#Y`1bRZiJ6p2nP5~ zwJ>pacO$PBy61dw7@v=IdpvyjcKyGQtyq-&3)6*>bLFYq+!w38qx?u?9qGXOW<5W;d8<*7!tl;^AX5DGAT<>SfcJq zr?emM3g=B(*f^PfArkT;4G7lSb|TBJlrJcbN`2jNQs($vq{snal6KBQA|@psSehMXWfdidm>VM9==);d6d05k35D=^Z8qqO{v@6o9_N=6d--&)$C_hZXipW z6UmcNmXkac14v~Fcz$-?SCNVc^tf2SJNo~{|IfdZM;tCRCen8F*Zt2wMUazH{!sSb HIPiY~jRm45 literal 0 HcmV?d00001 diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py b/tests/providers/amazon/aws/hooks/test_bedrock.py new file mode 100644 index 0000000000000..73612aacbc86d --- /dev/null +++ b/tests/providers/amazon/aws/hooks/test_bedrock.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook + + +class TestBedrockRuntimeHook: + def test_conn_returns_a_boto3_connection(self): + hook = BedrockRuntimeHook() + + assert hook.conn is not None + assert hook.conn.meta.service_model.service_name == "bedrock-runtime" diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py b/tests/providers/amazon/aws/operators/test_bedrock.py new file mode 100644 index 0000000000000..f6274de48f0b0 --- /dev/null +++ b/tests/providers/amazon/aws/operators/test_bedrock.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +from typing import Generator +from unittest import mock + +import pytest +from moto import mock_aws + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook +from airflow.providers.amazon.aws.operators.bedrock import BedrockInvokeModelOperator + +MODEL_ID = "meta.llama2-13b-chat-v1" +PROMPT = "A very important question." +GENERATED_RESPONSE = "An important answer." +MOCK_RESPONSE = json.dumps( + { + "generation": GENERATED_RESPONSE, + "prompt_token_count": len(PROMPT), + "generation_token_count": len(GENERATED_RESPONSE), + "stop_reason": "stop", + } +) + + +@pytest.fixture +def runtime_hook() -> Generator[BedrockRuntimeHook, None, None]: + with mock_aws(): + yield BedrockRuntimeHook(aws_conn_id="aws_default") + + +class TestBedrockInvokeModelOperator: + @mock.patch.object(BedrockRuntimeHook, "conn") + def test_invoke_model_prompt_good_combinations(self, mock_conn): + mock_conn.invoke_model.return_value["body"].read.return_value = MOCK_RESPONSE + operator = BedrockInvokeModelOperator( + task_id="test_task", model_id=MODEL_ID, input_data={"input_data": {"prompt": PROMPT}} + ) + + response = operator.execute({}) + + assert response["generation"] == GENERATED_RESPONSE diff --git a/tests/system/providers/amazon/aws/example_bedrock.py b/tests/system/providers/amazon/aws/example_bedrock.py new file mode 100644 index 0000000000000..e86e5a2e92b9d --- /dev/null +++ b/tests/system/providers/amazon/aws/example_bedrock.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow.models.baseoperator import chain +from airflow.models.dag import DAG +from airflow.providers.amazon.aws.operators.bedrock import BedrockInvokeModelOperator +from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder + +sys_test_context_task = SystemTestContextBuilder().build() + +DAG_ID = "example_bedrock" +PROMPT = "What color is an orange?" + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + tags=["example"], + catchup=False, +) as dag: + test_context = sys_test_context_task() + env_id = test_context["ENV_ID"] + + # [START howto_operator_invoke_llama_model] + invoke_llama_model = BedrockInvokeModelOperator( + task_id="invoke_llama", + model_id="meta.llama2-13b-chat-v1", + input_data={"prompt": PROMPT}, + ) + # [END howto_operator_invoke_llama_model] + + # [START howto_operator_invoke_titan_model] + invoke_titan_model = BedrockInvokeModelOperator( + task_id="invoke_titan", + model_id="amazon.titan-text-express-v1", + input_data={"inputText": PROMPT}, + ) + # [END howto_operator_invoke_titan_model] + + chain( + # TEST SETUP + test_context, + # TEST BODY + invoke_llama_model, + invoke_titan_model, + # TEST TEARDOWN + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) From c06377b278a67445fa66ac8ee01d6d4c72381747 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 29 Mar 2024 23:54:29 -0700 Subject: [PATCH 2/2] taragolis fix list --- airflow/providers/amazon/aws/operators/bedrock.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py index 3bb5fe4c89072..d8eaf9e5d3c23 100644 --- a/airflow/providers/amazon/aws/operators/bedrock.py +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -21,15 +21,13 @@ from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.utils.helpers import prune_dict if TYPE_CHECKING: from airflow.utils.context import Context -DEFAULT_CONN_ID = "aws_default" - - class BedrockInvokeModelOperator(AwsBaseOperator[BedrockRuntimeHook]): """ Invoke the specified Bedrock model to run inference using the input provided. @@ -54,10 +52,16 @@ class BedrockInvokeModelOperator(AwsBaseOperator[BedrockRuntimeHook]): empty, then default boto3 configuration would be used (and must be maintained on each worker node). :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ aws_hook_class = BedrockRuntimeHook - template_fields: Sequence[str] = ("model_id", "input_data", "content_type", "accept_type") + template_fields: Sequence[str] = aws_template_fields( + "model_id", "input_data", "content_type", "accept_type" + ) def __init__( self, @@ -67,11 +71,11 @@ def __init__( accept_type: str | None = None, **kwargs, ): + super().__init__(**kwargs) self.model_id = model_id self.input_data = input_data self.content_type = content_type self.accept_type = accept_type - super().__init__(**kwargs) def execute(self, context: Context) -> dict[str, str | int]: # These are optional values which the API defaults to "application/json" if not provided here.