思源笔记源码折腾记录 - 流式 AI

思源的 AI 响应是不支持流式的,但是现在的 AI 一个赛一个话痨,不支持流式经常会出现等了半天结果响应超时之类的状况。插件倒是能够支持流式但这样就不能使用核心的 AI 接口了,所以就稍微改动了一下思源的源码来支持流式响应。

代码主要都是 AI 写的,我主要负责当 AI 鼓励师,改动好像比较大而且不知道有没有 bug 所以还是先自用吧(其实还因为我 git 不熟,根本就不会 pr.....),完整点的内容可以参考这个 commit:

改动思路

主要是在后端添加流式响应能力,前端增加对应的处理逻辑。改动分为几个层次:

  1. 模型层 - 添加流式接口
  2. API 层 - 支持 SSE 流式传输
  3. 前端工具 - 新增 fetchStream 函数
  4. UI 交互 - 改进聊天界面

后端改动

1. 模型层 - 添加流式接口

首先需要在模型层添加流式处理的能力。

文件: kernel/model/ai.go

// 新增流式接口
func ChatGPTContinueWriteStream(msg string, contextMsgs []string, cloud bool) (stream *openai.ChatCompletionStream, err error) {
	if Conf.AI.OpenAI.APIMaxContexts < len(contextMsgs) {
		contextMsgs = contextMsgs[len(contextMsgs)-Conf.AI.OpenAI.APIMaxContexts:]
	}

	if cloud {
		return nil, errors.New("streaming not supported for CloudGPT")
	}

	gpt := &OpenAIGPT{c: util.NewOpenAIClient(Conf.AI.OpenAI.APIKey, Conf.AI.OpenAI.APIProxy, Conf.AI.OpenAI.APIBaseURL, Conf.AI.OpenAI.APIUserAgent, Conf.AI.OpenAI.APIVersion, Conf.AI.OpenAI.APIProvider)}
	return gpt.chatStream(msg, contextMsgs)
}

// 更新 GPT 接口
type GPT interface {
	chat(msg string, contextMsgs []string) (partRet string, stop bool, err error)
	chatStream(msg string, contextMsgs []string) (stream *openai.ChatCompletionStream, err error)
}

// OpenAIGPT 实现流式方法
func (gpt *OpenAIGPT) chatStream(msg string, contextMsgs []string) (stream *openai.ChatCompletionStream, err error) {
	return util.ChatGPTStream(msg, contextMsgs, gpt.c, Conf.AI.OpenAI.APIModel, Conf.AI.OpenAI.APIMaxTokens, Conf.AI.OpenAI.APITemperature, Conf.AI.OpenAI.APITimeout)
}

文件: kernel/util/openai.go

func ChatGPTStream(msg string, contextMsgs []string, c *openai.Client, model string, maxTokens int, temperature float64, timeout int) (*openai.ChatCompletionStream, error) {
	var reqMsgs []openai.ChatCompletionMessage
	for _, ctxMsg := range contextMsgs {
		if "" == ctxMsg {
			continue
		}

		reqMsgs = append(reqMsgs, openai.ChatCompletionMessage{
			Role:    "user",
			Content: ctxMsg,
		})
	}

	if "" != msg {
		reqMsgs = append(reqMsgs, openai.ChatCompletionMessage{
			Role:    "user",
			Content: msg,
		})
	}

	if 1 > len(reqMsgs) {
		return nil, nil
	}

	req := openai.ChatCompletionRequest{
		Model:       model,
		MaxTokens:   maxTokens,
		Temperature: float32(temperature),
		Messages:    reqMsgs,
		Stream:      true, // 关键:开启流式响应
	}
	ctx, _ := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)

	return c.CreateChatCompletionStream(ctx, req)
}

2. API 层 - 支持 SSE 流式传输

原来的 API 是同步返回的,需要改成支持 Server-Sent Events 的流式传输。

原版实现 (kernel/api/ai.go):

func chatGPT(c *gin.Context) {
	ret := gulu.Ret.NewResult()
	defer c.JSON(http.StatusOK, ret)

	arg, ok := util.JsonArg(c, ret)
	if !ok {
		return
	}

	msg := arg["msg"].(string)
	ret.Data = model.ChatGPT(msg) // 同步调用,等待完整响应
}

修改后的实现:

