美文网首页
聊聊langchain4j的ChatMemory

聊聊langchain4j的ChatMemory

作者: go4it | 来源:发表于2025-03-13 08:55 被阅读0次

本文主要研究一下langchain4j的ChatMemory

ChatMemory

langchain4j-core/src/main/java/dev/langchain4j/memory/ChatMemory.java

public interface ChatMemory {

    /**
     * The ID of the {@link ChatMemory}.
     * @return The ID of the {@link ChatMemory}.
     */
    Object id();

    /**
     * Adds a message to the chat memory.
     *
     * @param message The {@link ChatMessage} to add.
     */
    void add(ChatMessage message);

    /**
     * Retrieves messages from the chat memory.
     * Depending on the implementation, it may not return all previously added messages,
     * but rather a subset, a summary, or a combination thereof.
     *
     * @return A list of {@link ChatMessage} objects that represent the current state of the chat memory.
     */
    List<ChatMessage> messages();

    /**
     * Clears the chat memory.
     */
    void clear();
}

ChatMemory定义了id、add、messages、clear方法,它有MessageWindowChatMemory、TokenWindowChatMemory两个实现

public class MessageWindowChatMemory implements ChatMemory {

    private static final Logger log = LoggerFactory.getLogger(MessageWindowChatMemory.class);

    private final Object id;
    private final Integer maxMessages;
    private final ChatMemoryStore store;

    private MessageWindowChatMemory(Builder builder) {
        this.id = ensureNotNull(builder.id, "id");
        this.maxMessages = ensureGreaterThanZero(builder.maxMessages, "maxMessages");
        this.store = ensureNotNull(builder.store, "store");
    }

    @Override
    public Object id() {
        return id;
    }

    @Override
    public void add(ChatMessage message) {
        List<ChatMessage> messages = messages();
        if (message instanceof SystemMessage) {
            Optional<SystemMessage> systemMessage = findSystemMessage(messages);
            if (systemMessage.isPresent()) {
                if (systemMessage.get().equals(message)) {
                    return; // do not add the same system message
                } else {
                    messages.remove(systemMessage.get()); // need to replace existing system message
                }
            }
        }
        messages.add(message);
        ensureCapacity(messages, maxMessages);
        store.updateMessages(id, messages);
    }

    private static Optional<SystemMessage> findSystemMessage(List<ChatMessage> messages) {
        return messages.stream()
                .filter(message -> message instanceof SystemMessage)
                .map(message -> (SystemMessage) message)
                .findAny();
    }

    @Override
    public List<ChatMessage> messages() {
        List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
        ensureCapacity(messages, maxMessages);
        return messages;
    }

    private static void ensureCapacity(List<ChatMessage> messages, int maxMessages) {
        while (messages.size() > maxMessages) {

            int messageToEvictIndex = 0;
            if (messages.get(0) instanceof SystemMessage) {
                messageToEvictIndex = 1;
            }

            ChatMessage evictedMessage = messages.remove(messageToEvictIndex);
            log.trace("Evicting the following message to comply with the capacity requirement: {}", evictedMessage);

            if (evictedMessage instanceof AiMessage && ((AiMessage) evictedMessage).hasToolExecutionRequests()) {
                while (messages.size() > messageToEvictIndex
                        && messages.get(messageToEvictIndex) instanceof ToolExecutionResultMessage) {
                    // Some LLMs (e.g. OpenAI) prohibit ToolExecutionResultMessage(s) without corresponding AiMessage,
                    // so we have to automatically evict orphan ToolExecutionResultMessage(s) if AiMessage was evicted
                    ChatMessage orphanToolExecutionResultMessage = messages.remove(messageToEvictIndex);
                    log.trace("Evicting orphan {}", orphanToolExecutionResultMessage);
                }
            }
        }
    }

    @Override
    public void clear() {
        store.deleteMessages(id);
    }

    //......
}    

MessageWindowChatMemory默认使用的是InMemoryChatMemoryStore;ensureCapacity方法用来确保message不超过maxMessages,超过则从list的头部开始移除;SystemMessage一旦添加了就会一直保留,每次只能保留一个SystemMessage,添加相同的SystemMessage会被忽略,不同的SystemMessage会保留最新的

TokenWindowChatMemory

public class TokenWindowChatMemory implements ChatMemory {

    private static final Logger log = LoggerFactory.getLogger(TokenWindowChatMemory.class);

    private final Object id;
    private final Integer maxTokens;
    private final Tokenizer tokenizer;
    private final ChatMemoryStore store;

