Skip to content

Commit

Permalink
Merge pull request #120 from RohanArepally/master
Browse files Browse the repository at this point in the history
Fix bug with querying rows that are nested containers
  • Loading branch information
maximdanilchenko authored Oct 8, 2024
2 parents 81a9253 + 838d614 commit 6e23527
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 11 deletions.
9 changes: 5 additions & 4 deletions aiochclient/_types.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ cdef class StrType:
self.container = container

cdef str _convert(self, str string):
string = decode(string.encode())
if self.container:
return remove_single_quotes(string)
return string
Expand Down Expand Up @@ -524,7 +525,7 @@ cdef class TupleType:

cdef tuple _convert(self, str string):
return tuple(
tp(decode(val.encode()))
tp(val)
for tp, val in zip(self.types, seq_parser(string[1:-1]))
)

Expand Down Expand Up @@ -554,7 +555,7 @@ cdef class MapType:
cdef dict _convert(self, str string):
key, value = string[1:-1].split(':', 1)
return {
self.key_type.p_type(decode(key.encode())): self.value_type.p_type(decode(value.encode()))
self.key_type.p_type(key): self.value_type.p_type(value)
}

cpdef dict p_type(self, string):
Expand All @@ -579,7 +580,7 @@ cdef class ArrayType:
)

cdef list _convert(self, str string):
return [self.type.p_type(decode(val.encode())) for val in seq_parser(string[1:-1])]
return [self.type.p_type(val) for val in seq_parser(string[1:-1])]

cpdef list p_type(self, str string):
return self._convert(string)
Expand Down Expand Up @@ -611,7 +612,7 @@ cdef class NestedType:
for val in seq_parser(string[1:-1]):
temp = []
for tp, elem in zip(self.types, seq_parser(val.strip("()"))):
temp.append(tp.p_type(decode(elem.encode())))
temp.append(tp.p_type(elem))
result.append(tuple(temp))
return result

Expand Down
12 changes: 6 additions & 6 deletions aiochclient/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def unconvert(value) -> bytes:

class StrType(BaseType):
def p_type(self, string: str) -> str:
string = self.decode(string.encode())
if self.container:
return remove_single_quotes(string)
return string
Expand Down Expand Up @@ -299,7 +300,7 @@ def __init__(self, name: str, **kwargs):

def p_type(self, string: str) -> tuple:
return tuple(
tp.p_type(self.decode(val.encode()))
tp.p_type(val)
for tp, val in zip(self.types, self.seq_parser(string.strip("()")))
)

Expand All @@ -324,9 +325,8 @@ def __init__(self, name: str, **kwargs):
def p_type(self, string: str) -> dict:
key, value = string[1:-1].split(':', 1)
return {
self.key_type.p_type(self.decode(key.encode())): self.value_type.p_type(
self.decode(value.encode())
)
self.key_type.p_type(key): self.value_type.p_type(value)

}

def convert(self, value: bytes) -> dict:
Expand All @@ -350,7 +350,7 @@ def __init__(self, name: str, **kwargs):

def p_type(self, string: str) -> list:
return [
self.type.p_type(self.decode(val.encode()))
self.type.p_type(val)
for val in self.seq_parser(string[1:-1])
]

Expand All @@ -375,7 +375,7 @@ def __init__(self, name: str, **kwargs):
def p_type(self, string: str) -> List[tuple]:
return [
tuple(
tp.p_type(self.decode(elem.encode()))
tp.p_type(elem)
for tp, elem in zip(self.types, self.seq_parser(val.strip("()")))
)
for val in self.seq_parser(string[1:-1])
Expand Down
18 changes: 17 additions & 1 deletion tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def rows(uuid):
["hello", "world"],
["hello", "world"],
["hello", None],
[("hello\'", 3, "hello")],
"'\b\f\r\n\t\\",
uuid,
[uuid, uuid, uuid],
Expand Down Expand Up @@ -105,6 +106,7 @@ def rows(uuid):
[],
[],
[],
[],
"'\b\f\r\n\t\\",
None,
[],
Expand Down Expand Up @@ -198,6 +200,7 @@ async def all_types_db(chclient, rows):
array_string Array(String),
array_low_cardinality_string Array(LowCardinality(String)),
array_nullable_string Array(Nullable(String)),
array_tuple Array(Tuple(String, UInt8, String)),
escape_string String,
uuid Nullable(UUID),
array_uuid Array(UUID),
Expand Down Expand Up @@ -258,7 +261,7 @@ async def all_types_db(chclient, rows):
def class_chclient(chclient, all_types_db, rows, request):
request.cls.ch = chclient
cls_rows = rows
cls_rows[1][44] = dt.datetime(
cls_rows[1][45] = dt.datetime(
2019, 1, 1, 3, 0
) # DateTime64 always returns datetime type
request.cls.rows = [tuple(r) for r in cls_rows]
Expand Down Expand Up @@ -676,6 +679,19 @@ async def test_array_string(self):
record = await self.select_record_bytes("array_string")
assert record[0] == result
assert record["array_string"] == result

async def test_array_tuple(self):
result = [("hello'", 3, "hello")]
assert await self.select_field("array_tuple") == result
record = await self.select_record("array_tuple")
assert record[0] == result
assert record["array_tuple"] == result

result = b"[('hello\\'',3,'hello')]"
assert await self.select_field_bytes("array_tuple") == result
record = await self.select_record_bytes("array_tuple")
assert record[0] == result
assert record["array_tuple"] == result

async def test_array_low_cardinality_string(self):
result = ["hello", "world"]
Expand Down

0 comments on commit 6e23527

Please sign in to comment.