-
Notifications
You must be signed in to change notification settings - Fork 3
/
ner_entity.py
42 lines (34 loc) · 961 Bytes
/
ner_entity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from __future__ import annotations
from dataclasses import dataclass
from flair.data import Span
@dataclass(frozen=True)
class NerEntityType:
name: str
@dataclass(frozen=True)
class NerEntityInfo:
start: int
end: int
type: NerEntityType
@dataclass(frozen=True)
class NerEntity:
entity: NerEntityInfo
probability: float
@classmethod
def from_span(cls, span: Span) -> NerEntity:
return NerEntity(
entity=NerEntityInfo(
start=span.start_pos,
end=span.end_pos,
type=NerEntityType(name=cls._retrieve_type_from(span.tag))
),
probability=min(1, span.score)
)
@staticmethod
def _retrieve_type_from(tag: str) -> str:
if tag == "PER":
return "flair:person"
if tag == "LOC":
return "flair:location"
if tag == "ORG":
return "flair:organization"
return None