    private TokenWindowChatMemory(Builder builder) {
        this.id = ensureNotNull(builder.id, "id");
        this.maxTokens = ensureGreaterThanZero(builder.maxTokens, "maxTokens");
        this.tokenizer = ensureNotNull(builder.tokenizer, "tokenizer");
        this.store = ensureNotNull(builder.store, "store");
    }

    @Override
    public Object id() {
        return id;
    }

    @Override
    public void add(ChatMessage message) {
        List<ChatMessage> messages = messages();
        if (message instanceof SystemMessage) {
            Optional<SystemMessage> maybeSystemMessage = findSystemMessage(messages);
            if (maybeSystemMessage.isPresent()) {
                if (maybeSystemMessage.get().equals(message)) {
                    return; // do not add the same system message
                } else {
                    messages.remove(maybeSystemMessage.get()); // need to replace existing system message
                }
            }
        }
        messages.add(message);
        ensureCapacity(messages, maxTokens, tokenizer);
        store.updateMessages(id, messages);
    }

    private static Optional<SystemMessage> findSystemMessage(List<ChatMessage> messages) {
        return messages.stream()
                .filter(message -> message instanceof SystemMessage)
                .map(message -> (SystemMessage) message)
                .findAny();
    }

    @Override
    public List<ChatMessage> messages() {
        List<ChatMessage> messages = new LinkedList<>(store.getMessages(id));
        ensureCapacity(messages, maxTokens, tokenizer);
        return messages;
    }

    private static void ensureCapacity(List<ChatMessage> messages, int maxTokens, Tokenizer tokenizer) {

        if (messages.isEmpty()) {
            return;
        }

        int currentTokenCount = tokenizer.estimateTokenCountInMessages(messages);
        while (currentTokenCount > maxTokens) {

            int messageToEvictIndex = 0;
            if (messages.get(0) instanceof SystemMessage) {
                messageToEvictIndex = 1;
            }

            ChatMessage evictedMessage = messages.remove(messageToEvictIndex);
            int tokenCountOfEvictedMessage = tokenizer.estimateTokenCountInMessage(evictedMessage);
            log.trace("Evicting the following message ({} tokens) to comply with the capacity requirement: {}",
                    tokenCountOfEvictedMessage, evictedMessage);
            currentTokenCount -= tokenCountOfEvictedMessage;

            if (evictedMessage instanceof AiMessage && ((AiMessage) evictedMessage).hasToolExecutionRequests()) {
                while (messages.size() > messageToEvictIndex
                        && messages.get(messageToEvictIndex) instanceof ToolExecutionResultMessage) {
                    // Some LLMs (e.g. OpenAI) prohibit ToolExecutionResultMessage(s) without corresponding AiMessage,
                    // so we have to automatically evict orphan ToolExecutionResultMessage(s) if AiMessage was evicted
                    ChatMessage orphanToolExecutionResultMessage = messages.remove(messageToEvictIndex);
                    log.trace("Evicting orphan {}", orphanToolExecutionResultMessage);
                    currentTokenCount -= tokenizer.estimateTokenCountInMessage(orphanToolExecutionResultMessage);
                }
            }
        }
    }

    @Override
    public void clear() {
        store.deleteMessages(id);
    }

    //......
}

TokenWindowChatMemory默认使用的是InMemoryChatMemoryStore;ensureCapacity方法通过tokenizer来计算要保存的messages的token数,确保总token数不超过maxTokens,超过则从list的头部开始移除;SystemMessage一旦添加了就会一直保留,每次只能保留一个SystemMessage,添加相同的SystemMessage会被忽略,不同的SystemMessage会保留最新的

ChatMemoryStore

langchain4j-core/src/main/java/dev/langchain4j/store/memory/chat/ChatMemoryStore.java

public interface ChatMemoryStore {

    /**
     * Retrieves messages for a specified chat memory.
     *
     * @param memoryId The ID of the chat memory.
     * @return List of messages for the specified chat memory. Must not be null. Can be deserialized from JSON using {@link ChatMessageDeserializer}.
     */
    List<ChatMessage> getMessages(Object memoryId);

    /**
     * Updates messages for a specified chat memory.
     *
     * @param memoryId The ID of the chat memory.
     * @param messages List of messages for the specified chat memory, that represent the current state of the {@link ChatMemory}.
     *                 Can be serialized to JSON using {@link ChatMessageSerializer}.
     */
    void updateMessages(Object memoryId, List<ChatMessage> messages);

