1717package org .springframework .ai .transformer .splitter ;
1818
1919import java .util .ArrayList ;
20+ import java .util .Collections ;
2021import java .util .List ;
2122
2223import com .knuddels .jtokkit .Encodings ;
@@ -46,6 +47,8 @@ public class TokenTextSplitter extends TextSplitter {
4647
4748 private static final boolean KEEP_SEPARATOR = true ;
4849
50+ private static final List <Character > DEFAULT_PUNCTUATIONS = List .of ('.' , '?' , '!' , '。' , '?' , '!' , '\n' );
51+
4952 private final EncodingRegistry registry = Encodings .newLazyEncodingRegistry ();
5053
5154 private final Encoding encoding = this .registry .getEncoding (EncodingType .CL100K_BASE );
@@ -64,21 +67,24 @@ public class TokenTextSplitter extends TextSplitter {
6467
6568 private final boolean keepSeparator ;
6669
70+ private final List <Character > punctuations ;
71+
6772 public TokenTextSplitter () {
68- this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , KEEP_SEPARATOR );
73+ this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , KEEP_SEPARATOR , DEFAULT_PUNCTUATIONS );
6974 }
7075
7176 public TokenTextSplitter (boolean keepSeparator ) {
72- this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , keepSeparator );
77+ this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , keepSeparator , DEFAULT_PUNCTUATIONS );
7378 }
7479
7580 public TokenTextSplitter (int chunkSize , int minChunkSizeChars , int minChunkLengthToEmbed , int maxNumChunks ,
76- boolean keepSeparator ) {
81+ boolean keepSeparator , List < Character > punctuations ) {
7782 this .chunkSize = chunkSize ;
7883 this .minChunkSizeChars = minChunkSizeChars ;
7984 this .minChunkLengthToEmbed = minChunkLengthToEmbed ;
8085 this .maxNumChunks = maxNumChunks ;
8186 this .keepSeparator = keepSeparator ;
87+ this .punctuations = punctuations ;
8288 }
8389
8490 public static Builder builder () {
@@ -124,8 +130,10 @@ protected List<String> doSplit(String text, int chunkSize) {
124130 // This prevents unnecessary splitting of small texts
125131 if (tokens .size () > chunkSize ) {
126132 // Find the last period or punctuation mark in the chunk
127- int lastPunctuation = Math .max (chunkText .lastIndexOf ('.' ), Math .max (chunkText .lastIndexOf ('?' ),
128- Math .max (chunkText .lastIndexOf ('!' ), chunkText .lastIndexOf ('\n' ))));
133+ int lastPunctuation = -1 ;
134+ for (char punctuation : punctuations ) {
135+ lastPunctuation = Math .max (lastPunctuation , chunkText .lastIndexOf (punctuation ));
136+ }
129137
130138 if (lastPunctuation != -1 && lastPunctuation > this .minChunkSizeChars ) {
131139 // Truncate the chunk text at the punctuation mark
@@ -180,6 +188,8 @@ public static final class Builder {
180188
181189 private boolean keepSeparator = KEEP_SEPARATOR ;
182190
191+ private List <Character > punctuations = DEFAULT_PUNCTUATIONS ;
192+
183193 private Builder () {
184194 }
185195
@@ -208,9 +218,18 @@ public Builder withKeepSeparator(boolean keepSeparator) {
208218 return this ;
209219 }
210220
221+ public Builder withPunctuations (char ... punctuations ) {
222+ List <Character > list = new ArrayList <>();
223+ for (char punctuation : punctuations ) {
224+ list .add (punctuation );
225+ }
226+ this .punctuations = Collections .unmodifiableList (list );
227+ return this ;
228+ }
229+
211230 public TokenTextSplitter build () {
212231 return new TokenTextSplitter (this .chunkSize , this .minChunkSizeChars , this .minChunkLengthToEmbed ,
213- this .maxNumChunks , this .keepSeparator );
232+ this .maxNumChunks , this .keepSeparator , this . punctuations );
214233 }
215234
216235 }
0 commit comments