Skip to content

Commit

Permalink
Implement support for vector type (#439)
Browse files Browse the repository at this point in the history
Vectors get decoded into array.array. Encoding supports any list-like
array of numbers, but has an optimized fast path for things like array
and ndarray that avoids needing to box integers.
  • Loading branch information
msullivan authored and fantix committed Jun 18, 2023
1 parent ec90e35 commit 0bee718
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 3 deletions.
107 changes: 104 additions & 3 deletions edgedb/protocol/codecs/codecs.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
#


import array
import decimal
import uuid
import datetime
from edgedb import describe
from edgedb import enums
from edgedb.datatypes import datatypes

from libc.string cimport memcpy


include "./edb_types.pxi"

Expand Down Expand Up @@ -347,14 +350,16 @@ cdef dict BASE_SCALAR_CODECS = {}
cdef register_base_scalar_codec(
str name,
pgproto.encode_func encoder,
pgproto.decode_func decoder):
pgproto.decode_func decoder,
object tid = None):

cdef:
BaseCodec codec

tid = TYPE_IDS.get(name)
if tid is None:
raise RuntimeError(f'cannot find known ID for type {name!r}')
tid = TYPE_IDS.get(name)
if tid is None:
raise RuntimeError(f'cannot find known ID for type {name!r}')
tid = tid.bytes

if tid in BASE_SCALAR_CODECS:
Expand Down Expand Up @@ -510,6 +515,94 @@ cdef config_memory_decode(pgproto.CodecContext settings, FRBuffer *buf):
return datatypes.ConfigMemory(bytes=bytes)


DEF PGVECTOR_MAX_DIM = (1 << 16) - 1


cdef pgvector_encode_memview(pgproto.CodecContext settings, WriteBuffer buf,
float[:] obj):
cdef:
float item
Py_ssize_t objlen
Py_ssize_t i

objlen = len(obj)
if objlen > PGVECTOR_MAX_DIM:
raise ValueError('too many elements in vector value')

buf.write_int32(4 + objlen*4)
buf.write_int16(objlen)
buf.write_int16(0)
for i in range(objlen):
buf.write_float(obj[i])


cdef pgvector_encode(pgproto.CodecContext settings, WriteBuffer buf,
object obj):
cdef:
float item
Py_ssize_t objlen
float[:] memview
Py_ssize_t i

# If we can take a typed memview of the object, we use that.
# That is good, because it means we can consume array.array and
# numpy.ndarray without needing to unbox.
# Otherwise we take the slow path, indexing into the array using
# the normal protocol.
try:
memview = obj
except (ValueError, TypeError) as e:
pass
else:
pgvector_encode_memview(settings, buf, memview)
return

if not _is_array_iterable(obj):
raise TypeError(
'a sized iterable container expected (got type {!r})'.format(
type(obj).__name__))

# Annoyingly, this is literally identical code to the fast path...
# but the types are different in critical ways.
objlen = len(obj)
if objlen > PGVECTOR_MAX_DIM:
raise ValueError('too many elements in vector value')

buf.write_int32(4 + objlen*4)
buf.write_int16(objlen)
buf.write_int16(0)
for i in range(objlen):
buf.write_float(obj[i])


cdef object ONE_EL_ARRAY = array.array('f', [0.0])


cdef pgvector_decode(pgproto.CodecContext settings, FRBuffer *buf):
cdef:
int32_t dim
Py_ssize_t size
Py_buffer view
char *p
float[:] array_view

dim = hton.unpack_uint16(frb_read(buf, 2))
frb_read(buf, 2)

size = dim * 4
p = frb_read(buf, size)

# Create a float array with size dim
val = ONE_EL_ARRAY * dim

# And fill it with the buffer contents
array_view = val
memcpy(&array_view[0], p, size)
val.byteswap()

return val


cdef checked_decimal_encode(
pgproto.CodecContext settings, WriteBuffer buf, obj
):
Expand Down Expand Up @@ -708,4 +801,12 @@ cdef register_base_scalar_codecs():
config_memory_decode)


register_base_scalar_codec(
'ext::pgvector::vector',
pgvector_encode,
pgvector_decode,
uuid.UUID('9565dd88-04f5-11ee-a691-0b6ebe179825'),
)


register_base_scalar_codecs()
141 changes: 141 additions & 0 deletions tests/test_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2019-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from edgedb import _testbase as tb
import edgedb

import array


# An array.array subtype where indexing doesn't work.
# We use this to verify that the non-boxing memoryview based
# fast path works, since the slow path won't work on this object.
class brokenarray(array.array):
def __getitem__(self, i):
raise AssertionError("the fast path wasn't used!")


class TestVector(tb.SyncQueryTestCase):
def setUp(self):
super().setUp()

if not self.client.query_required_single('''
select exists (
select sys::ExtensionPackage filter .name = 'vector'
)
'''):
self.skipTest("feature not implemented")

self.client.execute('''
create extension vector version '1.0'
''')

def tearDown(self):
try:
self.client.execute('''
drop extension vector version '1.0'
''')
finally:
super().tearDown()

async def test_vector_01(self):
# if not self.client.query_required_single('''
# select exists (
# select sys::ExtensionPackage filter .name = 'vector'
# )
# '''):
# self.skipTest("feature not implemented")

# self.client.execute('''
# create extension vector version '1.0'
# ''')

val = self.client.query_single('''
select <vector::vector>'[1.5,2.0,3.8]'
''')
self.assertTrue(isinstance(val, array.array))
self.assertEqual(val, array.array('f', [1.5, 2.0, 3.8]))

val = self.client.query_single(
'''
select <str><vector::vector>$0
''',
[3.0, 9.0, -42.5],
)
self.assertEqual(val, '[3,9,-42.5]')

val = self.client.query_single(
'''
select <str><vector::vector>$0
''',
array.array('f', [3.0, 9.0, -42.5])
)
self.assertEqual(val, '[3,9,-42.5]')

val = self.client.query_single(
'''
select <str><vector::vector>$0
''',
array.array('i', [1, 2, 3]),
)
self.assertEqual(val, '[1,2,3]')

# Test that the fast-path works: if the encoder tries to
# call __getitem__ on this brokenarray, it will fail.
val = self.client.query_single(
'''
select <str><vector::vector>$0
''',
brokenarray('f', [3.0, 9.0, -42.5])
)
self.assertEqual(val, '[3,9,-42.5]')

# I don't think it's worth adding a dependency to test this,
# but this works too:
# import numpy as np
# val = self.client.query_single(
# '''
# select <str><vector::vector>$0
# ''',
# np.asarray([3.0, 9.0, -42.5], dtype=np.float32),
# )

# Some sad path tests
with self.assertRaises(edgedb.InvalidArgumentError):
self.client.query_single(
'''
select <vector::vector>$0
''',
[3.0, None, -42.5],
)

with self.assertRaises(edgedb.InvalidArgumentError):
self.client.query_single(
'''
select <vector::vector>$0
''',
[3.0, 'x', -42.5],
)

with self.assertRaises(edgedb.InvalidArgumentError):
self.client.query_single(
'''
select <vector::vector>$0
''',
'foo',
)

0 comments on commit 0bee718

Please sign in to comment.