func chatGPT(c *gin.Context) {
	ret := gulu.Ret.NewResult()
	arg, ok := util.JsonArg(c, ret)
	if !ok {
		c.JSON(http.StatusOK, ret)
		return
	}

	msg := arg["msg"].(string)

	// 设置 SSE Headers
	c.Writer.Header().Set("Content-Type", "text/event-stream")
	c.Writer.Header().Set("Cache-Control", "no-cache")
	c.Writer.Header().Set("Connection", "keep-alive")
	c.Writer.Header().Set("Transfer-Encoding", "chunked")
	c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
	c.Writer.Header().Set("Access-Control-Allow-Headers", "Cache-Control")

	// 调用模型层的流式函数
	stream, err := model.ChatGPTContinueWriteStream(msg, model.CachedContextMsg, false)
	if err != nil {
		// 发送错误事件
		errorData, _ := json.Marshal(map[string]interface{}{
			"error": err.Error(),
		})
		c.Writer.WriteString("data: " + string(errorData) + "\n\n")
		c.Writer.Flush()
		return
	}

	// 监听客户端断开连接
	clientDisconnected := c.Writer.CloseNotify()

	// 发送开始事件
	startData, _ := json.Marshal(map[string]interface{}{
		"status":  "started",
		"message": "开始生成回复...",
	})
	c.Writer.WriteString("data: " + string(startData) + "\n\n")
	c.Writer.Flush()

	for {
		select {
		case <-clientDisconnected:
			return
		default:
			response, err := stream.Recv()
			if errors.Is(err, io.EOF) {
				// 流结束,发送完成事件
				doneData, _ := json.Marshal(map[string]interface{}{
					"status":  "completed",
					"message": "生成完成",
				})
				c.Writer.WriteString("data: " + string(doneData) + "\n\n")
				c.Writer.Flush()
				return
			}
			if err != nil {
				// 发送错误事件
				errorData, _ := json.Marshal(map[string]interface{}{
					"error": err.Error(),
				})
				c.Writer.WriteString("data: " + string(errorData) + "\n\n")
				c.Writer.Flush()
				return
			}

			// 发送内容事件
			content := response.Choices[0].Delta.Content
			if content != "" {
				data, _ := json.Marshal(map[string]interface{}{
					"content": content,
					"status":  "streaming",
				})
				c.Writer.WriteString("data: " + string(data) + "\n\n")
				c.Writer.Flush()
			}
		}
	}
}

前端改动

1. 工具函数 - fetchStream

原来的 fetchPost 不支持流式,需要新增一个 fetchStream 函数来处理 SSE 流。

文件: app/src/util/fetch.ts

export const fetchStream = async (
    params: any, 
    onMessage: (content: string) => void, 
    onDone: () => void, 
    onError: (error: Error) => void, 
    onAbort?: () => void
) => {
    let controller: AbortController | null = null;
    let timeoutId: NodeJS.Timeout | null = null;
    let lastEventTime = Date.now();
  
    const resetTimeout = () => {
        if (timeoutId) {
            clearTimeout(timeoutId);
        }
        lastEventTime = Date.now();
        // 每个事件后重置超时计时器
        timeoutId = setTimeout(() => {
            if (controller) {
                controller.abort();
            }
        }, 10000); // 10秒无事件则超时
    };
  
    try {
        // 创建可取消的请求
        controller = new AbortController();
  
        // 初始超时设置
        resetTimeout();
  
        const response = await fetch("/api/ai/chatGPT", {
            method: "POST",
            headers: {
                "Content-Type": "application/json",
            },
            body: JSON.stringify(params),
            signal: controller.signal,
        });

        if (!response.ok) {
            throw new Error(`HTTP error! status: ${response.status}`);
        }

        const reader = response.body?.getReader();
        if (!reader) {
            throw new Error("Response body is null");
        }

        const decoder = new TextDecoder("utf-8");
        let buffer = "";
        let isFirstChunk = true;

        try {
            while (true) {
                const { done, value } = await reader.read();
                if (done) {
                    break;
                }

                buffer += decoder.decode(value, { stream: true });
                const events = buffer.split("\n\n");
                buffer = events.pop() || ""; // 保留不完整的事件

                for (const event of events) {
                    if (event.startsWith("data: ")) {
                        const dataStr = event.substring(6);
                  
                        // 处理特殊事件
                        if (dataStr === "[DONE]") {
                            return; // 流结束
                        }
                  
                        try {
                            const data = JSON.parse(dataStr) as { 
                                content?: string; 
                                status?: string; 
                                error?: string; 
                                message?: string 
                            };
                      
                            // 处理错误
                            if (data.error) {
                                onError(new Error(data.error));
                                return;
                            }
                      
                            // 处理状态消息
                            if (data.status && data.message) {
                                console.log(`AI Status: ${data.status} - ${data.message}`);
                                resetTimeout(); // 重置超时
                                continue;
                            }
                      
                            // 处理内容
                            if (data.content) {
                                // 第一个chunk可能需要特殊处理
                                if (isFirstChunk && data.content.trim() === "") {
                                    continue; // 跳过空的第一个chunk
                                }
                                isFirstChunk = false;
                                onMessage(data.content);
                                resetTimeout(); // 每次收到内容都重置超时
                            }
                        } catch (e) {
                            // 忽略JSON解析错误,继续处理下一个事件
                            console.warn("Failed to parse SSE data:", dataStr);
                        }
                    }
                }
            }
        } finally {
            reader.releaseLock();
        }
  
        onDone();
    } catch (error) {
        if (error instanceof Error) {
            if (error.name === 'AbortError') {
                // 检查是否是因为超时
                const timeSinceLastEvent = Date.now() - lastEventTime;
                if (timeSinceLastEvent >= 10000) {
                    onError(new Error("响应超时,但已保留已有内容"));
                } else {
                    onError(new Error("请求已终止"));
                }
            } else {
                onError(error);
            }
        } else {
            onError(new Error("未知错误"));
        }
    } finally {
        if (timeoutId) {
            clearTimeout(timeoutId);
        }
    }
  
    // 返回终止函数,供外部调用
    return () => {
        if (controller) {
            controller.abort();
        }
        if (onAbort) {
            onAbort();
        }
    };
};

