From ca7463d42b15f508e0d69ce4a462cafb13c9f716 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=90=91=E5=AE=81?= <1772105645@qq.com>
Date: Wed, 20 May 2026 21:07:47 +0800
Subject: [PATCH] =?UTF-8?q?fix:=20AI=20=E6=B5=81=E5=BC=8F=E5=9B=9E?=
=?UTF-8?q?=E5=A4=8D=E8=BF=87=E6=BB=A4=E6=80=9D=E8=80=83=E6=A0=87=E7=AD=BE?=
=?UTF-8?q?=20+=20=E5=8A=A0=E8=BD=BD=E5=8E=86=E5=8F=B2=E4=B8=8A=E4=B8=8B?=
=?UTF-8?q?=E6=96=87?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- ChatAgentService.RunStreamingAsync 增加状态机过滤 qwen3 标签
- RunAsync 同步方法也增加过滤
- StreamMessageEndpoint 从 Redis/DB 加载历史消息构建上下文
---
.../Endpoints/Chat/StreamMessageEndpoint.cs | 42 +++++-
src/RAG.Infrastructure/AI/ChatAgentService.cs | 133 +++++++++++++++++-
2 files changed, 170 insertions(+), 5 deletions(-)
diff --git a/src/RAG.Api/Endpoints/Chat/StreamMessageEndpoint.cs b/src/RAG.Api/Endpoints/Chat/StreamMessageEndpoint.cs
index 22f5f03..dadecee 100644
--- a/src/RAG.Api/Endpoints/Chat/StreamMessageEndpoint.cs
+++ b/src/RAG.Api/Endpoints/Chat/StreamMessageEndpoint.cs
@@ -1,6 +1,7 @@
using System.Text;
using System.Text.Json;
using FastEndpoints;
+using Microsoft.EntityFrameworkCore;
using RAG.Domain.Entities;
using RAG.Domain.Enums;
using RAG.Domain.Interfaces;
@@ -40,6 +41,43 @@ public class StreamMessageEndpoint(RagDbContext db, IAIChatAgent chatAgent, ICha
await cache.AppendMessageAsync(conversationId,
new CachedChatMessage(userMessage.Id, ChatRole.User.ToString(), userMessage.Content, null, userMessage.CreatedAt), ct);
+ // 加载历史消息构建上下文
+ var cached = await cache.GetMessagesAsync(conversationId, ct);
+ List<(string Role, string Content)> history;
+ if (cached is { Count: > 0 })
+ {
+ history = cached
+ .OrderBy(m => m.CreatedAt)
+ .Select(m => (m.Role, m.Content))
+ .ToList();
+ }
+ else
+ {
+ var dbMessages = await db.ChatMessages
+ .Where(m => m.ConversationId == conversationId)
+ .OrderBy(m => m.CreatedAt)
+ .ToListAsync(ct);
+
+ history = dbMessages.Select(m => (m.Role.ToString(), m.Content)).ToList();
+
+ // 回填 Redis 缓存
+ if (history.Count > 0)
+ {
+ await cache.SetMessagesAsync(conversationId,
+ history.Select((h, i) => new CachedChatMessage(
+ Guid.NewGuid(), h.Item1, h.Item2, null,
+ DateTime.UtcNow.AddSeconds(i))).ToList(), ct);
+ }
+ }
+
+ var promptBuilder = new StringBuilder();
+ foreach (var (role, content) in history)
+ {
+ promptBuilder.AppendLine($"{role}: {content}");
+ }
+
+ var prompt = promptBuilder.ToString();
+
// SSE 响应
HttpContext.Response.ContentType = "text/event-stream";
HttpContext.Response.Headers.CacheControl = "no-cache";
@@ -47,7 +85,7 @@ public class StreamMessageEndpoint(RagDbContext db, IAIChatAgent chatAgent, ICha
var fullReply = new StringBuilder();
- await foreach (var chunk in chatAgent.RunStreamingAsync(req.Content, ct))
+ await foreach (var chunk in chatAgent.RunStreamingAsync(prompt, ct))
{
fullReply.Append(chunk);
var sseData = JsonSerializer.Serialize(new { content = chunk });
@@ -69,7 +107,7 @@ public class StreamMessageEndpoint(RagDbContext db, IAIChatAgent chatAgent, ICha
await cache.AppendMessageAsync(conversationId,
new CachedChatMessage(assistantMessage.Id, ChatRole.Assistant.ToString(), assistantMessage.Content, null, assistantMessage.CreatedAt), ct);
- // 发送结束标记(含完整消息 ID)
+ // 发送结束标记
var doneData = JsonSerializer.Serialize(new { messageId = assistantMessage.Id });
await HttpContext.Response.WriteAsync($"event: done\ndata: {doneData}\n\n", ct);
await HttpContext.Response.Body.FlushAsync(ct);
diff --git a/src/RAG.Infrastructure/AI/ChatAgentService.cs b/src/RAG.Infrastructure/AI/ChatAgentService.cs
index c1239a3..ee0c4c5 100644
--- a/src/RAG.Infrastructure/AI/ChatAgentService.cs
+++ b/src/RAG.Infrastructure/AI/ChatAgentService.cs
@@ -1,4 +1,5 @@
using System.Runtime.CompilerServices;
+using System.Text;
using System.ClientModel;
using Microsoft.Agents.AI;
using Microsoft.Extensions.Options;
@@ -28,16 +29,142 @@ public class ChatAgentService : IAIChatAgent
public async Task RunAsync(string prompt, CancellationToken ct)
{
var response = await _agent.RunAsync(prompt, null, null, ct);
- return response.Text;
+ return FilterThinkTags(response.Text);
}
public async IAsyncEnumerable RunStreamingAsync(
string prompt, [EnumeratorCancellation] CancellationToken ct)
{
+ // 状态机过滤 ... 标签
+ var buffer = new StringBuilder();
+ var inThink = false;
+
await foreach (var update in _agent.RunStreamingAsync(prompt, null, null, ct))
{
- if (!string.IsNullOrEmpty(update.Text))
- yield return update.Text;
+ if (string.IsNullOrEmpty(update.Text))
+ continue;
+
+ buffer.Append(update.Text);
+
+ // 从 buffer 中提取可以安全输出的文本
+ while (buffer.Length > 0)
+ {
+ var content = buffer.ToString();
+
+ if (inThink)
+ {
+ var endIdx = content.IndexOf(" 10)
+ buffer.Remove(0, buffer.Length - 10);
+ break;
+ }
+
+ // 找到结束标签,跳过
+ var closeIdx = content.IndexOf('>', endIdx);
+ if (closeIdx < 0)
+ break;
+
+ inThink = false;
+ buffer.Remove(0, closeIdx + 1);
+ }
+ else
+ {
+ var thinkIdx = content.IndexOf(" 7)
+ {
+ var safeLen = buffer.Length - 7;
+ var output = buffer.ToString(0, safeLen);
+ buffer.Remove(0, safeLen);
+ if (output.Length > 0)
+ yield return output;
+ }
+
+ break;
+ }
+
+ // 输出 think 标签之前的内容
+ if (thinkIdx > 0)
+ yield return buffer.ToString(0, thinkIdx);
+
+ // 检查同一 chunk 内是否有
+ var endIdx = content.IndexOf("= 0)
+ {
+ var closeIdx = content.IndexOf('>', endIdx);
+ if (closeIdx >= 0)
+ {
+ // 完整的 think 块在同一 chunk,直接跳过
+ buffer.Remove(0, closeIdx + 1);
+ }
+ else
+ {
+ // 有结束标签但不完整
+ inThink = true;
+ buffer.Clear();
+ break;
+ }
+ }
+ else
+ {
+ // 进入 think 块
+ inThink = true;
+ buffer.Remove(0, thinkIdx);
+ break;
+ }
+ }
+ }
}
+
+ // 输出 buffer 中剩余的非 think 内容
+ if (!inThink && buffer.Length > 0)
+ yield return buffer.ToString();
+
+ buffer.Clear();
+ }
+
+ /// 过滤完整的文本中的 think 标签(非流式用)
+ private static string FilterThinkTags(string text)
+ {
+ if (string.IsNullOrEmpty(text)) return text;
+
+ var sb = new StringBuilder();
+ var span = text.AsSpan();
+ while (!span.IsEmpty)
+ {
+ var thinkStart = span.IndexOf("
+ var afterThink = span[(thinkStart + 6)..];
+ var gtIdx = afterThink.IndexOf('>');
+ if (gtIdx < 0)
+ break;
+
+ var endTag = afterThink[(gtIdx + 1)..].IndexOf("');
+ if (endGt < 0)
+ break;
+
+ span = afterEnd[(endGt + 1)..];
+ }
+
+ return sb.ToString().TrimStart('\n', '\r');
}
}