Skip to content

Commit

Permalink
Merge pull request #180 from ku-nlp/dev
Browse files Browse the repository at this point in the history
v2.1.1
  • Loading branch information
nobu-g authored Jun 3, 2023
2 parents b06245f + 667961e commit d46fd27
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 210 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [v2.1.1] - 2023-06-03
### Fixed
- Fix a bug in the interactive mode.
- Fix a bug in the seq2seq model's output.

## [v2.1.0] - 2023-06-02
### Added
- Support Python 3.11.
Expand Down Expand Up @@ -159,7 +164,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Removed
- Remove an unnecessary dependency, `fugashi`.

[Unreleased]: https://github.com/ku-nlp/kwja/compare/v2.1.0...HEAD
[Unreleased]: https://github.com/ku-nlp/kwja/compare/v2.1.1...HEAD
[2.1.1]: https://github.com/ku-nlp/kwja/compare/v2.1.0...v2.1.1
[2.1.0]: https://github.com/ku-nlp/kwja/compare/v2.0.0...v2.1.0
[2.0.0]: https://github.com/ku-nlp/kwja/compare/v1.4.2...v2.0.0
[1.4.2]: https://github.com/ku-nlp/kwja/compare/v1.4.1...v1.4.2
Expand Down
339 changes: 170 additions & 169 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "kwja"
version = "2.1.0"
description = "A unified language analyzer for Japanese"
version = "2.1.1"
description = "A unified Japanese analyzer based on foundation models"
authors = [
"Hirokazu Kiyomaru <kiyomaru@i.kyoto-u.ac.jp>",
"Nobuhiro Ueda <ueda@nlp.i.kyoto-u.ac.jp>",
Expand Down
12 changes: 12 additions & 0 deletions src/kwja/callbacks/senter_module_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,15 @@ def write_on_epoch_end(
batch_indices: Optional[Sequence[Any]] = None,
) -> None:
pass # pragma: no cover

def on_predict_epoch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
) -> None:
output_string: str = "\n"
if isinstance(self.destination, Path):
with self.destination.open(mode="a") as f:
f.write(output_string)
elif isinstance(self.destination, TextIOBase):
self.destination.write(output_string)
17 changes: 12 additions & 5 deletions src/kwja/callbacks/word_module_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@