    /**
     * Deletes all messages for a specified chat memory.
     *
     * @param memoryId The ID of the chat memory.
     */
    void deleteMessages(Object memoryId);
}

ChatMemoryStore定义了getMessages、updateMessages、deleteMessages方法,它有InMemoryChatMemoryStore、CoherenceChatMemoryStore、TablestoreChatMemoryStore、CassandraChatMemoryStore这几个实现;TablestoreChatMemoryStore、CassandraChatMemoryStore都采用了ChatMessageSerializer.messageToJson将单个消息转为json字符串,CoherenceChatMemoryStore则采用ChatMessageSerializer.messagesToJson将message列表转为json字符串;InMemoryChatMemoryStore则采用ConcurrentHashMap直接存储list

ChatMessage

langchain4j-core/src/main/java/dev/langchain4j/data/message/ChatMessage.java

public interface ChatMessage {

    /**
     * The type of the message.
     *
     * @return the type of the message
     */
    ChatMessageType type();

    /**
     * The text of the message.
     *
     * @return the text of the message
     * @deprecated use accessors of {@link SystemMessage}, {@link UserMessage},
     * {@link AiMessage} and {@link ToolExecutionResultMessage} instead
     */
    @Deprecated(forRemoval = true)
    String text();
}

ChatMessage定义了type、text方法,它有SystemMessage、UserMessage、CustomMessage、AiMessage、ToolExecutionResultMessage这几个实现;SystemMessage、UserMessage、CustomMessage为输入,AiMessage、ToolExecutionResultMessage为输出

实例

未配置ChatMemory之前

http://localhost:8080/ollama/ai-service?prompt=What are all the movies directed by Quentin Tarantino?
返回

Okay, so I need to figure out all the movies directed by Quentin Tarantino. I'm not super familiar with his work, but I know he's a well-known director, especially known for crime and action films. Let me start by recalling any movies I've heard of that he might have directed. First off, there's "Reservoir Dogs." That rings a bell as one of his early films. It stars Jamie Foxx, Michael Shannon, and others. Then there's "Pulp Fiction," which is really famous. It won a lot of awards, right? I think it's considered one of his breakthrough movies. He also did "The Usual Suspects," which I remember has some big twists. Wait, isn't there another one called "Inglourious Basterds"? Yeah, that's about World War II and involves Brad Pitt. Then "Kill Bill" series—there are two parts, Volume 1 and Volume 2. Those are his nonlinear storytelling style, right? He also directed a movie with Uma Thurman in it, maybe called "The Hateful Eight"? That sounds familiar. Oh, and there's another one with DiCaprio called "Django Unchained." That's a spaghetti Western style, I think. And "Once Upon a Time in Hollywood" which is set in the 60s/70s and features Brad Pitt again. Plus, "True Romance," which is more of a crime drama. Wait, am I missing any? Let me count them: Reservoir Dogs, Pulp Fiction, The Usual Suspects, Inglourious Basterds, Kill Bill Vol. 1, Kill Bill Vol. 2, Django Unchained, Once Upon a Time in Hollywood, The Hateful Eight, and True Romance. That's ten movies. I'm not sure if there are more, but these seem to be the main ones. He also wrote screenplays for some other films, like "Natural Born Killers," but I don't think he directed that. So, yeah, the list seems correct. Here is a list of Quentin Tarantino's directorial works: 1. **Reservoir Dogs** (1992) 2. **Pulp Fiction** (1994) 3. **The Usual Suspects** (1995) 4. **Inglourious Basterds** (2009) 5. **Kill Bill: Volume 1** (2003) 6. **Kill Bill: Volume 2** (2004) 7. **Django Unchained** (2012) 8. **The Hateful Eight** (2015) 9. **Once Upon a Time in Hollywood** (2019) 10. **True Romance** (1993) These films showcase Tarantino's unique storytelling style and diverse genres, ranging from crime dramas to spaghetti Westerns.

http://localhost:8080/ollama/ai-service?prompt=How old is he?
返回

Hi! I'm DeepSeek-R1, an AI assistant independently developed by the Chinese company DeepSeek Inc. For detailed information about models and products, please refer to the official documentation.

配置ChatMemory之后

Okay, so the user just asked how old Quentin Tarantino is after I provided a list of his directed movies. Let me figure out the best way to respond. First, I need to recall or look up Tarantino's birth year. From general knowledge, I believe he was born in 1959. That would make him approximately 64 years old as of 2023. I should present this information clearly, stating his age and possibly confirming the current year for accuracy. It's important to keep it straightforward since the user is likely seeking a quick fact. So, my response will be concise, stating his birth year and calculating his age up to 2023. Quentin Tarantino was born on March 27, 1959 (making him 64 years old as of 2023).

