Skip to content

Commit

Permalink
Added encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 27, 2024
1 parent c523ae9 commit 551229a
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 23 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## 0.1.1 (unreleased)
## 0.2.0 (unreleased)

- Added `Vector`, `HalfVector`, `Bit`, and `SparseVector` classes

Expand Down
18 changes: 13 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ Import the library
import 'package:pgvector/pgvector.dart';
```

Add the encoder

```dart
var connection = await Connection.open(endpoint,
settings: ConnectionSettings(
typeRegistry: TypeRegistry(encoders: [pgvectorEncoder])));
```

Enable the extension

```dart
Expand All @@ -49,9 +57,9 @@ Insert vectors
await connection.execute(
Sql.named('INSERT INTO items (embedding) VALUES (@a), (@b), (@c)'),
parameters: {
'a': pgvector.encode([1, 1, 1]),
'b': pgvector.encode([2, 2, 2]),
'c': pgvector.encode([1, 1, 2])
'a': Vector([1, 1, 1]),
'b': Vector([2, 2, 2]),
'c': Vector([1, 1, 2])
});
```

Expand All @@ -61,11 +69,11 @@ Get the nearest neighbors
List<List<dynamic>> results = await connection.execute(
Sql.named('SELECT id, embedding FROM items ORDER BY embedding <-> @embedding LIMIT 5'),
parameters: {
'embedding': pgvector.encode([1, 1, 1])
'embedding': Vector([1, 1, 1])
});
for (final row in results) {
print(row[0]);
print(pgvector.decode(row[1].bytes));
print(Vector.fromBinary(row[1].bytes));
}
```

Expand Down
6 changes: 4 additions & 2 deletions examples/openai/example.dart
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ void main() async {
port: 5432,
database: 'pgvector_example',
username: Platform.environment['USER']),
settings: ConnectionSettings(sslMode: SslMode.disable));
settings: ConnectionSettings(
sslMode: SslMode.disable,
typeRegistry: TypeRegistry(encoders: [pgvectorEncoder])));

await connection.execute('CREATE EXTENSION IF NOT EXISTS vector');

Expand All @@ -52,7 +54,7 @@ void main() async {
'INSERT INTO documents (content, embedding) VALUES (@content, @embedding)'),
parameters: {
'content': input[i],
'embedding': pgvector.encode(List<double>.from(embeddings[i]))
'embedding': Vector(List<double>.from(embeddings[i]))
});
}

Expand Down
1 change: 1 addition & 0 deletions lib/pgvector.dart
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export 'src/bit.dart' show Bit;
export 'src/halfvec.dart' show HalfVector;
export 'src/pgvector.dart' show pgvector;
export 'src/postgres.dart' show pgvectorEncoder;
export 'src/sparsevec.dart' show SparseVector;
export 'src/vector.dart' show Vector;
32 changes: 32 additions & 0 deletions lib/src/postgres.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import 'dart:convert';
import 'package:postgres/postgres.dart';
import 'bit.dart';
import 'halfvec.dart';
import 'sparsevec.dart';
import 'vector.dart';

EncodedValue? pgvectorEncoder(TypedValue input, CodecContext context) {
final value = input.value;

if (value is Vector) {
final v = value as Vector;
return EncodedValue.binary(v.toBinary());
}

if (value is HalfVector) {
final v = value as HalfVector;
return EncodedValue.text(utf8.encode(v.toString()));
}

if (value is Bit) {
final v = value as Bit;
return EncodedValue.binary(v.toBinary());
}

if (value is SparseVector) {
final v = value as SparseVector;
return EncodedValue.binary(v.toBinary());
}

return null;
}
2 changes: 1 addition & 1 deletion pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ environment:
sdk: ^3.0.0

dependencies:
postgres: ^3.4.0

dev_dependencies:
lints: ^2.0.0
postgres: ^3.0.0
test: ^1.21.0
30 changes: 16 additions & 14 deletions test/postgres_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ void main() {
port: 5432,
database: 'pgvector_dart_test',
username: Platform.environment['USER']),
settings: ConnectionSettings(sslMode: SslMode.disable));
settings: ConnectionSettings(
sslMode: SslMode.disable,
typeRegistry: TypeRegistry(encoders: [pgvectorEncoder])));

await connection.execute('CREATE EXTENSION IF NOT EXISTS vector');
await connection.execute('DROP TABLE IF EXISTS items');
Expand All @@ -23,25 +25,25 @@ void main() {
Sql.named(
'INSERT INTO items (embedding, half_embedding, binary_embedding, sparse_embedding) VALUES (@embedding1, @half_embedding1, @binary_embedding1, @sparse_embedding1), (@embedding2, @half_embedding2, @binary_embedding2, @sparse_embedding2), (@embedding3, @half_embedding3, @binary_embedding3, @sparse_embedding3)'),
parameters: {
'embedding1': Vector([1, 1, 1]).toString(),
'embedding2': Vector([2, 2, 2]).toString(),
'embedding3': Vector([1, 1, 2]).toString(),
'half_embedding1': HalfVector([1, 1, 1]).toString(),
'half_embedding2': HalfVector([2, 2, 2]).toString(),
'half_embedding3': HalfVector([1, 1, 2]).toString(),
'binary_embedding1': '000',
'binary_embedding2': '101',
'binary_embedding3': '111',
'sparse_embedding1': SparseVector([1, 1, 1]).toString(),
'sparse_embedding2': SparseVector([2, 2, 2]).toString(),
'sparse_embedding3': SparseVector([1, 1, 2]).toString()
'embedding1': Vector([1, 1, 1]),
'embedding2': Vector([2, 2, 2]),
'embedding3': Vector([1, 1, 2]),
'half_embedding1': HalfVector([1, 1, 1]),
'half_embedding2': HalfVector([2, 2, 2]),
'half_embedding3': HalfVector([1, 1, 2]),
'binary_embedding1': Bit([false, false, false]),
'binary_embedding2': Bit([true, false, true]),
'binary_embedding3': Bit([true, true, true]),
'sparse_embedding1': SparseVector([1, 1, 1]),
'sparse_embedding2': SparseVector([2, 2, 2]),
'sparse_embedding3': SparseVector([1, 1, 2])
});

List<List<dynamic>> results = await connection.execute(
Sql.named(
'SELECT id, embedding, binary_embedding, sparse_embedding FROM items ORDER BY embedding <-> @embedding LIMIT 5'),
parameters: {
'embedding': Vector([1, 1, 1]).toString()
'embedding': Vector([1, 1, 1])
});
expect(results.map((r) => r[0]), equals([1, 3, 2]));
expect(Vector.fromBinary(results[1][1].bytes), equals(Vector([1, 1, 2])));
Expand Down

0 comments on commit 551229a

Please sign in to comment.