Skip to content

Commit

Permalink
feat: migrate text splitters to Component syntax (#2530)
Browse files Browse the repository at this point in the history
* feat: migrate text splitters to Component syntax

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <[email protected]>
  • Loading branch information
3 people authored Jul 4, 2024
1 parent cb81852 commit 86aaab0
Show file tree
Hide file tree
Showing 8 changed files with 2,498 additions and 1,965 deletions.
Empty file.
58 changes: 58 additions & 0 deletions src/backend/base/langflow/base/textsplitters/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from abc import abstractmethod
from typing import Any
from langchain_text_splitters import TextSplitter


from langflow.custom import Component
from langflow.io import Output
from langflow.schema import Data
from langflow.utils.util import build_loader_repr_from_data


class LCTextSplitterComponent(Component):
trace_type = "text_splitter"
outputs = [
Output(display_name="Data", name="data", method="split_data"),
]

def _validate_outputs(self):
required_output_methods = ["text_splitter"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:
if method_name not in output_names:
raise ValueError(f"Output with name '{method_name}' must be defined.")
elif not hasattr(self, method_name):
raise ValueError(f"Method '{method_name}' must be defined.")

def split_data(self) -> list[Data]:
data_input = self.get_data_input()
documents = []

if not isinstance(data_input, list):
data_input: list[Any] = [data_input]

for _input in data_input:
if isinstance(_input, Data):
documents.append(_input.to_lc_document())
else:
documents.append(_input)

splitter = self.build_text_splitter()
docs = splitter.split_documents(documents)
data = self.to_data(docs)
self.repr_value = build_loader_repr_from_data(data)
return data

@abstractmethod
def get_data_input(self) -> Any:
"""
Get the data input.
"""
pass

@abstractmethod
def build_text_splitter(self) -> TextSplitter:
"""
Build the text splitter.
"""
pass
Original file line number Diff line number Diff line change
@@ -1,24 +1,58 @@
from typing import List
from typing import List, Any

from langchain_text_splitters import CharacterTextSplitter
from langchain_text_splitters import CharacterTextSplitter, TextSplitter

from langflow.custom import CustomComponent
from langflow.base.textsplitters.model import LCTextSplitterComponent
from langflow.inputs import IntInput, DataInput, MessageTextInput
from langflow.schema import Data
from langflow.utils.util import unescape_string


class CharacterTextSplitterComponent(CustomComponent):
class CharacterTextSplitterComponent(LCTextSplitterComponent):
display_name = "CharacterTextSplitter"
description = "Splitting text that looks at characters."
description = "Split text by number of characters."
documentation = "https://docs.langflow.org/components/text-splitters#charactertextsplitter"
name = "CharacterTextSplitter"

def build_config(self):
return {
"inputs": {"display_name": "Input", "input_types": ["Document", "Data"]},
"chunk_overlap": {"display_name": "Chunk Overlap", "default": 200},
"chunk_size": {"display_name": "Chunk Size", "default": 1000},
"separator": {"display_name": "Separator", "default": "\n"},
}
inputs = [
IntInput(
name="chunk_size",
display_name="Chunk Size",
info="The maximum length of each chunk.",
value=1000,
),
IntInput(
name="chunk_overlap",
display_name="Chunk Overlap",
info="The amount of overlap between chunks.",
value=200,
),
DataInput(
name="data_input",
display_name="Input",
info="The texts to split.",
input_types=["Document", "Data"],
),
MessageTextInput(
name="separator",
display_name="Separator",
info='The characters to split on.\nIf left empty defaults to "\\n\\n".',
),
]

def get_data_input(self) -> Any:
return self.data_input

def build_text_splitter(self) -> TextSplitter:
if self.separator:
separator = unescape_string(self.separator)
else:
separator = "\n\n"
return CharacterTextSplitter(
chunk_overlap=self.chunk_overlap,
chunk_size=self.chunk_size,
separator=separator,
)

def build(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,85 +1,47 @@
from typing import List, Optional
from typing import Any

from langchain_text_splitters import Language, RecursiveCharacterTextSplitter
from langchain_text_splitters import Language, RecursiveCharacterTextSplitter, TextSplitter

from langflow.custom import CustomComponent
from langflow.schema import Data
from langflow.base.textsplitters.model import LCTextSplitterComponent
from langflow.inputs import IntInput, DataInput, DropdownInput


class LanguageRecursiveTextSplitterComponent(CustomComponent):
class LanguageRecursiveTextSplitterComponent(LCTextSplitterComponent):
display_name: str = "Language Recursive Text Splitter"
description: str = "Split text into chunks of a specified length based on language."
documentation: str = "https://docs.langflow.org/components/text-splitters#languagerecursivetextsplitter"
name = "LanguageRecursiveTextSplitter"

def build_config(self):
options = [x.value for x in Language]
return {
"inputs": {"display_name": "Input", "input_types": ["Document", "Data"]},
"separator_type": {
"display_name": "Separator Type",
"info": "The type of separator to use.",
"field_type": "str",
"options": options,
"value": "Python",
},
"separators": {
"display_name": "Separators",
"info": "The characters to split on.",
"is_list": True,
},
"chunk_size": {
"display_name": "Chunk Size",
"info": "The maximum length of each chunk.",
"field_type": "int",
"value": 1000,
},
"chunk_overlap": {
"display_name": "Chunk Overlap",
"info": "The amount of overlap between chunks.",
"field_type": "int",
"value": 200,
},
"code": {"show": False},
}

def build(
self,
inputs: List[Data],
chunk_size: Optional[int] = 1000,
chunk_overlap: Optional[int] = 200,
separator_type: str = "Python",
) -> list[Data]:
"""
Split text into chunks of a specified length.
Args:
separators (list[str]): The characters to split on.
chunk_size (int): The maximum length of each chunk.
chunk_overlap (int): The amount of overlap between chunks.
length_function (function): The function to use to calculate the length of the text.
Returns:
list[str]: The chunks of text.
"""

# Make sure chunk_size and chunk_overlap are ints
if isinstance(chunk_size, str):
chunk_size = int(chunk_size)
if isinstance(chunk_overlap, str):
chunk_overlap = int(chunk_overlap)

splitter = RecursiveCharacterTextSplitter.from_language(
language=Language(separator_type),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
inputs = [
IntInput(
name="chunk_size",
display_name="Chunk Size",
info="The maximum length of each chunk.",
value=1000,
),
IntInput(
name="chunk_overlap",
display_name="Chunk Overlap",
info="The amount of overlap between chunks.",
value=200,
),
DataInput(
name="data_input",
display_name="Input",
info="The texts to split.",
input_types=["Document", "Data"],
),
DropdownInput(
name="code_language", display_name="Code Language", options=[x.value for x in Language], value="python"
),
]

def get_data_input(self) -> Any:
return self.data_input

def build_text_splitter(self) -> TextSplitter:
return RecursiveCharacterTextSplitter.from_language(
language=Language(self.code_language),
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
)
documents = []
for _input in inputs:
if isinstance(_input, Data):
documents.append(_input.to_lc_document())
else:
documents.append(_input)
docs = splitter.split_documents(documents)
data = self.to_data(docs)
return data
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from langchain_text_splitters import RecursiveCharacterTextSplitter

from langflow.custom import Component
from typing import Any
from langchain_text_splitters import RecursiveCharacterTextSplitter, TextSplitter
from langflow.base.textsplitters.model import LCTextSplitterComponent
from langflow.inputs.inputs import DataInput, IntInput, MessageTextInput
from langflow.schema import Data
from langflow.template.field.base import Output
from langflow.utils.util import build_loader_repr_from_data, unescape_string
from langflow.utils.util import unescape_string


class RecursiveCharacterTextSplitterComponent(Component):
class RecursiveCharacterTextSplitterComponent(LCTextSplitterComponent):
display_name: str = "Recursive Character Text Splitter"
description: str = "Split text into chunks of a specified length."
description: str = "Split text trying to keep all related text together."
documentation: str = "https://docs.langflow.org/components/text-splitters#recursivecharactertextsplitter"
name = "RecursiveCharacterTextSplitter"

Expand Down Expand Up @@ -39,49 +37,20 @@ class RecursiveCharacterTextSplitterComponent(Component):
is_list=True,
),
]
outputs = [
Output(display_name="Data", name="data", method="split_data"),
]

def split_data(self) -> list[Data]:
"""
Split text into chunks of a specified length.
Args:
separators (list[str] | None): The characters to split on.
chunk_size (int): The maximum length of each chunk.
chunk_overlap (int): The amount of overlap between chunks.

Returns:
list[str]: The chunks of text.
"""
def get_data_input(self) -> Any:
return self.data_input

if self.separators == "":
self.separators: list[str] | None = None
elif self.separators:
def build_text_splitter(self) -> TextSplitter:
if not self.separators:
separators: list[str] | None = None
else:
# check if the separators list has escaped characters
# if there are escaped characters, unescape them
self.separators = [unescape_string(x) for x in self.separators]
separators = [unescape_string(x) for x in self.separators]

# Make sure chunk_size and chunk_overlap are ints
if self.chunk_size:
self.chunk_size: int = int(self.chunk_size)
if self.chunk_overlap:
self.chunk_overlap: int = int(self.chunk_overlap)
splitter = RecursiveCharacterTextSplitter(
separators=self.separators,
return RecursiveCharacterTextSplitter(
separators=separators,
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
)
documents = []
if not isinstance(self.data_input, list):
self.data_input: list[Data] = [self.data_input]
for _input in self.data_input:
if isinstance(_input, Data):
documents.append(_input.to_lc_document())
else:
documents.append(_input)
docs = splitter.split_documents(documents)
data = self.to_data(docs)
self.repr_value = build_loader_repr_from_data(data)
return data
Loading

0 comments on commit 86aaab0

Please sign in to comment.