Skip to content

Commit

Permalink
Default E5 endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Oct 7, 2024
1 parent d1644b3 commit e6931db
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.oneOf;

public class DefaultElserIT extends InferenceBaseRestTest {
public class DefaultEndPointsIT extends InferenceBaseRestTest {

private TestThreadPool threadPool;

@Before
public void createThreadPool() {
threadPool = new TestThreadPool(DefaultElserIT.class.getSimpleName());
threadPool = new TestThreadPool(DefaultEndPointsIT.class.getSimpleName());
}

@After
Expand All @@ -38,7 +38,7 @@ public void tearDown() throws Exception {
}

@SuppressWarnings("unchecked")
public void testInferCreatesDefaultElser() throws IOException {
public void testInferDeploysDefaultElser() throws IOException {
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);
assertDefaultElserConfig(model);
Expand Down Expand Up @@ -67,4 +67,39 @@ private static void assertDefaultElserConfig(Map<String, Object> modelConfig) {
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
);
}

@SuppressWarnings("unchecked")
public void testInferDeploysDefaultE5() throws IOException {
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
var model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
assertDefaultE5Config(model);

var inputs = List.of("Hello World", "Goodnight moon");
var queryParams = Map.of("timeout", "120s");
var results = infer(ElasticsearchInternalService.DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, inputs, queryParams);
var embeddings = (List<Map<String, Object>>) results.get("text_embedding");
assertThat(results.toString(), embeddings, hasSize(2));
}

@SuppressWarnings("unchecked")
private static void assertDefaultE5Config(Map<String, Object> modelConfig) {
assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_E5_ID, modelConfig.get("inference_id"));
assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service"));
assertEquals(modelConfig.toString(), TaskType.TEXT_EMBEDDING.toString(), modelConfig.get("task_type"));

var serviceSettings = (Map<String, Object>) modelConfig.get("service_settings");
assertThat(
modelConfig.toString(),
serviceSettings.get("model_id"),
is(oneOf(".multilingual-e5-small", ".multilingual-e5-small_linux-x86_64"))
);
assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads"));

var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
assertThat(
modelConfig.toString(),
adaptiveAllocations,
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ public static InferModelAction.Request buildInferenceRequest(
return request;
}

protected abstract boolean isDefaultId(String inferenceId);
abstract boolean isDefaultId(String inferenceId);

protected void maybeStartDeployment(
ElasticsearchInternalModel model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
);

public static final String DEFAULT_ELSER_ID = ".elser-2";
public static final String DEFAULT_E5_ID = ".default-multilingual-e5-small"; // TODO what to name this

private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);
Expand Down Expand Up @@ -723,19 +724,46 @@ public List<UnparsedModel> defaultConfigs() {
)
);

// TODO Chunking settings
Map<String, Object> e5Settings = Map.of(
ModelConfigurations.SERVICE_SETTINGS,
Map.of(
ElasticsearchInternalServiceSettings.MODEL_ID,
MULTILINGUAL_E5_SMALL_MODEL_ID, // TODO pick model depending on platform
ElasticsearchInternalServiceSettings.NUM_THREADS,
1,
ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS,
Map.of(
"enabled",
Boolean.TRUE,
"min_number_of_allocations",
1,
"max_number_of_allocations",
8 // no max?
)
)
);

return List.of(
new UnparsedModel(
DEFAULT_ELSER_ID,
TaskType.SPARSE_EMBEDDING,
NAME,
elserSettings,
Map.of() // no secrets
),
new UnparsedModel(
DEFAULT_E5_ID,
TaskType.TEXT_EMBEDDING,
NAME,
e5Settings,
Map.of() // no secrets
)
);
}

@Override
protected boolean isDefaultId(String inferenceId) {
return DEFAULT_ELSER_ID.equals(inferenceId);
boolean isDefaultId(String inferenceId) {
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,13 @@ public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic()
}
}

public void testIsDefaultId() {
var service = createService(mock(Client.class));
assertTrue(service.isDefaultId(".elser-2"));
assertTrue(service.isDefaultId(".default-multilingual-e5-small")); // TODO name?
assertFalse(service.isDefaultId("foo"));
}

private ElasticsearchInternalService createService(Client client) {
var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool);
return new ElasticsearchInternalService(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ protected void masterOperation(
if (getModelResponse.getResources().results().size() > 1) {
listener.onFailure(
ExceptionsHelper.badRequestException(
"cannot deploy more than one models at the same time; [{}] matches [{}] models]",
"cannot deploy more than one model at the same time; [{}] matches models [{}]",
request.getModelId(),
getModelResponse.getResources().results().size()
getModelResponse.getResources().results().stream().map(TrainedModelConfig::getModelId).toList()
)
);
return;
Expand Down

0 comments on commit e6931db

Please sign in to comment.