Skip to content

Commit

Permalink
fix(pdf_scraper): fix the pdf scraper gaph
Browse files Browse the repository at this point in the history
  • Loading branch information
VinciGit00 committed May 23, 2024
1 parent 00a392b commit d00cde6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 39 deletions.
32 changes: 19 additions & 13 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
try:
self.model_token = models_tokens["ollama"][llm_params["model"]]
except KeyError as exc:
print("model not found, using default token size (8192)")
self.model_token = 8192
else:
self.model_token = 8192
Expand All @@ -191,25 +192,28 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
elif "hugging_face" in llm_params["model"]:
try:
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return HuggingFace(llm_params)
elif "groq" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]

try:
self.model_token = models_tokens["groq"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return Groq(llm_params)
elif "bedrock" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
model_id = llm_params["model"]
client = llm_params.get('client', None)
try:
self.model_token = models_tokens["bedrock"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return Bedrock({
"client": client,
"model_id": model_id,
Expand All @@ -218,13 +222,18 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
}
})
elif "claude-3-" in llm_params["model"]:
self.model_token = models_tokens["claude"]["claude3"]
try:
self.model_token = models_tokens["claude"]["claude3"]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return Anthropic(llm_params)
elif "deepseek" in llm_params["model"]:
try:
self.model_token = models_tokens["deepseek"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return DeepSeek(llm_params)
else:
raise ValueError(
Expand Down Expand Up @@ -312,10 +321,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
models_tokens["bedrock"][embedder_config["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return BedrockEmbeddings(client=client, model_id=embedder_config["model"])
else:
raise ValueError(
"Model provided by the configuration not supported")
return BedrockEmbeddings(client=client, model_id=embedder_config["model"])

def get_state(self, key=None) -> dict:
"""""
Expand Down
32 changes: 6 additions & 26 deletions scrapegraphai/graphs/pdf_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
FetchNode,
ParseNode,
RAGNode,
GenerateAnswerNode
GenerateAnswerPDFNode
)


Expand Down Expand Up @@ -48,7 +48,7 @@ class PDFScraperGraph(AbstractGraph):
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
super().__init__(prompt, config, source, schema)
super().__init__(prompt, config, source)

self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir"

Expand All @@ -64,41 +64,21 @@ def _create_graph(self) -> BaseGraph:
input='pdf | pdf_dir',
output=["doc", "link_urls", "img_urls"],
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={
"chunk_size": self.model_token,
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.embedder_model,
}
)
generate_answer_node = GenerateAnswerNode(
generate_answer_node_pdf = GenerateAnswerPDFNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={
"llm_model": self.llm_model,
"schema": self.schema,
}
)

return BaseGraph(
nodes=[
fetch_node,
parse_node,
rag_node,
generate_answer_node,
generate_answer_node_pdf,
],
edges=[
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
(fetch_node, generate_answer_node_pdf)
],
entry_point=fetch_node
)
Expand All @@ -114,4 +94,4 @@ def run(self) -> str:
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)

return self.final_state.get("answer", "No answer found.")
return self.final_state.get("answer", "No answer found.")

0 comments on commit d00cde6

Please sign in to comment.