class WordModuleWriter(BasePredictionWriter):
def __init__(self, ambig_surf_specs: List[Dict[str, str]], destination: Optional[Union[str, Path]] = None) -> None:
def __init__(
self,
ambig_surf_specs: List[Dict[str, str]],
preserve_reading_lemma_canon: bool = False,
destination: Optional[Union[str, Path]] = None,
) -> None:
super().__init__(write_interval="batch")
if destination is None:
self.destination: Union[Path, TextIO] = sys.stdout
Expand All @@ -60,6 +65,8 @@ def __init__(self, ambig_surf_specs: List[Dict[str, str]], destination: Optional
self.jumandic = JumanDic(RESOURCE_PATH / "jumandic")
self.jinf = Jinf()

self.preserve_reading_lemma_canon: bool = preserve_reading_lemma_canon

self.prev_doc_id: Optional[str] = None
self.doc_id_sid2predicted_sentence: Dict[str, Dict[str, Sentence]] = defaultdict(dict)

Expand Down Expand Up @@ -110,7 +117,7 @@ def write_on_batch_end(
assert example.doc_id is not None, "doc_id isn't set"
document = dataset.doc_id2document.pop(example.doc_id)
num_morphemes = len(document.morphemes)
if dataset.from_seq2seq is True:
if self.preserve_reading_lemma_canon is True:
word_reading_predictions = [m.reading for m in document.morphemes]
canons = [m.canon for m in document.morphemes]
else:
Expand All @@ -131,7 +138,7 @@ def write_on_batch_end(
word_reading_predictions,
morpheme_attribute_predictions,
canons,
dataset.from_seq2seq,
self.preserve_reading_lemma_canon,
)
predicted_document = chunk_morphemes(document, morphemes, word_feature_probabilities)
predicted_document.doc_id = document.doc_id
Expand Down Expand Up @@ -180,7 +187,7 @@ def _build_morphemes(
reading_predictions: List[str],
morpheme_attribute_predictions: Tuple[List[int], List[int], List[int], List[int]],
canons: List[Optional[str]],
from_seq2seq: bool,
preserve_lemma: bool,
) -> List[Morpheme]:
assert len(surfs) == len(norms) == len(reading_predictions)
morphemes = []
Expand All @@ -197,7 +204,7 @@ def _build_morphemes(
conjform_id = CONJTYPE_TAG_CONJFORM_TAG2CONJFORM_ID[conjtype][conjform]

homograph_ops: List[Dict[str, Any]] = []
if from_seq2seq:
if preserve_lemma is True:
lemma = norm
else:
lemma = self._get_lemma(norm, pos, subpos, conjtype, conjform, homograph_ops)
Expand Down
49 changes: 31 additions & 18 deletions src/kwja/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, config: CLIConfig, batch_size: int) -> None:
self.module: Optional[pl.LightningModule] = None
self.trainer: Optional[pl.Trainer] = None

def load(self):
def load(self, **writer_kwargs):
self.module = self._load_module()
if self.config.torch_compile is True:
self.module = torch.compile(self.module) # type: ignore
Expand All @@ -51,7 +51,9 @@ def load(self):
self.trainer = pl.Trainer(
logger=False,
callbacks=[
hydra.utils.instantiate(self.module.hparams.callbacks.prediction_writer, destination=self.destination),
hydra.utils.instantiate(
self.module.hparams.callbacks.prediction_writer, destination=self.destination, **writer_kwargs
),
hydra.utils.instantiate(self.module.hparams.callbacks.progress_bar),
],
accelerator=self.device_name,
Expand Down Expand Up @@ -129,7 +131,7 @@ def apply_module(self, input_file: Path) -> None:
for sent_idx, sentence in enumerate(document.sentences):
sentence.sid = f"{document.doc_id}-{sent_idx}"
sentence.misc_comment = f"kwja:{kwja.__version__}"
output_string += document.to_raw_text().strip()
output_string += document.to_raw_text()
self.destination.write_text(output_string)

def export_prediction(self) -> str:
Expand Down Expand Up @@ -178,6 +180,13 @@ def export_prediction(self) -> str:


class WordModuleProcessor(BaseModuleProcessor):
def __init__(self, config: CLIConfig, batch_size: int, from_seq2seq: bool) -> None:
super().__init__(config, batch_size)
self.from_seq2seq = from_seq2seq

def load(self):
super().load(preserve_reading_lemma_canon=self.from_seq2seq)

def _load_module(self) -> pl.LightningModule:
typer.echo("Loading word module", err=True)
checkpoint_path: Path = download_checkpoint(module="word", model_size=self.model_size)
Expand All @@ -195,39 +204,43 @@ def export_prediction(self) -> str:


class CLIProcessor:
def __init__(self, config: CLIConfig) -> None:
def __init__(self, config: CLIConfig, tasks: List[str]) -> None:
self.raw_destination = Path(NamedTemporaryFile(suffix=".txt", delete=False).name)
self.processors: Dict[str, BaseModuleProcessor] = {
self._task2processors: Dict[str, BaseModuleProcessor] = {
"typo": TypoModuleProcessor(config, config.typo_batch_size),
"senter": SenterModuleProcessor(config, config.senter_batch_size),
"seq2seq": Seq2SeqModuleProcessor(config, config.seq2seq_batch_size),
"char": CharModuleProcessor(config, config.char_batch_size),
"word": WordModuleProcessor(config, config.word_batch_size),
"word": WordModuleProcessor(
config,
config.word_batch_size,
from_seq2seq="seq2seq" in tasks,
),
}
self.processors: List[BaseModuleProcessor] = [self._task2processors[task] for task in tasks]

def load_modules(self, tasks: List[str]) -> None:
for task in tasks:
self.processors[task].load()
def load_all_modules(self) -> None:
for processor in self.processors:
processor.load()

def refresh(self) -> None:
self.raw_destination.unlink(missing_ok=True)
for processor in self.processors.values():
for processor in self._task2processors.values():
processor.destination.unlink(missing_ok=True)

def run(self, input_documents: List[str], tasks: List[str], interactive: bool = False) -> None:
def run(self, input_documents: List[str], interactive: bool = False) -> None:
self.raw_destination.write_text(
"".join(_normalize_text(input_document) + "\nEOD\n" for input_document in input_documents)
)
input_file = self.raw_destination
for task in tasks:
processor = self.processors[task]
for processor in self.processors:
if interactive is False:
processor.load()
processor.apply_module(input_file)
input_file = processor.destination
if interactive is False:
processor.delete_module_and_trainer()
print(self.processors[tasks[-1]].export_prediction(), end="")
print(self.processors[-1].export_prediction(), end="")


def _normalize_text(text: str) -> str:
Expand Down Expand Up @@ -351,24 +364,24 @@ def main(
if word_batch_size is not None:
config.word_batch_size = word_batch_size

processor = CLIProcessor(config)
processor = CLIProcessor(config, specified_tasks)

# Batch mode
if input_text is not None:
if input_text.strip() != "":
processor.run(_split_into_documents(input_text), specified_tasks)
processor.run(_split_into_documents(input_text))
processor.refresh()
raise typer.Exit()

# Interactive mode
processor.load_modules(specified_tasks)
processor.load_all_modules()
typer.echo('Please end your input with a new line and type "EOD"', err=True)
input_text = ""
while True:
input_ = input()
if input_ == "EOD":
processor.refresh()
processor.run([input_text], specified_tasks, interactive=True)
processor.run([input_text], interactive=True)
print("EOD") # To indicate the end of the output.
input_text = ""
else:
Expand Down
3 changes: 0 additions & 3 deletions src/kwja/datamodule/datasets/word.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ def __init__(
# some tags are not annotated in editorial articles
self.skip_cohesion_ne_discourse = self.path.parts[-2] == "kyoto_ed"

# ---------- seq2seq ----------
self.from_seq2seq: bool = False

# ---------- reading prediction ----------
reading_resource_path = RESOURCE_PATH / "reading_prediction"
self.reading2reading_id = get_reading2reading_id(reading_resource_path / "vocab.txt")
Expand Down
3 changes: 0 additions & 3 deletions src/kwja/datamodule/datasets/word_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ def __init__(
self.tokenizer_input_format = "text"

super(BaseDataset, self).__init__(documents, tokenizer, max_seq_length, document_split_stride)
# ---------- seq2seq ----------
self.from_seq2seq: bool = juman_file is not None and juman_file.suffix == ".seq2seq"

# ---------- cohesion analysis ----------
self.cohesion_tasks = [CohesionTask(ct) for ct in cohesion_tasks]
self.exophora_referents = [ExophoraReferent(er) for er in exophora_referents]
Expand Down
4 changes: 3 additions & 1 deletion tests/callbacks/test_senter_module_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,14 @@ def test_write_on_batch_end(char_tokenizer: PreTrainedTokenizerBase) -> None:
# S-ID:{doc_id_prefix}-1-1 kwja:{kwja.__version__}
違う文書の一文目。
# S-ID:{doc_id_prefix}-1-2 kwja:{kwja.__version__}
二文目。"""
二文目。
"""
),
]
dataset = SenterInferenceDataset(ListConfig(texts), char_tokenizer, max_seq_length, doc_id_prefix=doc_id_prefix)
trainer = MockTrainer([DataLoader(dataset, batch_size=num_examples)])
writer = SenterModuleWriter(destination=destination)
writer.write_on_batch_end(trainer, ..., prediction, None, ..., 0, 0)
writer.on_predict_epoch_end(trainer, ...)
assert isinstance(writer.destination, Path), "destination isn't set"
assert writer.destination.read_text() == expected_texts[0]
5 changes: 0 additions & 5 deletions tests/cli/sample.txt

This file was deleted.

54 changes: 51 additions & 3 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import io
import tempfile
import textwrap
from typing import List, Set, Tuple

from rhoknp import Document
from rhoknp.utils.reader import chunk_by_document
from typer.testing import CliRunner

from kwja.cli.cli import app

runner = CliRunner()
runner = CliRunner(mix_stderr=False)


def test_version():
Expand All @@ -17,11 +22,54 @@ def test_device():


def test_text_input():
_ = runner.invoke(app, args=["--model-size", "tiny", "--text", "おはよう"])
ret = runner.invoke(app, args=["--model-size", "tiny", "--text", "おはよう"])
assert ret.exception is None


def test_file_input():
_ = runner.invoke(app, args=["--model-size", "tiny", "--filename", "./sample.txt"])
with tempfile.NamedTemporaryFile("wt") as f:
f.write(
textwrap.dedent(
"""\
KWJAは日本語の統合解析ツールです。汎用言語モデルを利用し、様々な言語解析を統一的な方法で解いています。
EOD
計算機による言語理解を実現するためには,計算機に常識・世界知識を与える必要があります.
10年前にはこれは非常に難しい問題でしたが,近年の計算機パワー,計算機ネットワークの飛躍的進展によって計算機が超大規模テキストを取り扱えるようになり,そこから常識を自動獲得することが少しずつ可能になってきました.
EOD
"""
)
)
f.seek(0)
ret = runner.invoke(app, args=["--model-size", "tiny", "--filename", f.name])
assert ret.exception is None


def test_sanity():
with tempfile.NamedTemporaryFile("wt") as f:
f.write(
textwrap.dedent(
"""\
KWJAは日本語の統合解析ツールです。汎用言語モデルを利用し、様々な言語解析を統一的な方法で解いています。
EOD
計算機による言語理解を実現するためには、計算機に常識・世界知識を与える必要があります。
10年前にはこれは非常に難しい問題でしたが、近年の計算機パワー、計算機ネットワークの飛躍的進展によって
計算機が超大規模テキストを取り扱えるようになり、そこから常識を自動獲得することが少しずつ可能になってきました。
EOD
"""
)
)
f.seek(0)
ret = runner.invoke(app, args=["--model-size", "tiny", "--filename", f.name])
documents: list[Document] = []
for knp_text in chunk_by_document(io.StringIO(ret.stdout)):
documents.append(Document.from_knp(knp_text))
assert len(documents) == 2
assert documents[0].text == "KWJAは日本語の統合解析ツールです。汎用言語モデルを利用し、様々な言語解析を統一的な方法で解いています。"
assert documents[1].text == (
"計算機による言語理解を実現するためには、計算機に常識・世界知識を与える必要があります。10年前にはこれは非常に難しい問題でしたが、"
+ "近年の計算機パワー、計算機ネットワークの飛躍的進展によって計算機が超大規模テキストを取り扱えるようになり、そこから常識を"
+ "自動獲得することが少しずつ可能になってきました。"
)


def test_task_input():
Expand Down

0 comments on commit d46fd27

Please sign in to comment.