Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.transformer.splitter;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import com.knuddels.jtokkit.Encodings;
Expand Down Expand Up @@ -46,6 +47,8 @@ public class TokenTextSplitter extends TextSplitter {

private static final boolean KEEP_SEPARATOR = true;

private static final List<Character> DEFAULT_PUNCTUATIONS = List.of('.', '?', '!', '。', '?', '!', '\n');

private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry();

private final Encoding encoding = this.registry.getEncoding(EncodingType.CL100K_BASE);
Expand All @@ -64,21 +67,24 @@ public class TokenTextSplitter extends TextSplitter {

private final boolean keepSeparator;

private final List<Character> punctuations;

public TokenTextSplitter() {
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR);
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR, DEFAULT_PUNCTUATIONS);
}

public TokenTextSplitter(boolean keepSeparator) {
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator);
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator, DEFAULT_PUNCTUATIONS);
}

public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
boolean keepSeparator) {
boolean keepSeparator, List<Character> punctuations) {
this.chunkSize = chunkSize;
this.minChunkSizeChars = minChunkSizeChars;
this.minChunkLengthToEmbed = minChunkLengthToEmbed;
this.maxNumChunks = maxNumChunks;
this.keepSeparator = keepSeparator;
this.punctuations = punctuations;
}

public static Builder builder() {
Expand Down Expand Up @@ -124,8 +130,10 @@ protected List<String> doSplit(String text, int chunkSize) {
// This prevents unnecessary splitting of small texts
if (tokens.size() > chunkSize) {
// Find the last period or punctuation mark in the chunk
int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'),
Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n'))));
int lastPunctuation = -1;
for (char punctuation : punctuations) {
lastPunctuation = Math.max(lastPunctuation, chunkText.lastIndexOf(punctuation));
}

if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) {
// Truncate the chunk text at the punctuation mark
Expand Down Expand Up @@ -180,6 +188,8 @@ public static final class Builder {

private boolean keepSeparator = KEEP_SEPARATOR;

private List<Character> punctuations = DEFAULT_PUNCTUATIONS;

private Builder() {
}

Expand Down Expand Up @@ -208,9 +218,18 @@ public Builder withKeepSeparator(boolean keepSeparator) {
return this;
}

public Builder withPunctuations(char... punctuations) {
List<Character> list = new ArrayList<>();
for (char punctuation : punctuations) {
list.add(punctuation);
}
this.punctuations = Collections.unmodifiableList(list);
return this;
}

public TokenTextSplitter build() {
return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed,
this.maxNumChunks, this.keepSeparator);
this.maxNumChunks, this.keepSeparator, this.punctuations);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,26 @@ public void testLargeTextStillSplitsAtPunctuation() {
assertThat(splitted.get(0).getText()).endsWith(".");
}

@Test
public void testLargeTextStillSplitsAtChinesePunctuation() {
// Verify that punctuation-based splitting still works when text exceeds chunk
// size
TokenTextSplitter splitter = TokenTextSplitter.builder()
.withKeepSeparator(true)
.withChunkSize(15)
.withMinChunkSizeChars(10)
.build();

// This text has multiple sentences and will exceed 15 tokens
Document testDoc = new Document(
"This is the first sentence with enough words? This is the second sentence! And this is the third sentence。");
List<Document> splitted = splitter.split(testDoc);

// Should split into multiple chunks at punctuation marks
assertThat(splitted.size()).isGreaterThan(1);

// Verify first chunk ends with punctuation
assertThat(splitted.get(0).getText()).endsWith("!");
}

}