diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml new file mode 100644 index 0000000..e7f3d7f --- /dev/null +++ b/.github/workflows/publish-package.yml @@ -0,0 +1,26 @@ +name: publish package + +on: + release: + types: [ created ] + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.8' + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools wheel + python -m pip install twine + - name: Build and publish + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + python setup.py sdist bdist_wheel + twine upload dist/* diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml deleted file mode 100644 index 9993084..0000000 --- a/.github/workflows/python-publish.yml +++ /dev/null @@ -1,31 +0,0 @@ -# This workflows will upload a Python Package using Twine when a release is created -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries - -name: Upload Python Package - -on: - release: - types: [created] - -jobs: - deploy: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.8' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install setuptools wheel twine - - name: Build and publish - env: - TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} - TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} - run: | - python setup.py sdist bdist_wheel - twine upload dist/* diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index db1e5bc..0717269 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,23 +1,27 @@ -name: Unit Tests +name: unit tests -on: [push] +on: + workflow_dispatch: + push: + schedule: + - cron: "0 21 * * 6" jobs: build: - runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: '3.8' - name: Install dependencies run: | - python -m pip install --upgrade pip - python -m pip install torch - python -m pip install -e '.[dev]' + python -m pip install pip --upgrade + python -m pip install -r requirements.txt + python -m pip install pytest hypothesis torchnyan + python -m pip install git+https://github.com/speedcell4/torchrua.git@develop --force-reinstall --no-deps + python -m pip install pytorch-crf torch-struct - name: Test with pytest run: | python -m pytest tests \ No newline at end of file diff --git a/README.md b/README.md index 1186c7b..3c63b53 100644 --- a/README.md +++ b/README.md @@ -1,111 +1,24 @@ -# TorchLatent +
-![Unit Tests](https://github.com/speedcell4/torchlatent/workflows/Unit%20Tests/badge.svg) -[![PyPI version](https://badge.fury.io/py/torchlatent.svg)](https://badge.fury.io/py/torchlatent) -[![Downloads](https://pepy.tech/badge/torchrua)](https://pepy.tech/project/torchrua) +# TorchLatent -## Requirements +![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/speedcell4/torchlatent/unit-tests.yml?cacheSeconds=0) +![PyPI - Version](https://img.shields.io/pypi/v/torchlatent?label=pypi%20version&cacheSeconds=0) +![PyPI - Downloads](https://img.shields.io/pypi/dm/torchlatent?cacheSeconds=0) -- Python 3.8 -- PyTorch 1.10.2 +
## Installation -`python3 -m pip torchlatent` - -## Performance - -``` -TorchLatent (0.109244) => 0.003781 0.017763 0.087700 0.063497 -Third (0.232487) => 0.103277 0.129209 0.145311 -``` - -## Usage - -```python -import torch -from torchrua import pack_sequence - -from torchlatent.crf import CrfDecoder - -num_tags = 3 -num_conjugates = 1 - -decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) - -emissions = pack_sequence([ - torch.randn((5, num_conjugates, num_tags), requires_grad=True), - torch.randn((2, num_conjugates, num_tags), requires_grad=True), - torch.randn((3, num_conjugates, num_tags), requires_grad=True), -]) - -tags = pack_sequence([ - torch.randint(0, num_tags, (5, num_conjugates)), - torch.randint(0, num_tags, (2, num_conjugates)), - torch.randint(0, num_tags, (3, num_conjugates)), -]) - -print(decoder.fit(emissions=emissions, tags=tags)) -# tensor([[-6.7424], -# [-5.1288], -# [-2.7283]], grad_fn=) - -print(decoder.decode(emissions=emissions)) -# PackedSequence(data=tensor([[2], -# [0], -# [1], -# [0], -# [2], -# [0], -# [2], -# [0], -# [1], -# [2]]), -# batch_sizes=tensor([3, 3, 2, 1, 1]), -# sorted_indices=tensor([0, 2, 1]), -# unsorted_indices=tensor([0, 2, 1])) - -print(decoder.marginals(emissions=emissions)) -# tensor([[[0.1040, 0.1001, 0.7958]], -# -# [[0.5736, 0.0784, 0.3479]], -# -# [[0.0932, 0.8797, 0.0271]], -# -# [[0.6558, 0.0472, 0.2971]], -# -# [[0.2740, 0.1109, 0.6152]], -# -# [[0.4811, 0.2163, 0.3026]], -# -# [[0.2321, 0.3478, 0.4201]], -# -# [[0.4987, 0.1986, 0.3027]], -# -# [[0.2029, 0.5888, 0.2083]], -# -# [[0.2802, 0.2358, 0.4840]]], grad_fn=) -``` +`python -m pip torchlatent` ## Latent Structures -- [ ] Conditional Random Fields (CRF) - - [x] Conjugated - - [ ] Dynamic Transition Matrix - - [ ] Second-order - - [ ] Variant-order -- [ ] Tree CRF +- [x] Conditional Random Fields (CRF) +- [x] Cocke–Kasami-Younger Algorithm (CKY) +- [ ] Probabilistic Context-free Grammars (CFG) +- [ ] Connectionist Temporal Classification (CTC) +- [ ] Recurrent Neural Network Grammars (RNNG) - [ ] Non-Projective Dependency Tree (Matrix-tree Theorem) -- [ ] Probabilistic Context-free Grammars (PCFG) - [ ] Dependency Model with Valence (DMV) - -## Citation - -``` -@misc{wang2020torchlatent, - title={TorchLatent: High Performance Structured Prediction in PyTorch}, - author={Yiran Wang}, - year={2020}, - howpublished = "\url{https://github.com/speedcell4/torchlatent}" -} -``` +- [ ] Autoregressive Decoding (Beam Search) diff --git a/benchmark/__init__.py b/benchmark/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/benchmark/__main__.py b/benchmark/__main__.py deleted file mode 100644 index 351d087..0000000 --- a/benchmark/__main__.py +++ /dev/null @@ -1,9 +0,0 @@ -from aku import Aku - -from benchmark.crf import benchmark_crf - -aku = Aku() - -aku.option(benchmark_crf) - -aku.run() diff --git a/benchmark/crf.py b/benchmark/crf.py deleted file mode 100644 index a74dce0..0000000 --- a/benchmark/crf.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch -from torchrua import pack_sequence -from tqdm import tqdm - -from benchmark.meter import TimeMeter -from tests.third_party import ThirdPartyCrfDecoder -from torchlatent.crf import CrfDecoder - - -def benchmark_crf(num_tags: int = 50, num_conjugates: int = 1, num_runs: int = 100, - batch_size: int = 32, max_token_size: int = 512): - j1, f1, b1, d1, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() - j2, f2, b2, d2, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() - - if torch.cuda.is_available(): - device = torch.device('cuda:0') - else: - device = torch.device('cpu') - print(f'device => {device}') - - decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) - print(f'decoder => {decoder}') - - third_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) - print(f'third_decoder => {third_decoder}') - - for _ in tqdm(range(num_runs)): - token_sizes = torch.randint(1, max_token_size + 1, (batch_size,), device=device).detach().cpu().tolist() - - emissions = pack_sequence([ - torch.randn((token_size, num_conjugates, num_tags), device=device, requires_grad=True) - for token_size in token_sizes - ]) - - tags = pack_sequence([ - torch.randint(0, num_tags, (token_size, num_conjugates), device=device) - for token_size in token_sizes - ]) - - with j1: - indices = decoder.compile_indices(emissions=emissions, tags=tags) - - with f1: - loss = decoder.fit(emissions=emissions, tags=tags, indices=indices).neg().mean() - - with b1: - _, torch.autograd.grad(loss, emissions.data, torch.ones_like(loss)) - - with d1: - _ = decoder.decode(emissions=emissions, indices=indices) - - with f2: - loss = third_decoder.fit(emissions=emissions, tags=tags).neg().mean() - - with b2: - _, torch.autograd.grad(loss, emissions.data, torch.ones_like(loss)) - - with d2: - _ = third_decoder.decode(emissions=emissions) - - print(f'TorchLatent ({j1.merit + f1.merit + b1.merit:.6f}) => {j1} {f1} {b1} {d1}') - print(f'Third ({j2.merit + f2.merit + b2.merit:.6f}) => {j2} {f2} {b2} {d2}') diff --git a/benchmark/meter.py b/benchmark/meter.py deleted file mode 100644 index 64e1c6e..0000000 --- a/benchmark/meter.py +++ /dev/null @@ -1,23 +0,0 @@ -from datetime import datetime - - -class TimeMeter(object): - def __init__(self) -> None: - super(TimeMeter, self).__init__() - - self.seconds = 0 - self.counts = 0 - - def __enter__(self): - self.start_tm = datetime.now() - - def __exit__(self, exc_type, exc_val, exc_tb): - self.seconds += (datetime.now() - self.start_tm).total_seconds() - self.counts += 1 - - @property - def merit(self) -> float: - return self.seconds / max(1, self.counts) - - def __repr__(self) -> str: - return f'{self.merit :.6f}' diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9443809 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +torch +torchrua diff --git a/setup.py b/setup.py index 6f8a71b..ba04b02 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,16 @@ -from setuptools import setup, find_packages +from pathlib import Path + +from setuptools import find_packages, setup name = 'torchlatent' +root_dir = Path(__file__).parent.resolve() +with (root_dir / 'requirements.txt').open(mode='r', encoding='utf-8') as fp: + install_requires = [install_require.strip() for install_require in fp] + setup( name=name, - version='0.4.2', + version='0.4.3', packages=[package for package in find_packages() if package.startswith(name)], url='https://github.com/speedcell4/torchlatent', license='MIT', @@ -12,16 +18,5 @@ author_email='speedcell4@gmail.com', description='High Performance Structured Prediction in PyTorch', python_requires='>=3.8', - install_requires=[ - 'numpy', - 'torchrua>=0.4.0', - ], - extras_require={ - 'dev': [ - 'einops', - 'pytest', - 'hypothesis', - 'pytorch-crf', - ], - } + install_requires=install_requires, ) diff --git a/tests/strategies.py b/tests/strategies.py deleted file mode 100644 index 785f929..0000000 --- a/tests/strategies.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch - -from hypothesis import strategies as st - -TINY_BATCH_SIZE = 6 -TINY_TOKEN_SIZE = 12 - -BATCH_SIZE = 24 -TOKEN_SIZE = 50 -NUM_TAGS = 8 -NUM_CONJUGATES = 5 - - -@st.composite -def devices(draw): - if not torch.cuda.is_available(): - device = torch.device('cpu') - else: - device = torch.device('cuda:0') - _ = torch.empty((1,), device=device) - return device - - -@st.composite -def sizes(draw, *size: int, min_size: int = 1): - max_size, *size = size - - if len(size) == 0: - return draw(st.integers(min_value=min_size, max_value=max_size)) - else: - return [ - draw(sizes(*size, min_size=min_size)) - for _ in range(draw(st.integers(min_value=min_size, max_value=max_size))) - ] diff --git a/tests/test_cky.py b/tests/test_cky.py new file mode 100644 index 0000000..e627e8e --- /dev/null +++ b/tests/test_cky.py @@ -0,0 +1,104 @@ +import torch +from hypothesis import given, settings, strategies as st +from torch_struct import TreeCRF +from torchnyan import BATCH_SIZE, TINY_TOKEN_SIZE, assert_close, assert_grad_close, device, sizes +from torchrua import C + +from torchlatent.cky import CkyDecoder, cky_partitions, cky_scores +from torchlatent.semiring import Log + + +def get_argmax(cky): + argmax = cky.argmax + mask = argmax > 0 + + _, t, _, n = mask.size() + index = torch.arange(t, device=mask.device) + x = torch.masked_select(index[None, :, None, None], mask=mask) + y = torch.masked_select(index[None, None, :, None], mask=mask) + + index = torch.arange(n, device=mask.device) + z = torch.masked_select(index[None, None, None, :], mask=mask) + + return argmax, x, y, z + + +@settings(deadline=None) +@given( + token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), + num_targets=sizes(TINY_TOKEN_SIZE), + rua_targets=st.sampled_from([C.cat, C.pad, C.pack]), +) +def test_cky_scores(token_sizes, num_targets, rua_targets): + emissions = torch.randn( + (len(token_sizes), max(token_sizes), max(token_sizes), num_targets), + device=device, requires_grad=True, + ) + token_sizes = torch.tensor(token_sizes, device=device) + expected_cky = TreeCRF(emissions, lengths=token_sizes) + + argmax, x, y, z = get_argmax(expected_cky) + + emissions = torch.randn_like(emissions, requires_grad=True) + + expected_cky = TreeCRF(emissions, lengths=token_sizes) + expected = expected_cky.log_prob(argmax) + expected_cky.partition + + targets = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) + actual = cky_scores( + emissions=C(emissions, token_sizes), + targets=rua_targets(targets), + semiring=Log, + ) + + assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=(emissions,)) + + +@settings(deadline=None) +@given( + token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), + num_targets=sizes(TINY_TOKEN_SIZE), +) +def test_cky_partitions(token_sizes, num_targets): + emissions = torch.randn( + (len(token_sizes), max(token_sizes), max(token_sizes), num_targets), + device=device, requires_grad=True, + ) + token_sizes = torch.tensor(token_sizes, device=device) + + expected = TreeCRF(emissions, lengths=token_sizes).partition + + actual_emissions = C( + data=emissions.logsumexp(dim=-1), + token_sizes=token_sizes, + ) + actual = cky_partitions(actual_emissions, Log) + + assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=(emissions,)) + + +@settings(deadline=None) +@given( + token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), + num_targets=sizes(TINY_TOKEN_SIZE), +) +def test_cky_argmax(token_sizes, num_targets): + emissions = torch.randn( + (len(token_sizes), max(token_sizes), max(token_sizes), num_targets), + device=device, requires_grad=True, + ) + token_sizes = torch.tensor(token_sizes, device=device) + + expected_cky = TreeCRF(emissions, lengths=token_sizes) + + _, x, y, z = get_argmax(expected_cky) + + expected = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) + + actual_cky = CkyDecoder(num_targets=num_targets) + actual = actual_cky(emissions=C(emissions, token_sizes)).argmax + + for actual, expected in zip(actual.tolist(), expected.tolist()): + assert set(map(tuple, actual)) == set(map(tuple, expected)) diff --git a/tests/test_crf.py b/tests/test_crf.py index 2164a7b..f148245 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -1,117 +1,111 @@ import torch -from hypothesis import given -from torchrua import pack_sequence, cat_sequence, pack_catted_sequence +from hypothesis import given, settings, strategies as st +from torchcrf import CRF +from torchnyan import BATCH_SIZE, TOKEN_SIZE, assert_close, assert_grad_close, assert_sequence_close, device, sizes +from torchrua import C, D, P -from tests.strategies import devices, sizes, BATCH_SIZE, TOKEN_SIZE, NUM_CONJUGATES, NUM_TAGS -from tests.third_party import ThirdPartyCrfDecoder -from tests.utils import assert_close, assert_grad_close, assert_packed_sequence_equal -from torchlatent.crf import CrfDecoder +from torchlatent.crf import CrfDecoder, crf_partitions, crf_scores +from torchlatent.semiring import Log +@settings(deadline=None) @given( - device=devices(), token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_conjugate=sizes(NUM_CONJUGATES), - num_tags=sizes(NUM_TAGS), + num_targets=sizes(TOKEN_SIZE), + rua_emissions=st.sampled_from([C.new, D.new, P.new]), + rua_targets=st.sampled_from([C.new, D.new, P.new]), ) -def test_crf_packed_fit(device, token_sizes, num_conjugate, num_tags): - emissions = pack_sequence([ - torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) +def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): + inputs = [ + torch.randn((token_size, num_targets), device=device, requires_grad=True) for token_size in token_sizes - ], device=device) + ] - tags = pack_sequence([ - torch.randint(0, num_tags, (token_size, num_conjugate), device=device) + targets = [ + torch.randint(0, num_targets, (token_size,), device=device) for token_size in token_sizes - ], device=device) - - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder.reset_parameters_with_(decoder=actual_decoder) - - actual = actual_decoder.fit(emissions=emissions, tags=tags) - expected = expected_decoder.fit(emissions=emissions, tags=tags) - - assert_close(actual=actual, expected=expected) - assert_grad_close(actual=actual, expected=expected, inputs=(emissions.data,)) + ] + expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device) -@given( - device=devices(), - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_conjugate=sizes(NUM_CONJUGATES), - num_tags=sizes(NUM_TAGS), -) -def test_crf_packed_decode(device, token_sizes, num_conjugate, num_tags): - emissions = pack_sequence([ - torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) - for token_size in token_sizes - ], device=device) + expected_emissions = D.new(inputs) + expected_tags = D.new(targets) - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder.reset_parameters_with_(decoder=actual_decoder) + expected = expected_crf._compute_score( + expected_emissions.data.transpose(0, 1), + expected_tags.data.transpose(0, 1), + expected_emissions.mask().transpose(0, 1), + ) - expected = expected_decoder.decode(emissions=emissions) - actual = actual_decoder.decode(emissions=emissions) + actual = crf_scores( + emissions=rua_emissions(inputs), + targets=rua_targets(targets), + transitions=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions), + semiring=Log, + ) - assert_packed_sequence_equal(actual=actual, expected=expected) + assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=inputs) +@settings(deadline=None) @given( - device=devices(), token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_conjugate=sizes(NUM_CONJUGATES), - num_tags=sizes(NUM_TAGS), + num_targets=sizes(TOKEN_SIZE), + rua_emissions=st.sampled_from([C.new, D.new, P.new]), ) -def test_crf_catted_fit(device, token_sizes, num_conjugate, num_tags): - emissions = [ - torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) - for token_size in token_sizes - ] - tags = [ - torch.randint(0, num_tags, (token_size, num_conjugate), device=device) +def test_crf_partitions(token_sizes, num_targets, rua_emissions): + inputs = [ + torch.randn((token_size, num_targets), device=device, requires_grad=True) for token_size in token_sizes ] - catted_emissions = cat_sequence(emissions, device=device) - packed_emissions = pack_sequence(emissions, device=device) + expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device) - catted_tags = cat_sequence(tags, device=device) - packed_tags = pack_sequence(tags, device=device) + expected_emissions = D.new(inputs) - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder.reset_parameters_with_(decoder=actual_decoder) + expected = expected_crf._compute_normalizer( + expected_emissions.data.transpose(0, 1), + expected_emissions.mask().t(), + ) - actual = actual_decoder.fit(emissions=catted_emissions, tags=catted_tags) - expected = expected_decoder.fit(emissions=packed_emissions, tags=packed_tags) + actual = crf_partitions( + emissions=rua_emissions(inputs), + transitions=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions), + semiring=Log, + ) - assert_close(actual=actual, expected=expected) - assert_grad_close(actual=actual, expected=expected, inputs=tuple(emissions)) + assert_close(actual=actual, expected=expected, rtol=1e-4, atol=1e-4) + assert_grad_close(actual=actual, expected=expected, inputs=inputs, rtol=1e-4, atol=1e-4) +@settings(deadline=None) @given( - device=devices(), token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_conjugate=sizes(NUM_CONJUGATES), - num_tags=sizes(NUM_TAGS), + num_targets=sizes(TOKEN_SIZE), + rua_emissions=st.sampled_from([C.new, D.new, P.new]), ) -def test_crf_catted_decode(device, token_sizes, num_conjugate, num_tags): - emissions = [ - torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) +def test_crf_argmax(token_sizes, num_targets, rua_emissions): + inputs = [ + torch.randn((token_size, num_targets), device=device, requires_grad=True) for token_size in token_sizes ] - catted_emissions = cat_sequence(emissions, device=device) - packed_emissions = pack_sequence(emissions, device=device) + expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device) + + expected_emissions = D.new(inputs) + + expected = expected_crf.decode( + expected_emissions.data.transpose(0, 1), + expected_emissions.mask().t(), + ) + expected = C.new([torch.tensor(tensor, device=device) for tensor in expected]) - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder.reset_parameters_with_(decoder=actual_decoder) + actual_crf = CrfDecoder(num_targets=num_targets) + actual_crf.transitions = expected_crf.transitions + actual_crf.head_transitions = expected_crf.start_transitions + actual_crf.last_transitions = expected_crf.end_transitions - expected = expected_decoder.decode(emissions=packed_emissions) - actual = actual_decoder.decode(emissions=catted_emissions) - actual = pack_catted_sequence(*actual, device=device) + actual = actual_crf(rua_emissions(inputs)).argmax.cat() - assert_packed_sequence_equal(actual=actual, expected=expected) + assert_sequence_close(actual=actual, expected=expected) diff --git a/tests/test_functional.py b/tests/test_functional.py index c12b74a..2686b68 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,16 +1,16 @@ import torch -from hypothesis import given, strategies as st +from hypothesis import given, settings, strategies as st +from torchnyan.assertion import assert_close, assert_grad_close +from torchnyan.strategy import TINY_BATCH_SIZE, TINY_TOKEN_SIZE, device, sizes -from tests.strategies import devices, sizes, TINY_TOKEN_SIZE, TINY_BATCH_SIZE -from tests.utils import assert_close, assert_grad_close from torchlatent.functional import logaddexp, logsumexp +@settings(deadline=None) @given( - device=devices(), token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE) ) -def test_logaddexp(device, token_sizes): +def test_logaddexp(token_sizes): x = torch.randn(token_sizes, device=device, requires_grad=True) y = torch.randn(token_sizes, device=device, requires_grad=True) @@ -21,12 +21,12 @@ def test_logaddexp(device, token_sizes): assert_grad_close(actual=actual, expected=expected, inputs=(x, y)) +@settings(deadline=None) @given( data=st.data(), - device=devices(), token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE) ) -def test_logsumexp(data, device, token_sizes): +def test_logsumexp(data, token_sizes): tensor = torch.randn(token_sizes, device=device, requires_grad=True) dim = data.draw(st.integers(min_value=-len(token_sizes), max_value=len(token_sizes) - 1)) diff --git a/tests/third_party.py b/tests/third_party.py deleted file mode 100644 index 6172f91..0000000 --- a/tests/third_party.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torchcrf -from torch import Tensor, nn -from torch.nn.utils.rnn import PackedSequence -from torch.types import Device -from torchrua import pad_catted_indices, pad_packed_sequence, pack_sequence - -from torchlatent.crf import CrfDecoder - - -@torch.no_grad() -def token_sizes_to_mask(sizes: Tensor, batch_first: bool, device: Device = None) -> Tensor: - if device is None: - device = sizes.device - - size, ptr = pad_catted_indices(sizes, batch_first=batch_first, device=device) - mask = torch.zeros(size, device=device, dtype=torch.bool) - mask[ptr] = True - return mask - - -class ThirdPartyCrfDecoder(nn.Module): - def __init__(self, num_tags: int, num_conjugates: int) -> None: - super(ThirdPartyCrfDecoder, self).__init__() - self.num_tags = num_tags - self.num_conjugates = num_conjugates - - self.decoders = nn.ModuleList([ - torchcrf.CRF(num_tags=num_tags, batch_first=False) - for _ in range(num_conjugates) - ]) - - @torch.no_grad() - def reset_parameters_with_(self, decoder: CrfDecoder) -> None: - assert self.num_tags == decoder.num_tags - assert self.num_conjugates == decoder.num_conjugates - - for index in range(self.num_conjugates): - self.decoders[index].transitions.data[::] = decoder.transitions[:, index, :, :] - self.decoders[index].start_transitions.data[::] = decoder.head_transitions[:, index, :] - self.decoders[index].end_transitions.data[::] = decoder.last_transitions[:, index, :] - - def fit(self, emissions: PackedSequence, tags: PackedSequence, **kwargs) -> Tensor: - num_emissions_conjugates = emissions.data.size()[1] - num_decoders_conjugates = self.num_conjugates - num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) - - emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) - tags, _ = pad_packed_sequence(tags, batch_first=False) - mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False) - - log_probs = [] - for index in range(num_conjugates): - decoder = self.decoders[index % num_decoders_conjugates] - emission = emissions[:, :, index % num_emissions_conjugates] - tag = tags[:, :, index % num_emissions_conjugates] - - log_probs.append(decoder(emissions=emission, tags=tag, mask=mask, reduction='none')) - - return torch.stack(log_probs, dim=-1) - - def decode(self, emissions: PackedSequence, **kwargs) -> PackedSequence: - num_emissions_conjugates = emissions.data.size()[1] - num_decoders_conjugates = self.num_conjugates - num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) - - emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) - mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False) - - predictions = [] - for index in range(num_conjugates): - decoder = self.decoders[index % num_decoders_conjugates] - emission = emissions[:, :, index % num_emissions_conjugates] - - prediction = decoder.decode(emissions=emission, mask=mask) - predictions.append(pack_sequence([torch.tensor(p) for p in prediction], device=emissions.device)) - - return PackedSequence( - torch.stack([prediction.data for prediction in predictions], dim=1), - batch_sizes=predictions[0].batch_sizes, - sorted_indices=predictions[0].sorted_indices, - unsorted_indices=predictions[0].unsorted_indices, - ) diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index 3040db4..0000000 --- a/tests/utils.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import List, Tuple, Union - -import torch -from torch import Tensor -from torch.nn.utils.rnn import PackedSequence -from torch.testing import assert_close -from torchrua.catting import CattedSequence - -__all__ = [ - 'assert_equal', 'assert_close', 'assert_grad_close', - 'assert_catted_sequence_equal', 'assert_catted_sequence_close', - 'assert_packed_sequence_equal', 'assert_packed_sequence_close', -] - - -def assert_equal(actual: Tensor, expected: Tensor) -> None: - assert torch.equal(actual, expected) - - -def assert_grad_close( - actual: Tensor, expected: Tensor, - inputs: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], - allow_unused: bool = False, - check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None: - kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride) - - grad = torch.rand_like(actual) - - actual_grads = torch.autograd.grad( - actual, inputs, grad, - create_graph=False, - allow_unused=allow_unused, - ) - - expected_grads = torch.autograd.grad( - expected, inputs, grad, - create_graph=False, - allow_unused=allow_unused, - ) - - for actual_grad, expected_grad in zip(actual_grads, expected_grads): - assert_close(actual=actual_grad, expected=expected_grad, **kwargs) - - -def assert_catted_sequence_close( - actual: CattedSequence, expected: CattedSequence, - check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None: - kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride) - - assert_close(actual=actual.data, expected=expected.data, **kwargs) - assert_equal(actual=actual.token_sizes, expected=expected.token_sizes) - - -def assert_catted_sequence_equal(actual: CattedSequence, expected: CattedSequence) -> None: - assert_equal(actual=actual.data, expected=expected.data) - assert_equal(actual=actual.token_sizes, expected=expected.token_sizes) - - -def assert_packed_sequence_close( - actual: PackedSequence, expected: PackedSequence, - check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None: - kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride) - - assert_close(actual=actual.data, expected=expected.data, **kwargs) - assert_equal(actual=actual.batch_sizes, expected=expected.batch_sizes) - - if actual.sorted_indices is None: - assert expected.sorted_indices is None - else: - assert_equal(actual=actual.sorted_indices, expected=expected.sorted_indices) - - if actual.unsorted_indices is None: - assert expected.unsorted_indices is None - else: - assert_equal(actual=actual.unsorted_indices, expected=expected.unsorted_indices) - - -def assert_packed_sequence_equal(actual: PackedSequence, expected: PackedSequence) -> None: - assert_equal(actual=actual.data, expected=expected.data) - assert_equal(actual=actual.batch_sizes, expected=expected.batch_sizes) - - if actual.sorted_indices is None: - assert expected.sorted_indices is None - else: - assert_equal(actual=actual.sorted_indices, expected=expected.sorted_indices) - - if actual.unsorted_indices is None: - assert expected.unsorted_indices is None - else: - assert_equal(actual=actual.unsorted_indices, expected=expected.unsorted_indices) diff --git a/torchlatent/__init__.py b/torchlatent/__init__.py index e69de29..8203d58 100644 --- a/torchlatent/__init__.py +++ b/torchlatent/__init__.py @@ -0,0 +1 @@ +from torchlatent.crf import CrfDecoder, CrfDistribution diff --git a/torchlatent/abc.py b/torchlatent/abc.py new file mode 100644 index 0000000..d5e088d --- /dev/null +++ b/torchlatent/abc.py @@ -0,0 +1,60 @@ +from abc import ABCMeta +from typing import Union + +import torch +import torch.autograd +from torch import Tensor, nn +from torch.distributions.utils import lazy_property +from torchrua import C, D, P + + +class StructuredDistribution(object, metaclass=ABCMeta): + def __init__(self, emissions: Union[C, D, P]) -> None: + super(StructuredDistribution, self).__init__() + self.emissions = emissions + + def log_scores(self, targets: Union[C, D, P]) -> Tensor: + raise NotImplementedError + + def log_probs(self, targets: Union[C, D, P]) -> Tensor: + return self.log_scores(targets=targets) - self.log_partitions + + @lazy_property + def log_partitions(self) -> Tensor: + raise NotImplementedError + + @lazy_property + def marginals(self) -> Tensor: + grad, = torch.autograd.grad( + self.log_partitions, self.emissions.data, torch.ones_like(self.log_partitions), + create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True, + + ) + return grad + + @lazy_property + def max(self) -> Tensor: + raise NotImplementedError + + @lazy_property + def argmax(self) -> Tensor: + grad, = torch.autograd.grad( + self.max, self.emissions.data, torch.ones_like(self.max), + create_graph=False, retain_graph=False, only_inputs=True, allow_unused=True, + ) + return grad + + +class StructuredDecoder(nn.Module): + def __init__(self, *, num_targets: int) -> None: + super(StructuredDecoder, self).__init__() + self.num_targets = num_targets + + def reset_parameters(self) -> None: + pass + + def extra_repr(self) -> str: + return f'num_targets={self.num_targets}' + + def forward(self, emissions: Union[C, D, P]) -> StructuredDistribution: + raise NotImplementedError diff --git a/torchlatent/cky.py b/torchlatent/cky.py new file mode 100644 index 0000000..b8c863d --- /dev/null +++ b/torchlatent/cky.py @@ -0,0 +1,112 @@ +from typing import Tuple, Type, Union + +import torch +from torch import Tensor +from torch.distributions.utils import lazy_property +from torchrua import C, D, P + +from torchlatent.abc import StructuredDecoder, StructuredDistribution +from torchlatent.semiring import Log, Max, Semiring + + +def cky_scores(emissions: C, targets: Union[C, D, P], semiring: Type[Semiring]) -> Tensor: + xyz, token_sizes = targets = targets.cat() + batch_ptr, _ = targets.ptr() + + emissions = emissions.data[batch_ptr, xyz[..., 0], xyz[..., 1], xyz[..., 2]] + return semiring.segment_prod(emissions, token_sizes) + + +def diag(tensor: Tensor, offset: int) -> Tensor: + return tensor.diagonal(offset=offset, dim1=1, dim2=2) + + +def diag_scatter(chart: Tensor, score: Tensor, offset: int) -> None: + chart.diagonal(offset=offset, dim1=1, dim2=2)[::] = score + + +def left(chart: Tensor, offset: int) -> Tensor: + b, t, _, *size = chart.size() + c, n, m, *stride = chart.stride() + return chart.as_strided( + size=(b, t - offset, offset, *size), + stride=(c, n + m, m, *stride), + ) + + +def right(chart: Tensor, offset: int) -> Tensor: + b, t, _, *size = chart.size() + c, n, m, *stride = chart.stride() + return chart[:, 1:, offset:].as_strided( + size=(b, t - offset, offset, *size), + stride=(c, n + m, n, *stride), + ) + + +def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: + chart = torch.full_like(emissions.data, fill_value=semiring.zero, requires_grad=False) + + diag_scatter(chart, diag(emissions.data, offset=0), offset=0) + + for w in range(1, chart.size()[1]): + score = semiring.sum(semiring.mul(left(chart, offset=w), right(chart, offset=w)), dim=2) + diag_scatter(chart, semiring.mul(score, diag(emissions.data, offset=w)), offset=w) + + index = torch.arange(chart.size()[0], dtype=torch.long, device=chart.device) + return chart[index, 0, emissions.token_sizes - 1] + + +def masked_select(mask: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + _, t, _, n = mask.size() + + index = torch.arange(t, device=mask.device) + x = torch.masked_select(index[None, :, None, None], mask=mask) + y = torch.masked_select(index[None, None, :, None], mask=mask) + + index = torch.arange(n, device=mask.device) + z = torch.masked_select(index[None, None, None, :], mask=mask) + + return x, y, z + + +class CkyDistribution(StructuredDistribution): + def __init__(self, emissions: C) -> None: + super(CkyDistribution, self).__init__(emissions=emissions) + + def log_scores(self, targets: Union[C, D, P]) -> Tensor: + return cky_scores( + emissions=self.emissions, targets=targets, + semiring=Log, + ) + + @lazy_property + def log_partitions(self) -> Tensor: + return cky_partitions( + emissions=self.emissions._replace(data=Log.sum(self.emissions.data, dim=-1)), + semiring=Log, + ) + + @lazy_property + def max(self) -> Tensor: + return cky_partitions( + emissions=self.emissions._replace(data=Max.sum(self.emissions.data, dim=-1)), + semiring=Max, + ) + + @lazy_property + def argmax(self) -> C: + argmax = super(CkyDistribution, self).argmax + x, y, z = masked_select(argmax > 0) + + return C( + data=torch.stack([x, y, z], dim=-1), + token_sizes=self.emissions.token_sizes * 2 - 1, + ) + + +class CkyDecoder(StructuredDecoder): + def __init__(self, *, num_targets: int) -> None: + super(CkyDecoder, self).__init__(num_targets=num_targets) + + def forward(self, emissions: C) -> CkyDistribution: + return CkyDistribution(emissions=emissions) diff --git a/torchlatent/crf.py b/torchlatent/crf.py new file mode 100644 index 0000000..83d1b09 --- /dev/null +++ b/torchlatent/crf.py @@ -0,0 +1,113 @@ +from typing import Tuple, Type, Union + +import torch +from torch import Tensor, nn +from torch.distributions.utils import lazy_property +from torch.nn import init +from torchrua import C, D, P + +from torchlatent.abc import StructuredDecoder, StructuredDistribution +from torchlatent.semiring import Log, Max, Semiring + +T = Tuple[Tensor, Tensor, Tensor] + + +def crf_scores(emissions: Union[C, D, P], targets: Union[C, D, P], transitions: T, semiring: Type[Semiring]) -> Tensor: + transitions, head_transitions, last_transitions = transitions + + targets = _, token_sizes = targets.cat() + head_transitions = targets.head().rua(head_transitions) + last_transitions = targets.last().rua(last_transitions) + transitions = targets.data.roll(1).rua(transitions, targets) + + emissions, _ = emissions.idx().cat().rua(emissions, targets) + emissions = semiring.segment_prod(emissions, sizes=token_sizes) + + token_sizes = torch.stack([torch.ones_like(token_sizes), token_sizes - 1], dim=-1) + transitions = semiring.segment_prod(transitions, sizes=token_sizes.view(-1))[1::2] + + return semiring.mul( + semiring.mul(head_transitions, last_transitions), + semiring.mul(emissions, transitions), + ) + + +def crf_partitions(emissions: Union[C, D, P], transitions: T, semiring: Type[Semiring]) -> Tensor: + transitions, head_transitions, last_transitions = transitions + + emissions = emissions.pack() + last_indices = emissions.idx().last() + emissions, batch_sizes, _, _ = emissions + + _, *batch_sizes = sections = batch_sizes.detach().cpu().tolist() + emission, *emissions = torch.split(emissions, sections, dim=0) + + charts = [semiring.mul(head_transitions, emission)] + for emission, batch_size in zip(emissions, batch_sizes): + charts.append(semiring.mul( + semiring.bmm(charts[-1][:batch_size], transitions), + emission, + )) + + emission = torch.cat(charts, dim=0)[last_indices] + return semiring.sum(semiring.mul(emission, last_transitions), dim=-1) + + +class CrfDistribution(StructuredDistribution): + def __init__(self, emissions: Union[C, D, P], transitions: T) -> None: + super(CrfDistribution, self).__init__(emissions=emissions) + self.transitions = transitions + + def log_scores(self, targets: Union[C, D, P]) -> Tensor: + return crf_scores( + emissions=self.emissions, targets=targets, + transitions=self.transitions, + semiring=Log, + ) + + @lazy_property + def log_partitions(self) -> Tensor: + return crf_partitions( + emissions=self.emissions, + transitions=self.transitions, + semiring=Log, + ) + + @lazy_property + def max(self) -> Tensor: + return crf_partitions( + emissions=self.emissions, + transitions=self.transitions, + semiring=Max, + ) + + @lazy_property + def argmax(self) -> Union[C, D, P]: + argmax = super(CrfDistribution, self).argmax.argmax(dim=-1) + return self.emissions._replace(data=argmax) + + +class CrfDecoder(StructuredDecoder): + def __init__(self, *, num_targets: int) -> None: + super(CrfDecoder, self).__init__(num_targets=num_targets) + + self.transitions = nn.Parameter(torch.empty((num_targets, num_targets))) + self.head_transitions = nn.Parameter(torch.empty((num_targets,))) + self.last_transitions = nn.Parameter(torch.empty((num_targets,))) + + self.reset_parameters() + + def reset_parameters(self) -> None: + init.zeros_(self.transitions) + init.zeros_(self.head_transitions) + init.zeros_(self.last_transitions) + + def forward(self, emissions: Union[C, D, P]) -> CrfDistribution: + return CrfDistribution( + emissions=emissions, + transitions=( + self.transitions, + self.head_transitions, + self.last_transitions, + ), + ) diff --git a/torchlatent/crf/__init__.py b/torchlatent/crf/__init__.py deleted file mode 100644 index a126167..0000000 --- a/torchlatent/crf/__init__.py +++ /dev/null @@ -1,123 +0,0 @@ -from abc import ABCMeta -from typing import Optional, Tuple, Union - -import torch -from torch import Tensor -from torch import nn -from torch.nn import init -from torchrua import ReductionIndices, PackedSequence, CattedSequence -from torchrua import reduce_packed_indices, reduce_catted_indices - -from torchlatent.crf.catting import CattedCrfDistribution -from torchlatent.crf.packing import PackedCrfDistribution - -__all__ = [ - 'CrfDecoderABC', 'CrfDecoder', - 'PackedCrfDistribution', - 'CattedCrfDistribution', - 'Sequence', -] - -Sequence = Union[ - PackedSequence, - CattedSequence, -] - - -class CrfDecoderABC(nn.Module, metaclass=ABCMeta): - def __init__(self, num_tags: int, num_conjugates: int) -> None: - super(CrfDecoderABC, self).__init__() - - self.num_tags = num_tags - self.num_conjugates = num_conjugates - - def reset_parameters(self) -> None: - raise NotImplementedError - - def extra_repr(self) -> str: - return ', '.join([ - f'num_tags={self.num_tags}', - f'num_conjugates={self.num_conjugates}', - ]) - - @staticmethod - def compile_indices(emissions: Sequence, - tags: Optional[Sequence] = None, - indices: Optional[ReductionIndices] = None, **kwargs): - assert emissions.data.dim() == 3, f'{emissions.data.dim()} != {3}' - if tags is not None: - assert tags.data.dim() == 2, f'{tags.data.dim()} != {2}' - - if indices is None: - if isinstance(emissions, PackedSequence): - batch_sizes = emissions.batch_sizes.to(device=emissions.data.device) - return reduce_packed_indices(batch_sizes=batch_sizes) - - if isinstance(emissions, CattedSequence): - token_sizes = emissions.token_sizes.to(device=emissions.data.device) - return reduce_catted_indices(token_sizes=token_sizes) - - return indices - - def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: - return self.transitions, self.head_transitions, self.last_transitions - - def forward(self, emissions: Sequence, tags: Optional[Sequence] = None, - indices: Optional[ReductionIndices] = None, **kwargs): - indices = self.compile_indices(emissions=emissions, tags=tags, indices=indices) - transitions, head_transitions, last_transitions = self.obtain_parameters( - emissions=emissions, tags=tags, indices=indices, - ) - - if isinstance(emissions, PackedSequence): - dist = PackedCrfDistribution( - emissions=emissions, indices=indices, - transitions=transitions, - head_transitions=head_transitions, - last_transitions=last_transitions, - ) - return dist, tags - - if isinstance(emissions, CattedSequence): - dist = CattedCrfDistribution( - emissions=emissions, indices=indices, - transitions=transitions, - head_transitions=head_transitions, - last_transitions=last_transitions, - ) - return dist, tags - - raise TypeError(f'{type(emissions)} is not supported.') - - def fit(self, emissions: Sequence, tags: Sequence, - indices: Optional[ReductionIndices] = None, **kwargs) -> Tensor: - dist, tags = self(emissions=emissions, tags=tags, instr=indices, **kwargs) - - return dist.log_prob(tags=tags) - - def decode(self, emissions: Sequence, - indices: Optional[ReductionIndices] = None, **kwargs) -> Sequence: - dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) - return dist.argmax - - def marginals(self, emissions: Sequence, - indices: Optional[ReductionIndices] = None, **kwargs) -> Tensor: - dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) - return dist.marginals - - -class CrfDecoder(CrfDecoderABC): - def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: - super(CrfDecoder, self).__init__(num_tags=num_tags, num_conjugates=num_conjugates) - - self.transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags, self.num_tags))) - self.head_transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags))) - self.last_transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags))) - - self.reset_parameters() - - @torch.no_grad() - def reset_parameters(self, bound: float = 0.01) -> None: - init.uniform_(self.transitions, -bound, +bound) - init.uniform_(self.head_transitions, -bound, +bound) - init.uniform_(self.last_transitions, -bound, +bound) diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py deleted file mode 100644 index 474f704..0000000 --- a/torchlatent/crf/catting.py +++ /dev/null @@ -1,137 +0,0 @@ -from typing import Type - -import torch -from torch import Tensor, autograd -from torch.distributions.utils import lazy_property -from torchrua import CattedSequence -from torchrua import ReductionIndices, head_catted_indices -from torchrua import roll_catted_sequence, head_catted_sequence, last_catted_sequence - -from torchlatent.semiring import Semiring, Log, Max - -__all__ = [ - 'compute_catted_sequence_scores', - 'compute_catted_sequence_partitions', - 'CattedCrfDistribution', -] - - -def compute_catted_sequence_scores(semiring: Type[Semiring]): - def _compute_catted_sequence_scores( - emissions: CattedSequence, tags: CattedSequence, - transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor: - device = transitions.device - - emission_scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c] - - h = emissions.token_sizes.size()[0] - t = torch.arange(transitions.size()[0], device=device) # [t] - c = torch.arange(transitions.size()[1], device=device) # [c] - - x, y = roll_catted_sequence(tags, shifts=1).data, tags.data # [t, c] - head = head_catted_sequence(tags) # [h, c] - last = last_catted_sequence(tags) # [h, c] - - transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c] - transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] - transition_last_scores = last_transitions[t[:h, None], c[None, :], last] # [h, c] - - head_indices = head_catted_indices(emissions.token_sizes) - transition_scores[head_indices] = transition_head_scores # [h, c] - - batch_ptr = torch.repeat_interleave(emissions.token_sizes) - scores = semiring.mul(emission_scores, transition_scores) - scores = semiring.scatter_mul(scores, index=batch_ptr) - - scores = semiring.mul(scores, transition_last_scores) - - return scores - - return _compute_catted_sequence_scores - - -def compute_catted_sequence_partitions(semiring: Type[Semiring]): - def _compute_catted_sequence_partitions( - emissions: CattedSequence, indices: ReductionIndices, - transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor, eye: Tensor) -> Tensor: - h = emissions.token_sizes.size()[0] - t = torch.arange(transitions.size()[0], device=transitions.device) # [t] - c = torch.arange(transitions.size()[1], device=transitions.device) # [c] - head_indices = head_catted_indices(emissions.token_sizes) - - emission_scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] - emission_scores[head_indices] = eye[None, None, :, :] - emission_scores = semiring.reduce(tensor=emission_scores, indices=indices) - - emission_head_scores = emissions.data[head_indices, :, None, :] - transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] - transition_last_scores = last_transitions[t[:h, None], c[None, :], :, None] - - scores = semiring.mul(transition_head_scores, emission_head_scores) - scores = semiring.bmm(scores, emission_scores) - scores = semiring.bmm(scores, transition_last_scores)[..., 0, 0] - - return scores - - return _compute_catted_sequence_partitions - - -class CattedCrfDistribution(object): - def __init__(self, emissions: CattedSequence, indices: ReductionIndices, - transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> None: - super(CattedCrfDistribution, self).__init__() - self.emissions = emissions - self.indices = indices - - self.transitions = transitions - self.head_transitions = head_transitions - self.last_transitions = last_transitions - - def semiring_scores(self, semiring: Type[Semiring], tags: CattedSequence) -> Tensor: - return compute_catted_sequence_scores(semiring=semiring)( - emissions=self.emissions, tags=tags, - transitions=self.transitions, - head_transitions=self.head_transitions, - last_transitions=self.last_transitions, - ) - - def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: - return compute_catted_sequence_partitions(semiring=semiring)( - emissions=self.emissions, indices=self.indices, - transitions=self.transitions, - head_transitions=self.head_transitions, - last_transitions=self.last_transitions, - eye=semiring.eye_like(self.transitions), - ) - - def log_prob(self, tags: CattedSequence) -> Tensor: - return self.log_scores(tags=tags) - self.log_partitions - - def log_scores(self, tags: CattedSequence) -> Tensor: - return self.semiring_scores(semiring=Log, tags=tags) - - @lazy_property - def log_partitions(self) -> Tensor: - return self.semiring_partitions(semiring=Log) - - @lazy_property - def marginals(self) -> Tensor: - log_partitions = self.log_partitions - grad, = autograd.grad( - log_partitions, self.emissions.data, torch.ones_like(log_partitions), - create_graph=True, only_inputs=True, allow_unused=False, - ) - return grad - - @lazy_property - def argmax(self) -> CattedSequence: - max_partitions = self.semiring_partitions(semiring=Max) - - grad, = torch.autograd.grad( - max_partitions, self.emissions.data, torch.ones_like(max_partitions), - retain_graph=False, create_graph=False, allow_unused=False, - ) - return CattedSequence( - data=grad.argmax(dim=-1), - token_sizes=self.emissions.token_sizes, - ) diff --git a/torchlatent/crf/packing.py b/torchlatent/crf/packing.py deleted file mode 100644 index ec22c38..0000000 --- a/torchlatent/crf/packing.py +++ /dev/null @@ -1,143 +0,0 @@ -from typing import Type - -import torch -from torch import Tensor, autograd -from torch.distributions.utils import lazy_property -from torch.nn.utils.rnn import PackedSequence -from torchrua import head_packed_indices, ReductionIndices -from torchrua import roll_packed_sequence, head_packed_sequence, last_packed_sequence, major_sizes_to_ptr - -from torchlatent.semiring import Semiring, Log, Max - -__all__ = [ - 'compute_packed_sequence_scores', - 'compute_packed_sequence_partitions', - 'PackedCrfDistribution', -] - - -def compute_packed_sequence_scores(semiring: Type[Semiring]): - def _compute_packed_sequence_scores( - emissions: PackedSequence, tags: PackedSequence, - transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor: - device = transitions.device - - emission_scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c] - - h = emissions.batch_sizes[0].item() - t = torch.arange(transitions.size()[0], device=device) # [t] - c = torch.arange(transitions.size()[1], device=device) # [c] - - x, y = roll_packed_sequence(tags, shifts=1).data, tags.data # [t, c] - head = head_packed_sequence(tags, unsort=False) # [h, c] - last = last_packed_sequence(tags, unsort=False) # [h, c] - - transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c] - transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] - transition_last_scores = last_transitions[t[:h, None], c[None, :], last] # [h, c] - - indices = head_packed_indices(tags.batch_sizes) - transition_scores[indices] = transition_head_scores # [h, c] - - batch_ptr, _ = major_sizes_to_ptr(sizes=emissions.batch_sizes) - scores = semiring.mul(emission_scores, transition_scores) - scores = semiring.scatter_mul(scores, index=batch_ptr) - - scores = semiring.mul(scores, transition_last_scores) - - if emissions.unsorted_indices is not None: - scores = scores[emissions.unsorted_indices] - - return scores - - return _compute_packed_sequence_scores - - -def compute_packed_sequence_partitions(semiring: Type[Semiring]): - def _compute_packed_sequence_partitions( - emissions: PackedSequence, indices: ReductionIndices, - transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor, eye: Tensor) -> Tensor: - h = emissions.batch_sizes[0].item() - t = torch.arange(transitions.size()[0], device=transitions.device) # [t] - c = torch.arange(transitions.size()[1], device=transitions.device) # [c] - - emission_scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] - emission_scores[:h] = eye[None, None, :, :] - emission_scores = semiring.reduce(tensor=emission_scores, indices=indices) - - emission_head_scores = emissions.data[:h, :, None, :] - transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] - transition_last_scores = last_transitions[t[:h, None], c[None, :], :, None] - - scores = semiring.mul(transition_head_scores, emission_head_scores) - scores = semiring.bmm(scores, emission_scores) - scores = semiring.bmm(scores, transition_last_scores)[..., 0, 0] - - if emissions.unsorted_indices is not None: - scores = scores[emissions.unsorted_indices] - return scores - - return _compute_packed_sequence_partitions - - -class PackedCrfDistribution(object): - def __init__(self, emissions: PackedSequence, indices: ReductionIndices, - transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> None: - super(PackedCrfDistribution, self).__init__() - self.emissions = emissions - self.indices = indices - - self.transitions = transitions - self.head_transitions = head_transitions - self.last_transitions = last_transitions - - def semiring_scores(self, semiring: Type[Semiring], tags: PackedSequence) -> Tensor: - return compute_packed_sequence_scores(semiring=semiring)( - emissions=self.emissions, tags=tags, - transitions=self.transitions, - head_transitions=self.head_transitions, - last_transitions=self.last_transitions, - ) - - def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: - return compute_packed_sequence_partitions(semiring=semiring)( - emissions=self.emissions, indices=self.indices, - transitions=self.transitions, - head_transitions=self.head_transitions, - last_transitions=self.last_transitions, - eye=semiring.eye_like(self.transitions), - ) - - def log_prob(self, tags: PackedSequence) -> Tensor: - return self.log_scores(tags=tags) - self.log_partitions - - def log_scores(self, tags: PackedSequence) -> Tensor: - return self.semiring_scores(semiring=Log, tags=tags) - - @lazy_property - def log_partitions(self) -> Tensor: - return self.semiring_partitions(semiring=Log) - - @lazy_property - def marginals(self) -> Tensor: - log_partitions = self.log_partitions - grad, = autograd.grad( - log_partitions, self.emissions.data, torch.ones_like(log_partitions), - create_graph=True, only_inputs=True, allow_unused=False, - ) - return grad - - @lazy_property - def argmax(self) -> PackedSequence: - max_partitions = self.semiring_partitions(semiring=Max) - - grad, = torch.autograd.grad( - max_partitions, self.emissions.data, torch.ones_like(max_partitions), - retain_graph=False, create_graph=False, allow_unused=False, - ) - return PackedSequence( - data=grad.argmax(dim=-1), - batch_sizes=self.emissions.batch_sizes, - sorted_indices=self.emissions.sorted_indices, - unsorted_indices=self.emissions.unsorted_indices, - ) diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 1f6e44a..3f11e6d 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,13 +1,13 @@ import torch from torch import Tensor -from torchrua.scatter import scatter_add, scatter_max, scatter_mul, scatter_logsumexp -from torchrua.reduction import reduce_sequence, ReductionIndices +from torchrua import segment_logsumexp, segment_max, segment_prod, segment_sum -from torchlatent.functional import logsumexp, logaddexp +from torchlatent.functional import logaddexp, logsumexp __all__ = [ - 'Semiring', - 'Std', 'Log', 'Max', + 'Semiring', 'ExceptionSemiring', + 'Std', 'Log', 'Max', 'Xen', 'Div', + ] @@ -41,21 +41,17 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: raise NotImplementedError @classmethod - def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor: + def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor: raise NotImplementedError @classmethod - def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor: + def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: raise NotImplementedError @classmethod def bmm(cls, x: Tensor, y: Tensor) -> Tensor: return cls.sum(cls.mul(x[..., :, :, None], y[..., None, :, :]), dim=-2, keepdim=False) - @classmethod - def reduce(cls, tensor: Tensor, indices: ReductionIndices) -> Tensor: - return reduce_sequence(cls.bmm)(tensor=tensor, indices=indices) - class Std(Semiring): zero = 0. @@ -78,12 +74,12 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: return torch.prod(tensor, dim=dim, keepdim=keepdim) @classmethod - def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor: - return scatter_add(tensor=tensor, index=index) + def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return segment_sum(tensor, segment_sizes=sizes) @classmethod - def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor: - return scatter_mul(tensor=tensor, index=index) + def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return segment_prod(tensor, segment_sizes=sizes) class Log(Semiring): @@ -107,12 +103,12 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: return torch.sum(tensor, dim=dim, keepdim=keepdim) @classmethod - def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor: - return scatter_logsumexp(tensor=tensor, index=index) + def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return segment_logsumexp(tensor, segment_sizes=sizes) @classmethod - def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor: - return scatter_add(tensor=tensor, index=index) + def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return segment_sum(tensor, segment_sizes=sizes) class Max(Semiring): @@ -136,9 +132,77 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: return torch.sum(tensor, dim=dim, keepdim=keepdim) @classmethod - def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor: - return scatter_max(tensor=tensor, index=index) + def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return segment_max(tensor, segment_sizes=sizes) + + @classmethod + def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return segment_sum(tensor, segment_sizes=sizes) + + +class ExceptionSemiring(Semiring): + @classmethod + def sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, dim: int, keepdim: bool = False) -> Tensor: + raise NotImplementedError + + @classmethod + def segment_sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, sizes: Tensor) -> Tensor: + raise NotImplementedError + + +class Xen(ExceptionSemiring): + zero = 0. + one = 0. + + @classmethod + def add(cls, x: Tensor, y: Tensor) -> Tensor: + raise NotImplementedError + + @classmethod + def mul(cls, x: Tensor, y: Tensor) -> Tensor: + return x + y + + @classmethod + def sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return torch.sum((tensor - log_q) * log_p.exp(), dim=dim, keepdim=keepdim) + + @classmethod + def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return torch.sum(tensor, dim=dim, keepdim=keepdim) + + @classmethod + def segment_sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, sizes: Tensor) -> Tensor: + return segment_sum((tensor - log_q) * log_p.exp(), segment_sizes=sizes) + + @classmethod + def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return segment_sum(tensor, segment_sizes=sizes) + + +class Div(ExceptionSemiring): + zero = 0. + one = 0. + + @classmethod + def add(cls, x: Tensor, y: Tensor) -> Tensor: + raise NotImplementedError + + @classmethod + def mul(cls, x: Tensor, y: Tensor) -> Tensor: + return x + y + + @classmethod + def sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return torch.sum((tensor - log_q + log_p) * log_p.exp(), dim=dim, keepdim=keepdim) + + @classmethod + def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return torch.sum(tensor, dim=dim, keepdim=keepdim) + + @classmethod + def segment_sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, sizes: Tensor) -> Tensor: + return segment_sum((tensor - log_q + log_p) * log_p.exp(), segment_sizes=sizes) @classmethod - def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor: - return scatter_add(tensor=tensor, index=index) + def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return segment_sum(tensor, segment_sizes=sizes)