From d00cde60309935e283ba9116cf0b114e53cb9640 Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Thu, 23 May 2024 20:03:16 +0200 Subject: [PATCH] fix(pdf_scraper): fix the pdf scraper gaph --- scrapegraphai/graphs/abstract_graph.py | 32 ++++++++++++++--------- scrapegraphai/graphs/pdf_scraper_graph.py | 32 +++++------------------ 2 files changed, 25 insertions(+), 39 deletions(-) diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 6a0c7a4c..e9ba1213 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -181,6 +181,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object: try: self.model_token = models_tokens["ollama"][llm_params["model"]] except KeyError as exc: + print("model not found, using default token size (8192)") self.model_token = 8192 else: self.model_token = 8192 @@ -191,16 +192,18 @@ def _create_llm(self, llm_config: dict, chat=False) -> object: elif "hugging_face" in llm_params["model"]: try: self.model_token = models_tokens["hugging_face"][llm_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc + except KeyError: + print("model not found, using default token size (8192)") + self.model_token = 8192 return HuggingFace(llm_params) elif "groq" in llm_params["model"]: llm_params["model"] = llm_params["model"].split("/")[-1] try: self.model_token = models_tokens["groq"][llm_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc + except KeyError: + print("model not found, using default token size (8192)") + self.model_token = 8192 return Groq(llm_params) elif "bedrock" in llm_params["model"]: llm_params["model"] = llm_params["model"].split("/")[-1] @@ -208,8 +211,9 @@ def _create_llm(self, llm_config: dict, chat=False) -> object: client = llm_params.get('client', None) try: self.model_token = models_tokens["bedrock"][llm_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc + except KeyError: + print("model not found, using default token size (8192)") + self.model_token = 8192 return Bedrock({ "client": client, "model_id": model_id, @@ -218,13 +222,18 @@ def _create_llm(self, llm_config: dict, chat=False) -> object: } }) elif "claude-3-" in llm_params["model"]: - self.model_token = models_tokens["claude"]["claude3"] + try: + self.model_token = models_tokens["claude"]["claude3"] + except KeyError: + print("model not found, using default token size (8192)") + self.model_token = 8192 return Anthropic(llm_params) elif "deepseek" in llm_params["model"]: try: self.model_token = models_tokens["deepseek"][llm_params["model"]] - except KeyError as exc: - raise KeyError("Model not supported") from exc + except KeyError: + print("model not found, using default token size (8192)") + self.model_token = 8192 return DeepSeek(llm_params) else: raise ValueError( @@ -312,10 +321,7 @@ def _create_embedder(self, embedder_config: dict) -> object: models_tokens["bedrock"][embedder_config["model"]] except KeyError as exc: raise KeyError("Model not supported") from exc - return BedrockEmbeddings(client=client, model_id=embedder_config["model"]) - else: - raise ValueError( - "Model provided by the configuration not supported") + return BedrockEmbeddings(client=client, model_id=embedder_config["model"]) def get_state(self, key=None) -> dict: """"" diff --git a/scrapegraphai/graphs/pdf_scraper_graph.py b/scrapegraphai/graphs/pdf_scraper_graph.py index 86ab2a49..39278ab7 100644 --- a/scrapegraphai/graphs/pdf_scraper_graph.py +++ b/scrapegraphai/graphs/pdf_scraper_graph.py @@ -11,7 +11,7 @@ FetchNode, ParseNode, RAGNode, - GenerateAnswerNode + GenerateAnswerPDFNode ) @@ -48,7 +48,7 @@ class PDFScraperGraph(AbstractGraph): """ def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None): - super().__init__(prompt, config, source, schema) + super().__init__(prompt, config, source) self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir" @@ -64,41 +64,21 @@ def _create_graph(self) -> BaseGraph: input='pdf | pdf_dir', output=["doc", "link_urls", "img_urls"], ) - parse_node = ParseNode( - input="doc", - output=["parsed_doc"], - node_config={ - "chunk_size": self.model_token, - } - ) - rag_node = RAGNode( - input="user_prompt & (parsed_doc | doc)", - output=["relevant_chunks"], - node_config={ - "llm_model": self.llm_model, - "embedder_model": self.embedder_model, - } - ) - generate_answer_node = GenerateAnswerNode( + generate_answer_node_pdf = GenerateAnswerPDFNode( input="user_prompt & (relevant_chunks | parsed_doc | doc)", output=["answer"], node_config={ "llm_model": self.llm_model, - "schema": self.schema, } ) return BaseGraph( nodes=[ fetch_node, - parse_node, - rag_node, - generate_answer_node, + generate_answer_node_pdf, ], edges=[ - (fetch_node, parse_node), - (parse_node, rag_node), - (rag_node, generate_answer_node) + (fetch_node, generate_answer_node_pdf) ], entry_point=fetch_node ) @@ -114,4 +94,4 @@ def run(self) -> str: inputs = {"user_prompt": self.prompt, self.input_key: self.source} self.final_state, self.execution_info = self.graph.execute(inputs) - return self.final_state.get("answer", "No answer found.") \ No newline at end of file + return self.final_state.get("answer", "No answer found.")