Skip to content

Commit

Permalink
Add query_arrow_stream (#321)
Browse files Browse the repository at this point in the history
  • Loading branch information
NotSimone authored Mar 24, 2024
1 parent 7e09978 commit 803d572
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 8 deletions.
40 changes: 37 additions & 3 deletions clickhouse_connect/driver/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from clickhouse_connect.driver.insert import InsertContext
from clickhouse_connect.driver.summary import QuerySummary
from clickhouse_connect.driver.models import ColumnDef, SettingDef, SettingStatus
from clickhouse_connect.driver.query import QueryResult, to_arrow, QueryContext, arrow_buffer, quote_identifier
from clickhouse_connect.driver.query import QueryResult, to_arrow, to_arrow_batches, QueryContext, arrow_buffer, quote_identifier

io.DEFAULT_BUFFER_SIZE = 1024 * 256
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -255,7 +255,8 @@ def raw_query(self, query: str,
settings: Optional[Dict[str, Any]] = None,
fmt: str = None,
use_database: bool = True,
external_data: Optional[ExternalData] = None) -> bytes:
external_data: Optional[ExternalData] = None,
stream: bool = False) -> Union[bytes, io.IOBase]:
"""
Query method that simply returns the raw ClickHouse format bytes
:param query: Query statement/format string
Expand Down Expand Up @@ -348,7 +349,7 @@ def query_df_stream(self,
"""
Query method that returns the results as a StreamContext. For parameter values, see the
create_query_context method
:return: Pandas dataframe representing the result set
:return: Generator that yields a Pandas dataframe per block representing the result set
"""
return self._context_query(locals(), use_numpy=True,
as_pandas=True,
Expand Down Expand Up @@ -479,6 +480,39 @@ def query_arrow(self,
fmt='Arrow',
external_data=external_data))

def query_arrow_stream(self,
query: str,
parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
settings: Optional[Dict[str, Any]] = None,
use_strings: Optional[bool] = None,
external_data: Optional[ExternalData] = None) -> StreamContext:
"""
Query method that returns the results as a stream of Arrow tables
:param query: Query statement/format string
:param parameters: Optional dictionary used to format the query
:param settings: Optional dictionary of ClickHouse settings (key/string values)
:param use_strings: Convert ClickHouse String type to Arrow string type (instead of binary)
:param external_data ClickHouse "external data" to send with query
:return: Generator that yields a PyArrow.Table for per block representing the result set
"""
settings = dict_copy(settings)
if self.database:
settings['database'] = self.database
str_status = self._setting_status(arrow_str_setting)
if use_strings is None:
if str_status.is_writable and not str_status.is_set:
settings[arrow_str_setting] = '1' # Default to returning strings if possible
elif use_strings != str_status.is_set:
if not str_status.is_writable:
raise OperationalError(f'Cannot change readonly {arrow_str_setting} to {use_strings}')
settings[arrow_str_setting] = '1' if use_strings else '0'
return to_arrow_batches(self.raw_query(query,
parameters,
settings,
fmt='ArrowStream',
external_data=external_data,
stream=True))

@abstractmethod
def command(self,
cmd: str,
Expand Down
10 changes: 7 additions & 3 deletions clickhouse_connect/driver/httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,11 @@ def ping(self):

def raw_query(self, query: str,
parameters: Optional[Union[Sequence, Dict[str, Any]]] = None,
settings: Optional[Dict[str, Any]] = None, fmt: str = None,
use_database: bool = True, external_data: Optional[ExternalData] = None) -> bytes:
settings: Optional[Dict[str, Any]] = None,
fmt: str = None,
use_database: bool = True,
external_data: Optional[ExternalData] = None,
stream: bool = False) -> Union[bytes, HTTPResponse]:
"""
See BaseClient doc_string for this method
"""
Expand All @@ -469,7 +472,8 @@ def raw_query(self, query: str,
else:
body = final_query
fields = None
return self._raw_request(body, params, fields=fields).data
response = self._raw_request(body, params, fields=fields, stream=stream)
return response if stream else response.data

def close(self):
if self._owns_pool_manager:
Expand Down
7 changes: 7 additions & 0 deletions clickhouse_connect/driver/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytz

from enum import Enum
from io import IOBase
from typing import Any, Tuple, Dict, Sequence, Optional, Union, Generator
from datetime import date, datetime, tzinfo

Expand Down Expand Up @@ -489,6 +490,12 @@ def to_arrow(content: bytes):
return reader.read_all()


def to_arrow_batches(buffer: IOBase) -> StreamContext:
pyarrow = check_arrow()
reader = pyarrow.ipc.open_stream(buffer)
return StreamContext(buffer, reader)


def arrow_buffer(table) -> Tuple[Sequence[str], bytes]:
pyarrow = check_arrow()
sink = pyarrow.BufferOutputStream()
Expand Down
36 changes: 34 additions & 2 deletions tests/integration_tests/test_arrow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import date
from typing import Callable
import string

import pytest

Expand All @@ -21,8 +22,9 @@ def test_arrow(test_client: Client, table_context: Callable):
result_table = test_client.query_arrow('SELECT * FROM test_arrow_insert', use_strings=False)
arrow_schema = result_table.schema
assert arrow_schema.field(0).name == 'animal'
assert arrow_schema.field(0).type.id == 14
assert arrow_schema.field(1).type.bit_width == 64
assert arrow_schema.field(0).type == arrow.binary()
assert arrow_schema.field(1).name == 'legs'
assert arrow_schema.field(1).type == arrow.int64()
# pylint: disable=no-member
assert arrow.compute.sum(result_table['legs']).as_py() == 111
assert len(result_table.columns) == 2
Expand All @@ -35,6 +37,36 @@ def test_arrow(test_client: Client, table_context: Callable):
assert arrow_table.num_rows == 500


def test_arrow_stream(test_client: Client, table_context: Callable):
if not arrow:
pytest.skip('PyArrow package not available')
if not test_client.min_version('21'):
pytest.skip(f'PyArrow is not supported in this server version {test_client.server_version}')
with table_context('test_arrow_insert', ['counter Int64', 'letter String']):
counter = arrow.array(range(1000000))
alphabet = string.ascii_lowercase
letter = arrow.array([alphabet[x % 26] for x in range(1000000)])
names = ['counter', 'letter']
insert_table = arrow.Table.from_arrays([counter, letter], names=names)
test_client.insert_arrow('test_arrow_insert', insert_table)
stream = test_client.query_arrow_stream('SELECT * FROM test_arrow_insert', use_strings=True)
with stream:
result_tables = list(stream)
# Hopefully we made the table long enough we got multiple tables in the query
assert len(result_tables) > 1
total_rows = 0
for table in result_tables:
assert table.num_columns == 2
arrow_schema = table.schema
assert arrow_schema.field(0).name == 'counter'
assert arrow_schema.field(0).type == arrow.int64()
assert arrow_schema.field(1).name == 'letter'
assert arrow_schema.field(1).type == arrow.string()
assert table.column(1)[0].as_py() == alphabet[table.column(0)[0].as_py() % 26]
total_rows += table.num_rows
assert total_rows == 1000000


def test_arrow_map(test_client: Client, table_context: Callable):
if not arrow:
pytest.skip('PyArrow package not available')
Expand Down

0 comments on commit 803d572

Please sign in to comment.