Skip to content

Commit

Permalink
#469 Support Complex type -- part 1 (#471)
Browse files Browse the repository at this point in the history
  • Loading branch information
sitingren authored Nov 11, 2022
1 parent b66ad1a commit dfbd90a
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 58 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['2.7', '3.7', '3.8', '3.9', '3.10', 'pypy-3.9']
python-version: ['2.7', '3.7', '3.8', '3.9', '3.10', 'pypy3.9']

steps:
- name: Check out repository
uses: actions/checkout@v2
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Set up a Vertica server
Expand Down
4 changes: 2 additions & 2 deletions vertica_python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@
version_info = (1, 1, 1)
__version__ = '.'.join(map(str, version_info))

# The protocol version (3.9) implemented in this library.
PROTOCOL_VERSION = 3 << 16 | 9
# The protocol version (3.12) implemented in this library.
PROTOCOL_VERSION = 3 << 16 | 12

apilevel = 2.0
threadsafety = 1 # Threads may share the module, but not connections!
Expand Down
138 changes: 105 additions & 33 deletions vertica_python/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,49 @@ class VerticaType(object):
LONGVARBINARY = 116
BINARY = 117

ROW = 300
ARRAY = 301 # multidimensional
MAP = 302

# one-dimensional array of a primitive type
ARRAY1D_BOOL = 1505
ARRAY1D_INT8 = 1506
ARRAY1D_FLOAT8 = 1507
ARRAY1D_CHAR = 1508
ARRAY1D_VARCHAR = 1509
ARRAY1D_DATE = 1510
ARRAY1D_TIME = 1511
ARRAY1D_TIMESTAMP = 1512
ARRAY1D_TIMESTAMPTZ = 1513
ARRAY1D_INTERVAL = 1514
ARRAY1D_INTERVALYM = 1521
ARRAY1D_TIMETZ = 1515
ARRAY1D_NUMERIC = 1516
ARRAY1D_VARBINARY = 1517
ARRAY1D_UUID = 1520
ARRAY1D_BINARY = 1522
ARRAY1D_LONGVARCHAR = 1519
ARRAY1D_LONGVARBINARY = 1518

SET_BOOL = 2705
SET_INT8 = 2706
SET_FLOAT8 = 2707
SET_CHAR = 2708
SET_VARCHAR = 2709
SET_DATE = 2710
SET_TIME = 2711
SET_TIMESTAMP = 2712
SET_TIMESTAMPTZ = 2713
SET_INTERVAL = 2714
SET_INTERVALYM = 2721
SET_TIMETZ = 2715
SET_NUMERIC = 2716
SET_VARBINARY = 2717
SET_UUID = 2720
SET_BINARY = 2722
SET_LONGVARCHAR = 2719
SET_LONGVARBINARY = 2718

def __init__(self, *values):
self.values = values

Expand Down Expand Up @@ -159,44 +202,73 @@ def __ne__(self, other):
INTERVAL_MASK_HOUR2SEC = INTERVAL_MASK_HOUR | INTERVAL_MASK_MINUTE | INTERVAL_MASK_SECOND
INTERVAL_MASK_MIN2SEC = INTERVAL_MASK_MINUTE | INTERVAL_MASK_SECOND

