Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
SlimeNull committed Feb 27, 2023
2 parents f0932ea + a5133c3 commit bb7a3f1
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 63 deletions.
74 changes: 49 additions & 25 deletions src/EleCho.GoCqHttpSdk/CqWsSession.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Concurrent;
using System.IO;
using System.Net.WebSockets;
using System.Text.Json;
Expand Down Expand Up @@ -31,6 +32,7 @@ public class CqWsSession : CqSession, ICqPostSession, ICqActionSession, IDisposa

// 主循环线程
private Task? mainLoopTask;
private Task? mainPostLoopTask;
private Task? standaloneActionLoopTask;

// 三个接入点的套接字
Expand All @@ -39,6 +41,8 @@ public class CqWsSession : CqSession, ICqPostSession, ICqActionSession, IDisposa
private WebSocket? apiWebSocket;
private WebSocket? eventWebSocket;

private ConcurrentQueue<CqPostModel> postQueue;

/// <summary>
/// 已连接
/// </summary>
Expand Down Expand Up @@ -96,6 +100,7 @@ public CqWsSession(CqWsSessionOptions options)
// 初始化 action 发送器 和 post 管道
actionSender = new CqWsActionSender(this, apiWebSocket ?? webSocket ?? throw new InvalidOperationException("This would never happened"));
postPipeline = new CqPostPipeline();
postQueue = new ConcurrentQueue<CqPostModel>();
}

internal CqWsSession(WebSocket remoteWebSocket, Uri baseUri, string? accessToken, int bufferSize)
Expand All @@ -108,34 +113,20 @@ internal CqWsSession(WebSocket remoteWebSocket, Uri baseUri, string? accessToken

actionSender = new CqWsActionSender(this, remoteWebSocket);
postPipeline = new CqPostPipeline();
}

internal async Task ProcPostModelAsync(CqPostModel postModel)
{
CqPostContext? postContext = CqPostContext.FromModel(postModel);
postContext?.SetSession(this);

// 如果 post 上下文不为空, 则使用 PostPipeline 处理该事件
if (postContext != null)
{
await postPipeline.ExecuteAsync(postContext);

// WebSocket 需要模拟 QuickAction
await actionSender.HandleQuickAction(postContext, postModel);
}
postQueue = new ConcurrentQueue<CqPostModel>();
}

/// <summary>
/// 处理 WebSocket 数据
/// </summary>
/// <param name="wsDataModel"></param>
/// <returns></returns>
private async Task ProcWsDataAsync(CqWsDataModel? wsDataModel)
private void ProcWsDataAsync(CqWsDataModel? wsDataModel)
{
// 如果是 post 上报
if (wsDataModel is CqPostModel postModel)
{
await ProcPostModelAsync(postModel);
postQueue.Enqueue(postModel);
}
// 否则如果是 action 请求响应
else if (wsDataModel is CqActionResultRaw actionResultRaw)
Expand All @@ -156,12 +147,10 @@ private async Task WebSocketLoop(WebSocket webSocket)
MemoryStream ms = new MemoryStream();
while (!disposed)
{
IsConnected = webSocket.State == WebSocketState.Open;
IsConnected &= webSocket.State == WebSocketState.Open;

if (!IsConnected)
{
return;
}

try
{
Expand All @@ -181,12 +170,16 @@ private async Task WebSocketLoop(WebSocket webSocket)
try // 直接捕捉 JSON 反序列化异常
{
#endif
// 反序列化为 WebSocket 数据 (自己抽的类
string json = GlobalConfig.TextEncoding.GetString(ms.ToArray());
CqWsDataModel? wsDataModel = JsonSerializer.Deserialize<CqWsDataModel>(json, JsonHelper.Options);
#if DEBUG
// 反序列化为 WebSocket 数据 (自己抽的类
string json = GlobalConfig.TextEncoding.GetString(ms.ToArray());
#endif

ms.Seek(0, SeekOrigin.Begin);
CqWsDataModel? wsDataModel = JsonSerializer.Deserialize<CqWsDataModel>(ms, JsonHelper.Options);

// 处理 WebSocket 数据
await ProcWsDataAsync(wsDataModel);
ProcWsDataAsync(wsDataModel);

#if DEBUG
if (wsDataModel is not CqPostModel)
Expand All @@ -202,6 +195,34 @@ private async Task WebSocketLoop(WebSocket webSocket)
}
}

private async Task PostProcLoop()
{
while (!disposed)
{
if (!IsConnected)
return;

if (postQueue.TryDequeue(out var postModel))
{
CqPostContext? postContext = CqPostContext.FromModel(postModel);
postContext?.SetSession(this);

// 如果 post 上下文不为空, 则使用 PostPipeline 处理该事件
if (postContext != null)
{
await postPipeline.ExecuteAsync(postContext);

// WebSocket 需要模拟 QuickAction
await actionSender.HandleQuickAction(postContext, postModel);
}
}
else
{
await Task.Delay(1);
}
}
}

/// <summary>
/// 连接
/// </summary>
Expand Down Expand Up @@ -283,6 +304,9 @@ public async Task StartAsync()
// 当使用单独的 API 套接字的时候, 我们需要监听 API 套接字
if (apiWebSocket != null)
standaloneActionLoopTask = WebSocketLoop(apiWebSocket);

// 单独线程处理上报
mainPostLoopTask = PostProcLoop();
}

/// <summary>
Expand All @@ -303,7 +327,7 @@ public async Task WaitForShutdownAsync()
if (mainLoopTask == null)
throw new InvalidOperationException("Session is not started yet");

await mainLoopTask;
await Task.WhenAll(mainLoopTask, mainPostLoopTask);
}

