diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java index f81ebc25dc860..0bfb6e9e43b03 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java @@ -109,6 +109,55 @@ public void testModelIdDoesNotMatch() throws IOException { ); } + public void testNumAllocationsIsUpdated() throws IOException { + var modelId = "update_num_allocations"; + var deploymentId = modelId; + + CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client()); + var response = startMlNodeDeploymemnt(modelId, deploymentId); + assertOkOrCreated(response); + + var inferenceId = "test_num_allocations_updated"; + var putModel = putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING); + var serviceSettings = putModel.get("service_settings"); + assertThat( + putModel.toString(), + serviceSettings, + is( + Map.of( + "num_allocations", + 1, + "num_threads", + 1, + "model_id", + "update_num_allocations", + "deployment_id", + "update_num_allocations" + ) + ) + ); + + assertOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2)); + + var updatedServiceSettings = getModel(inferenceId).get("service_settings"); + assertThat( + updatedServiceSettings.toString(), + updatedServiceSettings, + is( + Map.of( + "num_allocations", + 2, + "num_threads", + 1, + "model_id", + "update_num_allocations", + "deployment_id", + "update_num_allocations" + ) + ) + ); + } + private String endpointConfig(String deploymentId) { return Strings.format(""" { @@ -147,6 +196,20 @@ private Response startMlNodeDeploymemnt(String modelId, String deploymentId) thr return client().performRequest(request); } + private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations) throws IOException { + String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update"; + + var body = Strings.format(""" + { + "number_of_allocations": %d + } + """, numAllocations); + + Request request = new Request("POST", endPoint); + request.setJsonEntity(body); + return client().performRequest(request); + } + protected void stopMlNodeDeployment(String deploymentId) throws IOException { String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop"; Request request = new Request("POST", endpoint); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index cbc50c361e3b5..37de2caadb475 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -24,6 +24,7 @@ import java.util.stream.Stream; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalToIgnoringCase; import static org.hamcrest.Matchers.hasSize; @@ -326,4 +327,9 @@ public void testSupportedStream() throws Exception { deleteModel(modelId); } } + + public void testGetZeroModels() throws IOException { + var models = getModels("_all", TaskType.RERANK); + assertThat(models, empty()); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java index 55aad5c55a2ac..bfa0d99462960 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java @@ -126,6 +126,11 @@ private void getModelsByTaskType(TaskType taskType, ActionListener unparsedModels, ActionListener listener) { + if (unparsedModels.isEmpty()) { + listener.onResponse(new GetInferenceModelAction.Response(List.of())); + return; + } + var parsedModelsByService = new HashMap>(); try { for (var unparsedModel : unparsedModels) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java index 642f6f144abc0..c796afe742a34 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java @@ -92,12 +92,7 @@ public ElasticsearchInternalServiceSettings getServiceSettings() { } public void updateNumAllocation(Integer numAllocations) { - this.internalServiceSettings = new ElasticsearchInternalServiceSettings( - numAllocations, - this.internalServiceSettings.getNumThreads(), - this.internalServiceSettings.modelId(), - this.internalServiceSettings.getAdaptiveAllocationsSettings() - ); + this.internalServiceSettings.setNumAllocations(numAllocations); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 3254b2e060847..afcb63131a571 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -781,7 +781,7 @@ public void updateModelsWithDynamicFields(List models, ActionListener(); for (var model : models) { if (model instanceof ElasticsearchInternalModel esModel) { - modelsByDeploymentIds.put(esModel.internalServiceSettings.deloymentId(), esModel); + modelsByDeploymentIds.put(esModel.mlNodeDeploymentId(), esModel); } else { listener.onFailure( new ElasticsearchStatusException( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index 5bd8d8cfc5c13..962c939146ef2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -39,7 +39,7 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings { public static final String DEPLOYMENT_ID = "deployment_id"; public static final String ADAPTIVE_ALLOCATIONS = "adaptive_allocations"; - private final Integer numAllocations; + private Integer numAllocations; private final int numThreads; private final String modelId; private final AdaptiveAllocationsSettings adaptiveAllocationsSettings; @@ -172,6 +172,10 @@ public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException { : null; } + public void setNumAllocations(Integer numAllocations) { + this.numAllocations = numAllocations; + } + @Override public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java index 74cdab79fe79b..50746b99a788e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalModelTests.java @@ -22,9 +22,9 @@ public void testUpdateNumAllocation() { ); model.updateNumAllocation(1); - assertEquals(1, model.internalServiceSettings.getNumAllocations().intValue()); + assertEquals(1, model.getServiceSettings().getNumAllocations().intValue()); model.updateNumAllocation(null); - assertNull(model.internalServiceSettings.getNumAllocations()); + assertNull(model.getServiceSettings().getNumAllocations()); } }