TYPENAME = {
VerticaType.UNKNOWN: "Unknown",
VerticaType.BOOL: "Boolean",
VerticaType.INT8: "Integer",
VerticaType.FLOAT8: "Float",
VerticaType.CHAR: "Char",
VerticaType.VARCHAR: "Varchar",
VerticaType.LONGVARCHAR: "Long Varchar",
VerticaType.DATE: "Date",
VerticaType.TIME: "Time",
VerticaType.TIMETZ: "TimeTz",
VerticaType.TIMESTAMP: "Timestamp",
VerticaType.TIMESTAMPTZ: "TimestampTz",
VerticaType.BINARY: "Binary",
VerticaType.VARBINARY: "Varbinary",
VerticaType.LONGVARBINARY: "Long Varbinary",
VerticaType.NUMERIC: "Numeric",
VerticaType.UUID: "Uuid",
VerticaType.ROW: "Row",
VerticaType.ARRAY: "Array",
VerticaType.MAP: "Map",
VerticaType.ARRAY1D_BOOL: "Array[Boolean]",
VerticaType.ARRAY1D_INT8: "Array[Int8]",
VerticaType.ARRAY1D_FLOAT8: "Array[Float8]",
VerticaType.ARRAY1D_CHAR: "Array[Char]",
VerticaType.ARRAY1D_VARCHAR: "Array[Varchar]",
VerticaType.ARRAY1D_DATE: "Array[Date]",
VerticaType.ARRAY1D_TIME: "Array[Time]",
VerticaType.ARRAY1D_TIMESTAMP: "Array[Timestamp]",
VerticaType.ARRAY1D_TIMESTAMPTZ: "Array[TimestampTz]",
VerticaType.ARRAY1D_TIMETZ: "Array[TimeTz]",
VerticaType.ARRAY1D_NUMERIC: "Array[Numeric]",
VerticaType.ARRAY1D_VARBINARY: "Array[Varbinary]",
VerticaType.ARRAY1D_UUID: "Array[Uuid]",
VerticaType.ARRAY1D_BINARY: "Array[Binary]",
VerticaType.ARRAY1D_LONGVARCHAR: "Array[Long Varchar]",
VerticaType.ARRAY1D_LONGVARBINARY: "Array[Long Varbinary]",
VerticaType.SET_BOOL: "Set[Boolean]",
VerticaType.SET_INT8: "Set[Int8]",
VerticaType.SET_FLOAT8: "Set[Float8]",
VerticaType.SET_CHAR: "Set[Char]",
VerticaType.SET_VARCHAR: "Set[Varchar]",
VerticaType.SET_DATE: "Set[Date]",
VerticaType.SET_TIME: "Set[Time]",
VerticaType.SET_TIMESTAMP: "Set[Timestamp]",
VerticaType.SET_TIMESTAMPTZ: "Set[TimestampTz]",
VerticaType.SET_TIMETZ: "Set[TimeTz]",
VerticaType.SET_NUMERIC: "Set[Numeric]",
VerticaType.SET_VARBINARY: "Set[Varbinary]",
VerticaType.SET_UUID: "Set[Uuid]",
VerticaType.SET_BINARY: "Set[Binary]",
VerticaType.SET_LONGVARCHAR: "Set[Long Varchar]",
VerticaType.SET_LONGVARBINARY: "Set[Long Varbinary]",
}

def getTypeName(data_type_oid, type_modifier):
"""Returns the base type name according to data_type_oid and type_modifier"""

if data_type_oid == VerticaType.BOOL:
return "Boolean"
elif data_type_oid == VerticaType.INT8:
return "Integer"
elif data_type_oid == VerticaType.FLOAT8:
return "Float"
elif data_type_oid == VerticaType.CHAR:
return "Char"
elif data_type_oid in (VerticaType.VARCHAR, VerticaType.UNKNOWN):
return "Varchar"
elif data_type_oid == VerticaType.LONGVARCHAR:
return "Long Varchar"
elif data_type_oid == VerticaType.DATE:
return "Date"
elif data_type_oid == VerticaType.TIME:
return "Time"
elif data_type_oid == VerticaType.TIMETZ:
return "TimeTz"
elif data_type_oid == VerticaType.TIMESTAMP:
return "Timestamp"
elif data_type_oid == VerticaType.TIMESTAMPTZ:
return "TimestampTz"
if data_type_oid in TYPENAME:
return TYPENAME[data_type_oid]
elif data_type_oid in (VerticaType.INTERVAL, VerticaType.INTERVALYM):
return "Interval " + getIntervalRange(data_type_oid, type_modifier)
elif data_type_oid == VerticaType.BINARY:
return "Binary"
elif data_type_oid == VerticaType.VARBINARY:
return "Varbinary"
elif data_type_oid == VerticaType.LONGVARBINARY:
return "Long Varbinary"
elif data_type_oid == VerticaType.NUMERIC:
return "Numeric"
elif data_type_oid == VerticaType.UUID:
return "Uuid"
elif data_type_oid in (VerticaType.ARRAY1D_INTERVAL, VerticaType.ARRAY1D_INTERVALYM):
# TODO IntervalRange
return "Array[Interval]"
elif data_type_oid in (VerticaType.SET_INTERVAL, VerticaType.SET_INTERVALYM):
# TODO IntervalRange
return "Set[Interval]"
else:
return "Unknown"

