Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add upper and lower max chunk size limits to ChunkingSettings #115130

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
public class SentenceBoundaryChunkingSettings implements ChunkingSettings {
public static final String NAME = "SentenceBoundaryChunkingSettings";
private static final ChunkingStrategy STRATEGY = ChunkingStrategy.SENTENCE;
private static final int MAX_CHUNK_SIZE_LOWER_LIMIT = 20;
private static final int MAX_CHUNK_SIZE_UPPER_LIMIT = 300;
private static final Set<String> VALID_KEYS = Set.of(
ChunkingSettingsOptions.STRATEGY.toString(),
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
Expand Down Expand Up @@ -62,9 +64,11 @@ public static SentenceBoundaryChunkingSettings fromMap(Map<String, Object> map)
);
}

Integer maxChunkSize = ServiceUtils.extractRequiredPositiveInteger(
Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerBetween(
map,
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
MAX_CHUNK_SIZE_LOWER_LIMIT,
MAX_CHUNK_SIZE_UPPER_LIMIT,
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
public class WordBoundaryChunkingSettings implements ChunkingSettings {
public static final String NAME = "WordBoundaryChunkingSettings";
private static final ChunkingStrategy STRATEGY = ChunkingStrategy.WORD;
private static final int MAX_CHUNK_SIZE_LOWER_LIMIT = 10;
private static final int MAX_CHUNK_SIZE_UPPER_LIMIT = 300;
private static final Set<String> VALID_KEYS = Set.of(
ChunkingSettingsOptions.STRATEGY.toString(),
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
Expand Down Expand Up @@ -56,9 +58,11 @@ public static WordBoundaryChunkingSettings fromMap(Map<String, Object> map) {
);
}

Integer maxChunkSize = ServiceUtils.extractRequiredPositiveInteger(
Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerBetween(
map,
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
MAX_CHUNK_SIZE_LOWER_LIMIT,
MAX_CHUNK_SIZE_UPPER_LIMIT,
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,30 @@ public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax(
return field;
}

public static Integer extractRequiredPositiveIntegerBetween(
Map<String, Object> map,
String settingName,
int minValue,
int maxValue,
String scope,
ValidationException validationException
) {
Integer field = extractRequiredPositiveInteger(map, settingName, scope, validationException);

if (field != null && field < minValue) {
validationException.addValidationError(
ServiceUtils.mustBeGreaterThanOrEqualNumberErrorMessage(settingName, scope, field, minValue)
);
}
if (field != null && field > maxValue) {
validationException.addValidationError(
ServiceUtils.mustBeLessThanOrEqualNumberErrorMessage(settingName, scope, field, maxValue)
);
}

return field;
}

public static Integer extractOptionalPositiveInteger(
Map<String, Object> map,
String settingName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,27 @@ public void testValidChunkingSettingsMap() {
}

private Map<Map<String, Object>, ChunkingSettings> chunkingSettingsMapToChunkingSettings() {
var maxChunkSize = randomNonNegativeInt();
var overlap = randomIntBetween(1, maxChunkSize / 2);
var maxChunkSizeWordBoundaryChunkingSettings = randomIntBetween(10, 300);
var overlap = randomIntBetween(1, maxChunkSizeWordBoundaryChunkingSettings / 2);
var maxChunkSizeSentenceBoundaryChunkingSettings = randomIntBetween(20, 300);

return Map.of(
Map.of(
ChunkingSettingsOptions.STRATEGY.toString(),
ChunkingStrategy.WORD.toString(),
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
maxChunkSize,
maxChunkSizeWordBoundaryChunkingSettings,
ChunkingSettingsOptions.OVERLAP.toString(),
overlap
),
new WordBoundaryChunkingSettings(maxChunkSize, overlap),
new WordBoundaryChunkingSettings(maxChunkSizeWordBoundaryChunkingSettings, overlap),
Map.of(
ChunkingSettingsOptions.STRATEGY.toString(),
ChunkingStrategy.SENTENCE.toString(),
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
maxChunkSize
maxChunkSizeSentenceBoundaryChunkingSettings
),
new SentenceBoundaryChunkingSettings(maxChunkSize, 1)
new SentenceBoundaryChunkingSettings(maxChunkSizeSentenceBoundaryChunkingSettings, 1)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ public static ChunkingSettings createRandomChunkingSettings() {

switch (randomStrategy) {
case WORD -> {
var maxChunkSize = randomNonNegativeInt();
var maxChunkSize = randomIntBetween(10, 300);
return new WordBoundaryChunkingSettings(maxChunkSize, randomIntBetween(1, maxChunkSize / 2));
}
case SENTENCE -> {
return new SentenceBoundaryChunkingSettings(randomNonNegativeInt(), randomBoolean() ? 0 : 1);
return new SentenceBoundaryChunkingSettings(randomIntBetween(20, 300), randomBoolean() ? 0 : 1);
}
default -> throw new IllegalArgumentException("Unsupported random strategy [" + randomStrategy + "]");
}
Expand All @@ -38,13 +38,13 @@ public static Map<String, Object> createRandomChunkingSettingsMap() {

switch (randomStrategy) {
case WORD -> {
var maxChunkSize = randomNonNegativeInt();
var maxChunkSize = randomIntBetween(10, 300);
chunkingSettingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), maxChunkSize);
chunkingSettingsMap.put(ChunkingSettingsOptions.OVERLAP.toString(), randomIntBetween(1, maxChunkSize / 2));

}
case SENTENCE -> {
chunkingSettingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), randomNonNegativeInt());
chunkingSettingsMap.put(ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), randomIntBetween(20, 300));
}
default -> {
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ public void testChunkSplitLargeChunkSizesWithChunkingSettings() {
}

public void testInvalidChunkingSettingsProvided() {
ChunkingSettings chunkingSettings = new WordBoundaryChunkingSettings(randomNonNegativeInt(), randomNonNegativeInt());
var maxChunkSize = randomIntBetween(10, 300);
ChunkingSettings chunkingSettings = new WordBoundaryChunkingSettings(maxChunkSize, randomIntBetween(1, maxChunkSize / 2));
assertThrows(IllegalArgumentException.class, () -> { new SentenceBoundaryChunker().chunk(TEST_TEXT, chunkingSettings); });
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.ChunkingStrategy;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.test.ESTestCase;

import java.io.IOException;
import java.util.HashMap;
Expand All @@ -28,14 +27,14 @@ public void testMaxChunkSizeNotProvided() {
}

public void testInvalidInputsProvided() {
var chunkingSettingsMap = buildChunkingSettingsMap(Optional.of(randomNonNegativeInt()));
var chunkingSettingsMap = buildChunkingSettingsMap(Optional.of(randomIntBetween(20, 300)));
chunkingSettingsMap.put(randomAlphaOfLength(10), randomNonNegativeInt());

assertThrows(ValidationException.class, () -> { SentenceBoundaryChunkingSettings.fromMap(chunkingSettingsMap); });
}

public void testValidInputsProvided() {
int maxChunkSize = randomNonNegativeInt();
int maxChunkSize = randomIntBetween(20, 300);
SentenceBoundaryChunkingSettings settings = SentenceBoundaryChunkingSettings.fromMap(
buildChunkingSettingsMap(Optional.of(maxChunkSize))
);
Expand All @@ -59,12 +58,12 @@ protected Writeable.Reader<SentenceBoundaryChunkingSettings> instanceReader() {

@Override
protected SentenceBoundaryChunkingSettings createTestInstance() {
return new SentenceBoundaryChunkingSettings(randomNonNegativeInt(), randomBoolean() ? 0 : 1);
return new SentenceBoundaryChunkingSettings(randomIntBetween(20, 300), randomBoolean() ? 0 : 1);
}

@Override
protected SentenceBoundaryChunkingSettings mutateInstance(SentenceBoundaryChunkingSettings instance) throws IOException {
var chunkSize = randomValueOtherThan(instance.maxChunkSize, ESTestCase::randomNonNegativeInt);
var chunkSize = randomValueOtherThan(instance.maxChunkSize, () -> randomIntBetween(20, 300));
return new SentenceBoundaryChunkingSettings(chunkSize, instance.sentenceOverlap);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ public void testNumberOfChunksWithWordBoundaryChunkingSettings() {
}

public void testInvalidChunkingSettingsProvided() {
ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(randomNonNegativeInt(), 0);
ChunkingSettings chunkingSettings = new SentenceBoundaryChunkingSettings(randomIntBetween(20, 300), 0);
assertThrows(IllegalArgumentException.class, () -> { new WordBoundaryChunker().chunk(TEST_TEXT, chunkingSettings); });
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

Expand All @@ -28,27 +27,28 @@ public void testMaxChunkSizeNotProvided() {

public void testOverlapNotProvided() {
assertThrows(ValidationException.class, () -> {
WordBoundaryChunkingSettings.fromMap(buildChunkingSettingsMap(Optional.of(randomNonNegativeInt()), Optional.empty()));
WordBoundaryChunkingSettings.fromMap(buildChunkingSettingsMap(Optional.of(randomIntBetween(10, 300)), Optional.empty()));
});
}

public void testInvalidInputsProvided() {
var chunkingSettingsMap = buildChunkingSettingsMap(Optional.of(randomNonNegativeInt()), Optional.of(randomNonNegativeInt()));
var maxChunkSize = randomIntBetween(10, 300);
var chunkingSettingsMap = buildChunkingSettingsMap(Optional.of(maxChunkSize), Optional.of(randomIntBetween(1, maxChunkSize / 2)));
chunkingSettingsMap.put(randomAlphaOfLength(10), randomNonNegativeInt());

assertThrows(ValidationException.class, () -> { WordBoundaryChunkingSettings.fromMap(chunkingSettingsMap); });
}

public void testOverlapGreaterThanHalfMaxChunkSize() {
var maxChunkSize = randomNonNegativeInt();
var maxChunkSize = randomIntBetween(10, 300);
var overlap = randomIntBetween((maxChunkSize / 2) + 1, maxChunkSize);
assertThrows(ValidationException.class, () -> {
WordBoundaryChunkingSettings.fromMap(buildChunkingSettingsMap(Optional.of(maxChunkSize), Optional.of(overlap)));
});
}

public void testValidInputsProvided() {
int maxChunkSize = randomNonNegativeInt();
int maxChunkSize = randomIntBetween(10, 300);
int overlap = randomIntBetween(1, maxChunkSize / 2);
WordBoundaryChunkingSettings settings = WordBoundaryChunkingSettings.fromMap(
buildChunkingSettingsMap(Optional.of(maxChunkSize), Optional.of(overlap))
Expand All @@ -75,29 +75,14 @@ protected Writeable.Reader<WordBoundaryChunkingSettings> instanceReader() {

@Override
protected WordBoundaryChunkingSettings createTestInstance() {
var maxChunkSize = randomNonNegativeInt();
var maxChunkSize = randomIntBetween(10, 300);
return new WordBoundaryChunkingSettings(maxChunkSize, randomIntBetween(1, maxChunkSize / 2));
}

@Override
protected WordBoundaryChunkingSettings mutateInstance(WordBoundaryChunkingSettings instance) throws IOException {
var valueToMutate = randomFrom(List.of(ChunkingSettingsOptions.MAX_CHUNK_SIZE, ChunkingSettingsOptions.OVERLAP));
var maxChunkSize = instance.maxChunkSize;
var overlap = instance.overlap;

if (valueToMutate.equals(ChunkingSettingsOptions.MAX_CHUNK_SIZE)) {
while (maxChunkSize == instance.maxChunkSize) {
maxChunkSize = randomNonNegativeInt();
}

if (overlap > maxChunkSize / 2) {
overlap = randomIntBetween(1, maxChunkSize / 2);
}
} else if (valueToMutate.equals(ChunkingSettingsOptions.OVERLAP)) {
while (overlap == instance.overlap) {
overlap = randomIntBetween(1, maxChunkSize / 2);
}
}
var maxChunkSize = randomValueOtherThan(instance.maxChunkSize, () -> randomIntBetween(10, 300));
var overlap = randomValueOtherThan(instance.overlap, () -> randomIntBetween(1, maxChunkSize / 2));

return new WordBoundaryChunkingSettings(maxChunkSize, overlap);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,60 @@ public void testExtractRequiredPositiveIntegerLessThanOrEqualToMax_AddsErrorWhen
assertThat(validation.validationErrors().get(1), is("[scope] does not contain the required setting [not_key]"));
}

public void testExtractRequiredPositiveIntegerBetween_ReturnsValueWhenValueIsBetweenMinAndMax() {
var minValue = randomNonNegativeInt();
var maxValue = randomIntBetween(minValue + 2, minValue + 10);
testExtractRequiredPositiveIntegerBetween_Successful(minValue, maxValue, randomIntBetween(minValue + 1, maxValue - 1));
}

public void testExtractRequiredPositiveIntegerBetween_ReturnsValueWhenValueIsEqualToMin() {
var minValue = randomNonNegativeInt();
var maxValue = randomIntBetween(minValue + 1, minValue + 10);
testExtractRequiredPositiveIntegerBetween_Successful(minValue, maxValue, minValue);
}

public void testExtractRequiredPositiveIntegerBetween_ReturnsValueWhenValueIsEqualToMax() {
var minValue = randomNonNegativeInt();
var maxValue = randomIntBetween(minValue + 1, minValue + 10);
testExtractRequiredPositiveIntegerBetween_Successful(minValue, maxValue, maxValue);
}

private void testExtractRequiredPositiveIntegerBetween_Successful(int minValue, int maxValue, int actualValue) {
var validation = new ValidationException();
validation.addValidationError("previous error");
Map<String, Object> map = modifiableMap(Map.of("key", actualValue));
var parsedInt = ServiceUtils.extractRequiredPositiveIntegerBetween(map, "key", minValue, maxValue, "scope", validation);

assertThat(validation.validationErrors(), hasSize(1));
assertNotNull(parsedInt);
assertThat(parsedInt, is(actualValue));
assertTrue(map.isEmpty());
}

public void testExtractRequiredIntBetween_AddsErrorForValueBelowMin() {
var minValue = randomInt();
var maxValue = randomIntBetween(minValue, minValue + 10);
testExtractRequiredIntBetween_Unsuccessful(minValue, maxValue, minValue - 1);
}

public void testExtractRequiredIntBetween_AddsErrorForValueAboveMax() {
var minValue = randomInt();
var maxValue = randomIntBetween(minValue, minValue + 10);
testExtractRequiredIntBetween_Unsuccessful(minValue, maxValue, maxValue + 1);
}

private void testExtractRequiredIntBetween_Unsuccessful(int minValue, int maxValue, int actualValue) {
var validation = new ValidationException();
validation.addValidationError("previous error");
Map<String, Object> map = modifiableMap(Map.of("key", 0));
var parsedInt = ServiceUtils.extractRequiredPositiveIntegerBetween(map, "key", 1, 5, "scope", validation);

assertThat(validation.validationErrors(), hasSize(2));
assertNull(parsedInt);
assertTrue(map.isEmpty());
assertThat(validation.validationErrors().get(1), containsString("Invalid value"));
}

public void testExtractOptionalEnum_ReturnsNull_WhenFieldDoesNotExist() {
var validation = new ValidationException();
Map<String, Object> map = modifiableMap(Map.of("key", "value"));
Expand Down