Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Fix issues in dynamically reading the number of allocations #115095

Open
wants to merge 2 commits into
base: 8.16
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,55 @@ public void testModelIdDoesNotMatch() throws IOException {
);
}

public void testNumAllocationsIsUpdated() throws IOException {
var modelId = "update_num_allocations";
var deploymentId = modelId;

CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
var response = startMlNodeDeploymemnt(modelId, deploymentId);
assertOkOrCreated(response);

var inferenceId = "test_num_allocations_updated";
var putModel = putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
var serviceSettings = putModel.get("service_settings");
assertThat(
putModel.toString(),
serviceSettings,
is(
Map.of(
"num_allocations",
1,
"num_threads",
1,
"model_id",
"update_num_allocations",
"deployment_id",
"update_num_allocations"
)
)
);

assertOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2));

var updatedServiceSettings = getModel(inferenceId).get("service_settings");
assertThat(
updatedServiceSettings.toString(),
updatedServiceSettings,
is(
Map.of(
"num_allocations",
2,
"num_threads",
1,
"model_id",
"update_num_allocations",
"deployment_id",
"update_num_allocations"
)
)
);
}

private String endpointConfig(String deploymentId) {
return Strings.format("""
{
Expand Down Expand Up @@ -147,6 +196,20 @@ private Response startMlNodeDeploymemnt(String modelId, String deploymentId) thr
return client().performRequest(request);
}

private Response updateMlNodeDeploymemnt(String deploymentId, int numAllocations) throws IOException {
String endPoint = "/_ml/trained_models/" + deploymentId + "/deployment/_update";

var body = Strings.format("""
{
"number_of_allocations": %d
}
""", numAllocations);

Request request = new Request("POST", endPoint);
request.setJsonEntity(body);
return client().performRequest(request);
}

protected void stopMlNodeDeployment(String deploymentId) throws IOException {
String endpoint = "/_ml/trained_models/" + deploymentId + "/deployment/_stop";
Request request = new Request("POST", endpoint);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.stream.Stream;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.equalToIgnoringCase;
import static org.hamcrest.Matchers.hasSize;
Expand Down Expand Up @@ -326,4 +327,9 @@ public void testSupportedStream() throws Exception {
deleteModel(modelId);
}
}

public void testGetZeroModels() throws IOException {
var models = getModels("_all", TaskType.RERANK);
assertThat(models, empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ private void getModelsByTaskType(TaskType taskType, ActionListener<GetInferenceM
}

private void parseModels(List<UnparsedModel> unparsedModels, ActionListener<GetInferenceModelAction.Response> listener) {
if (unparsedModels.isEmpty()) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this check the GroupedActionListner was called with 0 requests which throws an exception

listener.onResponse(new GetInferenceModelAction.Response(List.of()));
return;
}

var parsedModelsByService = new HashMap<String, List<Model>>();
try {
for (var unparsedModel : unparsedModels) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,7 @@ public ElasticsearchInternalServiceSettings getServiceSettings() {
}

public void updateNumAllocation(Integer numAllocations) {
this.internalServiceSettings = new ElasticsearchInternalServiceSettings(
numAllocations,
this.internalServiceSettings.getNumThreads(),
this.internalServiceSettings.modelId(),
this.internalServiceSettings.getAdaptiveAllocationsSettings()
);
this.internalServiceSettings.setNumAllocations(numAllocations);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ public void updateModelsWithDynamicFields(List<Model> models, ActionListener<Lis
var modelsByDeploymentIds = new HashMap<String, ElasticsearchInternalModel>();
for (var model : models) {
if (model instanceof ElasticsearchInternalModel esModel) {
modelsByDeploymentIds.put(esModel.internalServiceSettings.deloymentId(), esModel);
modelsByDeploymentIds.put(esModel.mlNodeDeploymentId(), esModel);
} else {
listener.onFailure(
new ElasticsearchStatusException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class ElasticsearchInternalServiceSettings implements ServiceSettings {
public static final String DEPLOYMENT_ID = "deployment_id";
public static final String ADAPTIVE_ALLOCATIONS = "adaptive_allocations";

private final Integer numAllocations;
private Integer numAllocations;
private final int numThreads;
private final String modelId;
private final AdaptiveAllocationsSettings adaptiveAllocationsSettings;
Expand Down Expand Up @@ -172,6 +172,10 @@ public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException {
: null;
}

public void setNumAllocations(Integer numAllocations) {
this.numAllocations = numAllocations;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ public void testUpdateNumAllocation() {
);

model.updateNumAllocation(1);
assertEquals(1, model.internalServiceSettings.getNumAllocations().intValue());
assertEquals(1, model.getServiceSettings().getNumAllocations().intValue());

model.updateNumAllocation(null);
assertNull(model.internalServiceSettings.getNumAllocations());
assertNull(model.getServiceSettings().getNumAllocations());
}
}