Expand Down
8 changes: 8 additions & 0 deletions vertica_python/vertica/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,17 @@ def __init__(self, col):
self.null_ok = col['null_ok']
self.is_identity = col['is_identity']
self.format_code = col['format_code']
self.child_columns = []
self.props = ColumnTuple(self.name, self.type_code, self.display_size, self.internal_size,
self.precision, self.scale, self.null_ok)

def add_child_column(self, col):
"""
Complex types involve multiple columns arranged in a hierarchy of parents and children.
Each parent column stores references to child columns in a list.
"""
self.child_columns.append(col)

def __str__(self):
return as_str(str(self.props))

Expand Down
5 changes: 5 additions & 0 deletions vertica_python/vertica/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def __init__(self, options=None):
self.options['host'], self.options['port']))
self.startup_connection()

# Complex types metadata is returned since protocol version 3.12
self.complex_types_enabled = self.parameters['protocol_version'] >= (3 << 16 | 12) and \
self.parameters.get('request_complex_types', 'off') == 'on'
self._logger.info('Connection is ready')

#############################################
Expand Down Expand Up @@ -672,6 +675,8 @@ def read_message(self):
else:
# The rest of the message is read later with write_to_disk()
message = messages.WriteFile(filename, file_length)
elif type_ == messages.RowDescription.message_id:
message = BackendMessage.from_type(type_, self.read_bytes(size - 4), complex_types_enabled=self.complex_types_enabled)
else:
message = BackendMessage.from_type(type_, self.read_bytes(size - 4))
self._logger.debug('<= %s', message)
Expand Down
8 changes: 4 additions & 4 deletions vertica_python/vertica/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def fetchone(self):
self._message = self.connection.read_message()
return row
elif isinstance(self._message, messages.RowDescription):
self.description = [Column(fd) for fd in self._message.fields]
self.description = self._message.get_description()
self._deserializers = self._des.get_row_deserializers(self.description,
{'unicode_error': self.unicode_error,
'session_tz': self.connection.parameters.get('timezone', 'unknown')})
Expand Down Expand Up @@ -366,7 +366,7 @@ def nextset(self):
# there might be another set, read next message to find out
self._message = self.connection.read_message()
if isinstance(self._message, messages.RowDescription):
self.description = [Column(fd) for fd in self._message.fields]
self.description = self._message.get_description()
self._deserializers = self._des.get_row_deserializers(self.description,
{'unicode_error': self.unicode_error,
'session_tz': self.connection.parameters.get('timezone', 'unknown')})
Expand Down Expand Up @@ -657,7 +657,7 @@ def _execute_simple_query(self, query):
if isinstance(self._message, messages.ErrorResponse):
raise errors.QueryError.from_error_response(self._message, query)
elif isinstance(self._message, messages.RowDescription):
self.description = [Column(fd) for fd in self._message.fields]
self.description = self._message.get_description()
self._deserializers = self._des.get_row_deserializers(self.description,
{'unicode_error': self.unicode_error,
'session_tz': self.connection.parameters.get('timezone', 'unknown')})
Expand Down Expand Up @@ -850,7 +850,7 @@ def _prepare(self, query):
if isinstance(self._message, messages.NoData):
self.description = None # response was NoData for a DDL/transaction PreparedStatement
else:
self.description = [Column(fd) for fd in self._message.fields]
self.description = self._message.get_description()
self._deserializers = self._des.get_row_deserializers(self.description,
{'unicode_error': self.unicode_error,
'session_tz': self.connection.parameters.get('timezone', 'unknown')})
Expand Down
53 changes: 39 additions & 14 deletions vertica_python/vertica/messages/backend_messages/row_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,17 @@
from six.moves import range

