diff --git a/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java b/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java index 9ff4f5ccffa..73cdeae648f 100644 --- a/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java +++ b/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java @@ -17,6 +17,7 @@ package org.springframework.ai.transformer.splitter; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import com.knuddels.jtokkit.Encodings; @@ -46,6 +47,8 @@ public class TokenTextSplitter extends TextSplitter { private static final boolean KEEP_SEPARATOR = true; + private static final List DEFAULT_PUNCTUATIONS = List.of('.', '?', '!', '。', '?', '!', '\n'); + private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry(); private final Encoding encoding = this.registry.getEncoding(EncodingType.CL100K_BASE); @@ -64,21 +67,24 @@ public class TokenTextSplitter extends TextSplitter { private final boolean keepSeparator; + private final List punctuations; + public TokenTextSplitter() { - this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR); + this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR, DEFAULT_PUNCTUATIONS); } public TokenTextSplitter(boolean keepSeparator) { - this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator); + this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator, DEFAULT_PUNCTUATIONS); } public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks, - boolean keepSeparator) { + boolean keepSeparator, List punctuations) { this.chunkSize = chunkSize; this.minChunkSizeChars = minChunkSizeChars; this.minChunkLengthToEmbed = minChunkLengthToEmbed; this.maxNumChunks = maxNumChunks; this.keepSeparator = keepSeparator; + this.punctuations = punctuations; } public static Builder builder() { @@ -124,8 +130,10 @@ protected List doSplit(String text, int chunkSize) { // This prevents unnecessary splitting of small texts if (tokens.size() > chunkSize) { // Find the last period or punctuation mark in the chunk - int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'), - Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n')))); + int lastPunctuation = -1; + for (char punctuation : punctuations) { + lastPunctuation = Math.max(lastPunctuation, chunkText.lastIndexOf(punctuation)); + } if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) { // Truncate the chunk text at the punctuation mark @@ -180,6 +188,8 @@ public static final class Builder { private boolean keepSeparator = KEEP_SEPARATOR; + private List punctuations = DEFAULT_PUNCTUATIONS; + private Builder() { } @@ -208,9 +218,18 @@ public Builder withKeepSeparator(boolean keepSeparator) { return this; } + public Builder withPunctuations(char... punctuations) { + List list = new ArrayList<>(); + for (char punctuation : punctuations) { + list.add(punctuation); + } + this.punctuations = Collections.unmodifiableList(list); + return this; + } + public TokenTextSplitter build() { return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed, - this.maxNumChunks, this.keepSeparator); + this.maxNumChunks, this.keepSeparator, this.punctuations); } } diff --git a/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java b/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java index 8076135cfb2..b7711155dda 100644 --- a/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java +++ b/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java @@ -165,4 +165,26 @@ public void testLargeTextStillSplitsAtPunctuation() { assertThat(splitted.get(0).getText()).endsWith("."); } + @Test + public void testLargeTextStillSplitsAtChinesePunctuation() { + // Verify that punctuation-based splitting still works when text exceeds chunk + // size + TokenTextSplitter splitter = TokenTextSplitter.builder() + .withKeepSeparator(true) + .withChunkSize(15) + .withMinChunkSizeChars(10) + .build(); + + // This text has multiple sentences and will exceed 15 tokens + Document testDoc = new Document( + "This is the first sentence with enough words? This is the second sentence! And this is the third sentence。"); + List splitted = splitter.split(testDoc); + + // Should split into multiple chunks at punctuation marks + assertThat(splitted.size()).isGreaterThan(1); + + // Verify first chunk ends with punctuation + assertThat(splitted.get(0).getText()).endsWith("!"); + } + }