2. UI 交互 - 改进聊天界面

原来的聊天界面是等待完整响应后一次性显示,需要改成实时流式显示。

原版实现 (app/src/ai/chat.ts):

export const AIChat = (protyle: IProtyle, element: Element) => {
    // ... 对话框创建代码 ...
  
    btnsElement[1].addEventListener("click", () => {
        let inputValue = inputElement.value;
        fetchPost("/api/ai/chatGPT", {
            msg: inputValue,
        }, (response) => {
            dialog.destroy();
            let respContent = "";
            if (response.data && "" !== response.data) {
                respContent = "\n\n" + response.data;
            }
            if (inputValue === "Clear context") {
                inputValue = "";
            }
            fillContent(protyle, `${inputValue}${respContent}`, [element]);
        });
    });
};

修改后的实现:

const sendMessage = () => {
    const inputValue = inputTextarea.value.trim();
    if (!inputValue || isStreaming) return;

    // 添加用户消息
    addMessage(inputValue, true);
    inputTextarea.value = '';

    // 添加AI消息占位符
    const aiMessageElement = addMessage('正在思考...', false);
    const aiContentElement = aiMessageElement.querySelector('.ai-message-content') as HTMLElement;

    // 更新按钮状态
    sendButton.textContent = '生成中...';
    sendButton.disabled = true;
    inputTextarea.disabled = true;

    isStreaming = true;
    let responseContent = '';

    // 获取上下文和构建提示词
    const context = getContextInfo();
    // ... 构建 prompt 的代码 ...

    fetchStream(
        { msg: prompt + systemPrompt },
        (contentChunk) => {
            if (isStreaming) {
                responseContent += contentChunk;
                // 实时更新AI消息内容
                updateAIMessageContent(aiMessageElement, responseContent);
                messagesContainer.scrollTop = messagesContainer.scrollHeight;
            }
        },
        () => {
            // 完成回调
            isStreaming = false;
            sendButton.textContent = '发送';
            sendButton.disabled = false;
            inputTextarea.disabled = false;
            inputTextarea.focus();
            abortFunction = null;
        },
        (error) => {
            // 错误处理
            isStreaming = false;
            updateAIMessageContent(aiMessageElement, `生成失败: ${error.message}`);
            aiContentElement.style.color = 'var(--b3-theme-error)';
            sendButton.textContent = '发送';
            sendButton.disabled = false;
            inputTextarea.disabled = false;
            abortFunction = null;
        },
        () => {
            // 取消回调
            isStreaming = false;
            updateAIMessageContent(aiMessageElement, '已终止响应');
            sendButton.textContent = '发送';
            sendButton.disabled = false;
            inputTextarea.disabled = false;
            abortFunction = null;
        }
    ).then((abortFn) => {
        abortFunction = abortFn;
    });
};

总结

主要改动就是这些,从底层到上层:

  1. 模型层 - 添加了 ChatGPTContinueWriteStream 流式接口
  2. API 层 - 改成 SSE 流式传输,支持实时数据推送
  3. 前端工具 - 新增 fetchStream 处理流式响应
  4. UI 交互 - 改成实时显示,支持取消和错误处理

基本木有测试哈,效果差不多这样.

  • 思源笔记

    思源笔记是一款隐私优先的个人知识管理系统,支持完全离线使用,同时也支持端到端加密同步。

    融合块、大纲和双向链接,重构你的思维。

    28442 引用 • 119755 回帖
2 操作
leolee 在 2025-07-18 14:26:34 更新了该帖
JeffreyChen 在 2025-07-18 13:15:35 更新了该帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...