Skip to content

Commit

Permalink
Added pgx codec for halfvec
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Jul 24, 2024
1 parent f07e20f commit d82eb00
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 12 deletions.
29 changes: 18 additions & 11 deletions halfvec.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql/driver"
"encoding/json"
"fmt"
"slices"
"strconv"
"strings"
)
Expand All @@ -26,17 +27,9 @@ func (v HalfVector) Slice() []float32 {

// String returns a string representation of the half vector.
func (v HalfVector) String() string {
buf := make([]byte, 0, 2+16*len(v.vec))
buf = append(buf, '[')

for i := 0; i < len(v.vec); i++ {
if i > 0 {
buf = append(buf, ',')
}
buf = strconv.AppendFloat(buf, float64(v.vec[i]), 'f', -1, 32)
}

buf = append(buf, ']')
// should never throw an error
// but returning an empty string is fine if it does
buf, _ := v.EncodeText(nil)
return string(buf)
}

Expand All @@ -54,6 +47,20 @@ func (v *HalfVector) Parse(s string) error {
return nil
}

// EncodeText encodes a text representation of the half vector.
func (v HalfVector) EncodeText(buf []byte) (newBuf []byte, err error) {
buf = slices.Grow(buf, 2+16*len(v.vec))
buf = append(buf, '[')
for i := 0; i < len(v.vec); i++ {
if i > 0 {
buf = append(buf, ',')
}
buf = strconv.AppendFloat(buf, float64(v.vec[i]), 'f', -1, 32)
}
buf = append(buf, ']')
return buf, nil
}

// statically assert that HalfVector implements sql.Scanner.
var _ sql.Scanner = (*HalfVector)(nil)

Expand Down
83 changes: 83 additions & 0 deletions pgx/halfvec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package pgx

import (
"database/sql/driver"
"fmt"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/pgvector/pgvector-go"
)

type HalfVectorCodec struct{}

func (HalfVectorCodec) FormatSupported(format int16) bool {
return format == pgx.TextFormatCode
}

func (HalfVectorCodec) PreferredFormat() int16 {
return pgx.TextFormatCode
}

func (HalfVectorCodec) PlanEncode(m *pgtype.Map, oid uint32, format int16, value any) pgtype.EncodePlan {
_, ok := value.(pgvector.HalfVector)
if !ok {
return nil
}

if format == pgx.TextFormatCode {
return encodePlanHalfVectorCodecText{}
}

return nil
}

type encodePlanHalfVectorCodecText struct{}

func (encodePlanHalfVectorCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) {
v := value.(pgvector.HalfVector)
return v.EncodeText(buf)
}

type scanPlanHalfVectorCodecText struct{}

func (HalfVectorCodec) PlanScan(m *pgtype.Map, oid uint32, format int16, target any) pgtype.ScanPlan {
_, ok := target.(*pgvector.HalfVector)
if !ok {
return nil
}

if format == pgx.TextFormatCode {
return scanPlanHalfVectorCodecText{}
}

return nil
}

func (scanPlanHalfVectorCodecText) Scan(src []byte, dst any) error {
v := (dst).(*pgvector.HalfVector)
return v.Scan(src)
}

func (c HalfVectorCodec) DecodeDatabaseSQLValue(m *pgtype.Map, oid uint32, format int16, src []byte) (driver.Value, error) {
return c.DecodeValue(m, oid, format, src)
}

func (c HalfVectorCodec) DecodeValue(m *pgtype.Map, oid uint32, format int16, src []byte) (any, error) {
if src == nil {
return nil, nil
}

var vec pgvector.HalfVector
scanPlan := c.PlanScan(m, oid, format, &vec)
if scanPlan == nil {
return nil, fmt.Errorf("Unable to decode halfvec type")
}

err := scanPlan.Scan(src, &vec)
if err != nil {
return nil, err
}

return vec, nil
}
7 changes: 6 additions & 1 deletion pgx/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ import (

func RegisterTypes(ctx context.Context, conn *pgx.Conn) error {
var vectorOid *uint32
var halfvecOid *uint32
var sparsevecOid *uint32
err := conn.QueryRow(ctx, "SELECT to_regtype('vector')::oid, to_regtype('sparsevec')::oid").Scan(&vectorOid, &sparsevecOid)
err := conn.QueryRow(ctx, "SELECT to_regtype('vector')::oid, to_regtype('halfvec')::oid, to_regtype('sparsevec')::oid").Scan(&vectorOid, &halfvecOid, &sparsevecOid)
if err != nil {
return err
}
Expand All @@ -23,6 +24,10 @@ func RegisterTypes(ctx context.Context, conn *pgx.Conn) error {
tm := conn.TypeMap()
tm.RegisterType(&pgtype.Type{Name: "vector", OID: *vectorOid, Codec: &VectorCodec{}})

if halfvecOid != nil {
tm.RegisterType(&pgtype.Type{Name: "halfvec", OID: *halfvecOid, Codec: &HalfVectorCodec{}})
}

if sparsevecOid != nil {
tm.RegisterType(&pgtype.Type{Name: "sparsevec", OID: *sparsevecOid, Codec: &SparseVectorCodec{}})
}
Expand Down

0 comments on commit d82eb00

Please sign in to comment.