From f1d79eacf6c17e0b15044ca70dd51705c5e64706 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Mon, 10 Jul 2023 15:45:59 +0800 Subject: [PATCH] Add ann search Signed-off-by: junjie.jiang --- .github/workflows/codecov.yml | 2 - Dockerfile | 2 + GPUDockerfile | 2 + tests/unittests/command/test_cmdline.py | 11 +- towhee/doc/source/operator/hub_ops.rst | 5 + towhee/runtime/hub_ops/ann_insert.py | 2 +- towhee/runtime/hub_ops/ann_search.py | 139 ++++++++++++++++++++++ towhee/runtime/hub_ops/operator_parser.py | 7 ++ 8 files changed, 164 insertions(+), 6 deletions(-) create mode 100644 towhee/runtime/hub_ops/ann_search.py diff --git a/.github/workflows/codecov.yml b/.github/workflows/codecov.yml index 5bb4b7aa4a..e8b25c794f 100644 --- a/.github/workflows/codecov.yml +++ b/.github/workflows/codecov.yml @@ -30,9 +30,7 @@ jobs: run: | export TOWHEE_WORKER=True rm -rf ./coverage.xml - pip install coverage pytest pytest-cov pytest-xdist pip install -r test_requirements.txt - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends -y ffmpeg libsm6 libxext6 coverage erase coverage run -m pytest coverage xml diff --git a/Dockerfile b/Dockerfile index e27ca15ca5..efa8a15c6e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -52,4 +52,6 @@ RUN /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y # ut image, build command: # docker build --platform x86_64 --target towhee-ut -t towhee/towhee-ut:latest . FROM towhee-conda as towhee-ut +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends -y ffmpeg libsm6 libxext6 +RUN pip install coverage pytest pytest-cov pytest-xdist WORKDIR /workspace \ No newline at end of file diff --git a/GPUDockerfile b/GPUDockerfile index 0961fcc187..8bda46a08c 100644 --- a/GPUDockerfile +++ b/GPUDockerfile @@ -52,4 +52,6 @@ RUN /opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y # ut image, build command: # docker build --platform x86_64 --target towhee-ut -t towhee/towhee-ut:latest . FROM towhee-conda as towhee-ut +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends -y ffmpeg libsm6 libxext6 +RUN pip install coverage pytest pytest-cov pytest-xdist WORKDIR /workspace \ No newline at end of file diff --git a/tests/unittests/command/test_cmdline.py b/tests/unittests/command/test_cmdline.py index 8b1be651e2..b524e50559 100644 --- a/tests/unittests/command/test_cmdline.py +++ b/tests/unittests/command/test_cmdline.py @@ -32,6 +32,11 @@ FILE_PATH = PUBLIC_PATH.parent.parent / 'towhee' / 'command' / 'cmdline.py' PYTHON_PATH = ':'.join(sys.path) +env = { + 'PYTHONPATH': PYTHON_PATH, + 'HOME': os.environ.get('HOME') +} + class TestCmdline(unittest.TestCase): """ @@ -58,7 +63,7 @@ def test_http_server(self): p = subprocess.Popen( [sys.executable, FILE_PATH, 'server', 'main_test:service', '--host', '0.0.0.0', '--http-port', '40001'], cwd=__file__.rsplit('/', 1)[0], - env={'PYTHONPATH': PYTHON_PATH} + env=env ) time.sleep(2) res = requests.post(url='http://0.0.0.0:40001/echo', data=json.dumps(1), timeout=3).json() @@ -70,7 +75,7 @@ def test_grpc_server(self): p = subprocess.Popen( [sys.executable, FILE_PATH, 'server', 'main_test:service', '--host', '0.0.0.0', '--grpc-port', '50001'], cwd=__file__.rsplit('/', 1)[0], - env={'PYTHONPATH': PYTHON_PATH} + env=env ) time.sleep(2) grpc_client = Client(host='0.0.0.0', port=50001) @@ -92,7 +97,7 @@ def test_repo(self): '--param', 'none', 'model_name=resnet34', ], cwd=__file__.rsplit('/', 1)[0], - env={'PYTHONPATH': PYTHON_PATH} + env=env ) while atp < 100: diff --git a/towhee/doc/source/operator/hub_ops.rst b/towhee/doc/source/operator/hub_ops.rst index 562105e32f..4b08e17f0a 100644 --- a/towhee/doc/source/operator/hub_ops.rst +++ b/towhee/doc/source/operator/hub_ops.rst @@ -26,3 +26,8 @@ HubOps :show-inheritance: :member-order: bysource +.. autoclass:: towhee.runtime.hub_ops.ann_search.AnnSearch + :members: + :show-inheritance: + :member-order: bysource + diff --git a/towhee/runtime/hub_ops/ann_insert.py b/towhee/runtime/hub_ops/ann_insert.py index c377947434..90907b2e2b 100644 --- a/towhee/runtime/hub_ops/ann_insert.py +++ b/towhee/runtime/hub_ops/ann_insert.py @@ -128,7 +128,7 @@ class AnnInsert: p = ( pipe.input('collection_name', 'vec') - .map(('collection_name', 'vec'), (), ops.ann_insert.osschat_milvus(host='127.0.0.1', port='19530')) + .map(('collection_name', 'vec'), (), ops.ann_insert.milvus_multi_collections(host='127.0.0.1', port='19530')) .output() ) diff --git a/towhee/runtime/hub_ops/ann_search.py b/towhee/runtime/hub_ops/ann_search.py new file mode 100644 index 0000000000..a0453e9256 --- /dev/null +++ b/towhee/runtime/hub_ops/ann_search.py @@ -0,0 +1,139 @@ +# Copyright 2023 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +from towhee.runtime.factory import HubOp + + +class AnnSearch: + """ + The ANN search operator is used to find the closest (or most similar) + point to a given point in a given set, i.e. find similar embeddings. + """ + + faiss_index: HubOp = HubOp('ops.ann_search.faiss') + """ + Only for local test. If you want to use a vector database in a production environment, + you can use Milvus(https://github.com/milvus-io/milvus). + + __init__(self, data_dir: str, top_k: int = 5) + data_dir(`str`): + Path to store data. + top_k(`int`): + top_k similar data + + __call__(self, query: 'ndarray') -> List[Tuple[id: int, score: float, meta: dict] + query(`ndarray`): + query embedding + + Example; + + .. code-block:: python + + from towhee import pipe, ops + + p = ( + pipe.input('vec') + .flat_map('vec', 'rows', ops.ann_search.faiss('./data_dir', 5)) + .map('rows', ('id', 'score'), lambda x: (x[0], x[1])) + .output('id', 'score') + ) + + p() + """ + + milvus_client: HubOp = HubOp('ann_search.milvus_client') + """ + Search embedding in Milvus, please make sure you have inserted data to Milvus Collection. + + __init__(self, host: str = 'localhost', port: int = 19530, collection_name: str = None, + user: str = None, password: str = None, **kwargs) + host(`str`): + The host for Milvus. + port(`str`): + The port for Milvus. + collection_name(`str`): + The collection name for Milvus. + user(`str`) + The user for Zilliz Cloud, defaults to None. + password(`str`): + The password for Zilliz Cloud, defaults to None. + kwargs(`dict`): + The same with pymilvus search: https://milvus.io/docs/search.md + + __call__(self, query: 'ndarray') -> List[Tuple] + query(`ndarray`): + query embedding + + Example: + + .. code-block:: python + + from towhee import pipe, ops, DataCollection + + p = ( + pipe.input('text') + .map('text', 'vec', ops.sentence_embedding.transformers(model_name='all-MiniLM-L12-v2')) + .flat_map('vec', 'rows', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', + collection_name='text_db2', **{'output_fields': ['text']})) + .map('rows', ('id', 'score', 'text'), lambda x: (x[0], x[1], x[2])) + .output('id', 'score', 'text') + ) + + DataCollection(p('cat')).show() + + """ + + milvus_multi_collections: HubOp = HubOp('ann_search.osschat_milvus') + """ + `milvus_multi_collections `_ A client that can access multiple collections. + + __init__(self, host: str = 'localhost', port: int = 19530, + user: str = None, password: str = None, **kwargs): + host(`str`): + The host for Milvus. + port(`str`): + The port for Milvus. + user(`str`) + The user for Zilliz Cloud, defaults to None. + password(`str`): + The password for Zilliz Cloud, defaults to None. + kwargs(`dict`): + The same with pymilvus search: https://milvus.io/docs/search.md + + __call__(self, collection_name: str, query: 'ndarray') -> List[Tuple] + collection_name(`str`): + The collection name for Milvus. + query(`ndarray`): + query embedding + + Example: + + .. code-block:: python + + from towhee import pipe, ops, DataCollection + + p = ( + pipe.input('text') + .map('text', 'vec', ops.sentence_embedding.transformers(model_name='all-MiniLM-L12-v2')) + .flat_map('vec', 'rows', ops.ann_search.milvus_multi_collections(host='127.0.0.1', port='19530', **{'output_fields': ['text']})) + .map('rows', ('id', 'score', 'text'), lambda x: (x[0], x[1], x[2])) + .output('id', 'score', 'text') + ) + + DataCollection(p('cat')).show() + """ + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return HubOp('towhee.ann_search')(*args, **kwds) diff --git a/towhee/runtime/hub_ops/operator_parser.py b/towhee/runtime/hub_ops/operator_parser.py index 7d99aa5005..23c1eae099 100644 --- a/towhee/runtime/hub_ops/operator_parser.py +++ b/towhee/runtime/hub_ops/operator_parser.py @@ -18,6 +18,7 @@ from .sentence_embedding import SentenceEmbedding from .data_source import DataSource from .ann_insert import AnnInsert +from .ann_search import AnnSearch class Ops: @@ -73,6 +74,12 @@ class Ops: The ANN Insert Operator is used to insert embeddings and create ANN indexes for fast similarity searches. """ + ann_search: AnnSearch = AnnSearch() + """ + The ANN search operator is used to find the closest (or most similar) + point to a given point in a given set, i.e. find similar embeddings. + """ + @classmethod def __getattr__(cls, name): @ops_parse