Skip to content

Commit

Permalink
Merge pull request MLT-OSS#26 from liuooo/fix-retrieval-file-key-hall…
Browse files Browse the repository at this point in the history
…ucination

fix file key hallucination in retrieval tool
  • Loading branch information
kense-lab authored Feb 19, 2024
2 parents f541c7e + 7e2fdad commit 9ff7eba
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions app/core/tools/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class RetrievalToolInput(BaseModel):
file_keys: List[str] = Field(..., description="file key list to look up in retrieval")
indexes: List[int] = Field(..., description="file index list to look up in retrieval")
query: str = Field(..., description="query to look up in retrieval")


Expand All @@ -26,35 +26,40 @@ class RetrievalTool(BaseTool):

def __init__(self) -> None:
super().__init__()
self.__files = []
self.__filenames = []
self.__keys = []

def configure(self, session: Session, run: Run, **kwargs):
"""
置当前 Retrieval 涉及文件信息
"""
self.__files = FileService.get_file_list_by_ids(session=session, file_ids=run.file_ids)
files = FileService.get_file_list_by_ids(session=session, file_ids=run.file_ids)
# pre-cache data to prevent thread conflicts that may occur later on.
for file in files:
self.__filenames.append(file.filename)
self.__keys.append(file.key)

def run(self, file_keys: List[str], query: str) -> dict:
# TODO: 实现真正的 retrieval
def run(self, indexes: List[int], query: str) -> dict:
files = {}
for file_key in file_keys:
for index in indexes:
file_key = self.__keys[index]
file_data = storage.load(file_key)
# 截取前 1000 字符,防止超出 LLM 最大上下文限制
files[file_key] = doc_loader.load(file_data)[:1000]
# 截取前 5000 字符,防止超出 LLM 最大上下文限制
files[file_key] = doc_loader.load(file_data)[:5000]

return files

def instruction_supplement(self) -> str:
"""
为 Retrieval 提供文件选择信息,用于 llm 调用抉择
"""
if len(self.__files) == 0:
if len(self.__filenames) == 0:
return ""
else:
filenames_info = [f"{file.filename}({file.key})" for file in self.__files]
filenames_info = [f"({index}){filename}" for index, filename in enumerate(self.__filenames)]
return (
'You can use the "retrieval" tool to retrieve relevant context from the following attached files. '
+ 'Each line represents a file in the format "filename(file key)":\n'
+ 'Each line represents a file in the format "(index)filename":\n'
+ "\n".join(filenames_info)
+ "\nMake sure to be extremely concise when using attached files. "
)

0 comments on commit 9ff7eba

Please sign in to comment.