-
Notifications
You must be signed in to change notification settings - Fork 1
/
tools.py
122 lines (106 loc) · 3.91 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#module to extract text from the user uploaded document
import re
from io import BytesIO
from typing import Any, Dict, List
import docx2txt
import streamlit as st
from embeddings import OpenAIEmbeddings
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.docstore.document import Document
from langchain.llms import OpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import VectorStore
from langchain.vectorstores.faiss import FAISS
from openai.error import AuthenticationError
from prompts import STUFF_PROMPT
from pypdf import PdfReader
OPENAI_API_KEY=st.secrets["pass"]
@st.cache_data
def parse_docx(file: BytesIO) -> str:
text = docx2txt.process(file)
text = re.sub(r"\n\s*\n", "\n\n", text)
return text
@st.cache_data
def parse_pdf(file: BytesIO) -> List[str]:
pdf = PdfReader(file)
output = []
for page in pdf.pages:
text = page.extract_text()
# Merge hyphenated words
text = re.sub(r"(\w+)-\n(\w+)", r"\1\2", text)
text = re.sub(r"(?<!\n\s)\n(?!\s\n)", " ", text.strip())
text = re.sub(r"\n\s*\n", "\n\n", text)
output.append(text)
return output
@st.cache_data
def parse_txt(file: BytesIO) -> str:
text = file.read().decode("utf-8")
text = re.sub(r"\n\s*\n", "\n\n", text)
return text
@st.cache_data
def text_to_docs(text: str or List[str]) -> List[Document]:
if isinstance(text, str):
text = [text]
page_docs = [Document(page_content=page) for page in text]
# Adding page numbers as metadata
for i, doc in enumerate(page_docs):
doc.metadata["page"] = i + 1
# Splitting pages into chunks
doc_chunks = []
for doc in page_docs:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""],
chunk_overlap=0,
)
chunks = text_splitter.split_text(doc.page_content)
for i, chunk in enumerate(chunks):
doc = Document(
page_content=chunk, metadata={"page": doc.metadata["page"], "chunk": i}
)
# Add sources a metadata
doc.metadata["source"] = f"{doc.metadata['page']}-{doc.metadata['chunk']}"
doc_chunks.append(doc)
return doc_chunks
# @st.cache_data
def embed_docs(docs: List[Document]) -> VectorStore:
"""Embeds a list of Documents and returns a FAISS index"""
if not st.session_state.get("OPENAI_API_KEY"):
raise AuthenticationError("Invalid OpenAI key !")
else:
# Embed the chunks
embeddings = OpenAIEmbeddings(
openai_api_key=st.session_state.get("OPENAI_API_KEY")
) # type: ignore
index = FAISS.from_documents(docs, embeddings)
return index
#@st.cache_data
def search_docs(index: VectorStore, query: str) -> List[Document]:
"""Searches a FAISS index for similar chunks to the query
and returns a list of Documents."""
# Search for similar chunks
docs = index.similarity_search(query, k=5)
return docs
#@st.cache_data
def get_answer(docs: List[Document], query: str) -> Dict[str, Any]:
"""Gets an answer to a question from a list of Documents."""
chain = load_qa_with_sources_chain(
OpenAI(
temperature=0, openai_api_key=st.session_state.get("OPENAI_API_KEY")
),
chain_type="stuff",
prompt=STUFF_PROMPT,
)
answer = chain(
{"input_documents": docs, "question": query}, return_only_outputs=True
)
return answer
#@st.cache_data
def get_sources(answer: Dict[str, Any], docs: List[Document]) -> List[Document]:
"""Gets the source documents for an answer."""
source_keys = [s for s in answer["output_text"].split("SOURCES: ")[-1].split(", ")]
source_docs = []
for doc in docs:
if doc.metadata["source"] in source_keys:
source_docs.append(doc)
return source_docs