Skip to content

Commit 06ea26f

Browse files
committed
feat(summarization): add extractive summarization for SubEM evaluation
Implement BERT-based extractive summarization that preserves exact entity names, addressing the Long-Range Understanding (LRU) benchmark failure caused by abstractive summarization paraphrasing entity names. New classes in com.redis.vl.extensions.summarization: - ExtractiveSelector: K-means sentence selection using BERT embeddings - SentenceSplitter: OpenNLP-based sentence detection - EmbeddedSentence: Clusterable wrapper for k-means algorithm Algorithm: 1. Split document into sentences using OpenNLP 2. Embed sentences with SentenceTransformers/BERT 3. Cluster using k-means++ (Apache Commons Math3) 4. Select sentence closest to each cluster centroid 5. Return sentences in original order (preserves exact text) Key benefit: Unlike abstractive summarization which paraphrases ("Jennifer" -> "the protagonist"), extractive summarization preserves verbatim text, enabling SubEM matching to succeed. Dependencies added: - org.apache.opennlp:opennlp-tools:2.3.0 - org.apache.commons:commons-math3:3.6.1
1 parent b3ab512 commit 06ea26f

File tree

7 files changed

+490
-0
lines changed

7 files changed

+490
-0
lines changed

core/build.gradle.kts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ dependencies {
2525
implementation("com.google.guava:guava:33.4.0-jre")
2626
implementation("com.github.f4b6a3:ulid-creator:5.2.3")
2727

28+
// For extractive summarization - sentence splitting
29+
implementation("org.apache.opennlp:opennlp-tools:2.3.0")
30+
31+
// For k-means clustering in extractive summarization
32+
implementation("org.apache.commons:commons-math3:3.6.1")
33+
2834
// Lombok for reducing boilerplate
2935
compileOnly("org.projectlombok:lombok:1.18.36")
3036
annotationProcessor("org.projectlombok:lombok:1.18.36")
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package com.redis.vl.extensions.summarization;
2+
3+
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
4+
import org.apache.commons.math3.ml.clustering.Clusterable;
5+
6+
/** A sentence with its embedding, implementing Clusterable for k-means. */
7+
public class EmbeddedSentence implements Clusterable {
8+
9+
private final int index;
10+
private final double[] embedding;
11+
12+
/**
13+
* Create an embedded sentence.
14+
*
15+
* @param index Original index in the sentence list (for preserving order)
16+
* @param embedding The BERT embedding as float array
17+
*/
18+
public EmbeddedSentence(int index, float[] embedding) {
19+
this.index = index;
20+
this.embedding = toDoubleArray(embedding);
21+
}
22+
23+
private static double[] toDoubleArray(float[] floats) {
24+
double[] doubles = new double[floats.length];
25+
for (int i = 0; i < floats.length; i++) {
26+
doubles[i] = floats[i];
27+
}
28+
return doubles;
29+
}
30+
31+
/** Get the original index of this sentence. */
32+
public int index() {
33+
return index;
34+
}
35+
36+
/** Get the embedding as double array (required by Clusterable). */
37+
@Override
38+
@SuppressFBWarnings(
39+
value = "EI_EXPOSE_REP",
40+
justification = "Clusterable interface requires direct array access for k-means performance")
41+
public double[] getPoint() {
42+
return embedding;
43+
}
44+
45+
/** Calculate cosine similarity with another embedded sentence. */
46+
public double cosineSimilarity(EmbeddedSentence other) {
47+
double dotProduct = 0.0;
48+
double normA = 0.0;
49+
double normB = 0.0;
50+
51+
for (int i = 0; i < embedding.length; i++) {
52+
dotProduct += embedding[i] * other.embedding[i];
53+
normA += embedding[i] * embedding[i];
54+
normB += other.embedding[i] * other.embedding[i];
55+
}
56+
57+
if (normA == 0 || normB == 0) return 0.0;
58+
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
59+
}
60+
}
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
package com.redis.vl.extensions.summarization;
2+
3+
import com.redis.vl.utils.vectorize.SentenceTransformersVectorizer;
4+
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
5+
import java.util.ArrayList;
6+
import java.util.Comparator;
7+
import java.util.List;
8+
import java.util.stream.IntStream;
9+
import org.apache.commons.math3.ml.clustering.CentroidCluster;
10+
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
11+
12+
/**
13+
* BERT-based extractive summarization using sentence clustering.
14+
*
15+
* <p>This class selects the most representative sentences from a document by embedding sentences
16+
* with BERT, clustering them with k-means, and selecting the sentence closest to each cluster
17+
* centroid.
18+
*
19+
* <p><b>Key Feature:</b> Preserves original text exactly, which is critical for SubEM (Substring
20+
* Exact Match) evaluation where paraphrasing fails.
21+
*
22+
* <h2>Example Usage:</h2>
23+
*
24+
* <pre>{@code
25+
* SentenceTransformersVectorizer vectorizer = SentenceTransformersVectorizer.builder()
26+
* .modelName("all-MiniLM-L6-v2")
27+
* .build();
28+
*
29+
* ExtractiveSelector selector = new ExtractiveSelector(vectorizer);
30+
* SentenceSplitter splitter = new SentenceSplitter();
31+
*
32+
* String document = "Long document text...";
33+
* List<String> sentences = splitter.split(document);
34+
* List<String> keySentences = selector.selectKeySentences(sentences, 10);
35+
*
36+
* // keySentences contains the 10 most representative sentences
37+
* // in their original order, with exact original text preserved
38+
* }</pre>
39+
*/
40+
public class ExtractiveSelector {
41+
42+
private final SentenceTransformersVectorizer embedder;
43+
private final int defaultNumSentences;
44+
private final int maxIterations;
45+
46+
/**
47+
* Create an extractive selector with default settings.
48+
*
49+
* @param embedder The sentence transformer vectorizer for embeddings
50+
*/
51+
public ExtractiveSelector(SentenceTransformersVectorizer embedder) {
52+
this(embedder, 10, 100);
53+
}
54+
55+
/**
56+
* Create an extractive selector with custom number of sentences.
57+
*
58+
* @param embedder The sentence transformer vectorizer for embeddings
59+
* @param defaultNumSentences Default number of sentences to select
60+
*/
61+
public ExtractiveSelector(SentenceTransformersVectorizer embedder, int defaultNumSentences) {
62+
this(embedder, defaultNumSentences, 100);
63+
}
64+
65+
/**
66+
* Create an extractive selector with full configuration.
67+
*
68+
* @param embedder The sentence transformer vectorizer for embeddings
69+
* @param defaultNumSentences Default number of sentences to select
70+
* @param maxIterations Maximum k-means iterations
71+
*/
72+
@SuppressFBWarnings(
73+
value = "EI_EXPOSE_REP2",
74+
justification = "Embedder is intentionally shared; it's a heavyweight resource")
75+
public ExtractiveSelector(
76+
SentenceTransformersVectorizer embedder, int defaultNumSentences, int maxIterations) {
77+
this.embedder = embedder;
78+
this.defaultNumSentences = defaultNumSentences;
79+
this.maxIterations = maxIterations;
80+
}
81+
82+
/**
83+
* Select the most representative sentences using the default count.
84+
*
85+
* @param sentences List of sentences to select from
86+
* @return Selected sentences in original order
87+
*/
88+
public List<String> selectKeySentences(List<String> sentences) {
89+
return selectKeySentences(sentences, defaultNumSentences);
90+
}
91+
92+
/**
93+
* Select the k most representative sentences from the input.
94+
*
95+
* <p>Algorithm:
96+
*
97+
* <ol>
98+
* <li>Embed all sentences using BERT
99+
* <li>Cluster embeddings using k-means++
100+
* <li>For each cluster, select the sentence closest to the centroid
101+
* <li>Return sentences in their original order
102+
* </ol>
103+
*
104+
* @param sentences List of sentences to select from
105+
* @param k Number of sentences to select
106+
* @return Selected sentences in original order (preserves exact text)
107+
*/
108+
public List<String> selectKeySentences(List<String> sentences, int k) {
109+
if (sentences == null || sentences.isEmpty()) {
110+
return List.of();
111+
}
112+
113+
// If we have fewer sentences than k, return all
114+
if (sentences.size() <= k) {
115+
return new ArrayList<>(sentences);
116+
}
117+
118+
// Filter out empty/whitespace sentences
119+
List<IndexedSentence> validSentences =
120+
IntStream.range(0, sentences.size())
121+
.filter(i -> sentences.get(i) != null && !sentences.get(i).isBlank())
122+
.mapToObj(i -> new IndexedSentence(i, sentences.get(i)))
123+
.toList();
124+
125+
if (validSentences.size() <= k) {
126+
return validSentences.stream().map(IndexedSentence::text).toList();
127+
}
128+
129+
// 1. Embed all sentences
130+
List<String> textsToEmbed = validSentences.stream().map(IndexedSentence::text).toList();
131+
List<float[]> embeddings = embedder.embedSentences(textsToEmbed);
132+
133+
// 2. Create clusterable points
134+
List<EmbeddedSentence> points =
135+
IntStream.range(0, validSentences.size())
136+
.mapToObj(i -> new EmbeddedSentence(validSentences.get(i).index(), embeddings.get(i)))
137+
.toList();
138+
139+
// 3. K-means++ clustering
140+
KMeansPlusPlusClusterer<EmbeddedSentence> clusterer =
141+
new KMeansPlusPlusClusterer<>(k, maxIterations);
142+
List<CentroidCluster<EmbeddedSentence>> clusters = clusterer.cluster(points);
143+
144+
// 4. Select sentence closest to each cluster centroid
145+
List<Integer> selectedIndices =
146+
clusters.stream()
147+
.map(this::findClosestToCentroid)
148+
.map(EmbeddedSentence::index)
149+
.sorted() // Preserve original order
150+
.toList();
151+
152+
// 5. Return original sentences
153+
return selectedIndices.stream().map(sentences::get).toList();
154+
}
155+
156+
/** Find the sentence closest to the cluster centroid. */
157+
private EmbeddedSentence findClosestToCentroid(CentroidCluster<EmbeddedSentence> cluster) {
158+
double[] centroid = cluster.getCenter().getPoint();
159+
160+
return cluster.getPoints().stream()
161+
.min(Comparator.comparingDouble(point -> euclideanDistance(point.getPoint(), centroid)))
162+
.orElseThrow(() -> new IllegalStateException("Empty cluster"));
163+
}
164+
165+
/** Calculate Euclidean distance between two points. */
166+
private double euclideanDistance(double[] a, double[] b) {
167+
double sum = 0.0;
168+
for (int i = 0; i < a.length; i++) {
169+
double diff = a[i] - b[i];
170+
sum += diff * diff;
171+
}
172+
return Math.sqrt(sum);
173+
}
174+
175+
/** Helper record to track original indices. */
176+
private record IndexedSentence(int index, String text) {}
177+
178+
/** Builder for ExtractiveSelector. */
179+
public static Builder builder(SentenceTransformersVectorizer embedder) {
180+
return new Builder(embedder);
181+
}
182+
183+
public static class Builder {
184+
private final SentenceTransformersVectorizer embedder;
185+
private int defaultNumSentences = 10;
186+
private int maxIterations = 100;
187+
188+
@SuppressFBWarnings(
189+
value = "EI_EXPOSE_REP2",
190+
justification = "Embedder is intentionally shared; it's a heavyweight resource")
191+
public Builder(SentenceTransformersVectorizer embedder) {
192+
this.embedder = embedder;
193+
}
194+
195+
public Builder defaultNumSentences(int n) {
196+
this.defaultNumSentences = n;
197+
return this;
198+
}
199+
200+
public Builder maxIterations(int n) {
201+
this.maxIterations = n;
202+
return this;
203+
}
204+
205+
public ExtractiveSelector build() {
206+
return new ExtractiveSelector(embedder, defaultNumSentences, maxIterations);
207+
}
208+
}
209+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package com.redis.vl.extensions.summarization;
2+
3+
import java.io.IOException;
4+
import java.io.InputStream;
5+
import java.util.Arrays;
6+
import java.util.List;
7+
import opennlp.tools.sentdetect.SentenceDetectorME;
8+
import opennlp.tools.sentdetect.SentenceModel;
9+
10+
/** OpenNLP-based sentence splitting utility. Thread-safe after initialization. */
11+
public class SentenceSplitter {
12+
13+
private final SentenceDetectorME detector;
14+
15+
/**
16+
* Create a sentence splitter using the default English model. The model is loaded from the
17+
* classpath.
18+
*/
19+
public SentenceSplitter() {
20+
this(loadDefaultModel());
21+
}
22+
23+
/**
24+
* Create a sentence splitter with a custom model.
25+
*
26+
* @param model The OpenNLP sentence model to use
27+
*/
28+
public SentenceSplitter(SentenceModel model) {
29+
this.detector = new SentenceDetectorME(model);
30+
}
31+
32+
private static SentenceModel loadDefaultModel() {
33+
try (InputStream modelIn =
34+
SentenceSplitter.class.getResourceAsStream("/models/opennlp/en-sent.bin")) {
35+
if (modelIn == null) {
36+
throw new IllegalStateException(
37+
"OpenNLP English sentence model not found. "
38+
+ "Ensure 'en-sent.bin' is in resources/models/opennlp/");
39+
}
40+
return new SentenceModel(modelIn);
41+
} catch (IOException e) {
42+
throw new IllegalStateException("Failed to load OpenNLP sentence model", e);
43+
}
44+
}
45+
46+
/**
47+
* Split text into sentences.
48+
*
49+
* @param text The text to split
50+
* @return List of sentences
51+
*/
52+
public List<String> split(String text) {
53+
if (text == null || text.isBlank()) {
54+
return List.of();
55+
}
56+
synchronized (detector) {
57+
return Arrays.asList(detector.sentDetect(text));
58+
}
59+
}
60+
61+
/**
62+
* Split text into sentences with position spans.
63+
*
64+
* @param text The text to split
65+
* @return Array of Span objects with start/end positions
66+
*/
67+
public opennlp.tools.util.Span[] splitWithSpans(String text) {
68+
if (text == null || text.isBlank()) {
69+
return new opennlp.tools.util.Span[0];
70+
}
71+
synchronized (detector) {
72+
return detector.sentPosDetect(text);
73+
}
74+
}
75+
}

core/src/main/java/com/redis/vl/utils/vectorize/SentenceTransformersVectorizer.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,20 @@ public List<List<Float>> embedBatchAsLists(List<String> texts) {
172172
return result;
173173
}
174174

175+
/**
176+
* Embed multiple sentences for clustering/selection. Useful for extractive summarization where we
177+
* need to compare sentence similarities.
178+
*
179+
* @param sentences List of sentences to embed
180+
* @return List of embedding vectors (float arrays)
181+
*/
182+
public List<float[]> embedSentences(List<String> sentences) {
183+
if (sentences == null || sentences.isEmpty()) {
184+
return List.of();
185+
}
186+
return generateEmbeddingsBatch(sentences, 32);
187+
}
188+
175189
private List<Float> floatArrayToList(float[] array) {
176190
List<Float> list = new ArrayList<>(array.length);
177191
for (float value : array) {
96.2 KB
Binary file not shown.

0 commit comments

Comments
 (0)