Skip to content

Commit

Permalink
feat: Add "airports" dataset (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuan325 authored Oct 16, 2023
1 parent d646454 commit e566554
Show file tree
Hide file tree
Showing 7 changed files with 7,763 additions and 10 deletions.
7,699 changes: 7,699 additions & 0 deletions data/airport_dataset.csv

Large diffs are not rendered by default.

9 changes: 7 additions & 2 deletions extension_service/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,17 @@ async def create(cls, config: C) -> "Client":

@abstractmethod
async def initialize_data(
self, toys: List[models.Toy], embeddings: List[models.Embedding]
self,
toys: List[models.Toy],
airports: List[models.Airport],
embeddings: List[models.Embedding],
) -> None:
pass

@abstractmethod
async def export_data(self) -> Tuple[List[models.Toy], List[models.Embedding]]:
async def export_data(
self,
) -> Tuple[List[models.Toy], List[models.Airport], List[models.Embedding]]:
pass

@abstractmethod
Expand Down
35 changes: 32 additions & 3 deletions extension_service/datastore/providers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ async def init(conn):
return cls(pool)

async def initialize_data(
self, toys: List[models.Toy], embeddings: List[models.Embedding]
self,
toys: List[models.Toy],
airports: List[models.Airport],
embeddings: List[models.Embedding],
) -> None:
async with self.__pool.acquire() as conn:
# If the table already exists, drop it to avoid conflicts
Expand All @@ -89,6 +92,26 @@ async def initialize_data(
],
)

# If the table already exists, drop it to avoid conflicts
await conn.execute("DROP TABLE IF EXISTS airports CASCADE")
# Create a new table
await conn.execute(
"""
CREATE TABLE airports(
id INT PRIMARY KEY,
iata TEXT,
name TEXT,
city TEXT,
country TEXT
)
"""
)
# Insert all the data
await conn.executemany(
"""INSERT INTO airports VALUES ($1, $2, $3, $4, $5)""",
[(a.id, a.iata, a.name, a.city, a.country) for a in airports],
)

await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
await conn.execute("DROP TABLE IF EXISTS product_embeddings")
await conn.execute(
Expand All @@ -105,16 +128,22 @@ async def initialize_data(
[(e.product_id, e.content, e.embedding) for e in embeddings],
)

async def export_data(self) -> Tuple[List[models.Toy], List[models.Embedding]]:
async def export_data(
self,
) -> Tuple[List[models.Toy], List[models.Airport], List[models.Embedding]]:
toy_task = asyncio.create_task(self.__pool.fetch("""SELECT * FROM products"""))
airport_task = asyncio.create_task(
self.__pool.fetch("""SELECT * FROM airports""")
)
emb_task = asyncio.create_task(
self.__pool.fetch("""SELECT * FROM product_embeddings""")
)

toys = [models.Toy.model_validate(dict(t)) for t in await toy_task]
airports = [models.Airport.model_validate(dict(a)) for a in await airport_task]
embeddings = [models.Embedding.model_validate(dict(v)) for v in await emb_task]

return toys, embeddings
return toys, airports, embeddings

async def semantic_similarity_search(
self, query_embedding: List[float], similarity_threshold: float, top_k: int
Expand Down
2 changes: 1 addition & 1 deletion extension_service/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .models import Embedding, Toy
from .models import Airport, Embedding, Toy
8 changes: 8 additions & 0 deletions extension_service/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ class Toy(BaseModel):
list_price: Decimal


class Airport(BaseModel):
id: int
iata: str
name: str
city: str
country: str


class Embedding(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down
13 changes: 10 additions & 3 deletions extension_service/run_database_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,25 @@ async def main():
cfg = parse_config("config.yml")
ds = await datastore.create(cfg.datastore)

toys, embeddings = await ds.export_data()
toys, airports, embeddings = await ds.export_data()

await ds.close()

with open("data/product_dataset.csv.new", "w") as f:
with open("../data/product_dataset.csv.new", "w") as f:
col_names = ["product_id", "product_name", "description", "list_price"]
writer = csv.DictWriter(f, col_names, delimiter=",")
writer.writeheader()
for t in toys:
writer.writerow(t.model_dump())

with open("data/product_embeddings_dataset.csv.new", "w") as f:
with open("../data/airport_dataset.csv.new", "w") as f:
col_names = ["id", "iata", "name", "city", "country"]
writer = csv.DictWriter(f, col_names, delimiter=",")
writer.writeheader()
for a in airports:
writer.writerow(a.model_dump())

with open("../data/product_embeddings_dataset.csv.new", "w") as f:
col_names = ["product_id", "content", "embedding"]
writer = csv.DictWriter(f, col_names, delimiter=",")
writer.writeheader()
Expand Down
7 changes: 6 additions & 1 deletion extension_service/run_database_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,19 @@ async def main() -> None:
reader = csv.DictReader(f, delimiter=",")
toys = [models.Toy.model_validate(line) for line in reader]

airports: List[models.Airport] = []
with open("../data/airport_dataset.csv", "r") as f:
reader = csv.DictReader(f, delimiter=",")
airports = [models.Airport.model_validate(line) for line in reader]

embeddings: List[models.Embedding] = []
with open("../data/product_embeddings_dataset.csv", "r") as f:
reader = csv.DictReader(f, delimiter=",")
embeddings = [models.Embedding.model_validate(line) for line in reader]

cfg = parse_config("config.yml")
ds = await datastore.create(cfg.datastore)
await ds.initialize_data(toys, embeddings)
await ds.initialize_data(toys, airports, embeddings)
await ds.close()


Expand Down

0 comments on commit e566554

Please sign in to comment.