-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: migrate text splitters to Component syntax (#2530)
* feat: migrate text splitters to Component syntax * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida <[email protected]>
- Loading branch information
1 parent
cb81852
commit 86aaab0
Showing
8 changed files
with
2,498 additions
and
1,965 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from abc import abstractmethod | ||
from typing import Any | ||
from langchain_text_splitters import TextSplitter | ||
|
||
|
||
from langflow.custom import Component | ||
from langflow.io import Output | ||
from langflow.schema import Data | ||
from langflow.utils.util import build_loader_repr_from_data | ||
|
||
|
||
class LCTextSplitterComponent(Component): | ||
trace_type = "text_splitter" | ||
outputs = [ | ||
Output(display_name="Data", name="data", method="split_data"), | ||
] | ||
|
||
def _validate_outputs(self): | ||
required_output_methods = ["text_splitter"] | ||
output_names = [output.name for output in self.outputs] | ||
for method_name in required_output_methods: | ||
if method_name not in output_names: | ||
raise ValueError(f"Output with name '{method_name}' must be defined.") | ||
elif not hasattr(self, method_name): | ||
raise ValueError(f"Method '{method_name}' must be defined.") | ||
|
||
def split_data(self) -> list[Data]: | ||
data_input = self.get_data_input() | ||
documents = [] | ||
|
||
if not isinstance(data_input, list): | ||
data_input: list[Any] = [data_input] | ||
|
||
for _input in data_input: | ||
if isinstance(_input, Data): | ||
documents.append(_input.to_lc_document()) | ||
else: | ||
documents.append(_input) | ||
|
||
splitter = self.build_text_splitter() | ||
docs = splitter.split_documents(documents) | ||
data = self.to_data(docs) | ||
self.repr_value = build_loader_repr_from_data(data) | ||
return data | ||
|
||
@abstractmethod | ||
def get_data_input(self) -> Any: | ||
""" | ||
Get the data input. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def build_text_splitter(self) -> TextSplitter: | ||
""" | ||
Build the text splitter. | ||
""" | ||
pass |
58 changes: 46 additions & 12 deletions
58
src/backend/base/langflow/components/textsplitters/CharacterTextSplitter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
112 changes: 37 additions & 75 deletions
112
src/backend/base/langflow/components/textsplitters/LanguageRecursiveTextSplitter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,85 +1,47 @@ | ||
from typing import List, Optional | ||
from typing import Any | ||
|
||
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter | ||
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter, TextSplitter | ||
|
||
from langflow.custom import CustomComponent | ||
from langflow.schema import Data | ||
from langflow.base.textsplitters.model import LCTextSplitterComponent | ||
from langflow.inputs import IntInput, DataInput, DropdownInput | ||
|
||
|
||
class LanguageRecursiveTextSplitterComponent(CustomComponent): | ||
class LanguageRecursiveTextSplitterComponent(LCTextSplitterComponent): | ||
display_name: str = "Language Recursive Text Splitter" | ||
description: str = "Split text into chunks of a specified length based on language." | ||
documentation: str = "https://docs.langflow.org/components/text-splitters#languagerecursivetextsplitter" | ||
name = "LanguageRecursiveTextSplitter" | ||
|
||
def build_config(self): | ||
options = [x.value for x in Language] | ||
return { | ||
"inputs": {"display_name": "Input", "input_types": ["Document", "Data"]}, | ||
"separator_type": { | ||
"display_name": "Separator Type", | ||
"info": "The type of separator to use.", | ||
"field_type": "str", | ||
"options": options, | ||
"value": "Python", | ||
}, | ||
"separators": { | ||
"display_name": "Separators", | ||
"info": "The characters to split on.", | ||
"is_list": True, | ||
}, | ||
"chunk_size": { | ||
"display_name": "Chunk Size", | ||
"info": "The maximum length of each chunk.", | ||
"field_type": "int", | ||
"value": 1000, | ||
}, | ||
"chunk_overlap": { | ||
"display_name": "Chunk Overlap", | ||
"info": "The amount of overlap between chunks.", | ||
"field_type": "int", | ||
"value": 200, | ||
}, | ||
"code": {"show": False}, | ||
} | ||
|
||
def build( | ||
self, | ||
inputs: List[Data], | ||
chunk_size: Optional[int] = 1000, | ||
chunk_overlap: Optional[int] = 200, | ||
separator_type: str = "Python", | ||
) -> list[Data]: | ||
""" | ||
Split text into chunks of a specified length. | ||
Args: | ||
separators (list[str]): The characters to split on. | ||
chunk_size (int): The maximum length of each chunk. | ||
chunk_overlap (int): The amount of overlap between chunks. | ||
length_function (function): The function to use to calculate the length of the text. | ||
Returns: | ||
list[str]: The chunks of text. | ||
""" | ||
|
||
# Make sure chunk_size and chunk_overlap are ints | ||
if isinstance(chunk_size, str): | ||
chunk_size = int(chunk_size) | ||
if isinstance(chunk_overlap, str): | ||
chunk_overlap = int(chunk_overlap) | ||
|
||
splitter = RecursiveCharacterTextSplitter.from_language( | ||
language=Language(separator_type), | ||
chunk_size=chunk_size, | ||
chunk_overlap=chunk_overlap, | ||
inputs = [ | ||
IntInput( | ||
name="chunk_size", | ||
display_name="Chunk Size", | ||
info="The maximum length of each chunk.", | ||
value=1000, | ||
), | ||
IntInput( | ||
name="chunk_overlap", | ||
display_name="Chunk Overlap", | ||
info="The amount of overlap between chunks.", | ||
value=200, | ||
), | ||
DataInput( | ||
name="data_input", | ||
display_name="Input", | ||
info="The texts to split.", | ||
input_types=["Document", "Data"], | ||
), | ||
DropdownInput( | ||
name="code_language", display_name="Code Language", options=[x.value for x in Language], value="python" | ||
), | ||
] | ||
|
||
def get_data_input(self) -> Any: | ||
return self.data_input | ||
|
||
def build_text_splitter(self) -> TextSplitter: | ||
return RecursiveCharacterTextSplitter.from_language( | ||
language=Language(self.code_language), | ||
chunk_size=self.chunk_size, | ||
chunk_overlap=self.chunk_overlap, | ||
) | ||
documents = [] | ||
for _input in inputs: | ||
if isinstance(_input, Data): | ||
documents.append(_input.to_lc_document()) | ||
else: | ||
documents.append(_input) | ||
docs = splitter.split_documents(documents) | ||
data = self.to_data(docs) | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.