/// <summary>
Expand Down
45 changes: 7 additions & 38 deletions src/TestConsole/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,56 +23,25 @@ internal class Program
{
public const int WebSocketPort = 5701;

static CqRHttpSession rHttpSession = new CqRHttpSession(new CqRHttpSessionOptions()
{
BaseUri = new Uri($"http://localhost:5701"),
});

static CqWsSession session = new CqWsSession(new CqWsSessionOptions()
{
BaseUri = new Uri($"ws://127.0.0.1:{WebSocketPort}"),
UseApiEndPoint = true,
UseEventEndPoint = true,
});

private static async Task Main(string[] args)
{
Console.Write("OpenAI API Key:\n> ");
var apikey =
Console.ReadLine()!;

session.UseMessageMatchPlugin(new MessageMatchPlugin1(session));
session.UseMessageMatchPlugin(new OpenAiMatchPlugin(session, apikey));
session.UseMessageMatchPlugin(new MessageMatchPlugin2(session));

session.UseGroupRequest(async context =>
{
await session.ApproveGroupRequestAsync(context.Flag, context.GroupRequestType);
});

session.UseGroupMessage(async context =>
{
Console.WriteLine($"{context.Sender.Nickname}: {context.Message.Text}");

if (context.Message.Text.StartsWith("ocr ", StringComparison.OrdinalIgnoreCase))
if (context.Message.Text.StartsWith("echo "))
{
var img = context.Message.FirstOrDefault(x => x is CqImageMsg);
if (img is CqImageMsg imgmsg)
{
var ocrrst =
await session.OcrImageAsync(imgmsg.File);

if (ocrrst == null)
return;

StringBuilder sb = new StringBuilder();
sb.AppendLine("OCR:");
foreach (var txtdet in ocrrst.Texts)
sb.AppendLine($"{txtdet.Text} Confidence:{txtdet.Confidence}");

await session.SendGroupMessageAsync(context.GroupId, new CqMessage(sb.ToString()));
}
await session.SendGroupMessageAsync(context.GroupId, new CqMessage(context.Message.Text.Substring(5)));
}

if (context.Message.Text.EndsWith("..."))
{
await session.SendGroupMessageAsync(context.GroupId, context.Message);
}
});

Console.WriteLine("OK");
Expand Down

0 comments on commit bb7a3f1

Please sign in to comment.