from ..message import BackendMessage
from ...column import Column
from ....datatypes import getTypeName


class RowDescription(BackendMessage):
message_id = b'T'

def __init__(self, data):
def __init__(self, data, complex_types_enabled):
BackendMessage.__init__(self)
self.fields = []
field_dict = {}
field_count = unpack('!H', data[0:2])[0]

if field_count == 0:
Expand All @@ -74,7 +76,7 @@ def __init__(self, data):
user_types.append((base_type_oid, type_name.decode('utf-8')))

# read info of each field
offset = calcsize("!HBIHHHiH")
offset = calcsize("!BIHHHiH")
for _ in range(field_count):
field_name = unpack_from("!{0}sx".format(data.find(b'\x00', pos) - pos), data, pos)[0]
pos += len(field_name) + 1
Expand All @@ -92,30 +94,53 @@ def __init__(self, data):
pos += len(table_name) + 1
table_name = table_name.decode('utf-8')

field_info = unpack_from("!HBIhHHiH", data, pos)
attribute_number = unpack_from("!H", data, pos)[0]
pos += 2
if complex_types_enabled:
parent_attribute_number = unpack_from("!H", data, pos)[0]
pos += 2
else:
parent_attribute_number = 0

field_info = unpack_from("!BIhHHiH", data, pos)
pos += offset

if field_info[1] == 1:
data_type_oid, data_type_name = user_types[field_info[2]]
if field_info[0] == 1:
data_type_oid, data_type_name = user_types[field_info[1]]
else:
data_type_oid = field_info[2]
data_type_name = getTypeName(data_type_oid, field_info[6])
data_type_oid = field_info[1]
data_type_name = getTypeName(data_type_oid, field_info[5])

self.fields.append({
# Create a Column object
column = Column({
'name': field_name,
'table_oid': table_oid,
'schema_name': schema_name,
'table_name': table_name,
'attribute_number': field_info[0],
'attribute_number': attribute_number,
'data_type_oid': data_type_oid,
'data_type_size': field_info[3],
'data_type_size': field_info[2],
'data_type_name': data_type_name,
'null_ok': field_info[4] == 1,
'is_identity': field_info[5] == 1,
'type_modifier': field_info[6],
'format_code': field_info[7],
'null_ok': field_info[3] == 1,
'is_identity': field_info[4] == 1,
'type_modifier': field_info[5],
'format_code': field_info[6],
})

# Add every column description to the dict so we can set the parents later
field_dict[(table_oid, attribute_number)] = column
if parent_attribute_number == 0:
self.fields.append(column)
else:
parent_col = field_dict.get((table_oid, parent_attribute_number))
if not parent_col:
raise KeyError("Complex type parent column not found: table_oid={}, attribute_number={}".format(table_oid, parent_attribute_number))
parent_col.add_child_column(column)

def get_description(self):
# return a list of Column objects for Cursor.description
return self.fields

def __str__(self):
return "RowDescription: {}".format(self.fields)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(self, user, database, session_label, os_user_name, autocommit, bina
b'client_pid': pid,
b'autocommit': 'on' if autocommit else 'off',
b'binary_data_protocol': '1' if binary_transfer else '0', # Defaults to text format '0'
b'protocol_features': '{"request_complex_types":true}',
}

def read_bytes(self):
Expand Down
4 changes: 2 additions & 2 deletions vertica_python/vertica/messages/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ class BackendMessage(Message):
_message_id_map = {}

@classmethod
def from_type(cls, type_, data):
def from_type(cls, type_, data, **kwargs):
klass = cls._message_id_map.get(type_)
if klass is not None:
return klass(data)
return klass(data, **kwargs)
else:
from .backend_messages import Unknown
return Unknown(type_, data)
Expand Down

0 comments on commit dfbd90a

Please sign in to comment.