Skip to content

Commit

Permalink
markdown element node parser reliability (#16172)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich committed Sep 24, 2024
1 parent 7160c0c commit be2db5a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, List, Optional
from llama_index.core.bridge.pydantic import BaseModel, SerializeAsAny, ConfigDict
from llama_index.core.bridge.pydantic import SerializeAsAny, ConfigDict
from llama_index.core.base.llms.types import (
ChatMessage,
ChatResponse,
Expand Down Expand Up @@ -69,7 +69,7 @@ class LLMStructuredPredictEndEvent(BaseEvent):
output (BaseModel): Predicted output class.
"""

output: SerializeAsAny[BaseModel]
output: SerializeAsAny[Any]

@classmethod
def class_name(cls) -> str:
Expand All @@ -84,7 +84,7 @@ class LLMStructuredPredictInProgressEvent(BaseEvent):
output (BaseModel): Predicted output class.
"""

output: SerializeAsAny[BaseModel]
output: SerializeAsAny[Any]

@classmethod
def class_name(cls) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.llms.llm import LLM
from llama_index.core.node_parser.interface import NodeParser
from llama_index.core.schema import BaseNode, Document, IndexNode, TextNode
from llama_index.core.schema import (
BaseNode,
Document,
IndexNode,
MetadataMode,
TextNode,
)
from llama_index.core.utils import get_tqdm_iterable

DEFAULT_SUMMARY_QUERY_STR = """\
Expand Down Expand Up @@ -191,7 +197,10 @@ async def _get_table_output(table_context: str, summary_query_str: str) -> Any:
query_engine = index.as_query_engine(llm=llm, output_cls=TableOutput)
try:
response = await query_engine.aquery(summary_query_str)
return cast(PydanticResponse, response).response
if isinstance(response, PydanticResponse):
return response.response
else:
raise ValueError(f"Expected PydanticResponse, got {type(response)}")
except (ValidationError, ValueError):
# There was a pydantic validation error, so we will run with text completion
# fill in the summary and leave other fields blank
Expand Down Expand Up @@ -325,7 +334,7 @@ def get_nodes_from_elements(

node_parser = self.nested_node_parser or SentenceSplitter()

nodes = []
nodes: List[BaseNode] = []
cur_text_el_buffer: List[str] = []
for element in elements:
if element.type == "table" or element.type == "table_text":
Expand Down Expand Up @@ -376,15 +385,17 @@ def get_nodes_from_elements(
# attempt to find start_char_idx for table
# raw table string regardless if perfect or not is stored in element.element

start_char_idx: Optional[int] = None
end_char_idx: Optional[int] = None
if ref_doc_text:
start_char_idx = ref_doc_text.find(str(element.element))
if start_char_idx >= 0:
end_char_idx = start_char_idx + len(str(element.element))
else:
start_char_idx = None
end_char_idx = None
start_char_idx = None # type: ignore
end_char_idx = None # type: ignore
else:
start_char_idx = None # type: ignore
end_char_idx = None # type: ignore

# shared index_id and node_id
node_id = str(uuid.uuid4())
index_node = IndexNode(
Expand Down Expand Up @@ -440,7 +451,11 @@ def get_nodes_from_elements(
node.excluded_llm_metadata_keys = (
node_inherited.excluded_llm_metadata_keys
)
return [node for node in nodes if len(node.get_content()) > 0]
return [
node
for node in nodes
if len(node.get_content(metadata_mode=MetadataMode.NONE)) > 0
]

def __call__(self, nodes: Sequence[BaseNode], **kwargs: Any) -> List[BaseNode]:
nodes = self.get_nodes_from_documents(nodes, **kwargs) # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def __call__(self, *args: Any, **kwds: Any) -> StructuredRefineResponse:
self._prompt,
**kwds,
)
answer = answer.model_dump_json()
if isinstance(answer, BaseModel):
answer = answer.model_dump_json()
else:
answer = self._llm.predict(
self._prompt,
Expand All @@ -94,7 +95,8 @@ async def acall(self, *args: Any, **kwds: Any) -> StructuredRefineResponse:
self._prompt,
**kwds,
)
answer = answer.model_dump_json()
if isinstance(answer, BaseModel):
answer = answer.model_dump_json()
else:
answer = await self._llm.apredict(
self._prompt,
Expand Down Expand Up @@ -185,7 +187,10 @@ def get_response(
prev_response = response
if isinstance(response, str):
if self._output_cls is not None:
response = self._output_cls.model_validate_json(response)
try:
response = self._output_cls.model_validate_json(response)
except ValidationError:
pass
else:
response = response or "Empty Response"
else:
Expand Down

0 comments on commit be2db5a

Please sign in to comment.