Skip to content

Commit

Permalink
✨ all tasks put into refs
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Oct 16, 2024
1 parent 25d5ae0 commit fcd1ee0
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
10 changes: 8 additions & 2 deletions nonebot/adapters/satori/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, driver: Driver, **kwargs: Any):
super().__init__(driver, **kwargs)
# 读取适配器所需的配置项
self.satori_config: Config = get_plugin_config(Config)
self.tasks: list[asyncio.Task] = [] # 存储 ws 任务
self.tasks: set[asyncio.Task] = set() # 存储 ws 任务等
self.sequences: dict[str, int] = {} # 存储 连接序列号
self._bots: defaultdict[str, set[str]] = defaultdict(set) # 存储 identity 和 bot_id 的映射
self.setup()
Expand Down Expand Up @@ -80,7 +80,9 @@ def setup(self) -> None:
async def startup(self) -> None:
"""定义启动时的操作,例如和平台建立连接"""
for client in self.satori_config.satori_clients:
self.tasks.append(asyncio.create_task(self.ws(client)))
t = asyncio.create_task(self.ws(client))
self.tasks.add(t)
t.add_done_callback(self.tasks.discard)

async def shutdown(self) -> None:
for task in self.tasks:
Expand Down Expand Up @@ -184,6 +186,8 @@ async def ws(self, info: ClientInfo) -> None:
await asyncio.sleep(3)
continue
heartbeat_task = asyncio.create_task(self._heartbeat(info, ws))
self.tasks.add(heartbeat_task)
heartbeat_task.add_done_callback(self.tasks.discard)
await self._loop(info, ws)
except WebSocketClosed as e:
log(
Expand Down Expand Up @@ -266,6 +270,8 @@ async def _loop(self, info: ClientInfo, ws: WebSocket):
if isinstance(event, (MessageEvent, InteractionEvent)):
event = event.convert()
_t = asyncio.create_task(bot.handle_event(event))
self.tasks.add(_t)
_t.add_done_callback(self.tasks.discard)
elif isinstance(payload, PongPayload):
log("TRACE", "Pong")
continue
Expand Down
4 changes: 3 additions & 1 deletion tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def _ping(json: dict) -> dict:
return {"op": 2}

for client in adapter.satori_config.satori_clients:
adapter.tasks.append(asyncio.create_task(adapter.ws(client)))
task = asyncio.create_task(adapter.ws(client))
adapter.tasks.add(task)
task.add_done_callback(adapter.tasks.discard)

await asyncio.sleep(5)
bots = nonebot.get_bots()
Expand Down

0 comments on commit fcd1ee0

Please sign in to comment.