Skip to content

Commit 66796f0

Browse files
committed
feat: GH-4589 Add Option to Store All User Messages in MessageChatMemoryAdvisor
1 parent 133eb40 commit 66796f0

File tree

2 files changed

+86
-15
lines changed

2 files changed

+86
-15
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,18 @@ public final class MessageChatMemoryAdvisor implements BaseChatMemoryAdvisor {
5454

5555
private final Scheduler scheduler;
5656

57+
private final boolean storeAllUserMessages;
58+
5759
private MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int order,
58-
Scheduler scheduler) {
60+
Scheduler scheduler, boolean storeAllUserMessages) {
5961
Assert.notNull(chatMemory, "chatMemory cannot be null");
6062
Assert.hasText(defaultConversationId, "defaultConversationId cannot be null or empty");
6163
Assert.notNull(scheduler, "scheduler cannot be null");
6264
this.chatMemory = chatMemory;
6365
this.defaultConversationId = defaultConversationId;
6466
this.order = order;
6567
this.scheduler = scheduler;
68+
this.storeAllUserMessages = storeAllUserMessages;
6669
}
6770

6871
@Override
@@ -88,12 +91,19 @@ public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChai
8891

8992
// 3. Create a new request with the advised messages.
9093
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
91-
.prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build())
92-
.build();
94+
.prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build())
95+
.build();
9396

9497
// 4. Add the new user message to the conversation memory.
95-
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
96-
this.chatMemory.add(conversationId, userMessage);
98+
if (this.storeAllUserMessages) {
99+
// Store all user messages: add the new message to the existing message list
100+
List allUserMessages = processedChatClientRequest.prompt().getUserMessages();
101+
this.chatMemory.add(conversationId, allUserMessages);
102+
} else {
103+
// Store only the latest user message
104+
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
105+
this.chatMemory.add(conversationId, userMessage);
106+
}
97107

98108
return processedChatClientRequest;
99109
}
@@ -103,10 +113,10 @@ public ChatClientResponse after(ChatClientResponse chatClientResponse, AdvisorCh
103113
List<Message> assistantMessages = new ArrayList<>();
104114
if (chatClientResponse.chatResponse() != null) {
105115
assistantMessages = chatClientResponse.chatResponse()
106-
.getResults()
107-
.stream()
108-
.map(g -> (Message) g.getOutput())
109-
.toList();
116+
.getResults()
117+
.stream()
118+
.map(g -> (Message) g.getOutput())
119+
.toList();
110120
}
111121
this.chatMemory.add(this.getConversationId(chatClientResponse.context(), this.defaultConversationId),
112122
assistantMessages);
@@ -121,11 +131,11 @@ public Flux<ChatClientResponse> adviseStream(ChatClientRequest chatClientRequest
121131

122132
// Process the request with the before method
123133
return Mono.just(chatClientRequest)
124-
.publishOn(scheduler)
125-
.map(request -> this.before(request, streamAdvisorChain))
126-
.flatMapMany(streamAdvisorChain::nextStream)
127-
.transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux,
128-
response -> this.after(response, streamAdvisorChain)));
134+
.publishOn(scheduler)
135+
.map(request -> this.before(request, streamAdvisorChain))
136+
.flatMapMany(streamAdvisorChain::nextStream)
137+
.transform(flux -> new ChatClientMessageAggregator().aggregateChatClientResponse(flux,
138+
response -> this.after(response, streamAdvisorChain)));
129139
}
130140

131141
public static Builder builder(ChatMemory chatMemory) {
@@ -142,6 +152,8 @@ public static final class Builder {
142152

143153
private ChatMemory chatMemory;
144154

155+
private boolean storeAllUserMessages = false;
156+
145157
private Builder(ChatMemory chatMemory) {
146158
this.chatMemory = chatMemory;
147159
}
@@ -171,12 +183,22 @@ public Builder scheduler(Scheduler scheduler) {
171183
return this;
172184
}
173185

186+
/**
187+
* Configure whether to store all user messages or only the latest one.
188+
* @param storeAllUserMessages true to store all user messages, false to store only the latest
189+
* @return the builder
190+
*/
191+
public Builder storeAllUserMessages(boolean storeAllUserMessages) {
192+
this.storeAllUserMessages = storeAllUserMessages;
193+
return this;
194+
}
195+
174196
/**
175197
* Build the advisor.
176198
* @return the advisor
177199
*/
178200
public MessageChatMemoryAdvisor build() {
179-
return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.order, this.scheduler);
201+
return new MessageChatMemoryAdvisor(this.chatMemory, this.conversationId, this.order, this.scheduler, this.storeAllUserMessages);
180202
}
181203

182204
}

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisorTests.java

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,53 @@ void testDefaultValues() {
108108
assertThat(advisor.getOrder()).isEqualTo(Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER);
109109
}
110110

111+
@Test
112+
void whenStoreAllUserMessagesIsTrueThenPreserveAllMessages() {
113+
// Create a chat memory
114+
ChatMemory chatMemory = MessageWindowChatMemory.builder()
115+
.chatMemoryRepository(new InMemoryChatMemoryRepository())
116+
.build();
117+
118+
// Create advisor with storeAllUserMessages set to true
119+
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory)
120+
.storeAllUserMessages(true)
121+
.build();
122+
123+
// Verify the advisor was built successfully
124+
assertThat(advisor).isNotNull();
125+
}
126+
127+
@Test
128+
void whenStoreAllUserMessagesIsFalseThenStoreOnlyLatest() {
129+
// Create a chat memory
130+
ChatMemory chatMemory = MessageWindowChatMemory.builder()
131+
.chatMemoryRepository(new InMemoryChatMemoryRepository())
132+
.build();
133+
134+
// Create advisor with storeAllUserMessages set to false (default)
135+
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory)
136+
.storeAllUserMessages(false)
137+
.build();
138+
139+
// Verify the advisor was built successfully
140+
assertThat(advisor).isNotNull();
141+
}
142+
143+
@Test
144+
void testDefaultStoreAllUserMessagesValue() {
145+
// Create a chat memory
146+
ChatMemory chatMemory = MessageWindowChatMemory.builder()
147+
.chatMemoryRepository(new InMemoryChatMemoryRepository())
148+
.build();
149+
150+
// Create advisor with default values
151+
MessageChatMemoryAdvisor advisor = MessageChatMemoryAdvisor.builder(chatMemory).build();
152+
153+
// Verify the advisor was built successfully
154+
assertThat(advisor).isNotNull();
155+
// Note: We cannot directly verify the default value of storeAllUserMessages
156+
// since it's a private field, but the construction should succeed
157+
}
158+
159+
111160
}

0 commit comments

Comments
 (0)