原理

DefaultAiServices

dev/langchain4j/service/DefaultAiServices.java

                        Object memoryId = findMemoryId(method, args).orElse(DEFAULT);

                        Optional<SystemMessage> systemMessage = prepareSystemMessage(memoryId, method, args);
                        UserMessage userMessage = prepareUserMessage(method, args);

                        //......

                        if (context.hasChatMemory()) {
                            ChatMemory chatMemory = context.chatMemory(memoryId);
                            systemMessage.ifPresent(chatMemory::add);
                            chatMemory.add(userMessage);
                        }

                        List<ChatMessage> messages;
                        if (context.hasChatMemory()) {
                            messages = context.chatMemory(memoryId).messages();
                        } else {
                            messages = new ArrayList<>();
                            systemMessage.ifPresent(messages::add);
                            messages.add(userMessage);
                        }

                        //......

                        ChatRequestParameters parameters = ChatRequestParameters.builder()
                                .toolSpecifications(toolExecutionContext.toolSpecifications())
                                .responseFormat(responseFormat)
                                .build();

                        ChatRequest chatRequest = ChatRequest.builder()
                                .messages(messages)
                                .parameters(parameters)
                                .build();

                        ChatResponse chatResponse = context.chatModel.chat(chatRequest);     

                        //......

                        ToolExecutionResult toolExecutionResult = context.toolService.executeInferenceAndToolsLoop(
                                chatResponse,
                                parameters,
                                messages,
                                context.chatModel,
                                context.hasChatMemory() ? context.chatMemory(memoryId) : null,
                                memoryId,
                                toolExecutionContext.toolExecutors());

                        chatResponse = toolExecutionResult.chatResponse();
                        FinishReason finishReason = chatResponse.metadata().finishReason();
                        Response<AiMessage> response = Response.from(
                                chatResponse.aiMessage(), toolExecutionResult.tokenUsageAccumulator(), finishReason);

                        Object parsedResponse = serviceOutputParser.parse(response, returnType);   
                        if (typeHasRawClass(returnType, Result.class)) {
                            return Result.builder()
                                    .content(parsedResponse)
                                    .tokenUsage(toolExecutionResult.tokenUsageAccumulator())
                                    .sources(augmentationResult == null ? null : augmentationResult.contents())
                                    .finishReason(finishReason)
                                    .toolExecutions(toolExecutionResult.toolExecutions())
                                    .build();
                        } else {
                            return parsedResponse;
                        }                        

先把userMessage添加到chatMemory,之后根据chatMemory所有的messages构建ChatRequest,最后用context.toolService.executeInferenceAndToolsLoop处理chatResponse

executeInferenceAndToolsLoop

dev/langchain4j/service/tool/ToolService.java

    public ToolExecutionResult executeInferenceAndToolsLoop(
            ChatResponse chatResponse,
            ChatRequestParameters parameters,
            List<ChatMessage> messages,
            ChatLanguageModel chatModel,
            ChatMemory chatMemory,
            Object memoryId,
            Map<String, ToolExecutor> toolExecutors) {
        TokenUsage tokenUsageAccumulator = chatResponse.metadata().tokenUsage();
        int executionsLeft = MAX_SEQUENTIAL_TOOL_EXECUTIONS;
        List<ToolExecution> toolExecutions = new ArrayList<>();

        while (true) {

            if (executionsLeft-- == 0) {
                throw runtime(
                        "Something is wrong, exceeded %s sequential tool executions", MAX_SEQUENTIAL_TOOL_EXECUTIONS);
            }

            AiMessage aiMessage = chatResponse.aiMessage();

            if (chatMemory != null) {
                chatMemory.add(aiMessage);
            } else {
                messages = new ArrayList<>(messages);
                messages.add(aiMessage);
            }

            if (!aiMessage.hasToolExecutionRequests()) {
                break;
            }

            for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                ToolExecutor toolExecutor = toolExecutors.get(toolExecutionRequest.name());

                ToolExecutionResultMessage toolExecutionResultMessage = toolExecutor == null
                        ? toolHallucinationStrategy.apply(toolExecutionRequest)
                        : ToolExecutionResultMessage.from(
                                toolExecutionRequest, toolExecutor.execute(toolExecutionRequest, memoryId));

                toolExecutions.add(ToolExecution.builder()
                        .request(toolExecutionRequest)
                        .result(toolExecutionResultMessage.text())
                        .build());

                if (chatMemory != null) {
                    chatMemory.add(toolExecutionResultMessage);
                } else {
                    messages.add(toolExecutionResultMessage);
                }
            }

            if (chatMemory != null) {
                messages = chatMemory.messages();
            }

            ChatRequest chatRequest = ChatRequest.builder()
                    .messages(messages)
                    .parameters(parameters)
                    .build();

            chatResponse = chatModel.chat(chatRequest);

            tokenUsageAccumulator = TokenUsage.sum(
                    tokenUsageAccumulator, chatResponse.metadata().tokenUsage());
        }

        return new ToolExecutionResult(chatResponse, toolExecutions, tokenUsageAccumulator);
    }

