From 1ee20dd30df0affc316e23481a584b45b16a8b6d Mon Sep 17 00:00:00 2001 From: ylfeng Date: Sun, 9 Jun 2024 01:40:36 +0800 Subject: [PATCH] update server --- .github/workflows/ltp-publish.yml | 1 - python/interface/examples/server.py | 174 +++++++++++----------------- 2 files changed, 69 insertions(+), 106 deletions(-) diff --git a/.github/workflows/ltp-publish.yml b/.github/workflows/ltp-publish.yml index a05437a6..a85781fc 100644 --- a/.github/workflows/ltp-publish.yml +++ b/.github/workflows/ltp-publish.yml @@ -29,7 +29,6 @@ jobs: release: name: Release runs-on: ubuntu-latest - if: "startsWith(github.ref, 'refs/tags/')" needs: [interface] steps: - uses: actions/download-artifact@v2 diff --git a/python/interface/examples/server.py b/python/interface/examples/server.py index 43ea160e..f0cbf701 100644 --- a/python/interface/examples/server.py +++ b/python/interface/examples/server.py @@ -9,70 +9,71 @@ python tools/server.py serve """ -import sys -import json -import logging -from typing import List - +from typing import List, Union import torch +from fastapi import FastAPI +from pydantic import BaseModel +from ltp import LTP -from tornado import ioloop -from tornado.httpserver import HTTPServer -from tornado.web import Application, RequestHandler -from tornado.log import app_log, gen_log, access_log, LogFormatter -from fire import Fire -from ltp import LTP +class SRLRole(BaseModel): + text: str + offset: int + length: int + type: str + + +class Parent(BaseModel): + parent: int + relate: str + + +class Word(BaseModel): + id: int + length: int + offset: int + text: str + pos: str + parent: int + relation: str + roles: List[SRLRole] + parents: List[Parent] + + +class NE(BaseModel): + text: str + offset: int + ne: str + length: int + + +class Item(BaseModel): + text: str + nes: List[NE] + words: List[Word] + + +app = FastAPI() + +ltp = LTP("LTP/tiny") +if torch.cuda.is_available(): + ltp.to("cuda") -class LTPHandler(RequestHandler): - def set_default_headers(self): - self.set_header("Access-Control-Allow-Origin", "*") - self.set_header('Access-Control-Allow-Headers', 'Content-Type') - self.set_header('Access-Control-Allow-Methods', 'GET, POST, PUT, DELETE, PATCH, OPTIONS') - self.set_header('Content-Type', 'application/json;charset=UTF-8') - - def initialize(self, ltp): - self.set_default_headers() - self.ltp = ltp - - def post(self): - try: - print(self.request.body.decode('utf-8')) - text = json.loads(self.request.body.decode('utf-8'))['text'] - # print(text) - result = self.ltp._predict([text]) - # print(result) - self.finish(result) - except Exception as e: - self.finish(self.ltp._predict(['服务器遇到错误!'])[0]) - - def options(self): - pass - - -class Server(object): - def __init__(self, path: str = 'LTP/tiny', batch_size: int = 50, device: str = None): - # 2024/6/1 7:9:45 adapt for "ltp==4.2.13" - self.ltp = LTP(path) - self.batch_size = batch_size - # 将模型移动到 GPU 上 - if device is None and torch.cuda.is_available(): - # ltp.cuda() - self.ltp.to("cuda") - elif device is not None: - self.ltp.to(device) - - def _predict(self, sentences: List[str]): - output = self.ltp.pipeline(sentences, tasks=["cws", "pos", "ner", "srl", "dep", "sdp", "sdpg"]) - - # https://github.com/HIT-SCIR/ltp/blob/main/python/interface/docs/quickstart.rst - # 需要注意的是,在依存句法当中,虚节点ROOT占据了0位置,因此节点的下标从1开始。 + +@app.post("/api") +async def predict(sentences: List[str]) -> List[Item]: + output = ltp.pipeline(sentences, tasks=["cws", "pos", "ner", "srl", "dep", "sdp", "sdpg"]) + + # https://github.com/HIT-SCIR/ltp/blob/main/python/interface/docs/quickstart.rst + # 需要注意的是,在依存句法当中,虚节点ROOT占据了0位置,因此节点的下标从1开始。 + result = [] + for idx, sentence in enumerate(sentences): id = 0 offset = 0 words = [] for word, pos, parent, relation in \ - zip(output.cws[0], output.pos[0], output.dep[0]['head'], output.dep[0]['label']): + zip(output.cws[idx], output.pos[idx], output.dep[idx]['head'], output.dep[idx]['label']): # print([id, word, pos, parent, relation]) words.append({ 'id': id, @@ -88,27 +89,24 @@ def _predict(self, sentences: List[str]): id = id + 1 offset = offset + len(word) - for token_srl in output.srl[0]: - for argument in token_srl['arguments']: + for token_srl in output.srl[idx]: + for (argument, text, start, end) in token_srl['arguments']: # print(token_srl['index'], token_srl['predicate'], argument) - text = argument[1] - start = argument[2] offset = words[start]['offset'] words[token_srl['index']]['roles'].append({ 'text': text, 'offset': offset, 'length': len(text), - 'type': argument[0] + 'type': argument }) start = 0 - for end, label in \ - zip(output.sdp[0]['head'], output.sdp[0]['label']): + for end, label in zip(output.sdp[idx]['head'], output.sdp[idx]['label']): words[start]['parents'].append({'parent': end - 1, 'relate': label}) start = start + 1 nes = [] - for role, text, start, end in output.ner[0]: + for role, text, start, end in output.ner[idx]: nes.append({ 'text': text, 'offset': start, @@ -116,46 +114,12 @@ def _predict(self, sentences: List[str]): 'length': len(text) }) - result = { - 'text': sentences[0], - 'nes': nes, - 'words': words - } - - return result - - def serve(self, port: int = 5000, n_process: int = None): - if n_process is None: - n_process = 1 if sys.platform == 'win32' else 8 - - fmt = LogFormatter(fmt='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', color=True) - root_logger = logging.getLogger() - - console_handler = logging.StreamHandler() - file_handler = logging.FileHandler('server.log') - - console_handler.setFormatter(fmt) - file_handler.setFormatter(fmt) - - root_logger.addHandler(console_handler) - root_logger.addHandler(file_handler) - - app_log.setLevel(logging.INFO) - gen_log.setLevel(logging.INFO) - access_log.setLevel(logging.INFO) - - # app_log.info("Model is loading...") - app_log.info("Model Has Been Loaded!") - - app = Application([ - (r"/.*", LTPHandler, dict(ltp=self)) - ]) - - server = HTTPServer(app) - server.bind(port) - server.start(n_process) - ioloop.IOLoop.instance().start() - + result.append( + { + 'text': sentence, + 'nes': nes, + 'words': words + } + ) -if __name__ == '__main__': - Fire(Server) + return result