diff --git a/src/RAG.Api/Endpoints/Chat/StreamMessageEndpoint.cs b/src/RAG.Api/Endpoints/Chat/StreamMessageEndpoint.cs index 22f5f03..dadecee 100644 --- a/src/RAG.Api/Endpoints/Chat/StreamMessageEndpoint.cs +++ b/src/RAG.Api/Endpoints/Chat/StreamMessageEndpoint.cs @@ -1,6 +1,7 @@ using System.Text; using System.Text.Json; using FastEndpoints; +using Microsoft.EntityFrameworkCore; using RAG.Domain.Entities; using RAG.Domain.Enums; using RAG.Domain.Interfaces; @@ -40,6 +41,43 @@ public class StreamMessageEndpoint(RagDbContext db, IAIChatAgent chatAgent, ICha await cache.AppendMessageAsync(conversationId, new CachedChatMessage(userMessage.Id, ChatRole.User.ToString(), userMessage.Content, null, userMessage.CreatedAt), ct); + // 加载历史消息构建上下文 + var cached = await cache.GetMessagesAsync(conversationId, ct); + List<(string Role, string Content)> history; + if (cached is { Count: > 0 }) + { + history = cached + .OrderBy(m => m.CreatedAt) + .Select(m => (m.Role, m.Content)) + .ToList(); + } + else + { + var dbMessages = await db.ChatMessages + .Where(m => m.ConversationId == conversationId) + .OrderBy(m => m.CreatedAt) + .ToListAsync(ct); + + history = dbMessages.Select(m => (m.Role.ToString(), m.Content)).ToList(); + + // 回填 Redis 缓存 + if (history.Count > 0) + { + await cache.SetMessagesAsync(conversationId, + history.Select((h, i) => new CachedChatMessage( + Guid.NewGuid(), h.Item1, h.Item2, null, + DateTime.UtcNow.AddSeconds(i))).ToList(), ct); + } + } + + var promptBuilder = new StringBuilder(); + foreach (var (role, content) in history) + { + promptBuilder.AppendLine($"{role}: {content}"); + } + + var prompt = promptBuilder.ToString(); + // SSE 响应 HttpContext.Response.ContentType = "text/event-stream"; HttpContext.Response.Headers.CacheControl = "no-cache"; @@ -47,7 +85,7 @@ public class StreamMessageEndpoint(RagDbContext db, IAIChatAgent chatAgent, ICha var fullReply = new StringBuilder(); - await foreach (var chunk in chatAgent.RunStreamingAsync(req.Content, ct)) + await foreach (var chunk in chatAgent.RunStreamingAsync(prompt, ct)) { fullReply.Append(chunk); var sseData = JsonSerializer.Serialize(new { content = chunk }); @@ -69,7 +107,7 @@ public class StreamMessageEndpoint(RagDbContext db, IAIChatAgent chatAgent, ICha await cache.AppendMessageAsync(conversationId, new CachedChatMessage(assistantMessage.Id, ChatRole.Assistant.ToString(), assistantMessage.Content, null, assistantMessage.CreatedAt), ct); - // 发送结束标记(含完整消息 ID) + // 发送结束标记 var doneData = JsonSerializer.Serialize(new { messageId = assistantMessage.Id }); await HttpContext.Response.WriteAsync($"event: done\ndata: {doneData}\n\n", ct); await HttpContext.Response.Body.FlushAsync(ct); diff --git a/src/RAG.Infrastructure/AI/ChatAgentService.cs b/src/RAG.Infrastructure/AI/ChatAgentService.cs index c1239a3..ee0c4c5 100644 --- a/src/RAG.Infrastructure/AI/ChatAgentService.cs +++ b/src/RAG.Infrastructure/AI/ChatAgentService.cs @@ -1,4 +1,5 @@ using System.Runtime.CompilerServices; +using System.Text; using System.ClientModel; using Microsoft.Agents.AI; using Microsoft.Extensions.Options; @@ -28,16 +29,142 @@ public class ChatAgentService : IAIChatAgent public async Task RunAsync(string prompt, CancellationToken ct) { var response = await _agent.RunAsync(prompt, null, null, ct); - return response.Text; + return FilterThinkTags(response.Text); } public async IAsyncEnumerable RunStreamingAsync( string prompt, [EnumeratorCancellation] CancellationToken ct) { + // 状态机过滤 ... 标签 + var buffer = new StringBuilder(); + var inThink = false; + await foreach (var update in _agent.RunStreamingAsync(prompt, null, null, ct)) { - if (!string.IsNullOrEmpty(update.Text)) - yield return update.Text; + if (string.IsNullOrEmpty(update.Text)) + continue; + + buffer.Append(update.Text); + + // 从 buffer 中提取可以安全输出的文本 + while (buffer.Length > 0) + { + var content = buffer.ToString(); + + if (inThink) + { + var endIdx = content.IndexOf(" 10) + buffer.Remove(0, buffer.Length - 10); + break; + } + + // 找到结束标签,跳过 + var closeIdx = content.IndexOf('>', endIdx); + if (closeIdx < 0) + break; + + inThink = false; + buffer.Remove(0, closeIdx + 1); + } + else + { + var thinkIdx = content.IndexOf(" 7) + { + var safeLen = buffer.Length - 7; + var output = buffer.ToString(0, safeLen); + buffer.Remove(0, safeLen); + if (output.Length > 0) + yield return output; + } + + break; + } + + // 输出 think 标签之前的内容 + if (thinkIdx > 0) + yield return buffer.ToString(0, thinkIdx); + + // 检查同一 chunk 内是否有 + var endIdx = content.IndexOf("= 0) + { + var closeIdx = content.IndexOf('>', endIdx); + if (closeIdx >= 0) + { + // 完整的 think 块在同一 chunk,直接跳过 + buffer.Remove(0, closeIdx + 1); + } + else + { + // 有结束标签但不完整 + inThink = true; + buffer.Clear(); + break; + } + } + else + { + // 进入 think 块 + inThink = true; + buffer.Remove(0, thinkIdx); + break; + } + } + } } + + // 输出 buffer 中剩余的非 think 内容 + if (!inThink && buffer.Length > 0) + yield return buffer.ToString(); + + buffer.Clear(); + } + + /// 过滤完整的文本中的 think 标签(非流式用) + private static string FilterThinkTags(string text) + { + if (string.IsNullOrEmpty(text)) return text; + + var sb = new StringBuilder(); + var span = text.AsSpan(); + while (!span.IsEmpty) + { + var thinkStart = span.IndexOf(" + var afterThink = span[(thinkStart + 6)..]; + var gtIdx = afterThink.IndexOf('>'); + if (gtIdx < 0) + break; + + var endTag = afterThink[(gtIdx + 1)..].IndexOf("'); + if (endGt < 0) + break; + + span = afterEnd[(endGt + 1)..]; + } + + return sb.ToString().TrimStart('\n', '\r'); } }