diff --git a/docs/source/apis.ipynb b/docs/source/apis.ipynb index d0eb914..dd79525 100644 --- a/docs/source/apis.ipynb +++ b/docs/source/apis.ipynb @@ -15,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "id": "0b4597b2-2a43-4491-8830-bf9f79428074", "metadata": { "nbsphinx": "hidden", @@ -35,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "c719e4fc-3ccf-4633-a787-b2fe0d1eac65", "metadata": { "tags": [] @@ -48,36 +48,17 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "f1313c02-d415-4ce6-bff0-3df537cc06c2", "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING! frequency_penalty is not default parameter.\n", - " frequency_penalty was transferred to model_kwargs.\n", - " Please confirm that frequency_penalty is what you intended.\n", - "WARNING! presence_penalty is not default parameter.\n", - " presence_penalty was transferred to model_kwargs.\n", - " Please confirm that presence_penalty is what you intended.\n", - "WARNING! top_p is not default parameter.\n", - " top_p was transferred to model_kwargs.\n", - " Please confirm that top_p is what you intended.\n" - ] - } - ], + "outputs": [], "source": [ "llm = ChatOpenAI(\n", " model_name=\"gpt-3.5-turbo\",\n", " temperature=0,\n", " max_tokens=2000,\n", - " frequency_penalty=0,\n", - " presence_penalty=0,\n", - " top_p=1.0,\n", ")" ] }, @@ -99,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "f61c94db-c05d-43ba-9ffc-b58552c715c3", "metadata": { "tags": [] @@ -156,7 +137,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "ff9ad27f-7a81-4123-8d0b-1e14802df67e", "metadata": { "tags": [] @@ -176,7 +157,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "760baa5f-9368-4b5a-abc0-6ac65c34b7a7", "metadata": { "tags": [] @@ -188,7 +169,7 @@ "{'player': {'action': 'stop'}}" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -199,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "462303c0-e83a-4e39-86cd-cab6875b40ef", "metadata": { "tags": [] @@ -211,7 +192,7 @@ "{'player': {'action': 'play'}}" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -222,7 +203,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "02c7f1e5-1c8d-4e9f-82e6-c37a41d6de14", "metadata": { "tags": [] @@ -234,7 +215,7 @@ "{'player': {'album': ['the lion king soundtrack']}}" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -245,7 +226,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "7a6d918c-53fe-426b-b37e-eec2abb8a704", "metadata": { "tags": [] @@ -260,7 +241,7 @@ " 'validated_data': {}}" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -271,7 +252,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "b18acf0a-d99e-48de-ace5-fb01bded5a41", "metadata": { "tags": [] @@ -283,7 +264,7 @@ "{'player': {'action': 'previous'}}" ] }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -304,7 +285,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "c50b080b-7179-4bbe-b234-83ce59e2d215", "metadata": { "tags": [] @@ -351,7 +332,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "404e4f1a-d316-41f2-ab94-040e22001fc4", "metadata": { "tags": [] @@ -363,7 +344,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "73c31ace-32dd-4a33-ae39-475db6934f6d", "metadata": { "tags": [] @@ -373,24 +354,22 @@ "data": { "text/plain": [ "{'action': {'sport': 'baseball',\n", - " 'location': 'LA',\n", + " 'location': 'LA area',\n", " 'price_range': {'price_max': '100', 'currency': '$'}}}" ] }, - "execution_count": 14, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "chain.run(\n", - " \"I want to buy tickets for a baseball game in LA area under $100\"\n", - ")[\"data\"]" + "chain.run(\"I want to buy tickets for a baseball game in LA area under $100\")[\"data\"]" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 14, "id": "78e3b3af-bfa8-4503-854a-b83a7f8f49e6", "metadata": { "tags": [] @@ -404,7 +383,7 @@ " 'price_range': {'price_min': '20', 'price_max': '40', 'currency': '$'}}}" ] }, - "execution_count": 16, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -431,7 +410,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 15, "id": "2b0bcf09-a3ae-4a8a-9ce3-f86834ce6ca2", "metadata": { "tags": [] @@ -594,7 +573,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 16, "id": "7b389d20-ae6b-4764-9209-3cd3c2f0a715", "metadata": { "tags": [] @@ -614,7 +593,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 17, "id": "b203ac4a-4f9f-45c6-a509-39b9a6cfd98f", "metadata": { "tags": [] @@ -626,7 +605,7 @@ "{}" ] }, - "execution_count": 19, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -641,7 +620,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 18, "id": "398377bf-5d30-4b4c-b637-e9af969d16a4", "metadata": { "tags": [] @@ -657,7 +636,7 @@ " 'attribute_selection': ['revenue', 'eps']}}" ] }, - "execution_count": 20, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -672,7 +651,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 19, "id": "2a620246-4c85-4256-8f58-0acbcc9455a3", "metadata": { "tags": [] @@ -686,7 +665,7 @@ " 'value': ['red', 'blue']}]}}" ] }, - "execution_count": 21, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -698,7 +677,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 20, "id": "4745517e-507e-4d1a-97e0-d143fa34cea2", "metadata": { "tags": [] @@ -713,7 +692,7 @@ " 'attribute_selection': ['revenue']}}" ] }, - "execution_count": 22, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -725,7 +704,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 25, "id": "b206407f-57e0-4212-8e75-970cb49b52e5", "metadata": { "tags": [] @@ -734,10 +713,14 @@ { "data": { "text/plain": [ - "{}" + "{'search_for_companies': {'attribute_filter': [{'attribute': 'market cap',\n", + " 'op': '>',\n", + " 'value': '1000000'},\n", + " {'attribute': 'building color', 'op': 'in', 'value': ['red', 'blue']}],\n", + " 'attribute_selection': ['revenue', 'eps']}}" ] }, - "execution_count": 23, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -745,14 +728,14 @@ "source": [ "text = (\n", " \"revenue, eps of indian companies that have market cap of over 1 million, \"\n", - " \"but less than 50 employees and own red and blue buildings\"\n", + " \"that own red and blue buildings\"\n", ")\n", "chain.run(text)[\"data\"]" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 22, "id": "a1025f99-eb0a-4d96-923e-35f36e4ac6b2", "metadata": { "tags": [] @@ -867,7 +850,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/docs/source/nested_objects.ipynb b/docs/source/nested_objects.ipynb index fa1556d..3de2aba 100644 --- a/docs/source/nested_objects.ipynb +++ b/docs/source/nested_objects.ipynb @@ -54,31 +54,12 @@ "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING! frequency_penalty is not default parameter.\n", - " frequency_penalty was transferred to model_kwargs.\n", - " Please confirm that frequency_penalty is what you intended.\n", - "WARNING! presence_penalty is not default parameter.\n", - " presence_penalty was transferred to model_kwargs.\n", - " Please confirm that presence_penalty is what you intended.\n", - "WARNING! top_p is not default parameter.\n", - " top_p was transferred to model_kwargs.\n", - " Please confirm that top_p is what you intended.\n" - ] - } - ], + "outputs": [], "source": [ "llm = ChatOpenAI(\n", " model_name=\"gpt-3.5-turbo\",\n", " temperature=0,\n", " max_tokens=2000,\n", - " frequency_penalty=0,\n", - " presence_penalty=0,\n", - " top_p=1.0,\n", ")" ] }, @@ -182,17 +163,11 @@ "data": { "text/plain": [ "{'information': [{'person_name': 'Alice Doe',\n", - " 'to_address': {'street': '100 Main St',\n", - " 'city': 'Boston',\n", - " 'state': 'MA',\n", - " 'zipcode': '23232',\n", - " 'country': 'USA'}},\n", + " 'from_address': {'city': 'New York'},\n", + " 'to_address': {'city': 'Boston', 'state': 'MA'}},\n", " {'person_name': 'Bob Smith',\n", - " 'to_address': {'street': '100 Main St',\n", - " 'city': 'New York',\n", - " 'state': 'NY',\n", - " 'zipcode': '10001',\n", - " 'country': 'USA'}}]}" + " 'from_address': {'city': 'Boston', 'state': 'MA'},\n", + " 'to_address': {'city': 'New York'}}]}" ] }, "execution_count": 6, @@ -218,11 +193,17 @@ "data": { "text/plain": [ "{'information': [{'person_name': 'Alice Doe',\n", + " 'from_address': {'city': 'New York'},\n", + " 'to_address': {'city': 'Boston'}},\n", + " {'person_name': 'Bob Smith',\n", + " 'from_address': {'city': 'New York'},\n", " 'to_address': {'city': 'Boston'}},\n", - " {'person_name': 'Bob Smith', 'to_address': {'city': 'Boston'}},\n", " {'person_name': 'Andrew', 'to_address': {'city': 'Boston'}},\n", " {'person_name': 'Joana', 'to_address': {'city': 'Boston'}},\n", - " {'person_name': 'Paul', 'to_address': {'city': 'Boston'}}]}" + " {'person_name': 'Paul', 'to_address': {'city': 'Boston'}},\n", + " {'person_name': 'Betty',\n", + " 'from_address': {'city': 'Boston'},\n", + " 'to_address': {'city': 'New York'}}]}" ] }, "execution_count": 7, @@ -330,10 +311,11 @@ "data": { "text/plain": [ "{'information': [{'person_name': 'Alice Doe',\n", - " 'to_address': [{'street': 'New York'}]},\n", - " {'person_name': 'Bob Smith', 'to_address': [{'street': 'New York'}]},\n", - " {'person_name': 'Bob Smith', 'to_address': [{'street': 'Boston'}]},\n", - " {'person_name': 'Bob Smith', 'to_address': [{'street': 'LA'}]}]}" + " 'from_address': [{'city': 'New York'}],\n", + " 'to_address': [{'city': 'Boston'}]},\n", + " {'person_name': 'Bob Smith',\n", + " 'from_address': [{'city': 'New York'}],\n", + " 'to_address': [{'city': 'Boston'}, {'city': 'LA'}]}]}" ] }, "execution_count": 10, @@ -346,14 +328,6 @@ " \"Alice Doe and Bob Smith moved from New York to Boston. Bob later moved to LA.\"\n", ")[\"data\"]" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2710486b-05c5-45b6-ab91-d990da92983f", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -372,7 +346,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/docs/source/objects.ipynb b/docs/source/objects.ipynb index a089167..8c0ad3b 100644 --- a/docs/source/objects.ipynb +++ b/docs/source/objects.ipynb @@ -52,31 +52,12 @@ "metadata": { "tags": [] }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING! frequency_penalty is not default parameter.\n", - " frequency_penalty was transferred to model_kwargs.\n", - " Please confirm that frequency_penalty is what you intended.\n", - "WARNING! presence_penalty is not default parameter.\n", - " presence_penalty was transferred to model_kwargs.\n", - " Please confirm that presence_penalty is what you intended.\n", - "WARNING! top_p is not default parameter.\n", - " top_p was transferred to model_kwargs.\n", - " Please confirm that top_p is what you intended.\n" - ] - } - ], + "outputs": [], "source": [ "llm = ChatOpenAI(\n", " model_name=\"gpt-3.5-turbo\",\n", " temperature=0,\n", " max_tokens=2000,\n", - " frequency_penalty=0,\n", - " presence_penalty=0,\n", - " top_p=1.0,\n", ")" ] }, @@ -249,8 +230,8 @@ "chain = create_extraction_chain(llm, schema)\n", "print(\n", " chain.run(\n", - " \"My name is Bob Alice and my phone number is (123)-444-9999. I found my true love one\"\n", - " \" on a blue sunday. Her number was (333)1232832. Her name was Moana Sunrise and she was 10 years old.\"\n", + " \"My name is Bob Alice and my phone number is (123)-444-9999. I found my true love one\"\n", + " \" on a blue sunday. Her number was (333)1232832. Her name was Moana Sunrise and she was 10 years old.\"\n", " )[\"data\"]\n", ")" ] @@ -284,8 +265,8 @@ ], "source": [ "chain.run(\n", - " \"My phone number is (123)-444-9999. I found my true love one on a blue sunday.\"\n", - " \" Her number was (333)1232832\"\n", + " \"My phone number is (123)-444-9999. I found my true love one on a blue sunday.\"\n", + " \" Her number was (333)1232832\"\n", ")[\"data\"]" ] }, @@ -345,7 +326,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "5c694d79-e72c-4712-b891-111bc0279032", "metadata": { "tags": [] @@ -358,7 +339,7 @@ " {'first_name': 'Moana', 'last_name': 'Sunrise', 'age': '10'}]}" ] }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -366,8 +347,8 @@ "source": [ "chain = create_extraction_chain(llm, schema)\n", "chain.run(\n", - " \"My name is Bob Alice and my phone number is (123)-444-9999. I found my true love one\"\n", - " \" on a blue sunday. Her number was (333)1232832. Her name was Moana Sunrise and she was 10 years old.\"\n", + " \"My name is Bob Alice and my phone number is (123)-444-9999. I found my true love one\"\n", + " \" on a blue sunday. Her number was (333)1232832. Her name was Moana Sunrise and she was 10 years old.\"\n", ")[\"data\"]" ] }, @@ -381,7 +362,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "a2944e8c-4630-4b29-b505-b2ca6fceba01", "metadata": { "tags": [] @@ -439,7 +420,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/docs/source/prompt.ipynb b/docs/source/prompt.ipynb index a1f82c1..51afe95 100644 --- a/docs/source/prompt.ipynb +++ b/docs/source/prompt.ipynb @@ -144,7 +144,7 @@ " template=(\n", " \"[Pep talk for your LLM goes here]\\n\\n\"\n", " \"Add some type description\\n\\n\"\n", - " \"{type_description}\\n\\n\" # Can comment out\n", + " \"{type_description}\\n\\n\" # Can comment out\n", " \"Add some format instructions\\n\\n\"\n", " \"{format_instructions}\\n\"\n", " \"Suffix heren\\n\"\n", @@ -154,7 +154,7 @@ "\n", "chain = create_extraction_chain(llm, schema, instruction_template=instruction_template)\n", "\n", - "print(chain.prompt.format_prompt(text='hello').to_string())" + "print(chain.prompt.format_prompt(text=\"hello\").to_string())" ] }, { @@ -200,9 +200,8 @@ "source": [ "class CatType(TypeDescriptor):\n", " def describe(self, node: Object) -> str:\n", - " \"\"\"Describe the schema of the node.\"\"\" \n", - " return f\"A 😼 ate the schema of {type(node)} 😼\"\n", - " " + " \"\"\"Describe the schema of the node.\"\"\"\n", + " return f\"A 😼 ate the schema of {type(node)} 😼\"" ] }, { @@ -242,7 +241,7 @@ " template=(\n", " \"[Pep talk for your LLM goes here]\\n\\n\"\n", " \"Add some type description\\n\\n\"\n", - " \"{type_description}\\n\\n\" # Can comment out\n", + " \"{type_description}\\n\\n\" # Can comment out\n", " \"Add some format instructions\\n\\n\"\n", " \"{format_instructions}\\n\"\n", " \"Suffix heren\\n\"\n", @@ -250,9 +249,15 @@ ")\n", "\n", "\n", - "chain = create_extraction_chain(llm, schema, instruction_template=instruction_template, encoder_or_encoder_class=CatEncoder, type_descriptor=CatType())\n", + "chain = create_extraction_chain(\n", + " llm,\n", + " schema,\n", + " instruction_template=instruction_template,\n", + " encoder_or_encoder_class=CatEncoder,\n", + " type_descriptor=CatType(),\n", + ")\n", "\n", - "print(chain.prompt.format_prompt(text='hello').to_string())" + "print(chain.prompt.format_prompt(text=\"hello\").to_string())" ] } ], @@ -272,7 +277,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.1" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/docs/source/schema_serialization.ipynb b/docs/source/schema_serialization.ipynb index a199f3b..0ac6e44 100644 --- a/docs/source/schema_serialization.ipynb +++ b/docs/source/schema_serialization.ipynb @@ -7,7 +7,9 @@ "source": [ "# Schema serialization\n", "\n", - "A Kor schema can be serialized and deserialzed to JSON. This lets you store the schema outside of the code." + "A Kor schema can be serialized and deserialzed to JSON. This lets you store the schema outside of the code.\n", + "\n", + "**ATTENTION** This only works with pydantic v1 at the moment." ] }, { @@ -214,7 +216,7 @@ " model_name=\"gpt-3.5-turbo\",\n", " temperature=0,\n", " max_tokens=2000,\n", - " model_kwargs={\"frequency_penalty\":0,\"presence_penalty\":0, \"top_p\": 1.0}\n", + " model_kwargs={\"frequency_penalty\": 0, \"presence_penalty\": 0, \"top_p\": 1.0},\n", ")" ] }, @@ -273,8 +275,8 @@ "chain = create_extraction_chain(llm, schema)\n", "print(\n", " chain.run(\n", - " \"My name is Bob Alice and my phone number is (123)-444-9999. I found my true love one\"\n", - " \" on a blue sunday. Her number was (333)1232832. Her name was Moana Sunrise and she was 10 years old.\"\n", + " \"My name is Bob Alice and my phone number is (123)-444-9999. I found my true love one\"\n", + " \" on a blue sunday. Her number was (333)1232832. Her name was Moana Sunrise and she was 10 years old.\"\n", " )[\"data\"]\n", ")" ] @@ -296,7 +298,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/docs/source/tutorial.ipynb b/docs/source/tutorial.ipynb index 47c79c6..9094bac 100644 --- a/docs/source/tutorial.ipynb +++ b/docs/source/tutorial.ipynb @@ -139,9 +139,6 @@ " model_name=\"gpt-3.5-turbo\",\n", " temperature=0,\n", " max_tokens=2000,\n", - " frequency_penalty=0,\n", - " presence_penalty=0,\n", - " top_p=1.0,\n", ")" ] }, @@ -272,6 +269,8 @@ "Please output the extracted information in CSV format in Excel dialect. Please use a | as the delimiter. \n", " Do NOT add any clarifying information. Output MUST follow the schema above. Do NOT add any additional columns that do not appear in the schema.\n", "\n", + "\n", + "\n", "Input: Alice and Bob are friends\n", "Output: first_name\n", "Alice\n", @@ -387,7 +386,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.1" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/docs/source/type_descriptors.ipynb b/docs/source/type_descriptors.ipynb index 905eb30..9c8dbd4 100644 --- a/docs/source/type_descriptors.ipynb +++ b/docs/source/type_descriptors.ipynb @@ -338,7 +338,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.1" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/docs/source/untyped_objects.ipynb b/docs/source/untyped_objects.ipynb index f3f14a7..7af97ba 100644 --- a/docs/source/untyped_objects.ipynb +++ b/docs/source/untyped_objects.ipynb @@ -132,8 +132,8 @@ ], "source": [ "chain.run(\n", - " \"Alice Doe and Bob Smith moved from New York to Boston. Andrew was 12 years\"\n", - " \" old. He also moved to Boston. So did Joana and Paul. Betty did the opposite.\"\n", + " \"Alice Doe and Bob Smith moved from New York to Boston. Andrew was 12 years\"\n", + " \" old. He also moved to Boston. So did Joana and Paul. Betty did the opposite.\"\n", ")[\"data\"]" ] }, diff --git a/docs/source/validation.ipynb b/docs/source/validation.ipynb index d24a713..451667d 100644 --- a/docs/source/validation.ipynb +++ b/docs/source/validation.ipynb @@ -94,16 +94,18 @@ "\n", "class MusicRequest(BaseModel):\n", " song: Optional[List[str]] = Field(\n", - " description=\"The song(s) that the user would like to be played.\"\n", + " default=None, description=\"The song(s) that the user would like to be played.\"\n", " )\n", " album: Optional[List[str]] = Field(\n", - " description=\"The album(s) that the user would like to be played.\"\n", + " default=None, description=\"The album(s) that the user would like to be played.\"\n", " )\n", " artist: Optional[List[str]] = Field(\n", + " default=None,\n", " description=\"The artist(s) whose music the user would like to hear.\",\n", " examples=[(\"Songs by paul simon\", \"paul simon\")],\n", " )\n", " action: Optional[Action] = Field(\n", + " default=None,\n", " description=\"The action that should be taken; one of `play`, `stop`, `next`, `previous`\",\n", " examples=[\n", " (\"Please stop the music\", \"stop\"),\n", @@ -247,9 +249,7 @@ } ], "source": [ - "chain.run(\"i want to hear yellow submarine by the beatles\")[\n", - " \"validated_data\"\n", - "]" + "chain.run(\"i want to hear yellow submarine by the beatles\")[\"validated_data\"]" ] }, { @@ -318,9 +318,7 @@ } ], "source": [ - "chain.run(\"play songs by paul simon and led zeppelin and the doors\")[\n", - " \"validated_data\"\n", - "]" + "chain.run(\"play songs by paul simon and led zeppelin and the doors\")[\"validated_data\"]" ] }, { @@ -343,9 +341,7 @@ } ], "source": [ - "chain.run(\"could you play the previous song again?\")[\n", - " \"validated_data\"\n", - "]" + "chain.run(\"could you play the previous song again?\")[\"validated_data\"]" ] }, { @@ -416,13 +412,15 @@ " description=\"The song(s) that the user would like to be played.\"\n", " ) # <-- Note this is NOT Optional\n", " album: Optional[List[str]] = Field(\n", - " description=\"The album(s) that the user would like to be played.\"\n", + " default=None, description=\"The album(s) that the user would like to be played.\"\n", " )\n", " artist: Optional[List[str]] = Field(\n", + " default=None,\n", " description=\"The artist(s) whose music the user would like to hear.\",\n", " examples=[(\"Songs by paul simon\", \"paul simon\")],\n", " )\n", " action: Optional[Action] = Field(\n", + " default=None,\n", " description=\"The action that should be taken; one of `play`, `stop`, `next`, `previous`\",\n", " examples=[\n", " (\"Please stop the music\", \"stop\"),\n", @@ -468,22 +466,32 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "id": "d594eb6e-02a0-4774-9dca-421db192372d", "metadata": { "tags": [] }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Error in LangChainTracer.on_chain_end callback: No constructor defined\n" + ] + }, { "data": { "text/plain": [ "{'data': {'player': {'action': 'stop'}},\n", " 'raw': '{\"player\": {\"action\": \"stop\"}}',\n", - " 'errors': [ValidationError(model='Player', errors=[{'loc': ('song',), 'msg': 'field required', 'type': 'value_error.missing'}])],\n", + " 'errors': [1 validation error for Player\n", + " song\n", + " Field required [type=missing, input_value={'action': 'stop'}, input_type=dict]\n", + " For further information visit https://errors.pydantic.dev/2.3/v/missing],\n", " 'validated_data': None}" ] }, - "execution_count": 18, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -494,7 +502,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "id": "672f9908-c6b7-47cf-8c82-03b0e5b8fa84", "metadata": { "tags": [] @@ -506,15 +514,13 @@ "Player(song=['yellow submarine'], album=None, artist=['the beatles'], action=None)" ] }, - "execution_count": 19, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "chain.run(\"i want to hear yellow submarine by the beatles\")[\n", - " \"validated_data\"\n", - "]" + "chain.run(\"i want to hear yellow submarine by the beatles\")[\"validated_data\"]" ] }, { @@ -543,7 +549,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "id": "cd9dffbe-82bf-4d1f-9b0a-3ec2c58b63d6", "metadata": { "tags": [] @@ -557,7 +563,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "id": "f859b6e1-c2d8-48e0-af17-2ffb286bffe9", "metadata": { "tags": [] @@ -577,7 +583,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 24, "id": "af6c6339-81db-482b-9507-31f41d2fa489", "metadata": { "tags": [] @@ -618,7 +624,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 25, "id": "837a08c2-8de0-448a-9984-0cf19a73d4a3", "metadata": { "tags": [] @@ -630,15 +636,13 @@ "[Person(name='john', age=13), Person(name='maria', age=24)]" ] }, - "execution_count": 23, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "chain.run(\"john is 13 years old. maria is 24 years old\")[\n", - " \"validated_data\"\n", - "]" + "chain.run(\"john is 13 years old. maria is 24 years old\")[\"validated_data\"]" ] }, { @@ -655,7 +659,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 26, "id": "ee92ae5e-52e9-4405-9718-be71d25ce412", "metadata": { "tags": [] @@ -684,7 +688,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 27, "id": "7464aab0-45fb-4e22-bed4-b695c7f60629", "metadata": { "tags": [] @@ -699,7 +703,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 28, "id": "8becd13d-bd23-4d37-8fd3-7548a7fe51c1", "metadata": { "tags": [] @@ -715,7 +719,7 @@ " 'validated_data': Root(people=[Person(name='tom', age=23), Person(name='Jessica', age=75)])}" ] }, - "execution_count": 26, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -728,7 +732,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 29, "id": "7ea06fe0-104b-487a-ab78-cd23de66ec88", "metadata": { "tags": [] @@ -768,7 +772,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 41, "id": "38245cd6-e188-40c9-a940-184da8755983", "metadata": { "tags": [] @@ -778,17 +782,17 @@ "class Pet(BaseModel):\n", " name: str = Field(description=\"the name of the pet\")\n", " species: Optional[str] = Field(\n", - " description=\"The species of the pet; e.g., dog or cat\"\n", + " default=None, description=\"The species of the pet; e.g., dog or cat\"\n", " )\n", - " age: Optional[int] = Field(description=\"The number of the age; e.g.,\")\n", + " age: Optional[int] = Field(default=None, description=\"The number of the age; e.g.,\")\n", " age_unit: Optional[str] = Field(\n", - " description=\"The unit of the age; e.g., days or weeks\"\n", + " default=None, description=\"The unit of the age; e.g., days or weeks\"\n", " )\n", "\n", "\n", "class Person(BaseModel):\n", " name: str = Field(description=\"The person's name\")\n", - " age: Optional[int] = Field(description=\"The age of the person\")\n", + " age: Optional[int] = Field(default=None, description=\"The age of the person\")\n", " pets: List[Pet] = Field(\n", " description=\"The pets owned by the person\",\n", " examples=[\n", @@ -809,7 +813,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 42, "id": "236a9510-6f69-4d63-8854-62e2c57380a6", "metadata": { "tags": [] @@ -826,7 +830,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 43, "id": "843e992a-32c5-4382-95ab-33eb3cd7810b", "metadata": { "tags": [] @@ -874,7 +878,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 45, "id": "e34b6194-9a2b-43a1-95c6-2c9930d036ed", "metadata": { "tags": [] @@ -883,17 +887,17 @@ { "data": { "text/plain": [ - "Root(people=[Person(name='Neo', age=None, pets=[Pet(name='Tom', species='dog', age=None, age_unit=None), Pet(name='Weeby', species='cat', age=23, age_unit='days')]), Person(name='Julia', age=None, pets=[Pet(name='Wind', species='horse', age=7, age_unit='years')])])" + "Root(people=[Person(name='Neo', age=None, pets=[Pet(name='Tom', species='dog', age=None, age_unit=None), Pet(name='Weeby', species='cat', age=23, age_unit='days')]), Person(name='Julia', age=None, pets=[Pet(name='Wind', species='horse', age=None, age_unit=None)])])" ] }, - "execution_count": 36, + "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "chain.predict_and_parse(\n", - " text=\"Neo had a dog by the name of Tom and a cat by the name of Weeby. Weeby was 23 days old. Julia owned a horse. The horses name was Wind. And he was 7 years old\"\n", + "chain.run(\n", + " text=\"Neo had a dog by the name of Tom and a cat by the name of Weeby. Weeby was 23 days old. Julia owned a horse. The horses name was Wind\"\n", ")[\"validated_data\"]" ] } @@ -914,7 +918,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.2" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/kor/adapters.py b/kor/adapters.py index 6ce0a0a..02678fb 100644 --- a/kor/adapters.py +++ b/kor/adapters.py @@ -143,7 +143,7 @@ def _translate_pydantic_to_kor( type_to_use = unpacked_optional if is_optional_equivalent else type_ # If the type is a parameterized generic, we want to extract - # the innter type; e.g., List[str] -> str + # the inner type; e.g., List[str] -> str if not isinstance(type_to_use, type): # i.e., parameterized generic origin_ = get_origin(type_to_use) if not isinstance(origin_, type) or not issubclass(origin_, List): diff --git a/kor/validators.py b/kor/validators.py index d576b25..3acf92c 100644 --- a/kor/validators.py +++ b/kor/validators.py @@ -2,8 +2,9 @@ import abc from typing import Any, List, Mapping, Optional, Tuple, Type, Union -from pydantic import BaseModel -from pydantic import ValidationError as PydanticValidationError +from pydantic import BaseModel, ValidationError + +from ._pydantic import PYDANTIC_MAJOR_VERSION class Validator(abc.ABC): @@ -53,12 +54,21 @@ def clean_data( for item in data: try: - records.append(self.model_class.parse_obj(item)) - except PydanticValidationError as e: + if PYDANTIC_MAJOR_VERSION == 1: + record = self.model_class.parse_obj(item) + else: + record = self.model_class.model_validate(item) + + records.append(record) + except ValidationError as e: exceptions.append(e) return records, exceptions else: try: - return self.model_class.parse_obj(data), [] - except PydanticValidationError as e: + if PYDANTIC_MAJOR_VERSION == 1: + record = self.model_class.parse_obj(data) + else: + record = self.model_class.model_validate(data) + return record, [] + except ValidationError as e: return None, [e] diff --git a/poetry.lock b/poetry.lock index 9c5edbb..84eb612 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2320,40 +2320,6 @@ files = [ {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, ] -[[package]] -name = "numpy" -version = "1.25.0" -description = "Fundamental package for array computing in Python" -optional = false -python-versions = ">=3.9" -files = [ - {file = "numpy-1.25.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8aa130c3042052d656751df5e81f6d61edff3e289b5994edcf77f54118a8d9f4"}, - {file = "numpy-1.25.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9e3f2b96e3b63c978bc29daaa3700c028fe3f049ea3031b58aa33fe2a5809d24"}, - {file = "numpy-1.25.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d6b267f349a99d3908b56645eebf340cb58f01bd1e773b4eea1a905b3f0e4208"}, - {file = "numpy-1.25.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4aedd08f15d3045a4e9c648f1e04daca2ab1044256959f1f95aafeeb3d794c16"}, - {file = "numpy-1.25.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6d183b5c58513f74225c376643234c369468e02947b47942eacbb23c1671f25d"}, - {file = "numpy-1.25.0-cp310-cp310-win32.whl", hash = "sha256:d76a84998c51b8b68b40448ddd02bd1081bb33abcdc28beee6cd284fe11036c6"}, - {file = "numpy-1.25.0-cp310-cp310-win_amd64.whl", hash = "sha256:c0dc071017bc00abb7d7201bac06fa80333c6314477b3d10b52b58fa6a6e38f6"}, - {file = "numpy-1.25.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c69fe5f05eea336b7a740e114dec995e2f927003c30702d896892403df6dbf0"}, - {file = "numpy-1.25.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9c7211d7920b97aeca7b3773a6783492b5b93baba39e7c36054f6e749fc7490c"}, - {file = "numpy-1.25.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecc68f11404930e9c7ecfc937aa423e1e50158317bf67ca91736a9864eae0232"}, - {file = "numpy-1.25.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e559c6afbca484072a98a51b6fa466aae785cfe89b69e8b856c3191bc8872a82"}, - {file = "numpy-1.25.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6c284907e37f5e04d2412950960894b143a648dea3f79290757eb878b91acbd1"}, - {file = "numpy-1.25.0-cp311-cp311-win32.whl", hash = "sha256:95367ccd88c07af21b379be1725b5322362bb83679d36691f124a16357390153"}, - {file = "numpy-1.25.0-cp311-cp311-win_amd64.whl", hash = "sha256:b76aa836a952059d70a2788a2d98cb2a533ccd46222558b6970348939e55fc24"}, - {file = "numpy-1.25.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b792164e539d99d93e4e5e09ae10f8cbe5466de7d759fc155e075237e0c274e4"}, - {file = "numpy-1.25.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7cd981ccc0afe49b9883f14761bb57c964df71124dcd155b0cba2b591f0d64b9"}, - {file = "numpy-1.25.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5aa48bebfb41f93043a796128854b84407d4df730d3fb6e5dc36402f5cd594c0"}, - {file = "numpy-1.25.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5177310ac2e63d6603f659fadc1e7bab33dd5a8db4e0596df34214eeab0fee3b"}, - {file = "numpy-1.25.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0ac6edfb35d2a99aaf102b509c8e9319c499ebd4978df4971b94419a116d0790"}, - {file = "numpy-1.25.0-cp39-cp39-win32.whl", hash = "sha256:7412125b4f18aeddca2ecd7219ea2d2708f697943e6f624be41aa5f8a9852cc4"}, - {file = "numpy-1.25.0-cp39-cp39-win_amd64.whl", hash = "sha256:26815c6c8498dc49d81faa76d61078c4f9f0859ce7817919021b9eba72b425e3"}, - {file = "numpy-1.25.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:5b1b90860bf7d8a8c313b372d4f27343a54f415b20fb69dd601b7efe1029c91e"}, - {file = "numpy-1.25.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85cdae87d8c136fd4da4dad1e48064d700f63e923d5af6c8c782ac0df8044542"}, - {file = "numpy-1.25.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:cc3fda2b36482891db1060f00f881c77f9423eead4c3579629940a3e12095fe8"}, - {file = "numpy-1.25.0.tar.gz", hash = "sha256:f1accae9a28dc3cda46a91de86acf69de0d1b5f4edd44a9b0c3ceb8036dfff19"}, -] - [[package]] name = "overrides" version = "7.3.1" diff --git a/tests/test_validators.py b/tests/test_validators.py index 7f65e6b..a5cc870 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,10 +1,10 @@ """Test validator module.""" -from pydantic import BaseModel, ValidationError +from typing import Optional + +from pydantic import BaseModel, Field, ValidationError from pydantic import validator as pydantic_validator -from kor.validators import ( - PydanticValidator, -) +from kor.validators import PydanticValidator def test_pydantic_validator() -> None: @@ -15,6 +15,8 @@ class ToyModel(BaseModel): name: str age: int + foo: Optional[str] = None + foo2: Optional[str] = Field(default=None, description="some field") @pydantic_validator("age") def age_must_be_positive(cls, v: int) -> int: