Skip to content

Commit

Permalink
Text Embedding task (#21)
Browse files Browse the repository at this point in the history
* updated and re-ran generators

* added embedding concepts to mediapipe-core

* fixed embedding header file and bindings

* adds text embedding classes to text pkg

* updates example with text embedding

* removed dead file

* added more embedding tests

* added embedding model download to CI script

* touch ups

* Update packages/mediapipe-core/lib/src/io/containers.dart

Co-authored-by: Kate Lovett <[email protected]>

* Update packages/mediapipe-task-text/example/.gitignore

Co-authored-by: Kate Lovett <[email protected]>

* Update packages/mediapipe-task-text/example/lib/text_embedding_demo.dart

Co-authored-by: Kate Lovett <[email protected]>

* moved worker dispose method to base class

* docstring & comment improvements

* throw exceptions in impossible code paths instead of returning null

* class hierarchy improvements

* fixed outdates tests

* cleaned up dispose methods

* various tidying

* fixed deprecation warning

* moves repeated widgets into helper method

---------

Co-authored-by: Kate Lovett <[email protected]>
  • Loading branch information
craiglabenz and Piinks authored Apr 3, 2024
1 parent 4c159a7 commit 4304801
Show file tree
Hide file tree
Showing 56 changed files with 2,368 additions and 189 deletions.
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
headers:
cd tool/builder && dart bin/main.dart headers

# Downloads all necessary task models
models:
cd tool/builder && dart bin/main.dart model -m textclassification
cd tool/builder && dart bin/main.dart model -m textembedding


# Runs `ffigen` for all packages
generate: generate_core generate_text

Expand Down
6 changes: 6 additions & 0 deletions packages/mediapipe-core/lib/generated/core_symbols.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ files:
name: Classifications
c:@S@ClassifierOptions:
name: ClassifierOptions
c:@S@EmbedderOptions:
name: EmbedderOptions
c:@S@Embedding:
name: Embedding
c:@S@EmbeddingResult:
name: EmbeddingResult
c:@S@__darwin_pthread_handler_rec:
name: __darwin_pthread_handler_rec
c:@S@_opaque_pthread_attr_t:
Expand Down
1 change: 1 addition & 0 deletions packages/mediapipe-core/lib/io.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

export 'src/interface/containers.dart' show EmbeddingType;
export 'src/io/mediapipe_core.dart';
1 change: 1 addition & 0 deletions packages/mediapipe-core/lib/mediapipe_core.dart
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
library mediapipe_core;

export 'src/extensions.dart';
export 'src/interface/containers.dart' show EmbeddingType;
export 'universal_mediapipe_core.dart'
if (dart.library.html) 'src/web/mediapipe_core.dart'
if (dart.library.io) 'src/io/mediapipe_core.dart';
77 changes: 75 additions & 2 deletions packages/mediapipe-core/lib/src/interface/containers.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

import 'dart:typed_data';

import 'package:equatable/equatable.dart';

/// {@template Category}
/// Dart representation of MediaPipe's "Category" concept.
///
/// Category is a util class that contains a [categoryName], its [displayName],
/// a float value as [score], and the [index] of the label in the corresponding
/// label file. It is Typically used as result of classification or detection
/// label file. It is typically used as result of classification or detection
/// tasks.
///
/// See more:
Expand Down Expand Up @@ -45,7 +47,7 @@ abstract class BaseCategory extends Equatable {
/// See also:
/// * [MediaPipe's Classifications documentation](https://developers.google.com/mediapipe/api/solutions/java/com/google/mediapipe/tasks/components/containers/Classifications)
/// {@endtemplate}
abstract base class BaseClassifications extends Equatable {
abstract class BaseClassifications extends Equatable {
/// A list of [Category] objects which contain the actual classification
/// information, including human-readable labels and probability scores.
Iterable<BaseCategory> get categories;
Expand All @@ -71,3 +73,74 @@ abstract base class BaseClassifications extends Equatable {
@override
List<Object?> get props => [categories, headIndex, headName];
}

/// Marker for which flavor of analysis was performed for a specific
/// [Embedding] instance.
enum EmbeddingType {
/// Indicates an [Embedding] object has a non-null value for
/// [Embedding.floatEmbedding].
float,

/// Indicates an [Embedding] object has a non-null value for
/// [Embedding.quantizedEmbedding].
quantized;

/// Returns the opposite type.
EmbeddingType get opposite => switch (this) {
EmbeddingType.float => EmbeddingType.quantized,
EmbeddingType.quantized => EmbeddingType.float,
};
}

/// {@template Embedding}
/// Represents the embedding for a given embedder head. Typically used in
/// embedding tasks.
///
/// One and only one of 'floatEmbedding' and 'quantizedEmbedding' will contain
/// data, based on whether or not the embedder was configured to perform scala
/// quantization.
/// {@endtemplate}
abstract class BaseEmbedding extends Equatable {
/// Length of this embedding.
int get length;

/// The index of the embedder head to which these entries refer.
int get headIndex;

/// The optional name of the embedder head, which is the corresponding tensor
/// metadata name.
String? get headName;

/// Floating-point embedding. [null] if the embedder was configured to perform
/// scalar-quantization.
Float32List? get floatEmbedding;

/// Scalar-quantized embedding. [null] if the embedder was not configured to
/// perform scalar quantization.
Uint8List? get quantizedEmbedding;

/// [True] if this embedding came from an embedder configured to perform
/// scalar quantization.
bool get isQuantized => type == EmbeddingType.quantized;

/// [True] if this embedding came from an embedder that was not configured to
/// perform scalar quantization.
bool get isFloat => type == EmbeddingType.float;

/// Indicator for the type of results in this embedding.
EmbeddingType get type;

@override
String toString() {
return 'Embedding(quantizedEmbedding=$quantizedEmbedding, floatEmbedding='
'$floatEmbedding, headIndex=$headIndex, headName=$headName)';
}

@override
List<Object?> get props => [
quantizedEmbedding,
floatEmbedding,
headIndex,
headName,
];
}
42 changes: 41 additions & 1 deletion packages/mediapipe-core/lib/src/interface/task_options.dart
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@ import 'package:equatable/equatable.dart';
/// including a descendent of the universal options struct, [BaseBaseOptions].
/// The second field will be task-specific.
/// {@endtemplate}
///
/// This implementation is not immutable to track whether `dispose` has been
/// called. All values used by pkg:equatable are in fact immutable.
// ignore: must_be_immutable
abstract class BaseTaskOptions extends Equatable {
/// {@macro TaskOptions}
const BaseTaskOptions();
BaseTaskOptions();

/// {@template TaskOptions.baseOptions}
/// Options class shared by all MediaPipe tasks - namely, how to find and load
Expand Down Expand Up @@ -125,3 +129,39 @@ abstract class BaseClassifierOptions extends BaseInnerTaskOptions {
...(categoryDenylist ?? []),
];
}

/// {@template EmbedderOptions}
/// Dart representation of MediaPipe's "EmbedderOptions" concept.
///
/// Embedder options shared across MediaPipe embedding tasks.
///
/// See also:
/// * [MediaPipe's EmbedderOptions documentation](https://developers.google.com/mediapipe/api/solutions/java/com/google/mediapipe/tasks/text/textembedder/TextEmbedder.TextEmbedderOptions)
/// * [BaseOptions], which is often used in conjunction to specify a
/// embedder's desired behavior.
/// {@endtemplate}
abstract class BaseEmbedderOptions extends BaseInnerTaskOptions {
/// {@macro EmbedderOptions}
const BaseEmbedderOptions();

/// Whether to normalize the returned feature vector with L2 norm. Use this
/// option only if the model does not already contain a native L2_NORMALIZATION
/// TF Lite Op. In most cases, this is already the case and L2 norm is thus
/// achieved through TF Lite inference.
///
/// See also:
/// * [TutorialsPoint guide on L2 normalization](https://www.tutorialspoint.com/machine_learning_with_python/machine_learning_with_python_ltwo_normalization.htm)
bool get l2Normalize;

/// Whether the returned embedding should be quantized to bytes via scalar
/// quantization. Embeddings are implicitly assumed to be unit-norm and
/// therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use
/// the l2_normalize option if this is not the case.
///
/// See also:
/// * [l2Normalize]
bool get quantize;

@override
List<Object?> get props => [l2Normalize, quantize];
}
55 changes: 52 additions & 3 deletions packages/mediapipe-core/lib/src/interface/task_result.dart
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,31 @@
// found in the LICENSE file.

import 'package:mediapipe_core/interface.dart';
import 'package:meta/meta.dart';

/// Anchor class for all result objects from MediaPipe tasks.
abstract class TaskResult {
/// {@template TaskResult.dispose}
/// Releases platform memory if any is held.
/// Releases all resources for this object.
///
/// See also:
/// * [isClosed] which tracks whether this method has been called.
/// {@endtemplate}
@mustCallSuper
void dispose() {
_isClosed = true;
}

/// {@template TaskResult.isClosed}
/// Tracks whether this object has been properly released via `dispose`.
///
/// See also:
/// * [dispose], whose calling should set this to `true`.
/// {@endtemplate}
void dispose();
bool get isClosed => _isClosed;

/// Inner tracker for whether [dispose] has been called;
bool _isClosed = false;
}

/// {@template ClassifierResult}
Expand Down Expand Up @@ -37,7 +55,7 @@ abstract class BaseClassifierResult extends TaskResult {
/// Container for classification results that may describe a slice of time
/// within a larger, streaming data source (.e.g, a video or audio file).
/// {@endtemplate}
mixin TimestampedResult {
mixin TimestampedResult on TaskResult {
/// The optional timestamp (as a [Duration]) of the start of the chunk of data
/// corresponding to these results.
///
Expand All @@ -47,3 +65,34 @@ mixin TimestampedResult {
/// input data is split into multiple chunks starting at different timestamps.
Duration? get timestamp;
}

/// {@template EmbeddingResult}
/// Represents the embedding results of a model. Typically used as a result for
/// embedding tasks.
///
/// This flavor of embedding result will never have a timestamp.
///
/// See also:
/// * [TimestampedEmbeddingResult] for data which may have a timestamp.
///
/// {@endtemplate}
abstract class BaseEmbedderResult extends TaskResult {
/// {@macro EmbeddingResult}
BaseEmbedderResult();

/// The embedding results for each head of the model.
Iterable<BaseEmbedding> get embeddings;

@override
String toString() {
return '$runtimeType(embeddings=[...${embeddings.length} items])';
}

/// A [toString] variant that calls the full [toString] on each child
/// embedding. Use with caution - this can produce a long value.
String toStringVerbose() {
final embeddingStrings =
embeddings.map<String>((emb) => emb.toString()).toList().join(', ');
return '$runtimeType(embeddings=[$embeddingStrings])';
}
}
Loading

0 comments on commit 4304801

Please sign in to comment.