Skip to content

Commit

Permalink
add run policy dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 committed Mar 21, 2024
1 parent 1796410 commit 41feedb
Show file tree
Hide file tree
Showing 13 changed files with 125 additions and 49 deletions.
22 changes: 11 additions & 11 deletions data/cymbalair_policy.csv

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion retrieval_service/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import yaml
from fastapi import FastAPI
from langchain.embeddings import VertexAIEmbeddings
from langchain_google_vertexai import VertexAIEmbeddings
from pydantic import BaseModel

import datastore
Expand Down
2 changes: 1 addition & 1 deletion retrieval_service/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from fastapi import APIRouter, Depends, HTTPException, Request
from google.auth.transport import requests # type:ignore
from google.oauth2 import id_token # type:ignore
from langchain.embeddings.base import Embeddings
from langchain_core.embeddings import Embeddings

import datastore

Expand Down
3 changes: 1 addition & 2 deletions retrieval_service/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,8 @@ async def export_dataset(

with open(policies_new_path, "w") as f:
col_names = [
"langchain_id",
"id",
"content",
"metadata",
"embedding",
]
writer = csv.DictWriter(f, col_names, delimiter=",")
Expand Down
9 changes: 3 additions & 6 deletions retrieval_service/datastore/providers/cloudsql_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import asyncio
import json
from datetime import datetime
from typing import Any, Dict, Literal, Optional

Expand Down Expand Up @@ -247,9 +246,8 @@ async def initialize_data(
text(
"""
CREATE TABLE policies(
langchain_id INT PRIMARY KEY,
id INT PRIMARY KEY,
content TEXT NOT NULL,
metadata JSON,
embedding vector(768) NOT NULL
)
"""
Expand All @@ -259,14 +257,13 @@ async def initialize_data(
await conn.execute(
text(
"""
INSERT INTO policies VALUES (:langchain_id, :content, :metadata, :embedding)
INSERT INTO policies VALUES (:id, :content, :embedding)
"""
),
[
{
"langchain_id": p.langchain_id,
"id": p.id,
"content": p.content,
"metadata": json.dumps(p.metadata),
"embedding": p.embedding,
}
for p in policies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ async def test_export_dataset(ds: cloudsql_postgres.Client):
assert diff_flights["columns_removed"] == []

diff_policies = compare(
load_csv(open(policies_ds_path), "langchain_id"),
load_csv(open(policies_new_path), "langchain_id"),
load_csv(open(policies_ds_path), "id"),
load_csv(open(policies_new_path), "id"),
)
assert diff_policies["added"] == []
assert diff_policies["removed"] == []
Expand Down
5 changes: 2 additions & 3 deletions retrieval_service/datastore/providers/firestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,10 @@ async def delete_collections(collection_list: list[AsyncCollectionReference]):
for policy in policies:
create_policies_tasks.append(
self.__client.collection("policies")
.document(str(policy.langchain_id))
.document(str(policy.id))
.set(
{
"content": policy.content,
"metadata": policy.metadata,
"embedding": policy.embedding,
}
)
Expand Down Expand Up @@ -182,7 +181,7 @@ async def export_data(
policies = []
async for doc in policies_docs:
policy_dict = doc.to_dict()
policy_dict["langchain_id"] = doc.langchain_id
policy_dict["id"] = doc.id
policies.append(models.Policy.model_validate(policy_dict))
return airports, amenities, flights, policies

Expand Down
10 changes: 4 additions & 6 deletions retrieval_service/datastore/providers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,23 +227,21 @@ async def initialize_data(
await conn.execute(
"""
CREATE TABLE policies(
langchain_id INT PRIMARY KEY,
id INT PRIMARY KEY,
content TEXT NOT NULL,
metadata JSON,
embedding vector(768) NOT NULL
)
"""
)
# Insert all the data
await conn.executemany(
"""
INSERT INTO policies VALUES ($1, $2, $3, $4)
INSERT INTO policies VALUES ($1, $2, $3)
""",
[
(
p.langchain_id,
p.id,
p.content,
p.metadata,
p.embedding,
)
for p in policies
Expand All @@ -268,7 +266,7 @@ async def export_data(
self.__pool.fetch("""SELECT * FROM flights ORDER BY id ASC""")
)
policy_task = asyncio.create_task(
self.__pool.fetch("""SELECT * FROM policies ORDER BY langchain_id ASC""")
self.__pool.fetch("""SELECT * FROM policies ORDER BY id ASC""")
)

airports = [models.Airport.model_validate(dict(a)) for a in await airport_task]
Expand Down
4 changes: 2 additions & 2 deletions retrieval_service/datastore/providers/postgres_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ async def test_export_dataset(ds: postgres.Client):
assert diff_flights["columns_removed"] == []

diff_policies = compare(
load_csv(open(policies_ds_path), "langchain_id"),
load_csv(open(policies_new_path), "langchain_id"),
load_csv(open(policies_ds_path), "id"),
load_csv(open(policies_new_path), "id"),
)
assert diff_policies["added"] == []
assert diff_policies["removed"] == []
Expand Down
10 changes: 1 addition & 9 deletions retrieval_service/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,8 @@ class Ticket(BaseModel):


class Policy(BaseModel):
langchain_id: int
id: int
content: str
metadata: str
embedding: Optional[list[float]] = None

@field_validator("embedding", mode="before")
Expand All @@ -120,10 +119,3 @@ def validate(cls, v):
v = ast.literal_eval(v)
v = [float(f) for f in v]
return v

@field_validator("metadata", mode="before")
def convert_json_to_string(cls, v):
try:
return json.loads(v)
except:
return json.dumps(v)
8 changes: 6 additions & 2 deletions retrieval_service/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
asyncpg==0.29.0
fastapi==0.109.2
google-cloud-firestore==2.14.0
google-cloud-aiplatform==1.41.0
langchain==0.1.11
google-cloud-aiplatform==1.44.0
langchain-core==0.1.33
pgvector==0.2.5
pydantic==2.6.1
uvicorn[standard]==0.27.0.post1
cloud-sql-python-connector==1.6.0
sqlalchemy==2.0.25
pandas==2.2.1
pandas-stubs==2.2.1.240316
langchain-text-splitters==0.0.1
langchain-google-vertexai==0.1.1
5 changes: 2 additions & 3 deletions retrieval_service/run_generate_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import asyncio
import csv

from langchain.embeddings import VertexAIEmbeddings
from langchain_google_vertexai import VertexAIEmbeddings

import models
from app import EMBEDDING_MODEL_NAME
Expand Down Expand Up @@ -77,9 +77,8 @@ async def main() -> None:

with open("../data/cymbalair_policy.csv.new", "w") as f:
col_names = [
"langchain_id",
"id",
"content",
"metadata",
"embedding",
]
writer = csv.DictWriter(f, col_names, delimiter=",")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,88 @@
# Cymbal Air: Passenger Policy
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time

import pandas as pd
from langchain_google_vertexai import VertexAIEmbeddings
from langchain_text_splitters import (
MarkdownHeaderTextSplitter,
RecursiveCharacterTextSplitter,
)

from app import EMBEDDING_MODEL_NAME


def main() -> None:
policies_ds_path = "../data/cymbalair_policy.csv"

chunked = text_split(_POLICY)
data_embeddings = vectorize(chunked)
data_embeddings.to_csv(policies_ds_path, index=True, index_label="id")

print("Done generating policy dataset.")


def text_split(data):
headers_to_split_on = [("#", "Header 1"), ("##", "Header 2")]
markdown_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on, strip_headers=False
)
md_header_splits = markdown_splitter.split_text(data)

text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=30,
length_function=len,
)
splits = text_splitter.split_documents(md_header_splits)

chunked = [{"content": s.page_content} for s in splits]
return chunked


def vectorize(chunked):
embed_service = VertexAIEmbeddings(model_name=EMBEDDING_MODEL_NAME)

def retry_with_backoff(func, *args, retry_delay=5, backoff_factor=2, **kwargs):
max_attempts = 3
retries = 0
for i in range(max_attempts):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"error: {e}")
retries += 1
wait = retry_delay * (backoff_factor**retries)
print(f"Retry after waiting for {wait} seconds...")
time.sleep(wait)

batch_size = 5
for i in range(0, len(chunked), batch_size):
request = [x["content"] for x in chunked[i : i + batch_size]]
response = retry_with_backoff(embed_service.embed_documents, request)
# Store the retrieved vector embeddings for each chunk back.
for x, e in zip(chunked[i : i + batch_size], response):
x["embedding"] = e

data_embeddings = pd.DataFrame(chunked)
data_embeddings.head()
return data_embeddings


_POLICY = """# Cymbal Air: Passenger Policy
## Ticket Purchase and Changes
Types of Fares: Cymbal Air offers a variety of fares (Economy, Premium Economy, Business Class, and First Class). Fare restrictions, such as change fees and refundability, vary depending on the fare purchased.
Expand Down Expand Up @@ -28,3 +112,7 @@
Cymbal Air strives to maintain on-time performance, but disruptions due to weather, mechanical issues, or other events may occur. In the event of delays or cancellations:
Rebooking: We will make reasonable efforts to rebook affected passengers on the next available flight.
Compensation: Compensation for flight delays and cancellations may be provided in certain situations as outlined by our policies and regulations.
"""

if __name__ == "__main__":
main()

0 comments on commit 41feedb

Please sign in to comment.