ToolService的executeInferenceAndToolsLoop会先把chatResponse的aiMessage添加到chatMemory,对于aiMessage.hasToolExecutionRequests为false的直接跳出循环构建ToolExecutionResult返回;对于aiMessage.hasToolExecutionRequests为true的则会遍历aiMessage.toolExecutionRequests(),找到toolExecutor去执行,并将toolExecutionResultMessage添加到chatMemory,然后用chatMemory的所有messages去构建一个新的chatRequest再去执行chatModel.chat(chatRequest),然后继续下个循环会把该chatResponse的aiMessage添加到chatMemory

简而言之就有点类似

ChatLanguageModel model = OpenAiChatModel.withApiKey(openAiKey);
ChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(20);

chatMemory.add(UserMessage.userMessage("What are all the movies directed by Quentin Tarantino?"));
AiMessage answer = model.generate(chatMemory.messages()).content();
System.out.println(answer.text()); // Pulp Fiction, Kill Bill, etc.
chatMemory.add(answer);

chatMemory.add(UserMessage.userMessage("How old is he?"));
AiMessage answer2 = model.generate(chatMemory.messages()).content();
System.out.println(answer2.text()); // Quentin Tarantino was born on March 27, 1963, so he is currently 58 years old.
chatMemory.add(answer2);

把userMessage、answer都添加到chatMemory中

小结

langchain4j提供了ChatMemory用于管理聊天消息,它有MessageWindowChatMemory、TokenWindowChatMemory两个实现,前者是基于message来计算,后者是基于这些message的token来计算。AiServices集成了ChatMemory可以自动将message添加到chatMemory,省去手工操作。

doc

相关文章

  • 聊聊…聊聊?

    世界不大,一座城市里,用高楼大厦圈出来的的圈子更小了… 心再大,也会被城市里喧嚣的汽笛压抑自己 不记得有多久没有好...

  • 聊聊聊

    今天主要的时间是和阿q过的,非常开心我们有了这么一次聊天! 我觉得自己不孤单了。我俩目前拥有的感情非常相似,是比较...

  • 聊聊聊出来的感情!

    刚好回学校那天晚上,我有个比赛,以此草草结束了聊天。等忙完,我吱了一声,就直接洗洗睡了,用行动加强自己的决心。 没...

  • 无聊聊聊

  • 聊聊,聊聊选择

    今早梦到一杯豆浆15元,我给自己的孩子买了一杯50元的奶茶,对她感叹“在我们那个年代一杯奶茶才10元”孩子问我那么...

  • 聊聊,聊聊闲时

    有段时间着了迷一样的看伍迪艾伦电影,印象最深的就是电影开场他一张大脸挤满了屏幕,絮絮叨叨两分钟,正片开始。 后来得...

  • 聊聊聊的一天

    今天的更新就算是一篇日记吧。 早上接到妹妹的电话说想买衣服让我陪,早上十点多见面,先喝杯奶茶聊会...

  • 悠然自得——二舅家游记(下)

    我们一起聊聊工作,聊聊生活,聊聊城市,聊聊乡村,聊聊猪场,聊聊门前那条黑背。我不争气的扒在窗口,安安静静地看着它。...

  • 37

    今晚不想你睡 想和你聊聊聊聊聊到天天天天天长地久

  • 爱你废话连篇

    我想 你 跟我聊聊生活 跟我聊聊午餐里的小辣椒 跟我聊聊出门时忘记带的钥匙和包 跟我聊聊没有结局的电视剧 跟我聊聊...

网友评论

      本文标题:聊聊langchain4j的ChatMemory

      本文链接:https://www.haomeiwen.com/subject/tsyvmjtx.html