Skip to content

Commit

Permalink
add search
Browse files Browse the repository at this point in the history
  • Loading branch information
julesbarbosa committed Aug 14, 2024
1 parent 7301a4f commit f61213e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
12 changes: 12 additions & 0 deletions client/src/lib/Api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ export async function similarSearch(imagePath: string, limit: number, excludeLab
return response;
}

export async function TextSearch(queryStr: string, limit: number, excludeLabeled: boolean): Promise<Hits> {
const response = await fetchJSON<Hits>("/text_search", {
"q": queryStr,
"exclude_labeled": excludeLabeled,
limit: limit.toString(),
}).catch((e) => {
console.error(e);
throw new Error("Failed to retrieve text search results.", { cause: e })
});
return response;
}

export async function random(limit: number): Promise<Hits> {
const response = await fetchJSON<Hits>("/random", {
limit: limit.toString(),
Expand Down
5 changes: 4 additions & 1 deletion mmdx/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,12 @@ def search_by_image_path(
)
return df_hits

def search_by_text(
def search_by_seller(
self, query_string: str, limit: int, exclude_labeled: bool
) -> pd.DataFrame:
original_path = os.environ.get("CSV_PATH")
original_df = pd.read_csv(original_path)

query_str_embedding = self.model.embed_text(query=query_string)
df_hits = self.__vector_embedding_search(
query_str_embedding, limit, exclude_labeled
Expand Down
8 changes: 8 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ def image_search():
return {"total": len(hits.index), "hits": hits.to_dict("records")}


@app.route("/api/v1/text_search")
def text_search():
query: str = request.args.get("q")
exclude_labeled: bool = request.args.get("exclude_labeled", "false") == "true"
limit: int = request.args.get("limit", 12, type=int)
hits = db.search_by_seller(query_string=query, limit=limit, exclude_labeled=exclude_labeled)
return {"total": len(hits.index), "hits": hits.to_dict("records")}

@app.route("/api/v1/labeled")
def labeled_search():
limit: int = request.args.get("limit", 12, type=int)
Expand Down

0 comments on commit f61213e

Please sign in to comment.