From 993986c3d38250c866765b151b86ff82d972b1d6 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 27 Mar 2022 00:37:38 +0900 Subject: [PATCH 001/102] Feat: Use new reduce_sequence --- torchlatent/semiring.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 1f6e44a..67adc9c 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,7 +1,7 @@ import torch from torch import Tensor -from torchrua.scatter import scatter_add, scatter_max, scatter_mul, scatter_logsumexp from torchrua.reduction import reduce_sequence, ReductionIndices +from torchrua.scatter import scatter_add, scatter_max, scatter_mul, scatter_logsumexp from torchlatent.functional import logsumexp, logaddexp @@ -54,7 +54,7 @@ def bmm(cls, x: Tensor, y: Tensor) -> Tensor: @classmethod def reduce(cls, tensor: Tensor, indices: ReductionIndices) -> Tensor: - return reduce_sequence(cls.bmm)(tensor=tensor, indices=indices) + return reduce_sequence(data=tensor, indices=indices, op=cls.bmm) class Std(Semiring): From 270f1923a8f5a28afa732581276dc9c1951b0131 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 27 Mar 2022 16:53:36 +0900 Subject: [PATCH 002/102] Feat: Add segment_add and segment_mul --- torchlatent/semiring.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 67adc9c..840a17b 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -48,6 +48,14 @@ def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor: def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor: raise NotImplementedError + @classmethod + def segment_add(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + raise NotImplementedError + + @classmethod + def segment_mul(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + raise NotImplementedError + @classmethod def bmm(cls, x: Tensor, y: Tensor) -> Tensor: return cls.sum(cls.mul(x[..., :, :, None], y[..., None, :, :]), dim=-2, keepdim=False) @@ -85,6 +93,14 @@ def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor: def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor: return scatter_mul(tensor=tensor, index=index) + @classmethod + def segment_add(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) + + @classmethod + def segment_mul(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + raise NotImplementedError + class Log(Semiring): zero = -float('inf') @@ -114,6 +130,16 @@ def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor: def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor: return scatter_add(tensor=tensor, index=index) + @classmethod + def segment_add(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + m = torch.segment_reduce(tensor, reduce='max', lengths=sizes, unsafe=True).detach() + z = (tensor - torch.repeat_interleave(m, repeats=sizes)).exp() + return torch.segment_reduce(z, reduce='sum', lengths=sizes, unsafe=True).log() + m + + @classmethod + def segment_mul(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) + class Max(Semiring): zero = -float('inf') @@ -142,3 +168,11 @@ def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor: @classmethod def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor: return scatter_add(tensor=tensor, index=index) + + @classmethod + def segment_add(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return torch.segment_reduce(tensor, reduce='max', lengths=sizes, unsafe=True) + + @classmethod + def segment_mul(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) From f88b4ef2e4ca8e27c2e26b2c567600766e10a4d2 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 27 Mar 2022 17:30:04 +0900 Subject: [PATCH 003/102] Benchmark: check CattedSequence --- benchmark/crf.py | 64 +++++++++++++++++++++++++------------ torchlatent/crf/__init__.py | 12 ++++--- torchlatent/crf/catting.py | 3 +- 3 files changed, 52 insertions(+), 27 deletions(-) diff --git a/benchmark/crf.py b/benchmark/crf.py index a74dce0..e34e79b 100644 --- a/benchmark/crf.py +++ b/benchmark/crf.py @@ -1,5 +1,5 @@ import torch -from torchrua import pack_sequence +from torchrua import pack_sequence, cat_sequence from tqdm import tqdm from benchmark.meter import TimeMeter @@ -9,8 +9,9 @@ def benchmark_crf(num_tags: int = 50, num_conjugates: int = 1, num_runs: int = 100, batch_size: int = 32, max_token_size: int = 512): - j1, f1, b1, d1, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() - j2, f2, b2, d2, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() + jit1, fwd1, bwd1, dec1, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() + jit2, fwd2, bwd2, dec2, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() + jit3, fwd3, bwd3, dec3, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() if torch.cuda.is_available(): device = torch.device('cuda:0') @@ -27,36 +28,57 @@ def benchmark_crf(num_tags: int = 50, num_conjugates: int = 1, num_runs: int = 1 for _ in tqdm(range(num_runs)): token_sizes = torch.randint(1, max_token_size + 1, (batch_size,), device=device).detach().cpu().tolist() - emissions = pack_sequence([ + catted_emissions = cat_sequence([ torch.randn((token_size, num_conjugates, num_tags), device=device, requires_grad=True) for token_size in token_sizes ]) + catted_tags = cat_sequence([ + torch.randint(0, num_tags, (token_size, num_conjugates), device=device) + for token_size in token_sizes + ]) - tags = pack_sequence([ + packed_emissions = pack_sequence([ + torch.randn((token_size, num_conjugates, num_tags), device=device, requires_grad=True) + for token_size in token_sizes + ]) + packed_tags = pack_sequence([ torch.randint(0, num_tags, (token_size, num_conjugates), device=device) for token_size in token_sizes ]) - with j1: - indices = decoder.compile_indices(emissions=emissions, tags=tags) + with jit1: + indices = decoder.compile_indices(emissions=packed_emissions, tags=packed_tags) + + with fwd1: + loss = decoder.fit(emissions=packed_emissions, tags=packed_tags, indices=indices).neg().mean() + + with bwd1: + _, torch.autograd.grad(loss, packed_emissions.data, torch.randn_like(loss)) + + with dec1: + _ = decoder.decode(emissions=packed_emissions, indices=indices) + + with jit2: + indices = decoder.compile_indices(emissions=catted_emissions, tags=catted_tags) - with f1: - loss = decoder.fit(emissions=emissions, tags=tags, indices=indices).neg().mean() + with fwd2: + loss = decoder.fit(emissions=catted_emissions, tags=catted_tags, indices=indices).neg().mean() - with b1: - _, torch.autograd.grad(loss, emissions.data, torch.ones_like(loss)) + with bwd2: + _, torch.autograd.grad(loss, catted_emissions.data, torch.randn_like(loss)) - with d1: - _ = decoder.decode(emissions=emissions, indices=indices) + with dec2: + _ = decoder.decode(emissions=catted_emissions, indices=indices) - with f2: - loss = third_decoder.fit(emissions=emissions, tags=tags).neg().mean() + with fwd3: + loss = third_decoder.fit(emissions=packed_emissions, tags=packed_tags).neg().mean() - with b2: - _, torch.autograd.grad(loss, emissions.data, torch.ones_like(loss)) + with bwd3: + _, torch.autograd.grad(loss, packed_emissions.data, torch.randn_like(loss)) - with d2: - _ = third_decoder.decode(emissions=emissions) + with dec3: + _ = third_decoder.decode(emissions=packed_emissions) - print(f'TorchLatent ({j1.merit + f1.merit + b1.merit:.6f}) => {j1} {f1} {b1} {d1}') - print(f'Third ({j2.merit + f2.merit + b2.merit:.6f}) => {j2} {f2} {b2} {d2}') + print(f'PackedLatent ({jit1.merit + fwd1.merit + bwd1.merit:.6f}) => {jit1} {fwd1} {bwd1} {dec1}') + print(f'CattedLatent ({jit2.merit + fwd2.merit + bwd2.merit:.6f}) => {jit2} {fwd2} {bwd2} {dec2}') + print(f'Third ({jit3.merit + fwd3.merit + bwd3.merit:.6f}) => {jit3} {fwd3} {bwd3} {dec3}') diff --git a/torchlatent/crf/__init__.py b/torchlatent/crf/__init__.py index a126167..fb5bf78 100644 --- a/torchlatent/crf/__init__.py +++ b/torchlatent/crf/__init__.py @@ -50,12 +50,16 @@ def compile_indices(emissions: Sequence, if indices is None: if isinstance(emissions, PackedSequence): - batch_sizes = emissions.batch_sizes.to(device=emissions.data.device) - return reduce_packed_indices(batch_sizes=batch_sizes) + return reduce_packed_indices( + batch_sizes=emissions.batch_sizes, + device=emissions.data.device, + ) if isinstance(emissions, CattedSequence): - token_sizes = emissions.token_sizes.to(device=emissions.data.device) - return reduce_catted_indices(token_sizes=token_sizes) + return reduce_catted_indices( + token_sizes=emissions.token_sizes, + device=emissions.data.device, + ) return indices diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py index 474f704..7132951 100644 --- a/torchlatent/crf/catting.py +++ b/torchlatent/crf/catting.py @@ -39,9 +39,8 @@ def _compute_catted_sequence_scores( head_indices = head_catted_indices(emissions.token_sizes) transition_scores[head_indices] = transition_head_scores # [h, c] - batch_ptr = torch.repeat_interleave(emissions.token_sizes) scores = semiring.mul(emission_scores, transition_scores) - scores = semiring.scatter_mul(scores, index=batch_ptr) + scores = semiring.segment_mul(scores, sizes=emissions.token_sizes) scores = semiring.mul(scores, transition_last_scores) From 2a2a9f63a1c9de618bd23623810dd3a4cd1461d1 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 27 Mar 2022 23:13:12 +0900 Subject: [PATCH 004/102] Refactor: Separate ThirdPartyCrfDecoder --- benchmark/crf.py | 2 +- tests/test_crf.py | 2 +- third/__init__.py | 1 + tests/third_party.py => third/crf.py | 8 +++----- 4 files changed, 6 insertions(+), 7 deletions(-) create mode 100644 third/__init__.py rename tests/third_party.py => third/crf.py (94%) diff --git a/benchmark/crf.py b/benchmark/crf.py index e34e79b..d60bc7d 100644 --- a/benchmark/crf.py +++ b/benchmark/crf.py @@ -3,7 +3,7 @@ from tqdm import tqdm from benchmark.meter import TimeMeter -from tests.third_party import ThirdPartyCrfDecoder +from third.crf import CrfDecoder as ThirdPartyCrfDecoder from torchlatent.crf import CrfDecoder diff --git a/tests/test_crf.py b/tests/test_crf.py index 2164a7b..9c3a0af 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -3,8 +3,8 @@ from torchrua import pack_sequence, cat_sequence, pack_catted_sequence from tests.strategies import devices, sizes, BATCH_SIZE, TOKEN_SIZE, NUM_CONJUGATES, NUM_TAGS -from tests.third_party import ThirdPartyCrfDecoder from tests.utils import assert_close, assert_grad_close, assert_packed_sequence_equal +from third.crf import CrfDecoder as ThirdPartyCrfDecoder from torchlatent.crf import CrfDecoder diff --git a/third/__init__.py b/third/__init__.py new file mode 100644 index 0000000..1b8004e --- /dev/null +++ b/third/__init__.py @@ -0,0 +1 @@ +from third.crf import CrfDecoder diff --git a/tests/third_party.py b/third/crf.py similarity index 94% rename from tests/third_party.py rename to third/crf.py index 6172f91..5fb565f 100644 --- a/tests/third_party.py +++ b/third/crf.py @@ -5,8 +5,6 @@ from torch.types import Device from torchrua import pad_catted_indices, pad_packed_sequence, pack_sequence -from torchlatent.crf import CrfDecoder - @torch.no_grad() def token_sizes_to_mask(sizes: Tensor, batch_first: bool, device: Device = None) -> Tensor: @@ -19,9 +17,9 @@ def token_sizes_to_mask(sizes: Tensor, batch_first: bool, device: Device = None) return mask -class ThirdPartyCrfDecoder(nn.Module): +class CrfDecoder(nn.Module): def __init__(self, num_tags: int, num_conjugates: int) -> None: - super(ThirdPartyCrfDecoder, self).__init__() + super(CrfDecoder, self).__init__() self.num_tags = num_tags self.num_conjugates = num_conjugates @@ -31,7 +29,7 @@ def __init__(self, num_tags: int, num_conjugates: int) -> None: ]) @torch.no_grad() - def reset_parameters_with_(self, decoder: CrfDecoder) -> None: + def reset_parameters_with_(self, decoder) -> None: assert self.num_tags == decoder.num_tags assert self.num_conjugates == decoder.num_conjugates From e0a216d7b51d272d6190df101d00f335c078ffeb Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 31 Mar 2022 19:04:58 +0900 Subject: [PATCH 005/102] Feat: Add Sequence type annotation --- torchlatent/types.py | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 torchlatent/types.py diff --git a/torchlatent/types.py b/torchlatent/types.py new file mode 100644 index 0000000..117cb50 --- /dev/null +++ b/torchlatent/types.py @@ -0,0 +1,6 @@ +from typing import Union + +from torch.nn.utils.rnn import PackedSequence +from torchrua import CattedSequence, PaddedSequence + +Sequence = Union[CattedSequence, PackedSequence, PaddedSequence] From 39a8cc759a0b37f8f1ee4777dfd2db9df40f460b Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 31 Mar 2022 22:48:39 +0900 Subject: [PATCH 006/102] Feat: Add cky_partition --- torchlatent/cky.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 torchlatent/cky.py diff --git a/torchlatent/cky.py b/torchlatent/cky.py new file mode 100644 index 0000000..dc83aea --- /dev/null +++ b/torchlatent/cky.py @@ -0,0 +1,58 @@ +from typing import NamedTuple, Tuple, Type + +import torch +from torch import Tensor +from torch.types import Device +from torchrua import major_sizes_to_ptr, accumulate_sizes + +from torchlatent.semiring import Semiring + + +class CkyIndices(NamedTuple): + width_size: int + cache_size: int + + src: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]] + tgt: Tuple[Tensor, Tensor] + + +@torch.no_grad() +def cky_indices(token_sizes: Tensor, device: Device = None): + if device is None: + device = token_sizes.device + + token_sizes = token_sizes.to(device=device) + acc_token_sizes = accumulate_sizes(sizes=token_sizes) + + token_ptr, batch_ptr = major_sizes_to_ptr(sizes=token_sizes) + x_ptr, z_ptr = major_sizes_to_ptr(sizes=token_ptr + 1) + batch_ptr = batch_ptr[z_ptr] + y_ptr = z_ptr - acc_token_sizes[batch_ptr] + + width_size = token_sizes.max().item() + cache_size, = token_ptr.size() + + return CkyIndices( + width_size=width_size, cache_size=cache_size, + src=((y_ptr - x_ptr, z_ptr), (batch_ptr, x_ptr, y_ptr)), + tgt=(token_sizes - 1, acc_token_sizes), + ) + + +def cky_partition(data: Tensor, indices: CkyIndices, semiring: Type[Semiring]) -> Tensor: + width_size, cache_size, (src1, src2), tgt = indices + + tensor0 = torch.full((width_size, cache_size, *data.size()[3:]), fill_value=semiring.zero, requires_grad=False) + tensor1 = torch.full((width_size, cache_size, *data.size()[3:]), fill_value=semiring.zero, requires_grad=False) + tensor2 = torch.full((width_size, cache_size, *data.size()[3:]), fill_value=semiring.zero, requires_grad=False) + + tensor0[src1] = data[src2] + tensor1[0, :] = tensor2[-1, :] = tensor0[0, :] + + for w in range(1, width_size): + tensor1[w, :-w] = tensor2[-w - 1, w:] = semiring.mul( + semiring.sum(semiring.mul(tensor1[:w, :-w], tensor2[-w:, w:]), dim=0), + tensor0[w, w:], + ) + + return tensor1[tgt] From 1061ebfafe972684c1f89b5944a64bb713785df6 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 31 Mar 2022 23:18:39 +0900 Subject: [PATCH 007/102] Feat: Add DistributionABC --- torchlatent/abc.py | 39 ++++++++++++++++++++++++++++ torchlatent/cky.py | 63 ++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 92 insertions(+), 10 deletions(-) create mode 100644 torchlatent/abc.py diff --git a/torchlatent/abc.py b/torchlatent/abc.py new file mode 100644 index 0000000..cfa15f6 --- /dev/null +++ b/torchlatent/abc.py @@ -0,0 +1,39 @@ +from abc import ABCMeta + +from torch import Tensor +from torch.distributions import Distribution +from torch.distributions.utils import lazy_property + +from torchlatent.types import Sequence + +__all__ = [ + 'DistributionABC', +] + + +class DistributionABC(Distribution, metaclass=ABCMeta): + def log_scores(self, targets: Sequence) -> Tensor: + raise NotImplementedError + + @lazy_property + def log_partitions(self) -> Tensor: + raise NotImplementedError + + def log_prob(self, targets: Sequence) -> Tensor: + return self.log_scores(targets=targets) - self.log_partitions + + @lazy_property + def max(self) -> Tensor: + raise NotImplementedError + + @lazy_property + def argmax(self) -> Tensor: + raise NotImplementedError + + @lazy_property + def log_marginals(self) -> Tensor: + raise NotImplementedError + + @lazy_property + def entropy(self) -> Tensor: + raise NotImplementedError diff --git a/torchlatent/cky.py b/torchlatent/cky.py index dc83aea..2509aeb 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -1,15 +1,19 @@ -from typing import NamedTuple, Tuple, Type +from typing import Tuple, NamedTuple +from typing import Type import torch from torch import Tensor +from torch.distributions.utils import lazy_property from torch.types import Device from torchrua import major_sizes_to_ptr, accumulate_sizes -from torchlatent.semiring import Semiring +from torchlatent.abc import DistributionABC +from torchlatent.semiring import Semiring, Log, Max +from torchlatent.types import Sequence class CkyIndices(NamedTuple): - width_size: int + token_size: int cache_size: int src: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]] @@ -29,30 +33,69 @@ def cky_indices(token_sizes: Tensor, device: Device = None): batch_ptr = batch_ptr[z_ptr] y_ptr = z_ptr - acc_token_sizes[batch_ptr] - width_size = token_sizes.max().item() + token_size = token_sizes.max().item() cache_size, = token_ptr.size() return CkyIndices( - width_size=width_size, cache_size=cache_size, + token_size=token_size, cache_size=cache_size, src=((y_ptr - x_ptr, z_ptr), (batch_ptr, x_ptr, y_ptr)), tgt=(token_sizes - 1, acc_token_sizes), ) def cky_partition(data: Tensor, indices: CkyIndices, semiring: Type[Semiring]) -> Tensor: - width_size, cache_size, (src1, src2), tgt = indices + token_size, cache_size, (src1, src2), tgt = indices - tensor0 = torch.full((width_size, cache_size, *data.size()[3:]), fill_value=semiring.zero, requires_grad=False) - tensor1 = torch.full((width_size, cache_size, *data.size()[3:]), fill_value=semiring.zero, requires_grad=False) - tensor2 = torch.full((width_size, cache_size, *data.size()[3:]), fill_value=semiring.zero, requires_grad=False) + tensor0 = torch.full((token_size, cache_size, *data.size()[3:]), fill_value=semiring.zero, requires_grad=False) + tensor1 = torch.full((token_size, cache_size, *data.size()[3:]), fill_value=semiring.zero, requires_grad=False) + tensor2 = torch.full((token_size, cache_size, *data.size()[3:]), fill_value=semiring.zero, requires_grad=False) tensor0[src1] = data[src2] tensor1[0, :] = tensor2[-1, :] = tensor0[0, :] - for w in range(1, width_size): + for w in range(1, token_size): tensor1[w, :-w] = tensor2[-w - 1, w:] = semiring.mul( semiring.sum(semiring.mul(tensor1[:w, :-w], tensor2[-w:, w:]), dim=0), tensor0[w, w:], ) return tensor1[tgt] + + +class CkyDistribution(DistributionABC): + def __init__(self, scores: Tensor, indices: CkyIndices) -> None: + super(CkyDistribution, self).__init__(validate_args=False) + + self.scores = scores + self.indices = indices + + def log_scores(self, targets: Sequence) -> Tensor: + raise NotImplementedError + + @lazy_property + def log_partitions(self) -> Tensor: + return cky_partition(data=self.scores, indices=self.indices, semiring=Log) + + @lazy_property + def max(self) -> Tensor: + return cky_partition(data=self.scores, indices=self.indices, semiring=Max) + + @lazy_property + def argmax(self) -> Tensor: + pass + + @lazy_property + def log_marginals(self) -> Tensor: + pass + + @lazy_property + def entropy(self) -> Tensor: + pass + + +if __name__ == '__main__': + ans = CkyDistribution( + torch.randn((3, 5, 5), requires_grad=True), + cky_indices(torch.tensor([5, 2, 3])), + ).max + print(ans) From 059b44c5c64eec02af3ae574d27e464706c2bda9 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 31 Mar 2022 23:42:58 +0900 Subject: [PATCH 008/102] Feat: Add CkyDecoder --- torchlatent/cky.py | 61 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 56 insertions(+), 5 deletions(-) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 2509aeb..beb0e33 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -1,11 +1,16 @@ +from abc import ABCMeta from typing import Tuple, NamedTuple from typing import Type import torch from torch import Tensor +from torch import nn from torch.distributions.utils import lazy_property +from torch.nn.utils.rnn import PackedSequence from torch.types import Device +from torchrua import CattedSequence, pack_sequence from torchrua import major_sizes_to_ptr, accumulate_sizes +from torchrua import pad_packed_sequence, pad_catted_sequence from torchlatent.abc import DistributionABC from torchlatent.semiring import Semiring, Log, Max @@ -93,9 +98,55 @@ def entropy(self) -> Tensor: pass +class CkyDecoderABC(nn.Module, metaclass=ABCMeta): + def obtain_scores(self, *args, **kwargs) -> Tensor: + raise NotImplementedError + + def forward(self, sequence: Sequence, indices: CkyIndices = None) -> CkyDistribution: + if isinstance(sequence, CattedSequence): + features, token_sizes = pad_catted_sequence(sequence, batch_first=True) + elif isinstance(sequence, PackedSequence): + features, token_sizes = pad_packed_sequence(sequence, batch_first=True) + elif isinstance(sequence, tuple) and torch.tensor(sequence[0]) and torch.is_tensor(sequence[1]): + features, token_sizes = sequence + else: + raise KeyError(f'type {type(sequence)} is not supported') + + if indices is None: + indices = cky_indices(token_sizes=token_sizes, device=features.device) + + return CkyDistribution( + scores=self.obtain_scores(features=features), + indices=indices, + ) + + +class CkyDecoder(CkyDecoderABC): + def __init__(self, in_features: int, bias: bool = True) -> None: + super(CkyDecoder, self).__init__() + + self.score = nn.Bilinear( + in1_features=in_features, + in2_features=in_features, + out_features=1, + bias=bias, + ) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.score.extra_repr()})' + + def obtain_scores(self, features: Tensor, *args, **kwargs) -> Tensor: + x, y = torch.broadcast_tensors(features[:, :, None], features[:, None, :]) + return self.score(x, y)[..., 0] + + if __name__ == '__main__': - ans = CkyDistribution( - torch.randn((3, 5, 5), requires_grad=True), - cky_indices(torch.tensor([5, 2, 3])), - ).max - print(ans) + e = pack_sequence([ + torch.randn((5, 3), requires_grad=True), + torch.randn((2, 3), requires_grad=True), + torch.randn((3, 3), requires_grad=True), + ]) + layer = CkyDecoder(in_features=3, bias=True) + dist = layer(e) + print(f'layer => {layer}') + print(dist.log_partitions) From f679b6ff718fd60d32a1f0ccb90dc23c0904a420 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 31 Mar 2022 23:54:48 +0900 Subject: [PATCH 009/102] Feat: Update CkyDecoder --- torchlatent/cky.py | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index beb0e33..8d6a782 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -99,7 +99,13 @@ def entropy(self) -> Tensor: class CkyDecoderABC(nn.Module, metaclass=ABCMeta): - def obtain_scores(self, *args, **kwargs) -> Tensor: + def reset_parameters(self) -> None: + raise NotImplementedError + + def extra_repr(self) -> str: + raise NotImplementedError + + def forward_scores(self, *args, **kwargs) -> Tensor: raise NotImplementedError def forward(self, sequence: Sequence, indices: CkyIndices = None) -> CkyDistribution: @@ -116,7 +122,7 @@ def forward(self, sequence: Sequence, indices: CkyIndices = None) -> CkyDistribu indices = cky_indices(token_sizes=token_sizes, device=features.device) return CkyDistribution( - scores=self.obtain_scores(features=features), + scores=self.forward_scores(features=features), indices=indices, ) @@ -125,19 +131,22 @@ class CkyDecoder(CkyDecoderABC): def __init__(self, in_features: int, bias: bool = True) -> None: super(CkyDecoder, self).__init__() - self.score = nn.Bilinear( - in1_features=in_features, - in2_features=in_features, - out_features=1, - bias=bias, - ) + self.fc1 = nn.Linear(in_features, in_features, bias=bias) + self.fc2 = nn.Linear(in_features, in_features, bias=bias) def __repr__(self) -> str: - return f'{self.__class__.__name__}({self.score.extra_repr()})' - - def obtain_scores(self, features: Tensor, *args, **kwargs) -> Tensor: - x, y = torch.broadcast_tensors(features[:, :, None], features[:, None, :]) - return self.score(x, y)[..., 0] + return f'{self.__class__.__name__}({self.extra_repr()})' + + def extra_repr(self) -> str: + return ', '.join([ + f'in_features={self.fc1.in_features}', + f'bias={self.fc1.bias is not None}', + ]) + + def forward_scores(self, features: Tensor, *args, **kwargs) -> Tensor: + x = self.fc1(features[..., :, None, :]) + y = self.fc2(features[..., None, :, :]) + return (x[..., None, :] @ y[..., :, None])[..., 0, 0] if __name__ == '__main__': From dec0be4d500fb502e89ed89250922c5b492e496d Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 1 Apr 2022 00:18:01 +0900 Subject: [PATCH 010/102] Feat: Add .argmax --- torchlatent/abc.py | 17 ++++++++++++----- torchlatent/cky.py | 41 +++++++++++++++++++++-------------------- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/torchlatent/abc.py b/torchlatent/abc.py index cfa15f6..0346302 100644 --- a/torchlatent/abc.py +++ b/torchlatent/abc.py @@ -1,5 +1,6 @@ from abc import ABCMeta +import torch.autograd from torch import Tensor from torch.distributions import Distribution from torch.distributions.utils import lazy_property @@ -12,15 +13,17 @@ class DistributionABC(Distribution, metaclass=ABCMeta): - def log_scores(self, targets: Sequence) -> Tensor: + scores: Tensor + + def log_scores(self, value: Sequence) -> Tensor: raise NotImplementedError @lazy_property def log_partitions(self) -> Tensor: raise NotImplementedError - def log_prob(self, targets: Sequence) -> Tensor: - return self.log_scores(targets=targets) - self.log_partitions + def log_prob(self, value: Sequence) -> Tensor: + return self.log_scores(value=value) - self.log_partitions @lazy_property def max(self) -> Tensor: @@ -28,10 +31,14 @@ def max(self) -> Tensor: @lazy_property def argmax(self) -> Tensor: - raise NotImplementedError + grad, = torch.autograd.grad( + self.max, self.scores, torch.ones_like(self.max), + create_graph=False, only_inputs=True, allow_unused=False, + ) + return grad @lazy_property - def log_marginals(self) -> Tensor: + def marginals(self) -> Tensor: raise NotImplementedError @lazy_property diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 8d6a782..a14a1fc 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -8,7 +8,7 @@ from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence from torch.types import Device -from torchrua import CattedSequence, pack_sequence +from torchrua import CattedSequence from torchrua import major_sizes_to_ptr, accumulate_sizes from torchrua import pad_packed_sequence, pad_catted_sequence @@ -74,28 +74,41 @@ def __init__(self, scores: Tensor, indices: CkyIndices) -> None: self.scores = scores self.indices = indices - def log_scores(self, targets: Sequence) -> Tensor: + def log_scores(self, value: Sequence) -> Tensor: raise NotImplementedError @lazy_property def log_partitions(self) -> Tensor: - return cky_partition(data=self.scores, indices=self.indices, semiring=Log) + return cky_partition(data=Log.sum(self.scores, dim=-1), indices=self.indices, semiring=Log) @lazy_property def max(self) -> Tensor: - return cky_partition(data=self.scores, indices=self.indices, semiring=Max) + return cky_partition(data=Max.sum(self.scores, dim=-1), indices=self.indices, semiring=Max) @lazy_property def argmax(self) -> Tensor: - pass + mask = super(CkyDistribution, self).argmax > 0 + b, n, _, m = self.scores + + index = torch.arange(n, device=mask.device) + x = torch.masked_select(index[None, :, None, None], mask=mask) + y = torch.masked_select(index[None, None, :, None], mask=mask) + + index = torch.arange(m, device=mask.device) + z = torch.masked_select(index[None, None, None, :], mask=mask) + return torch.stack([x, y, z], dim=0) @lazy_property - def log_marginals(self) -> Tensor: - pass + def marginals(self) -> Tensor: + grad, = torch.autograd.grad( + self.log_partitions, self.scores, torch.ones_like(self.log_partitions), + create_graph=True, only_inputs=True, allow_unused=False, + ) + return grad @lazy_property def entropy(self) -> Tensor: - pass + raise NotImplementedError class CkyDecoderABC(nn.Module, metaclass=ABCMeta): @@ -147,15 +160,3 @@ def forward_scores(self, features: Tensor, *args, **kwargs) -> Tensor: x = self.fc1(features[..., :, None, :]) y = self.fc2(features[..., None, :, :]) return (x[..., None, :] @ y[..., :, None])[..., 0, 0] - - -if __name__ == '__main__': - e = pack_sequence([ - torch.randn((5, 3), requires_grad=True), - torch.randn((2, 3), requires_grad=True), - torch.randn((3, 3), requires_grad=True), - ]) - layer = CkyDecoder(in_features=3, bias=True) - dist = layer(e) - print(f'layer => {layer}') - print(dist.log_partitions) From 5544683d7191bb63379f628dc4665d70749e6c86 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 2 Apr 2022 20:47:44 +0900 Subject: [PATCH 011/102] Test: Add unit test for cky_log_partitions --- tests/test_cky.py | 27 +++++++++++++++++++++++++++ tests/utils.py | 4 ++-- torchlatent/cky.py | 1 + 3 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 tests/test_cky.py diff --git a/tests/test_cky.py b/tests/test_cky.py new file mode 100644 index 0000000..0bcf039 --- /dev/null +++ b/tests/test_cky.py @@ -0,0 +1,27 @@ +import torch +from hypothesis import given +from torch.testing import assert_close +from torch_struct import TreeCRF + +from tests.strategies import sizes, BATCH_SIZE, TOKEN_SIZE, devices, TINY_BATCH_SIZE +from tests.utils import assert_grad_close +from torchlatent.cky import CkyDistribution, cky_indices + + +@given( + device=devices(), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_tags=sizes(TOKEN_SIZE), +) +def test_cky_log_partitions(device, token_sizes, num_tags): + scores = torch.randn( + (len(token_sizes), max(token_sizes), max(token_sizes), num_tags), + requires_grad=True, device=device, + ) + token_sizes = torch.tensor(token_sizes, device=device) + + excepted = TreeCRF(log_potentials=scores, lengths=token_sizes) + actual = CkyDistribution(scores=scores, indices=cky_indices(token_sizes=token_sizes, device=device)) + + assert_close(actual=actual.log_partitions, expected=excepted.partition) + assert_grad_close(actual=actual.log_partitions, expected=excepted.partition, inputs=scores, rtol=1e-5, atol=1e-5) diff --git a/tests/utils.py b/tests/utils.py index 3040db4..5981653 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -21,8 +21,8 @@ def assert_grad_close( actual: Tensor, expected: Tensor, inputs: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], allow_unused: bool = False, - check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None: - kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride) + check_device: bool = True, check_dtype: bool = True, check_stride: bool = True, **kwargs) -> None: + kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride, **kwargs) grad = torch.rand_like(actual) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index a14a1fc..09edafa 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -160,3 +160,4 @@ def forward_scores(self, features: Tensor, *args, **kwargs) -> Tensor: x = self.fc1(features[..., :, None, :]) y = self.fc2(features[..., None, :, :]) return (x[..., None, :] @ y[..., :, None])[..., 0, 0] + From 0aa03cc40acf8f95380b3113930e48e64c98c738 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 2 Apr 2022 21:53:36 +0900 Subject: [PATCH 012/102] Test: Add unit test for test_cky_log_scores --- tests/test_cky.py | 21 ++++++++++++++++++++- torchlatent/cky.py | 24 +++++++++++++++++++----- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index 0bcf039..d2ab038 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -2,8 +2,9 @@ from hypothesis import given from torch.testing import assert_close from torch_struct import TreeCRF +from torchrua import CattedSequence -from tests.strategies import sizes, BATCH_SIZE, TOKEN_SIZE, devices, TINY_BATCH_SIZE +from tests.strategies import sizes, BATCH_SIZE, TOKEN_SIZE, devices from tests.utils import assert_grad_close from torchlatent.cky import CkyDistribution, cky_indices @@ -25,3 +26,21 @@ def test_cky_log_partitions(device, token_sizes, num_tags): assert_close(actual=actual.log_partitions, expected=excepted.partition) assert_grad_close(actual=actual.log_partitions, expected=excepted.partition, inputs=scores, rtol=1e-5, atol=1e-5) + + +@given( + device=devices(), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_tags=sizes(TOKEN_SIZE), +) +def test_cky_log_scores(device, token_sizes, num_tags): + scores = torch.randn( + (len(token_sizes), max(token_sizes), max(token_sizes), num_tags), + requires_grad=True, device=device, + ) + token_sizes = torch.tensor(token_sizes, device=device) + + cky = CkyDistribution(scores=scores, indices=cky_indices(token_sizes=token_sizes, device=device)) + argmax = CattedSequence(data=cky.argmax, token_sizes=token_sizes * 2 - 1) + + assert_close(actual=cky.max, expected=cky.log_scores(argmax)) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 09edafa..684ed8c 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -75,6 +75,13 @@ def __init__(self, scores: Tensor, indices: CkyIndices) -> None: self.indices = indices def log_scores(self, value: Sequence) -> Tensor: + if isinstance(value, CattedSequence): + (x_ptr, y_ptr, z_ptr), token_sizes = value + _, batch_ptr = major_sizes_to_ptr(sizes=token_sizes) + return torch.segment_reduce( + self.scores[batch_ptr, x_ptr, y_ptr, z_ptr], + reduce='sum', lengths=token_sizes, + ) raise NotImplementedError @lazy_property @@ -88,7 +95,7 @@ def max(self) -> Tensor: @lazy_property def argmax(self) -> Tensor: mask = super(CkyDistribution, self).argmax > 0 - b, n, _, m = self.scores + b, n, _, m = mask.size() index = torch.arange(n, device=mask.device) x = torch.masked_select(index[None, :, None, None], mask=mask) @@ -118,9 +125,6 @@ def reset_parameters(self) -> None: def extra_repr(self) -> str: raise NotImplementedError - def forward_scores(self, *args, **kwargs) -> Tensor: - raise NotImplementedError - def forward(self, sequence: Sequence, indices: CkyIndices = None) -> CkyDistribution: if isinstance(sequence, CattedSequence): features, token_sizes = pad_catted_sequence(sequence, batch_first=True) @@ -139,6 +143,17 @@ def forward(self, sequence: Sequence, indices: CkyIndices = None) -> CkyDistribu indices=indices, ) + def fit(self, sequence: Sequence, value: Sequence, indices: CkyIndices = None) -> Tensor: + dist = self.forward(sequence=sequence, indices=indices) + return dist.log_partitions - dist.log_scores(value=value) + + def decode(self, sequence: Sequence, indices: CkyIndices) -> Sequence: + dist = self.forward(sequence=sequence, indices=indices) + if isinstance(sequence, CattedSequence): + return CattedSequence(data=dist.argmax, token_sizes=sequence.token_sizes * 2 - 1) + else: + raise NotImplementedError + class CkyDecoder(CkyDecoderABC): def __init__(self, in_features: int, bias: bool = True) -> None: @@ -160,4 +175,3 @@ def forward_scores(self, features: Tensor, *args, **kwargs) -> Tensor: x = self.fc1(features[..., :, None, :]) y = self.fc2(features[..., None, :, :]) return (x[..., None, :] @ y[..., :, None])[..., 0, 0] - From a40bb89d97d7a52b07338f89cdf3466645133a35 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 2 Apr 2022 22:22:46 +0900 Subject: [PATCH 013/102] Feat: Support PackedSequence --- torchlatent/cky.py | 82 +++++++++++++++++++++++++++++++++------------- 1 file changed, 60 insertions(+), 22 deletions(-) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 684ed8c..64cb614 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -8,7 +8,7 @@ from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence from torch.types import Device -from torchrua import CattedSequence +from torchrua import CattedSequence, cat_packed_indices, transpose_sizes, pack_catted_sequence, pack_sequence from torchrua import major_sizes_to_ptr, accumulate_sizes from torchrua import pad_packed_sequence, pad_catted_sequence @@ -42,7 +42,8 @@ def cky_indices(token_sizes: Tensor, device: Device = None): cache_size, = token_ptr.size() return CkyIndices( - token_size=token_size, cache_size=cache_size, + token_size=token_size, + cache_size=cache_size, src=((y_ptr - x_ptr, z_ptr), (batch_ptr, x_ptr, y_ptr)), tgt=(token_sizes - 1, acc_token_sizes), ) @@ -51,9 +52,10 @@ def cky_indices(token_sizes: Tensor, device: Device = None): def cky_partition(data: Tensor, indices: CkyIndices, semiring: Type[Semiring]) -> Tensor: token_size, cache_size, (src1, src2), tgt = indices - tensor0 = torch.full((token_size, cache_size, *data.size()[3:]), fill_value=semiring.zero, requires_grad=False) - tensor1 = torch.full((token_size, cache_size, *data.size()[3:]), fill_value=semiring.zero, requires_grad=False) - tensor2 = torch.full((token_size, cache_size, *data.size()[3:]), fill_value=semiring.zero, requires_grad=False) + size = (token_size, cache_size, *data.size()[3:]) + tensor0 = torch.full(size, fill_value=semiring.zero, requires_grad=False) + tensor1 = torch.full(size, fill_value=semiring.zero, requires_grad=False) + tensor2 = torch.full(size, fill_value=semiring.zero, requires_grad=False) tensor0[src1] = data[src2] tensor1[0, :] = tensor2[-1, :] = tensor0[0, :] @@ -76,13 +78,25 @@ def __init__(self, scores: Tensor, indices: CkyIndices) -> None: def log_scores(self, value: Sequence) -> Tensor: if isinstance(value, CattedSequence): - (x_ptr, y_ptr, z_ptr), token_sizes = value - _, batch_ptr = major_sizes_to_ptr(sizes=token_sizes) - return torch.segment_reduce( - self.scores[batch_ptr, x_ptr, y_ptr, z_ptr], - reduce='sum', lengths=token_sizes, + ptr, token_sizes = value + batch_ptr = torch.repeat_interleave(repeats=token_sizes) + return Log.segment_mul( + self.scores[batch_ptr, ptr[..., 0], ptr[..., 1], ptr[..., 2]], + sizes=token_sizes, + ) + + if isinstance(value, PackedSequence): + ptr, batch_sizes, _, unsorted_indices = value + indices, token_sizes = cat_packed_indices(batch_sizes=batch_sizes, unsorted_indices=unsorted_indices) + batch_ptr = torch.repeat_interleave(repeats=token_sizes) + ptr = ptr[indices] + + return Log.segment_mul( + self.scores[batch_ptr, ptr[..., 0], ptr[..., 1], ptr[..., 2]], + sizes=token_sizes, ) - raise NotImplementedError + + raise KeyError(f'type {type(value)} is not supported') @lazy_property def log_partitions(self) -> Tensor: @@ -103,7 +117,7 @@ def argmax(self) -> Tensor: index = torch.arange(m, device=mask.device) z = torch.masked_select(index[None, None, None, :], mask=mask) - return torch.stack([x, y, z], dim=0) + return torch.stack([x, y, z], dim=-1) @lazy_property def marginals(self) -> Tensor: @@ -125,6 +139,9 @@ def reset_parameters(self) -> None: def extra_repr(self) -> str: raise NotImplementedError + def forward_scores(self, features: Tensor, *args, **kwargs) -> Tensor: + raise NotImplementedError + def forward(self, sequence: Sequence, indices: CkyIndices = None) -> CkyDistribution: if isinstance(sequence, CattedSequence): features, token_sizes = pad_catted_sequence(sequence, batch_first=True) @@ -147,31 +164,52 @@ def fit(self, sequence: Sequence, value: Sequence, indices: CkyIndices = None) - dist = self.forward(sequence=sequence, indices=indices) return dist.log_partitions - dist.log_scores(value=value) - def decode(self, sequence: Sequence, indices: CkyIndices) -> Sequence: + def decode(self, sequence: Sequence, indices: CkyIndices = None) -> Sequence: dist = self.forward(sequence=sequence, indices=indices) + if isinstance(sequence, CattedSequence): return CattedSequence(data=dist.argmax, token_sizes=sequence.token_sizes * 2 - 1) - else: - raise NotImplementedError + + if isinstance(sequence, PackedSequence): + token_sizes = transpose_sizes(sizes=sequence.batch_sizes)[sequence.unsorted_indices] * 2 - 1 + return pack_catted_sequence(sequence=dist.argmax, token_sizes=token_sizes) + + raise KeyError(f'type {type(sequence)} is not supported') class CkyDecoder(CkyDecoderABC): - def __init__(self, in_features: int, bias: bool = True) -> None: + def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: super(CkyDecoder, self).__init__() - self.fc1 = nn.Linear(in_features, in_features, bias=bias) - self.fc2 = nn.Linear(in_features, in_features, bias=bias) + self.fc = nn.Bilinear( + in1_features=in_features, + in2_features=in_features, + out_features=out_features, + bias=bias, + ) def __repr__(self) -> str: return f'{self.__class__.__name__}({self.extra_repr()})' def extra_repr(self) -> str: return ', '.join([ - f'in_features={self.fc1.in_features}', + f'in_features={self.fc.in_features}', + f'in_features={self.fc.out_features}', f'bias={self.fc1.bias is not None}', ]) def forward_scores(self, features: Tensor, *args, **kwargs) -> Tensor: - x = self.fc1(features[..., :, None, :]) - y = self.fc2(features[..., None, :, :]) - return (x[..., None, :] @ y[..., :, None])[..., 0, 0] + x, y = torch.broadcast_tensors(features[..., :, None, :], features[..., None, :, :]) + return self.fc(x, y) + + +if __name__ == '__main__': + decoder = CkyDecoder(2, 3) + e = pack_sequence([ + torch.randn((5, 2), requires_grad=True), + torch.randn((2, 2), requires_grad=True), + torch.randn((3, 2), requires_grad=True), + ]) + dist = decoder.forward(e) + print(dist.max) + print(dist.log_scores(decoder.decode(e))) From 6a6da12f15229d8630d3ca4bdddf5643ff661616 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 16:09:51 +0900 Subject: [PATCH 014/102] Test: Add unit test for test_cky_log_scores --- tests/strategies.py | 2 ++ tests/test_cky.py | 39 ++++++++++++++++++++------------------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/tests/strategies.py b/tests/strategies.py index 785f929..980e6da 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -10,6 +10,8 @@ NUM_TAGS = 8 NUM_CONJUGATES = 5 +EMBEDDING_DIM = 25 + @st.composite def devices(draw): diff --git a/tests/test_cky.py b/tests/test_cky.py index d2ab038..f64c578 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -1,31 +1,31 @@ import torch -from hypothesis import given +from hypothesis import given, strategies as st from torch.testing import assert_close from torch_struct import TreeCRF -from torchrua import CattedSequence +from torchrua import pack_sequence -from tests.strategies import sizes, BATCH_SIZE, TOKEN_SIZE, devices +from tests.strategies import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, devices, TINY_BATCH_SIZE from tests.utils import assert_grad_close -from torchlatent.cky import CkyDistribution, cky_indices +from torchlatent.cky import CkyDistribution, cky_indices, CkyDecoder @given( device=devices(), - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + embedding_dim=sizes(EMBEDDING_DIM), num_tags=sizes(TOKEN_SIZE), + bias=st.booleans(), ) -def test_cky_log_partitions(device, token_sizes, num_tags): - scores = torch.randn( - (len(token_sizes), max(token_sizes), max(token_sizes), num_tags), - requires_grad=True, device=device, - ) - token_sizes = torch.tensor(token_sizes, device=device) +def test_cky_log_scores(device, token_sizes, embedding_dim, num_tags, bias): + sequence = pack_sequence([ + torch.randn((token_size, embedding_dim), requires_grad=True, device=device) + for token_size in token_sizes + ]) - excepted = TreeCRF(log_potentials=scores, lengths=token_sizes) - actual = CkyDistribution(scores=scores, indices=cky_indices(token_sizes=token_sizes, device=device)) + decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias) + cky = decoder.forward(sequence=sequence) - assert_close(actual=actual.log_partitions, expected=excepted.partition) - assert_grad_close(actual=actual.log_partitions, expected=excepted.partition, inputs=scores, rtol=1e-5, atol=1e-5) + assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) @given( @@ -33,14 +33,15 @@ def test_cky_log_partitions(device, token_sizes, num_tags): token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), num_tags=sizes(TOKEN_SIZE), ) -def test_cky_log_scores(device, token_sizes, num_tags): +def test_cky_log_partitions(device, token_sizes, num_tags): scores = torch.randn( (len(token_sizes), max(token_sizes), max(token_sizes), num_tags), requires_grad=True, device=device, ) token_sizes = torch.tensor(token_sizes, device=device) - cky = CkyDistribution(scores=scores, indices=cky_indices(token_sizes=token_sizes, device=device)) - argmax = CattedSequence(data=cky.argmax, token_sizes=token_sizes * 2 - 1) + excepted = TreeCRF(log_potentials=scores, lengths=token_sizes) + actual = CkyDistribution(scores=scores, indices=cky_indices(token_sizes=token_sizes, device=device)) - assert_close(actual=cky.max, expected=cky.log_scores(argmax)) + assert_close(actual=actual.log_partitions, expected=excepted.partition) + assert_grad_close(actual=actual.log_partitions, expected=excepted.partition, inputs=scores) From d43447115d63397fc28c66865351f5aecb7efc01 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 16:10:32 +0900 Subject: [PATCH 015/102] Test: Add unit test, test_cky_packed_max --- tests/test_cky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index f64c578..1cc9b97 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -16,7 +16,7 @@ num_tags=sizes(TOKEN_SIZE), bias=st.booleans(), ) -def test_cky_log_scores(device, token_sizes, embedding_dim, num_tags, bias): +def test_cky_packed_max(device, token_sizes, embedding_dim, num_tags, bias): sequence = pack_sequence([ torch.randn((token_size, embedding_dim), requires_grad=True, device=device) for token_size in token_sizes From 576bcec45d18359de2d69ef92eeeb72c73dad918 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 16:12:54 +0900 Subject: [PATCH 016/102] Test: Add unit test, test_cky_catted_max --- tests/test_cky.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index 1cc9b97..145b7c7 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -2,13 +2,32 @@ from hypothesis import given, strategies as st from torch.testing import assert_close from torch_struct import TreeCRF -from torchrua import pack_sequence +from torchrua import pack_sequence, cat_sequence from tests.strategies import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, devices, TINY_BATCH_SIZE from tests.utils import assert_grad_close from torchlatent.cky import CkyDistribution, cky_indices, CkyDecoder +@given( + device=devices(), + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + embedding_dim=sizes(EMBEDDING_DIM), + num_tags=sizes(TOKEN_SIZE), + bias=st.booleans(), +) +def test_cky_catted_max(device, token_sizes, embedding_dim, num_tags, bias): + sequence = cat_sequence([ + torch.randn((token_size, embedding_dim), requires_grad=True, device=device) + for token_size in token_sizes + ]) + + decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias) + cky = decoder.forward(sequence=sequence) + + assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) + + @given( device=devices(), token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), @@ -44,4 +63,4 @@ def test_cky_log_partitions(device, token_sizes, num_tags): actual = CkyDistribution(scores=scores, indices=cky_indices(token_sizes=token_sizes, device=device)) assert_close(actual=actual.log_partitions, expected=excepted.partition) - assert_grad_close(actual=actual.log_partitions, expected=excepted.partition, inputs=scores) + assert_grad_close(actual=actual.log_partitions, expected=excepted.partition, inputs=scores, rtol=1e-5, atol=1e-5) From 0e52b78d3e03f9a77d91fa1594e3ad1b3fc2c727 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 16:56:20 +0900 Subject: [PATCH 017/102] Feat: Add segment_indices --- torchlatent/cky.py | 46 ++++----------- torchlatent/crf/catting.py | 2 +- torchlatent/semiring.py | 111 ++++++++++++++++++++++++++----------- 3 files changed, 89 insertions(+), 70 deletions(-) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 64cb614..b3f3b33 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -8,12 +8,12 @@ from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence from torch.types import Device -from torchrua import CattedSequence, cat_packed_indices, transpose_sizes, pack_catted_sequence, pack_sequence +from torchrua import CattedSequence, transpose_sizes, pack_catted_sequence from torchrua import major_sizes_to_ptr, accumulate_sizes from torchrua import pad_packed_sequence, pad_catted_sequence from torchlatent.abc import DistributionABC -from torchlatent.semiring import Semiring, Log, Max +from torchlatent.semiring import Semiring, Log, Max, segment_indices from torchlatent.types import Sequence @@ -76,27 +76,13 @@ def __init__(self, scores: Tensor, indices: CkyIndices) -> None: self.scores = scores self.indices = indices - def log_scores(self, value: Sequence) -> Tensor: - if isinstance(value, CattedSequence): - ptr, token_sizes = value - batch_ptr = torch.repeat_interleave(repeats=token_sizes) - return Log.segment_mul( - self.scores[batch_ptr, ptr[..., 0], ptr[..., 1], ptr[..., 2]], - sizes=token_sizes, - ) - - if isinstance(value, PackedSequence): - ptr, batch_sizes, _, unsorted_indices = value - indices, token_sizes = cat_packed_indices(batch_sizes=batch_sizes, unsorted_indices=unsorted_indices) - batch_ptr = torch.repeat_interleave(repeats=token_sizes) - ptr = ptr[indices] - - return Log.segment_mul( - self.scores[batch_ptr, ptr[..., 0], ptr[..., 1], ptr[..., 2]], - sizes=token_sizes, - ) - - raise KeyError(f'type {type(value)} is not supported') + def log_scores(self, sequence: Sequence) -> Tensor: + indices, batch_ptr, sizes = segment_indices(sequence=sequence) + data = sequence.data[indices] + return Log.segment_prod( + tensor=self.scores[batch_ptr, data[..., 0], data[..., 1], data[..., 2]], + sizes=sizes, + ) @lazy_property def log_partitions(self) -> Tensor: @@ -162,7 +148,7 @@ def forward(self, sequence: Sequence, indices: CkyIndices = None) -> CkyDistribu def fit(self, sequence: Sequence, value: Sequence, indices: CkyIndices = None) -> Tensor: dist = self.forward(sequence=sequence, indices=indices) - return dist.log_partitions - dist.log_scores(value=value) + return dist.log_partitions - dist.log_scores(sequence=value) def decode(self, sequence: Sequence, indices: CkyIndices = None) -> Sequence: dist = self.forward(sequence=sequence, indices=indices) @@ -201,15 +187,3 @@ def extra_repr(self) -> str: def forward_scores(self, features: Tensor, *args, **kwargs) -> Tensor: x, y = torch.broadcast_tensors(features[..., :, None, :], features[..., None, :, :]) return self.fc(x, y) - - -if __name__ == '__main__': - decoder = CkyDecoder(2, 3) - e = pack_sequence([ - torch.randn((5, 2), requires_grad=True), - torch.randn((2, 2), requires_grad=True), - torch.randn((3, 2), requires_grad=True), - ]) - dist = decoder.forward(e) - print(dist.max) - print(dist.log_scores(decoder.decode(e))) diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py index 7132951..f81fc62 100644 --- a/torchlatent/crf/catting.py +++ b/torchlatent/crf/catting.py @@ -40,7 +40,7 @@ def _compute_catted_sequence_scores( transition_scores[head_indices] = transition_head_scores # [h, c] scores = semiring.mul(emission_scores, transition_scores) - scores = semiring.segment_mul(scores, sizes=emissions.token_sizes) + scores = semiring.segment_prod(scores, sizes=emissions.token_sizes) scores = semiring.mul(scores, transition_last_scores) diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 840a17b..4602d73 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,9 +1,13 @@ import torch from torch import Tensor +from torch.nn.utils.rnn import PackedSequence +from torch.types import Device +from torchrua import cat_packed_indices, cat_padded_indices, CattedSequence from torchrua.reduction import reduce_sequence, ReductionIndices -from torchrua.scatter import scatter_add, scatter_max, scatter_mul, scatter_logsumexp +from torchrua.scatter import scatter_add, scatter_logsumexp from torchlatent.functional import logsumexp, logaddexp +from torchlatent.types import Sequence __all__ = [ 'Semiring', @@ -11,6 +15,71 @@ ] +@torch.no_grad() +def segment_indices(sequence: Sequence, batch_first: bool = True, device: Device = None): + if isinstance(sequence, CattedSequence): + data, token_sizes = sequence + return segment_catted_indices(token_sizes=token_sizes, device=data.device) + + if isinstance(sequence, PackedSequence): + data, batch_sizes, _, unsorted_indices = sequence + return segment_packed_indices(batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=data.device) + + if isinstance(sequence, tuple) and torch.is_tensor(sequence[0]) and torch.is_tensor(sequence[1]): + data, token_sizes = sequence + return segment_padded_indices(token_sizes=token_sizes, batch_first=batch_first, device=device) + + raise KeyError(f'type {type(sequence)} is not supported') + + +@torch.no_grad() +def segment_catted_indices(token_sizes: Tensor, device: Device = None): + if device is None: + device = token_sizes.device + + token_sizes = token_sizes.to(device=device) + + batch_ptr = torch.repeat_interleave(repeats=token_sizes) + return ..., batch_ptr, token_sizes + + +@torch.no_grad() +def segment_packed_indices(batch_sizes: Tensor, unsorted_indices: Tensor, device: Device = None): + if device is None: + if unsorted_indices is not None: + device = unsorted_indices.device + else: + device = batch_sizes.device + + batch_sizes = batch_sizes.to(device=device) + unsorted_indices = unsorted_indices.to(device=device) + + indices, token_sizes = cat_packed_indices( + batch_sizes=batch_sizes, + unsorted_indices=unsorted_indices, + device=device, + ) + batch_ptr = torch.repeat_interleave(repeats=token_sizes) + return indices, batch_ptr, token_sizes + + +@torch.no_grad() +def segment_padded_indices(token_sizes: Tensor, batch_first: bool, device: Device = None): + if device is None: + device = token_sizes.device + + token_sizes = token_sizes.to(device=device) + + if batch_first: + (batch_ptr, token_ptr), _ = cat_padded_indices( + token_sizes=token_sizes, batch_first=batch_first, device=device) + return (batch_ptr, token_ptr), batch_ptr, token_sizes + else: + (token_ptr, batch_ptr), _ = cat_padded_indices( + token_sizes=token_sizes, batch_first=batch_first, device=device) + return (token_ptr, batch_ptr), batch_ptr, token_sizes + + class Semiring(object): zero: float one: float @@ -41,19 +110,11 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: raise NotImplementedError @classmethod - def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor: + def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor: raise NotImplementedError @classmethod - def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor: - raise NotImplementedError - - @classmethod - def segment_add(cls, tensor: Tensor, sizes: Tensor) -> Tensor: - raise NotImplementedError - - @classmethod - def segment_mul(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: raise NotImplementedError @classmethod @@ -86,19 +147,11 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: return torch.prod(tensor, dim=dim, keepdim=keepdim) @classmethod - def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor: - return scatter_add(tensor=tensor, index=index) - - @classmethod - def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor: - return scatter_mul(tensor=tensor, index=index) - - @classmethod - def segment_add(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor: return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) @classmethod - def segment_mul(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: raise NotImplementedError @@ -131,13 +184,13 @@ def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor: return scatter_add(tensor=tensor, index=index) @classmethod - def segment_add(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor: m = torch.segment_reduce(tensor, reduce='max', lengths=sizes, unsafe=True).detach() z = (tensor - torch.repeat_interleave(m, repeats=sizes)).exp() return torch.segment_reduce(z, reduce='sum', lengths=sizes, unsafe=True).log() + m @classmethod - def segment_mul(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) @@ -162,17 +215,9 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: return torch.sum(tensor, dim=dim, keepdim=keepdim) @classmethod - def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor: - return scatter_max(tensor=tensor, index=index) - - @classmethod - def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor: - return scatter_add(tensor=tensor, index=index) - - @classmethod - def segment_add(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor: return torch.segment_reduce(tensor, reduce='max', lengths=sizes, unsafe=True) @classmethod - def segment_mul(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) From afd4455783acd98472e2e47cf94971efb0817999 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 17:48:57 +0900 Subject: [PATCH 018/102] Feat: Add crf_segment_reduce --- torchlatent/crf2.py | 117 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) create mode 100644 torchlatent/crf2.py diff --git a/torchlatent/crf2.py b/torchlatent/crf2.py new file mode 100644 index 0000000..696f812 --- /dev/null +++ b/torchlatent/crf2.py @@ -0,0 +1,117 @@ +from typing import Tuple, Sequence, Type + +import torch +from torch import Tensor +from torch.nn.utils.rnn import PackedSequence +from torch.types import Device +from torchrua import roll_catted_indices, CattedSequence, head_catted_indices, last_catted_indices, head_packed_indices, \ + last_packed_indices, accumulate_sizes + +from torchlatent.semiring import segment_catted_indices, segment_packed_indices, Semiring + + +@torch.no_grad() +def broadcast_catted_shapes(sequence: CattedSequence, transitions: Tuple[Tensor, Tensor, Tensor]): + sequence, token_sizes = sequence + transitions, head_transitions, last_transitions = transitions + + t1, c1, *_ = sequence.size() + h1, = token_sizes.size() + + t2, c2, _, _ = transitions.size() + h3, c3, _ = head_transitions.size() + h4, c4, _ = last_transitions.size() + + return torch.broadcast_shapes((t1, c1, h1), (t2, c2, 1), (1, c3, h3), (1, c4, h4)) + + +@torch.no_grad() +def broadcast_packed_shapes(sequence: PackedSequence, transitions: Tuple[Tensor, Tensor, Tensor]): + sequence, batch_sizes, _, _ = sequence + transitions, head_transitions, last_transitions = transitions + + t1, c1, *_ = sequence.size() + h1 = batch_sizes[0].item() + + t2, c2, _, _ = transitions.size() + h3, c3, _ = head_transitions.size() + h4, c4, _ = last_transitions.size() + + return torch.broadcast_shapes((t1, c1, h1), (t2, c2, 1), (1, c3, h3), (1, c4, h4)) + + +@torch.no_grad() +def crf_segment_catted_indices(token_sizes: Tensor, device: Device = None): + if device is None: + device = token_sizes.device + + token_sizes = token_sizes.to(device=device) + + curr, _, token_sizes = segment_catted_indices(token_sizes=token_sizes, device=device) + prev = roll_catted_indices(token_sizes=token_sizes, shifts=1, device=device) + head = head_catted_indices(token_sizes=token_sizes, device=device) + last = last_catted_indices(token_sizes=token_sizes, device=device) + return prev, curr, torch.arange(token_sizes.size()[0], device=device), head, last, token_sizes + + +@torch.no_grad() +def crf_segment_packed_indices(batch_sizes: Tensor, unsorted_indices: Tensor, device: Device): + if device is None: + if unsorted_indices is not None: + device = unsorted_indices.device + else: + device = batch_sizes.device + + batch_sizes = batch_sizes.to(device=device) + unsorted_indices = unsorted_indices.to(device=device) + + curr, _, token_sizes = segment_packed_indices( + batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=device, + ) + prev = roll_catted_indices(token_sizes=token_sizes, shifts=1, device=device) + head = head_packed_indices(token_sizes=token_sizes, unsorted_indices=unsorted_indices, device=device) + last = last_packed_indices(token_sizes=token_sizes, unsorted_indices=unsorted_indices, device=device) + return curr[prev], curr, unsorted_indices, head, last, token_sizes + + +def crf_segment_reduce(emissions: Sequence, targets: Sequence, + transitions: Tuple[Tensor, Tensor, Tensor], semiring: Type[Semiring]) -> Tensor: + if isinstance(emissions, CattedSequence): + emissions, token_sizes = emissions + prev, curr, unsorted_indices, head, last, sizes = crf_segment_catted_indices( + token_sizes=token_sizes, device=emissions.device, + ) + elif isinstance(emissions, PackedSequence): + emissions, batch_sizes, _, unsorted_indices = emissions + prev, curr, unsorted_indices, head, last, sizes = crf_segment_packed_indices( + batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=emissions.device, + ) + else: + raise NotImplementedError + + if isinstance(targets, CattedSequence): + t, c, h = broadcast_catted_shapes(targets, transitions=transitions) + targets, _ = targets + elif isinstance(targets, PackedSequence): + t, c, h = broadcast_packed_shapes(targets, transitions=transitions) + targets, _, _, _ = targets + else: + raise NotImplementedError + + c = torch.arange(c, device=emissions.device) + + transitions, head_transitions, last_transitions = transitions + emissions = emissions.expand((t, c, -1)) + targets = targets.expand((t, c)) + transitions = transitions.expand((t, c, -1, -1)) + head_transitions = head_transitions.expand((h, c, -1)) + last_transitions = last_transitions.expand((h, c, -1)) + + emissions = emissions[curr[:, None], c[None, :], targets[curr]] + transitions = transitions[curr[:, None], c[None, :], targets[prev], targets[curr]] + transitions[accumulate_sizes(sizes=sizes)] = semiring.one + head_transitions = head_transitions[unsorted_indices[:, None], c[None, :], targets[head]] + last_transitions = last_transitions[unsorted_indices[:, None], c[None, :], targets[last]] + + emissions = semiring.segment_prod(semiring.mul(emissions, transitions), sizes=sizes) + return semiring.mul(emissions, semiring.mul(head_transitions, last_transitions)) From ac2f784ef0b60707edb821fc0c1f383a8c123c1f Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 18:02:09 +0900 Subject: [PATCH 019/102] Feat: Impl compute_catted_sequence_scores by using crf_segment_reduce --- torchlatent/crf/catting.py | 32 ++++++-------------------------- torchlatent/crf2.py | 5 +++-- torchlatent/semiring.py | 2 +- 3 files changed, 10 insertions(+), 29 deletions(-) diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py index f81fc62..d4e57de 100644 --- a/torchlatent/crf/catting.py +++ b/torchlatent/crf/catting.py @@ -5,8 +5,8 @@ from torch.distributions.utils import lazy_property from torchrua import CattedSequence from torchrua import ReductionIndices, head_catted_indices -from torchrua import roll_catted_sequence, head_catted_sequence, last_catted_sequence +from torchlatent.crf2 import crf_segment_reduce from torchlatent.semiring import Semiring, Log, Max __all__ = [ @@ -20,31 +20,11 @@ def compute_catted_sequence_scores(semiring: Type[Semiring]): def _compute_catted_sequence_scores( emissions: CattedSequence, tags: CattedSequence, transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor: - device = transitions.device - - emission_scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c] - - h = emissions.token_sizes.size()[0] - t = torch.arange(transitions.size()[0], device=device) # [t] - c = torch.arange(transitions.size()[1], device=device) # [c] - - x, y = roll_catted_sequence(tags, shifts=1).data, tags.data # [t, c] - head = head_catted_sequence(tags) # [h, c] - last = last_catted_sequence(tags) # [h, c] - - transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c] - transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] - transition_last_scores = last_transitions[t[:h, None], c[None, :], last] # [h, c] - - head_indices = head_catted_indices(emissions.token_sizes) - transition_scores[head_indices] = transition_head_scores # [h, c] - - scores = semiring.mul(emission_scores, transition_scores) - scores = semiring.segment_prod(scores, sizes=emissions.token_sizes) - - scores = semiring.mul(scores, transition_last_scores) - - return scores + return crf_segment_reduce( + emissions=emissions, targets=tags, + transitions=(transitions, head_transitions, last_transitions), + semiring=semiring, + ) return _compute_catted_sequence_scores diff --git a/torchlatent/crf2.py b/torchlatent/crf2.py index 696f812..bd08072 100644 --- a/torchlatent/crf2.py +++ b/torchlatent/crf2.py @@ -48,6 +48,7 @@ def crf_segment_catted_indices(token_sizes: Tensor, device: Device = None): token_sizes = token_sizes.to(device=device) curr, _, token_sizes = segment_catted_indices(token_sizes=token_sizes, device=device) + prev = roll_catted_indices(token_sizes=token_sizes, shifts=1, device=device) head = head_catted_indices(token_sizes=token_sizes, device=device) last = last_catted_indices(token_sizes=token_sizes, device=device) @@ -98,8 +99,6 @@ def crf_segment_reduce(emissions: Sequence, targets: Sequence, else: raise NotImplementedError - c = torch.arange(c, device=emissions.device) - transitions, head_transitions, last_transitions = transitions emissions = emissions.expand((t, c, -1)) targets = targets.expand((t, c)) @@ -107,6 +106,8 @@ def crf_segment_reduce(emissions: Sequence, targets: Sequence, head_transitions = head_transitions.expand((h, c, -1)) last_transitions = last_transitions.expand((h, c, -1)) + c = torch.arange(c, device=emissions.device) + emissions = emissions[curr[:, None], c[None, :], targets[curr]] transitions = transitions[curr[:, None], c[None, :], targets[prev], targets[curr]] transitions[accumulate_sizes(sizes=sizes)] = semiring.one diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 4602d73..62f7b4b 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -40,7 +40,7 @@ def segment_catted_indices(token_sizes: Tensor, device: Device = None): token_sizes = token_sizes.to(device=device) batch_ptr = torch.repeat_interleave(repeats=token_sizes) - return ..., batch_ptr, token_sizes + return torch.arange(batch_ptr.size()[0], device=device), batch_ptr, token_sizes @torch.no_grad() From a67090be92ea32684dccae920d38382eebca2c83 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 18:03:07 +0900 Subject: [PATCH 020/102] Feat: Impl compute_packed_sequence_scores by using crf_segment_reduce --- torchlatent/crf/packing.py | 39 ++++++++------------------------------ 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/torchlatent/crf/packing.py b/torchlatent/crf/packing.py index ec22c38..2fc6279 100644 --- a/torchlatent/crf/packing.py +++ b/torchlatent/crf/packing.py @@ -4,9 +4,9 @@ from torch import Tensor, autograd from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence -from torchrua import head_packed_indices, ReductionIndices -from torchrua import roll_packed_sequence, head_packed_sequence, last_packed_sequence, major_sizes_to_ptr +from torchrua import ReductionIndices +from torchlatent.crf2 import crf_segment_reduce from torchlatent.semiring import Semiring, Log, Max __all__ = [ @@ -20,35 +20,12 @@ def compute_packed_sequence_scores(semiring: Type[Semiring]): def _compute_packed_sequence_scores( emissions: PackedSequence, tags: PackedSequence, transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor: - device = transitions.device - - emission_scores = emissions.data.gather(dim=-1, index=tags.data[..., None])[..., 0] # [t, c] - - h = emissions.batch_sizes[0].item() - t = torch.arange(transitions.size()[0], device=device) # [t] - c = torch.arange(transitions.size()[1], device=device) # [c] - - x, y = roll_packed_sequence(tags, shifts=1).data, tags.data # [t, c] - head = head_packed_sequence(tags, unsort=False) # [h, c] - last = last_packed_sequence(tags, unsort=False) # [h, c] - - transition_scores = transitions[t[:, None], c[None, :], x, y] # [t, c] - transition_head_scores = head_transitions[t[:h, None], c[None, :], head] # [h, c] - transition_last_scores = last_transitions[t[:h, None], c[None, :], last] # [h, c] - - indices = head_packed_indices(tags.batch_sizes) - transition_scores[indices] = transition_head_scores # [h, c] - - batch_ptr, _ = major_sizes_to_ptr(sizes=emissions.batch_sizes) - scores = semiring.mul(emission_scores, transition_scores) - scores = semiring.scatter_mul(scores, index=batch_ptr) - - scores = semiring.mul(scores, transition_last_scores) - - if emissions.unsorted_indices is not None: - scores = scores[emissions.unsorted_indices] - - return scores + return crf_segment_reduce( + emissions=emissions, + targets=tags, + transitions=(transitions, head_transitions, last_transitions), + semiring=semiring, + ) return _compute_packed_sequence_scores From 8fd2c0c35b57a8d3682434372ce9e65b306bd10b Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 19:10:31 +0900 Subject: [PATCH 021/102] Feat: Add crf_partition --- torchlatent/crf/catting.py | 25 ++++++----------------- torchlatent/crf2.py | 41 +++++++++++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 20 deletions(-) diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py index d4e57de..cb3576b 100644 --- a/torchlatent/crf/catting.py +++ b/torchlatent/crf/catting.py @@ -6,7 +6,7 @@ from torchrua import CattedSequence from torchrua import ReductionIndices, head_catted_indices -from torchlatent.crf2 import crf_segment_reduce +from torchlatent.crf2 import crf_segment_reduce, crf_partition from torchlatent.semiring import Semiring, Log, Max __all__ = [ @@ -33,24 +33,11 @@ def compute_catted_sequence_partitions(semiring: Type[Semiring]): def _compute_catted_sequence_partitions( emissions: CattedSequence, indices: ReductionIndices, transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor, eye: Tensor) -> Tensor: - h = emissions.token_sizes.size()[0] - t = torch.arange(transitions.size()[0], device=transitions.device) # [t] - c = torch.arange(transitions.size()[1], device=transitions.device) # [c] - head_indices = head_catted_indices(emissions.token_sizes) - - emission_scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] - emission_scores[head_indices] = eye[None, None, :, :] - emission_scores = semiring.reduce(tensor=emission_scores, indices=indices) - - emission_head_scores = emissions.data[head_indices, :, None, :] - transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] - transition_last_scores = last_transitions[t[:h, None], c[None, :], :, None] - - scores = semiring.mul(transition_head_scores, emission_head_scores) - scores = semiring.bmm(scores, emission_scores) - scores = semiring.bmm(scores, transition_last_scores)[..., 0, 0] - - return scores + return crf_partition( + emissions=emissions, indices=indices, + transitions=(transitions, head_transitions, last_transitions), + semiring=semiring, + ) return _compute_catted_sequence_partitions diff --git a/torchlatent/crf2.py b/torchlatent/crf2.py index bd08072..f045a69 100644 --- a/torchlatent/crf2.py +++ b/torchlatent/crf2.py @@ -5,7 +5,7 @@ from torch.nn.utils.rnn import PackedSequence from torch.types import Device from torchrua import roll_catted_indices, CattedSequence, head_catted_indices, last_catted_indices, head_packed_indices, \ - last_packed_indices, accumulate_sizes + last_packed_indices, accumulate_sizes, ReductionIndices from torchlatent.semiring import segment_catted_indices, segment_packed_indices, Semiring @@ -116,3 +116,42 @@ def crf_segment_reduce(emissions: Sequence, targets: Sequence, emissions = semiring.segment_prod(semiring.mul(emissions, transitions), sizes=sizes) return semiring.mul(emissions, semiring.mul(head_transitions, last_transitions)) + + +def crf_partition(emissions: Sequence, indices: ReductionIndices, + transitions: Tuple[Tensor, Tensor, Tensor], semiring: Type[Semiring]): + if isinstance(emissions, CattedSequence): + t, c, h = broadcast_catted_shapes(emissions, transitions=transitions) + emissions, token_sizes = emissions + prev, curr, unsorted_indices, head, last, sizes = crf_segment_catted_indices( + token_sizes=token_sizes, device=emissions.device, + ) + elif isinstance(emissions, PackedSequence): + t, c, h = broadcast_packed_shapes(emissions, transitions=transitions) + emissions, batch_sizes, _, unsorted_indices = emissions + prev, curr, unsorted_indices, head, last, sizes = crf_segment_packed_indices( + batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=emissions.device, + ) + else: + raise NotImplementedError + + transitions, head_transitions, last_transitions = transitions + emissions = emissions.expand((t, c, -1)) + transitions = transitions.expand((t, c, -1, -1)) + head_transitions = head_transitions.expand((h, c, -1)) + last_transitions = last_transitions.expand((h, c, -1)) + + c = torch.arange(c, device=emissions.device) + + transitions = semiring.mul(emissions[:, :, None, :], transitions) + transitions[head] = semiring.eye_like(transitions)[None, None, :, :] + + head_emissions = emissions[head[:, None], c[None, :], None, :] + head_transitions = head_transitions[unsorted_indices[:, None], c[None, :], None, :] + last_transitions = last_transitions[unsorted_indices[:, None], c[None, :], :, None] + + scores = semiring.mul(head_emissions, head_transitions) + scores = semiring.bmm(scores, semiring.reduce(transitions, indices=indices)) + scores = semiring.bmm(scores, last_transitions) + + return scores[..., 0, 0] From db98438a4c7741f772d994efa0b32d60e5e31e89 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 19:39:34 +0900 Subject: [PATCH 022/102] Feat: Add CrfDecoder --- tests/test_cky.py | 2 +- torchlatent/abc.py | 17 ++++-- torchlatent/cky.py | 24 ++++---- torchlatent/crf2.py | 139 ++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 158 insertions(+), 24 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index 145b7c7..b2fbcd7 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -60,7 +60,7 @@ def test_cky_log_partitions(device, token_sizes, num_tags): token_sizes = torch.tensor(token_sizes, device=device) excepted = TreeCRF(log_potentials=scores, lengths=token_sizes) - actual = CkyDistribution(scores=scores, indices=cky_indices(token_sizes=token_sizes, device=device)) + actual = CkyDistribution(log_potentials=scores, indices=cky_indices(token_sizes=token_sizes, device=device)) assert_close(actual=actual.log_partitions, expected=excepted.partition) assert_grad_close(actual=actual.log_partitions, expected=excepted.partition, inputs=scores, rtol=1e-5, atol=1e-5) diff --git a/torchlatent/abc.py b/torchlatent/abc.py index 0346302..7268106 100644 --- a/torchlatent/abc.py +++ b/torchlatent/abc.py @@ -1,5 +1,6 @@ from abc import ABCMeta +import torch import torch.autograd from torch import Tensor from torch.distributions import Distribution @@ -13,17 +14,17 @@ class DistributionABC(Distribution, metaclass=ABCMeta): - scores: Tensor + log_potentials: Tensor - def log_scores(self, value: Sequence) -> Tensor: + def log_scores(self, targets: Sequence) -> Tensor: raise NotImplementedError @lazy_property def log_partitions(self) -> Tensor: raise NotImplementedError - def log_prob(self, value: Sequence) -> Tensor: - return self.log_scores(value=value) - self.log_partitions + def log_prob(self, targets: Sequence) -> Tensor: + return self.log_scores(targets=targets) - self.log_partitions @lazy_property def max(self) -> Tensor: @@ -32,14 +33,18 @@ def max(self) -> Tensor: @lazy_property def argmax(self) -> Tensor: grad, = torch.autograd.grad( - self.max, self.scores, torch.ones_like(self.max), + self.max, self.log_potentials, torch.ones_like(self.max), create_graph=False, only_inputs=True, allow_unused=False, ) return grad @lazy_property def marginals(self) -> Tensor: - raise NotImplementedError + grad, = torch.autograd.grad( + self.log_partitions, self.log_potentials, torch.ones_like(self.log_partitions), + create_graph=False, only_inputs=True, allow_unused=False, + ) + return grad @lazy_property def entropy(self) -> Tensor: diff --git a/torchlatent/cky.py b/torchlatent/cky.py index b3f3b33..714f66b 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -53,9 +53,9 @@ def cky_partition(data: Tensor, indices: CkyIndices, semiring: Type[Semiring]) - token_size, cache_size, (src1, src2), tgt = indices size = (token_size, cache_size, *data.size()[3:]) - tensor0 = torch.full(size, fill_value=semiring.zero, requires_grad=False) - tensor1 = torch.full(size, fill_value=semiring.zero, requires_grad=False) - tensor2 = torch.full(size, fill_value=semiring.zero, requires_grad=False) + tensor0 = torch.full(size, fill_targets=semiring.zero, requires_grad=False) + tensor1 = torch.full(size, fill_targets=semiring.zero, requires_grad=False) + tensor2 = torch.full(size, fill_targets=semiring.zero, requires_grad=False) tensor0[src1] = data[src2] tensor1[0, :] = tensor2[-1, :] = tensor0[0, :] @@ -70,27 +70,27 @@ def cky_partition(data: Tensor, indices: CkyIndices, semiring: Type[Semiring]) - class CkyDistribution(DistributionABC): - def __init__(self, scores: Tensor, indices: CkyIndices) -> None: + def __init__(self, log_potentials: Tensor, indices: CkyIndices) -> None: super(CkyDistribution, self).__init__(validate_args=False) - self.scores = scores + self.log_potentials = log_potentials self.indices = indices def log_scores(self, sequence: Sequence) -> Tensor: indices, batch_ptr, sizes = segment_indices(sequence=sequence) data = sequence.data[indices] return Log.segment_prod( - tensor=self.scores[batch_ptr, data[..., 0], data[..., 1], data[..., 2]], + tensor=self.log_potentials[batch_ptr, data[..., 0], data[..., 1], data[..., 2]], sizes=sizes, ) @lazy_property def log_partitions(self) -> Tensor: - return cky_partition(data=Log.sum(self.scores, dim=-1), indices=self.indices, semiring=Log) + return cky_partition(data=Log.sum(self.log_potentials, dim=-1), indices=self.indices, semiring=Log) @lazy_property def max(self) -> Tensor: - return cky_partition(data=Max.sum(self.scores, dim=-1), indices=self.indices, semiring=Max) + return cky_partition(data=Max.sum(self.log_potentials, dim=-1), indices=self.indices, semiring=Max) @lazy_property def argmax(self) -> Tensor: @@ -108,7 +108,7 @@ def argmax(self) -> Tensor: @lazy_property def marginals(self) -> Tensor: grad, = torch.autograd.grad( - self.log_partitions, self.scores, torch.ones_like(self.log_partitions), + self.log_partitions, self.log_potentials, torch.ones_like(self.log_partitions), create_graph=True, only_inputs=True, allow_unused=False, ) return grad @@ -142,13 +142,13 @@ def forward(self, sequence: Sequence, indices: CkyIndices = None) -> CkyDistribu indices = cky_indices(token_sizes=token_sizes, device=features.device) return CkyDistribution( - scores=self.forward_scores(features=features), + log_potentials=self.forward_scores(features=features), indices=indices, ) - def fit(self, sequence: Sequence, value: Sequence, indices: CkyIndices = None) -> Tensor: + def fit(self, sequence: Sequence, targets: Sequence, indices: CkyIndices = None) -> Tensor: dist = self.forward(sequence=sequence, indices=indices) - return dist.log_partitions - dist.log_scores(sequence=value) + return dist.log_partitions - dist.log_scores(sequence=targets) def decode(self, sequence: Sequence, indices: CkyIndices = None) -> Sequence: dist = self.forward(sequence=sequence, indices=indices) diff --git a/torchlatent/crf2.py b/torchlatent/crf2.py index f045a69..32f37ad 100644 --- a/torchlatent/crf2.py +++ b/torchlatent/crf2.py @@ -1,13 +1,23 @@ -from typing import Tuple, Sequence, Type +from typing import Sequence +from typing import Tuple +from typing import Type import torch +import torchcrf from torch import Tensor +from torch import nn +from torch.distributions.utils import lazy_property +from torch.nn import init from torch.nn.utils.rnn import PackedSequence from torch.types import Device +from torchrua import reduce_catted_indices, reduce_packed_indices from torchrua import roll_catted_indices, CattedSequence, head_catted_indices, last_catted_indices, head_packed_indices, \ - last_packed_indices, accumulate_sizes, ReductionIndices + last_packed_indices, accumulate_sizes, ReductionIndices, pack_sequence, pad_sequence, pad_packed_indices -from torchlatent.semiring import segment_catted_indices, segment_packed_indices, Semiring +from torchlatent.abc import DistributionABC +from torchlatent.semiring import segment_catted_indices, segment_packed_indices, Semiring, Log, Max + +CrfIndices = ReductionIndices @torch.no_grad() @@ -70,8 +80,8 @@ def crf_segment_packed_indices(batch_sizes: Tensor, unsorted_indices: Tensor, de batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=device, ) prev = roll_catted_indices(token_sizes=token_sizes, shifts=1, device=device) - head = head_packed_indices(token_sizes=token_sizes, unsorted_indices=unsorted_indices, device=device) - last = last_packed_indices(token_sizes=token_sizes, unsorted_indices=unsorted_indices, device=device) + head = head_packed_indices(batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=device) + last = last_packed_indices(batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=device) return curr[prev], curr, unsorted_indices, head, last, token_sizes @@ -155,3 +165,122 @@ def crf_partition(emissions: Sequence, indices: ReductionIndices, scores = semiring.bmm(scores, last_transitions) return scores[..., 0, 0] + + +class CrfDistribution(DistributionABC): + def __init__(self, log_potentials: Sequence, indices: CrfIndices, + transitions: Tuple[Tensor, Tensor, Tensor]) -> None: + super(CrfDistribution, self).__init__(validate_args=False) + + self.log_potentials = log_potentials + self.indices = indices + self.transitions = transitions + + def log_scores(self, targets: Sequence) -> Tensor: + return crf_segment_reduce( + emissions=self.log_potentials, + targets=targets, + transitions=self.transitions, + semiring=Log, + ) + + @lazy_property + def log_partitions(self) -> Tensor: + return crf_partition( + emissions=self.log_potentials, + indices=self.indices, + transitions=self.transitions, + semiring=Log, + ) + + @lazy_property + def max(self) -> Tensor: + return crf_partition( + emissions=self.log_potentials, + indices=self.indices, + transitions=self.transitions, + semiring=Max, + ) + + @lazy_property + def entropy(self) -> Tensor: + raise NotImplementedError + + +class CrfDecoder(nn.Module): + def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: + super(CrfDecoder, self).__init__() + + self.num_tags = num_tags + self.num_conjugates = num_conjugates + + self.transitions = nn.Parameter(torch.empty((1, num_conjugates, num_tags, num_tags))) + self.head_transitions = nn.Parameter(torch.empty((1, num_conjugates, num_tags))) + self.last_transitions = nn.Parameter(torch.empty((1, num_conjugates, num_tags))) + + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self) -> None: + init.zeros_(self.transitions) + init.zeros_(self.head_transitions) + init.zeros_(self.last_transitions) + + def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistribution: + if indices is None: + if isinstance(emissions, CattedSequence): + indices = reduce_catted_indices( + token_sizes=emissions.token_sizes, + device=emissions.data.device, + ) + elif isinstance(emissions, PackedSequence): + + indices = reduce_packed_indices( + batch_sizes=emissions.batch_sizes, + unsorted_indices=emissions.unsorted_indices, + device=emissions.data.device, + ) + else: + raise NotImplementedError + + return CrfDistribution( + log_potentials=emissions, + indices=indices, + transitions=(self.transitions, self.head_transitions, self.last_transitions), + ) + + +if __name__ == '__main__': + num_tags = 3 + + decoder1 = CrfDecoder(num_tags) + decoder2 = torchcrf.CRF(num_tags, batch_first=False) + + decoder1.transitions.data = decoder2.transitions[None, None, :, :] + decoder1.head_transitions.data = decoder2.start_transitions[None, None, :] + decoder1.last_transitions.data = decoder2.end_transitions[None, None, :] + + sequence = [ + torch.randn((5, num_tags), requires_grad=True), + torch.randn((2, num_tags), requires_grad=True), + torch.randn((3, num_tags), requires_grad=True), + ] + + token_sizes = torch.tensor([5, 2, 3]) + + e1 = pack_sequence([s[:, None, :] for s in sequence]) + + e2, _ = pad_sequence(sequence, batch_first=False) + size, ptr, _ = pad_packed_indices( + e1.batch_sizes, False, e1.sorted_indices, e1.unsorted_indices + ) + mask = torch.zeros(size, dtype=torch.bool) + mask[ptr] = True + + dist = decoder1.forward(e1) + lhs = dist.log_partitions[:, 0] + rhs = decoder2._compute_normalizer(e2, mask) + print(f'lhs => {lhs}') + print(f'rhs => {rhs}') + + print(torch.allclose(lhs, rhs)) From ba6f89c82fc611519c8d3bc15194740268698c7f Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 19:48:29 +0900 Subject: [PATCH 023/102] Refactor: Separate broadcast operations --- torchlatent/crf2.py | 81 ++++++++++++++++++--------------------------- 1 file changed, 32 insertions(+), 49 deletions(-) diff --git a/torchlatent/crf2.py b/torchlatent/crf2.py index 32f37ad..7e6bafe 100644 --- a/torchlatent/crf2.py +++ b/torchlatent/crf2.py @@ -10,9 +10,9 @@ from torch.nn import init from torch.nn.utils.rnn import PackedSequence from torch.types import Device -from torchrua import reduce_catted_indices, reduce_packed_indices +from torchrua import reduce_catted_indices, reduce_packed_indices, pad_catted_indices, cat_sequence from torchrua import roll_catted_indices, CattedSequence, head_catted_indices, last_catted_indices, head_packed_indices, \ - last_packed_indices, accumulate_sizes, ReductionIndices, pack_sequence, pad_sequence, pad_packed_indices + last_packed_indices, accumulate_sizes, ReductionIndices, pad_sequence from torchlatent.abc import DistributionABC from torchlatent.semiring import segment_catted_indices, segment_packed_indices, Semiring, Log, Max @@ -100,23 +100,8 @@ def crf_segment_reduce(emissions: Sequence, targets: Sequence, else: raise NotImplementedError - if isinstance(targets, CattedSequence): - t, c, h = broadcast_catted_shapes(targets, transitions=transitions) - targets, _ = targets - elif isinstance(targets, PackedSequence): - t, c, h = broadcast_packed_shapes(targets, transitions=transitions) - targets, _, _, _ = targets - else: - raise NotImplementedError - transitions, head_transitions, last_transitions = transitions - emissions = emissions.expand((t, c, -1)) - targets = targets.expand((t, c)) - transitions = transitions.expand((t, c, -1, -1)) - head_transitions = head_transitions.expand((h, c, -1)) - last_transitions = last_transitions.expand((h, c, -1)) - - c = torch.arange(c, device=emissions.device) + c = torch.arange(transitions.size()[1], device=emissions.device) emissions = emissions[curr[:, None], c[None, :], targets[curr]] transitions = transitions[curr[:, None], c[None, :], targets[prev], targets[curr]] @@ -131,13 +116,11 @@ def crf_segment_reduce(emissions: Sequence, targets: Sequence, def crf_partition(emissions: Sequence, indices: ReductionIndices, transitions: Tuple[Tensor, Tensor, Tensor], semiring: Type[Semiring]): if isinstance(emissions, CattedSequence): - t, c, h = broadcast_catted_shapes(emissions, transitions=transitions) emissions, token_sizes = emissions prev, curr, unsorted_indices, head, last, sizes = crf_segment_catted_indices( token_sizes=token_sizes, device=emissions.device, ) elif isinstance(emissions, PackedSequence): - t, c, h = broadcast_packed_shapes(emissions, transitions=transitions) emissions, batch_sizes, _, unsorted_indices = emissions prev, curr, unsorted_indices, head, last, sizes = crf_segment_packed_indices( batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=emissions.device, @@ -146,21 +129,15 @@ def crf_partition(emissions: Sequence, indices: ReductionIndices, raise NotImplementedError transitions, head_transitions, last_transitions = transitions - emissions = emissions.expand((t, c, -1)) - transitions = transitions.expand((t, c, -1, -1)) - head_transitions = head_transitions.expand((h, c, -1)) - last_transitions = last_transitions.expand((h, c, -1)) - - c = torch.arange(c, device=emissions.device) + c = torch.arange(transitions.size()[1], device=emissions.device) transitions = semiring.mul(emissions[:, :, None, :], transitions) transitions[head] = semiring.eye_like(transitions)[None, None, :, :] - head_emissions = emissions[head[:, None], c[None, :], None, :] head_transitions = head_transitions[unsorted_indices[:, None], c[None, :], None, :] last_transitions = last_transitions[unsorted_indices[:, None], c[None, :], :, None] - scores = semiring.mul(head_emissions, head_transitions) + scores = semiring.mul(emissions[head[:, None], c[None, :], None, :], head_transitions) scores = semiring.bmm(scores, semiring.reduce(transitions, indices=indices)) scores = semiring.bmm(scores, last_transitions) @@ -227,26 +204,31 @@ def reset_parameters(self) -> None: init.zeros_(self.last_transitions) def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistribution: - if indices is None: - if isinstance(emissions, CattedSequence): - indices = reduce_catted_indices( - token_sizes=emissions.token_sizes, - device=emissions.data.device, - ) - elif isinstance(emissions, PackedSequence): - - indices = reduce_packed_indices( - batch_sizes=emissions.batch_sizes, - unsorted_indices=emissions.unsorted_indices, - device=emissions.data.device, - ) - else: - raise NotImplementedError + transitions = (self.transitions, self.head_transitions, self.last_transitions) + if isinstance(emissions, CattedSequence): + t, c, h = broadcast_catted_shapes(sequence=emissions, transitions=transitions) + indices = reduce_catted_indices( + token_sizes=emissions.token_sizes, + device=emissions.data.device, + ) + elif isinstance(emissions, PackedSequence): + t, c, h = broadcast_packed_shapes(sequence=emissions, transitions=transitions) + indices = reduce_packed_indices( + batch_sizes=emissions.batch_sizes, + unsorted_indices=emissions.unsorted_indices, + device=emissions.data.device, + ) + else: + raise NotImplementedError + + emissions = emissions._replace(data=emissions.data.expand((t, c, -1))) + transitions = self.transitions.expand((t, c, -1, -1)) + head_transitions = self.head_transitions.expand((h, c, -1)) + last_transitions = self.last_transitions.expand((h, c, -1)) return CrfDistribution( - log_potentials=emissions, - indices=indices, - transitions=(self.transitions, self.head_transitions, self.last_transitions), + log_potentials=emissions, indices=indices, + transitions=(transitions, head_transitions, last_transitions), ) @@ -268,11 +250,12 @@ def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistrib token_sizes = torch.tensor([5, 2, 3]) - e1 = pack_sequence([s[:, None, :] for s in sequence]) + e1 = cat_sequence([s[:, None, :] for s in sequence]) e2, _ = pad_sequence(sequence, batch_first=False) - size, ptr, _ = pad_packed_indices( - e1.batch_sizes, False, e1.sorted_indices, e1.unsorted_indices + size, ptr = pad_catted_indices( + e1.token_sizes, False, + # e1.sorted_indices, e1.unsorted_indices ) mask = torch.zeros(size, dtype=torch.bool) mask[ptr] = True From a7330a206054688144a9e6849a386cd094be9108 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 20:49:21 +0900 Subject: [PATCH 024/102] Refactor: Check _compute_score --- torchlatent/crf2.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/torchlatent/crf2.py b/torchlatent/crf2.py index 7e6bafe..57d9140 100644 --- a/torchlatent/crf2.py +++ b/torchlatent/crf2.py @@ -89,11 +89,13 @@ def crf_segment_reduce(emissions: Sequence, targets: Sequence, transitions: Tuple[Tensor, Tensor, Tensor], semiring: Type[Semiring]) -> Tensor: if isinstance(emissions, CattedSequence): emissions, token_sizes = emissions + targets, _ = targets prev, curr, unsorted_indices, head, last, sizes = crf_segment_catted_indices( token_sizes=token_sizes, device=emissions.device, ) elif isinstance(emissions, PackedSequence): emissions, batch_sizes, _, unsorted_indices = emissions + targets, _, _, _ = targets prev, curr, unsorted_indices, head, last, sizes = crf_segment_packed_indices( batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=emissions.device, ) @@ -117,12 +119,12 @@ def crf_partition(emissions: Sequence, indices: ReductionIndices, transitions: Tuple[Tensor, Tensor, Tensor], semiring: Type[Semiring]): if isinstance(emissions, CattedSequence): emissions, token_sizes = emissions - prev, curr, unsorted_indices, head, last, sizes = crf_segment_catted_indices( + _, _, unsorted_indices, head, _, _ = crf_segment_catted_indices( token_sizes=token_sizes, device=emissions.device, ) elif isinstance(emissions, PackedSequence): emissions, batch_sizes, _, unsorted_indices = emissions - prev, curr, unsorted_indices, head, last, sizes = crf_segment_packed_indices( + _, _, unsorted_indices, head, _, _ = crf_segment_packed_indices( batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=emissions.device, ) else: @@ -242,28 +244,35 @@ def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistrib decoder1.head_transitions.data = decoder2.start_transitions[None, None, :] decoder1.last_transitions.data = decoder2.end_transitions[None, None, :] - sequence = [ + ss = [ torch.randn((5, num_tags), requires_grad=True), torch.randn((2, num_tags), requires_grad=True), torch.randn((3, num_tags), requires_grad=True), ] + tt = [ + torch.randint(0, num_tags, (5,)), + torch.randint(0, num_tags, (2,)), + torch.randint(0, num_tags, (3,)), + ] - token_sizes = torch.tensor([5, 2, 3]) - - e1 = cat_sequence([s[:, None, :] for s in sequence]) + e1 = cat_sequence([s[:, None] for s in ss]) + t1 = cat_sequence([t[:, None] for t in tt]) - e2, _ = pad_sequence(sequence, batch_first=False) - size, ptr = pad_catted_indices( - e1.token_sizes, False, - # e1.sorted_indices, e1.unsorted_indices - ) + e2, _ = pad_sequence(ss, batch_first=False) + t2, _ = pad_sequence(tt, batch_first=False) + size, ptr = pad_catted_indices(e1.token_sizes, batch_first=False) mask = torch.zeros(size, dtype=torch.bool) mask[ptr] = True dist = decoder1.forward(e1) lhs = dist.log_partitions[:, 0] - rhs = decoder2._compute_normalizer(e2, mask) + rhs = decoder2._compute_normalizer(e2, mask=mask) print(f'lhs => {lhs}') print(f'rhs => {rhs}') + print(torch.allclose(lhs, rhs)) + lhs = dist.log_scores(t1)[:, 0] + rhs = decoder2._compute_score(e2, t2, mask=mask) + print(f'lhs => {lhs}') + print(f'rhs => {rhs}') print(torch.allclose(lhs, rhs)) From 9822a563b0d828545104d487bdf250f77545000c Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 21:00:05 +0900 Subject: [PATCH 025/102] Refactor: Add CrfIndices --- torchlatent/crf2.py | 108 +++++++++++++++++++++++++------------------- 1 file changed, 61 insertions(+), 47 deletions(-) diff --git a/torchlatent/crf2.py b/torchlatent/crf2.py index 57d9140..3b0b7c9 100644 --- a/torchlatent/crf2.py +++ b/torchlatent/crf2.py @@ -1,4 +1,4 @@ -from typing import Sequence +from typing import Sequence, NamedTuple from typing import Tuple from typing import Type @@ -10,14 +10,22 @@ from torch.nn import init from torch.nn.utils.rnn import PackedSequence from torch.types import Device -from torchrua import reduce_catted_indices, reduce_packed_indices, pad_catted_indices, cat_sequence +from torchrua import reduce_catted_indices, reduce_packed_indices, pad_catted_indices, cat_sequence, cat_packed_indices from torchrua import roll_catted_indices, CattedSequence, head_catted_indices, last_catted_indices, head_packed_indices, \ last_packed_indices, accumulate_sizes, ReductionIndices, pad_sequence from torchlatent.abc import DistributionABC -from torchlatent.semiring import segment_catted_indices, segment_packed_indices, Semiring, Log, Max +from torchlatent.semiring import Semiring, Log, Max -CrfIndices = ReductionIndices + +class CrfIndices(NamedTuple): + head: Tensor + last: Tensor + prev: Tensor + curr: Tensor + token_sizes: Tensor + unsorted_indices: Tensor + indices: ReductionIndices @torch.no_grad() @@ -56,13 +64,13 @@ def crf_segment_catted_indices(token_sizes: Tensor, device: Device = None): device = token_sizes.device token_sizes = token_sizes.to(device=device) + curr = torch.arange(token_sizes.sum().item(), device=device) + unsorted_indices = torch.arange(token_sizes.size()[0], device=device) - curr, _, token_sizes = segment_catted_indices(token_sizes=token_sizes, device=device) - - prev = roll_catted_indices(token_sizes=token_sizes, shifts=1, device=device) + prev = roll_catted_indices(token_sizes=token_sizes, device=device, shifts=1) head = head_catted_indices(token_sizes=token_sizes, device=device) last = last_catted_indices(token_sizes=token_sizes, device=device) - return prev, curr, torch.arange(token_sizes.size()[0], device=device), head, last, token_sizes + return head, last, prev, curr, token_sizes, unsorted_indices @torch.no_grad() @@ -75,33 +83,54 @@ def crf_segment_packed_indices(batch_sizes: Tensor, unsorted_indices: Tensor, de batch_sizes = batch_sizes.to(device=device) unsorted_indices = unsorted_indices.to(device=device) + curr, token_sizes = cat_packed_indices(batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=device) - curr, _, token_sizes = segment_packed_indices( - batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=device, - ) - prev = roll_catted_indices(token_sizes=token_sizes, shifts=1, device=device) - head = head_packed_indices(batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=device) - last = last_packed_indices(batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=device) - return curr[prev], curr, unsorted_indices, head, last, token_sizes + prev = roll_catted_indices(token_sizes=token_sizes, device=device, shifts=1) + head = head_packed_indices(batch_sizes=batch_sizes, device=device, unsorted_indices=unsorted_indices) + last = last_packed_indices(batch_sizes=batch_sizes, device=device, unsorted_indices=unsorted_indices) + return head, last, curr[prev], curr, token_sizes, unsorted_indices -def crf_segment_reduce(emissions: Sequence, targets: Sequence, - transitions: Tuple[Tensor, Tensor, Tensor], semiring: Type[Semiring]) -> Tensor: +@torch.no_grad() +def crf_indices(emissions: Sequence): if isinstance(emissions, CattedSequence): - emissions, token_sizes = emissions - targets, _ = targets - prev, curr, unsorted_indices, head, last, sizes = crf_segment_catted_indices( - token_sizes=token_sizes, device=emissions.device, + head, last, prev, curr, token_sizes, unsorted_indices = crf_segment_catted_indices( + token_sizes=emissions.token_sizes, + device=emissions.data.device, + ) + indices = reduce_catted_indices( + token_sizes=emissions.token_sizes, + device=emissions.data.device, ) elif isinstance(emissions, PackedSequence): - emissions, batch_sizes, _, unsorted_indices = emissions - targets, _, _, _ = targets - prev, curr, unsorted_indices, head, last, sizes = crf_segment_packed_indices( - batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=emissions.device, + head, last, prev, curr, token_sizes, unsorted_indices = crf_segment_packed_indices( + batch_sizes=emissions.batch_sizes, + unsorted_indices=emissions.unsorted_indices, + device=emissions.data.device, + ) + indices = reduce_packed_indices( + batch_sizes=emissions.batch_sizes, + unsorted_indices=emissions.unsorted_indices, + device=emissions.data.device, ) else: raise NotImplementedError + return CrfIndices( + head=head, last=last, + prev=prev, curr=curr, + token_sizes=token_sizes, + unsorted_indices=unsorted_indices, + indices=indices, + ) + + +def crf_segment_reduce(emissions: Sequence, targets: Sequence, indices: CrfIndices, + transitions: Tuple[Tensor, Tensor, Tensor], semiring: Type[Semiring]) -> Tensor: + head, last, prev, curr, sizes, unsorted_indices, _ = indices + + emissions = emissions.data + targets = targets.data transitions, head_transitions, last_transitions = transitions c = torch.arange(transitions.size()[1], device=emissions.device) @@ -115,21 +144,11 @@ def crf_segment_reduce(emissions: Sequence, targets: Sequence, return semiring.mul(emissions, semiring.mul(head_transitions, last_transitions)) -def crf_partition(emissions: Sequence, indices: ReductionIndices, +def crf_partition(emissions: Sequence, indices: CrfIndices, transitions: Tuple[Tensor, Tensor, Tensor], semiring: Type[Semiring]): - if isinstance(emissions, CattedSequence): - emissions, token_sizes = emissions - _, _, unsorted_indices, head, _, _ = crf_segment_catted_indices( - token_sizes=token_sizes, device=emissions.device, - ) - elif isinstance(emissions, PackedSequence): - emissions, batch_sizes, _, unsorted_indices = emissions - _, _, unsorted_indices, head, _, _ = crf_segment_packed_indices( - batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=emissions.device, - ) - else: - raise NotImplementedError + head, _, _, _, _, unsorted_indices, indices = indices + emissions = emissions.data transitions, head_transitions, last_transitions = transitions c = torch.arange(transitions.size()[1], device=emissions.device) @@ -161,6 +180,7 @@ def log_scores(self, targets: Sequence) -> Tensor: targets=targets, transitions=self.transitions, semiring=Log, + indices=self.indices, ) @lazy_property @@ -209,17 +229,8 @@ def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistrib transitions = (self.transitions, self.head_transitions, self.last_transitions) if isinstance(emissions, CattedSequence): t, c, h = broadcast_catted_shapes(sequence=emissions, transitions=transitions) - indices = reduce_catted_indices( - token_sizes=emissions.token_sizes, - device=emissions.data.device, - ) elif isinstance(emissions, PackedSequence): t, c, h = broadcast_packed_shapes(sequence=emissions, transitions=transitions) - indices = reduce_packed_indices( - batch_sizes=emissions.batch_sizes, - unsorted_indices=emissions.unsorted_indices, - device=emissions.data.device, - ) else: raise NotImplementedError @@ -228,6 +239,9 @@ def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistrib head_transitions = self.head_transitions.expand((h, c, -1)) last_transitions = self.last_transitions.expand((h, c, -1)) + if indices is None: + indices = crf_indices(emissions=emissions) + return CrfDistribution( log_potentials=emissions, indices=indices, transitions=(transitions, head_transitions, last_transitions), From 1cec9401b637c4d5e54d751406be5f0f2b7a277e Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 21:35:12 +0900 Subject: [PATCH 026/102] Refactor: Add forward_parameters --- torchlatent/abc.py | 6 ++-- torchlatent/crf/catting.py | 4 +-- torchlatent/crf/packing.py | 4 +-- torchlatent/crf2.py | 66 ++++++++++++++++++++------------------ 4 files changed, 42 insertions(+), 38 deletions(-) diff --git a/torchlatent/abc.py b/torchlatent/abc.py index 7268106..7cfc891 100644 --- a/torchlatent/abc.py +++ b/torchlatent/abc.py @@ -14,7 +14,7 @@ class DistributionABC(Distribution, metaclass=ABCMeta): - log_potentials: Tensor + emissions: Tensor def log_scores(self, targets: Sequence) -> Tensor: raise NotImplementedError @@ -33,7 +33,7 @@ def max(self) -> Tensor: @lazy_property def argmax(self) -> Tensor: grad, = torch.autograd.grad( - self.max, self.log_potentials, torch.ones_like(self.max), + self.max, self.emissions, torch.ones_like(self.max), create_graph=False, only_inputs=True, allow_unused=False, ) return grad @@ -41,7 +41,7 @@ def argmax(self) -> Tensor: @lazy_property def marginals(self) -> Tensor: grad, = torch.autograd.grad( - self.log_partitions, self.log_potentials, torch.ones_like(self.log_partitions), + self.log_partitions, self.emissions, torch.ones_like(self.log_partitions), create_graph=False, only_inputs=True, allow_unused=False, ) return grad diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py index cb3576b..618f4ae 100644 --- a/torchlatent/crf/catting.py +++ b/torchlatent/crf/catting.py @@ -6,7 +6,7 @@ from torchrua import CattedSequence from torchrua import ReductionIndices, head_catted_indices -from torchlatent.crf2 import crf_segment_reduce, crf_partition +from torchlatent.crf2 import crf_reduce, crf_partition from torchlatent.semiring import Semiring, Log, Max __all__ = [ @@ -20,7 +20,7 @@ def compute_catted_sequence_scores(semiring: Type[Semiring]): def _compute_catted_sequence_scores( emissions: CattedSequence, tags: CattedSequence, transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor: - return crf_segment_reduce( + return crf_reduce( emissions=emissions, targets=tags, transitions=(transitions, head_transitions, last_transitions), semiring=semiring, diff --git a/torchlatent/crf/packing.py b/torchlatent/crf/packing.py index 2fc6279..21beac9 100644 --- a/torchlatent/crf/packing.py +++ b/torchlatent/crf/packing.py @@ -6,7 +6,7 @@ from torch.nn.utils.rnn import PackedSequence from torchrua import ReductionIndices -from torchlatent.crf2 import crf_segment_reduce +from torchlatent.crf2 import crf_reduce from torchlatent.semiring import Semiring, Log, Max __all__ = [ @@ -20,7 +20,7 @@ def compute_packed_sequence_scores(semiring: Type[Semiring]): def _compute_packed_sequence_scores( emissions: PackedSequence, tags: PackedSequence, transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor: - return crf_segment_reduce( + return crf_reduce( emissions=emissions, targets=tags, transitions=(transitions, head_transitions, last_transitions), diff --git a/torchlatent/crf2.py b/torchlatent/crf2.py index 3b0b7c9..d90084c 100644 --- a/torchlatent/crf2.py +++ b/torchlatent/crf2.py @@ -10,9 +10,10 @@ from torch.nn import init from torch.nn.utils.rnn import PackedSequence from torch.types import Device -from torchrua import reduce_catted_indices, reduce_packed_indices, pad_catted_indices, cat_sequence, cat_packed_indices -from torchrua import roll_catted_indices, CattedSequence, head_catted_indices, last_catted_indices, head_packed_indices, \ - last_packed_indices, accumulate_sizes, ReductionIndices, pad_sequence +from torchrua import ReductionIndices, reduce_catted_indices, reduce_packed_indices +from torchrua import head_catted_indices, last_catted_indices, head_packed_indices, last_packed_indices, \ + accumulate_sizes, pad_sequence +from torchrua import pad_catted_indices, cat_sequence, cat_packed_indices, roll_catted_indices, CattedSequence from torchlatent.abc import DistributionABC from torchlatent.semiring import Semiring, Log, Max @@ -59,7 +60,7 @@ def broadcast_packed_shapes(sequence: PackedSequence, transitions: Tuple[Tensor, @torch.no_grad() -def crf_segment_catted_indices(token_sizes: Tensor, device: Device = None): +def crf_reduce_catted_indices(token_sizes: Tensor, device: Device = None): if device is None: device = token_sizes.device @@ -74,7 +75,7 @@ def crf_segment_catted_indices(token_sizes: Tensor, device: Device = None): @torch.no_grad() -def crf_segment_packed_indices(batch_sizes: Tensor, unsorted_indices: Tensor, device: Device): +def crf_reduce_packed_indices(batch_sizes: Tensor, unsorted_indices: Tensor, device: Device): if device is None: if unsorted_indices is not None: device = unsorted_indices.device @@ -92,9 +93,9 @@ def crf_segment_packed_indices(batch_sizes: Tensor, unsorted_indices: Tensor, de @torch.no_grad() -def crf_indices(emissions: Sequence): +def crf_indices(emissions: Sequence) -> CrfIndices: if isinstance(emissions, CattedSequence): - head, last, prev, curr, token_sizes, unsorted_indices = crf_segment_catted_indices( + head, last, prev, curr, token_sizes, unsorted_indices = crf_reduce_catted_indices( token_sizes=emissions.token_sizes, device=emissions.data.device, ) @@ -103,7 +104,7 @@ def crf_indices(emissions: Sequence): device=emissions.data.device, ) elif isinstance(emissions, PackedSequence): - head, last, prev, curr, token_sizes, unsorted_indices = crf_segment_packed_indices( + head, last, prev, curr, token_sizes, unsorted_indices = crf_reduce_packed_indices( batch_sizes=emissions.batch_sizes, unsorted_indices=emissions.unsorted_indices, device=emissions.data.device, @@ -114,7 +115,7 @@ def crf_indices(emissions: Sequence): device=emissions.data.device, ) else: - raise NotImplementedError + raise KeyError(f'type {type(emissions)} is not supported') return CrfIndices( head=head, last=last, @@ -125,12 +126,10 @@ def crf_indices(emissions: Sequence): ) -def crf_segment_reduce(emissions: Sequence, targets: Sequence, indices: CrfIndices, - transitions: Tuple[Tensor, Tensor, Tensor], semiring: Type[Semiring]) -> Tensor: +def crf_reduce(emissions: Tensor, targets: Tensor, transitions: Tuple[Tensor, Tensor, Tensor], + indices: CrfIndices, semiring: Type[Semiring]) -> Tensor: head, last, prev, curr, sizes, unsorted_indices, _ = indices - emissions = emissions.data - targets = targets.data transitions, head_transitions, last_transitions = transitions c = torch.arange(transitions.size()[1], device=emissions.device) @@ -144,11 +143,10 @@ def crf_segment_reduce(emissions: Sequence, targets: Sequence, indices: CrfIndic return semiring.mul(emissions, semiring.mul(head_transitions, last_transitions)) -def crf_partition(emissions: Sequence, indices: CrfIndices, - transitions: Tuple[Tensor, Tensor, Tensor], semiring: Type[Semiring]): +def crf_partition(emissions: Tensor, transitions: Tuple[Tensor, Tensor, Tensor], + indices: CrfIndices, semiring: Type[Semiring]): head, _, _, _, _, unsorted_indices, indices = indices - emissions = emissions.data transitions, head_transitions, last_transitions = transitions c = torch.arange(transitions.size()[1], device=emissions.device) @@ -166,38 +164,37 @@ def crf_partition(emissions: Sequence, indices: CrfIndices, class CrfDistribution(DistributionABC): - def __init__(self, log_potentials: Sequence, indices: CrfIndices, - transitions: Tuple[Tensor, Tensor, Tensor]) -> None: + def __init__(self, emissions: Tensor, transitions: Tuple[Tensor, Tensor, Tensor], indices: CrfIndices) -> None: super(CrfDistribution, self).__init__(validate_args=False) - self.log_potentials = log_potentials + self.emissions = emissions self.indices = indices self.transitions = transitions def log_scores(self, targets: Sequence) -> Tensor: - return crf_segment_reduce( - emissions=self.log_potentials, - targets=targets, + return crf_reduce( + emissions=self.emissions, + targets=targets.data, transitions=self.transitions, - semiring=Log, indices=self.indices, + semiring=Log, ) @lazy_property def log_partitions(self) -> Tensor: return crf_partition( - emissions=self.log_potentials, - indices=self.indices, + emissions=self.emissions, transitions=self.transitions, + indices=self.indices, semiring=Log, ) @lazy_property def max(self) -> Tensor: return crf_partition( - emissions=self.log_potentials, - indices=self.indices, + emissions=self.emissions, transitions=self.transitions, + indices=self.indices, semiring=Max, ) @@ -225,8 +222,9 @@ def reset_parameters(self) -> None: init.zeros_(self.head_transitions) init.zeros_(self.last_transitions) - def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistribution: + def forward_parameters(self, emissions: Sequence): transitions = (self.transitions, self.head_transitions, self.last_transitions) + if isinstance(emissions, CattedSequence): t, c, h = broadcast_catted_shapes(sequence=emissions, transitions=transitions) elif isinstance(emissions, PackedSequence): @@ -234,17 +232,23 @@ def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistrib else: raise NotImplementedError - emissions = emissions._replace(data=emissions.data.expand((t, c, -1))) + emissions = emissions.data.expand((t, c, -1)) transitions = self.transitions.expand((t, c, -1, -1)) head_transitions = self.head_transitions.expand((h, c, -1)) last_transitions = self.last_transitions.expand((h, c, -1)) + return emissions, (transitions, head_transitions, last_transitions) + + def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistribution: if indices is None: indices = crf_indices(emissions=emissions) + emissions, transitions = self.forward_parameters(emissions=emissions) + return CrfDistribution( - log_potentials=emissions, indices=indices, - transitions=(transitions, head_transitions, last_transitions), + emissions=emissions, + transitions=transitions, + indices=indices, ) From 3331870ad1c7029a5a4cc803649c38faa5f7a767 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 21:45:22 +0900 Subject: [PATCH 027/102] Test: Add test_crf_catted_fit --- tests/test_crf2.py | 48 ++++++++++++++++++++++++++++++++++++++++++ torchlatent/crf2.py | 51 ++++++--------------------------------------- 2 files changed, 54 insertions(+), 45 deletions(-) create mode 100644 tests/test_crf2.py diff --git a/tests/test_crf2.py b/tests/test_crf2.py new file mode 100644 index 0000000..d846416 --- /dev/null +++ b/tests/test_crf2.py @@ -0,0 +1,48 @@ +import torch +import torchcrf +from hypothesis import given +from torch.testing import assert_close +from torchrua import cat_sequence, pad_sequence, pad_catted_indices + +from tests.strategies import devices, sizes, TOKEN_SIZE, TINY_BATCH_SIZE +from tests.utils import assert_grad_close +from torchlatent.crf2 import CrfDecoder + + +@given( + device=devices(), + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + num_tags=sizes(TOKEN_SIZE), +) +def test_crf_catted_fit(device, token_sizes, num_tags): + decoder1 = CrfDecoder(num_tags) + decoder2 = torchcrf.CRF(num_tags, batch_first=False) + + decoder1.transitions.data = decoder2.transitions[None, None, :, :] + decoder1.head_transitions.data = decoder2.start_transitions[None, None, :] + decoder1.last_transitions.data = decoder2.end_transitions[None, None, :] + + emissions = [ + torch.randn((token_size, num_tags), requires_grad=True, device=device) + for token_size in token_sizes + ] + targets = [ + torch.randint(0, num_tags, (token_size,), device=device) + for token_size in token_sizes + ] + + catted_emissions = cat_sequence([x[:, None] for x in emissions]) + catted_targets = cat_sequence([x[:, None] for x in targets]) + + padded_emissions, _ = pad_sequence(emissions, batch_first=False) + padded_targets, _ = pad_sequence(targets, batch_first=False) + + size, ptr = pad_catted_indices(catted_emissions.token_sizes, batch_first=False) + mask = torch.zeros(size, dtype=torch.bool, device=device) + mask[ptr] = True + + actual = decoder1.fit(catted_emissions, catted_targets)[:, 0] + excepted = decoder2.forward(padded_emissions, padded_targets, mask=mask, reduction='none').neg() + + assert_close(actual=actual, expected=excepted) + assert_grad_close(actual=actual, expected=excepted, inputs=emissions) diff --git a/torchlatent/crf2.py b/torchlatent/crf2.py index d90084c..bfb801b 100644 --- a/torchlatent/crf2.py +++ b/torchlatent/crf2.py @@ -3,7 +3,6 @@ from typing import Type import torch -import torchcrf from torch import Tensor from torch import nn from torch.distributions.utils import lazy_property @@ -12,8 +11,8 @@ from torch.types import Device from torchrua import ReductionIndices, reduce_catted_indices, reduce_packed_indices from torchrua import head_catted_indices, last_catted_indices, head_packed_indices, last_packed_indices, \ - accumulate_sizes, pad_sequence -from torchrua import pad_catted_indices, cat_sequence, cat_packed_indices, roll_catted_indices, CattedSequence + accumulate_sizes +from torchrua import cat_packed_indices, roll_catted_indices, CattedSequence from torchlatent.abc import DistributionABC from torchlatent.semiring import Semiring, Log, Max @@ -251,46 +250,8 @@ def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistrib indices=indices, ) + def fit(self, emissions: Sequence, targets: Sequence, indices: CrfIndices = None) -> Tensor: + dist = self.forward(emissions=emissions, indices=indices) + return dist.log_partitions - dist.log_scores(targets=targets) + -if __name__ == '__main__': - num_tags = 3 - - decoder1 = CrfDecoder(num_tags) - decoder2 = torchcrf.CRF(num_tags, batch_first=False) - - decoder1.transitions.data = decoder2.transitions[None, None, :, :] - decoder1.head_transitions.data = decoder2.start_transitions[None, None, :] - decoder1.last_transitions.data = decoder2.end_transitions[None, None, :] - - ss = [ - torch.randn((5, num_tags), requires_grad=True), - torch.randn((2, num_tags), requires_grad=True), - torch.randn((3, num_tags), requires_grad=True), - ] - tt = [ - torch.randint(0, num_tags, (5,)), - torch.randint(0, num_tags, (2,)), - torch.randint(0, num_tags, (3,)), - ] - - e1 = cat_sequence([s[:, None] for s in ss]) - t1 = cat_sequence([t[:, None] for t in tt]) - - e2, _ = pad_sequence(ss, batch_first=False) - t2, _ = pad_sequence(tt, batch_first=False) - size, ptr = pad_catted_indices(e1.token_sizes, batch_first=False) - mask = torch.zeros(size, dtype=torch.bool) - mask[ptr] = True - - dist = decoder1.forward(e1) - lhs = dist.log_partitions[:, 0] - rhs = decoder2._compute_normalizer(e2, mask=mask) - print(f'lhs => {lhs}') - print(f'rhs => {rhs}') - print(torch.allclose(lhs, rhs)) - - lhs = dist.log_scores(t1)[:, 0] - rhs = decoder2._compute_score(e2, t2, mask=mask) - print(f'lhs => {lhs}') - print(f'rhs => {rhs}') - print(torch.allclose(lhs, rhs)) From b984f41b46637326127e4207bf22e7693ecd84d9 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 21:46:43 +0900 Subject: [PATCH 028/102] Test: Add test_crf_packed_fit --- tests/test_crf2.py | 66 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 58 insertions(+), 8 deletions(-) diff --git a/tests/test_crf2.py b/tests/test_crf2.py index d846416..4ae8b74 100644 --- a/tests/test_crf2.py +++ b/tests/test_crf2.py @@ -2,7 +2,7 @@ import torchcrf from hypothesis import given from torch.testing import assert_close -from torchrua import cat_sequence, pad_sequence, pad_catted_indices +from torchrua import cat_sequence, pad_sequence, pad_catted_indices, pad_packed_indices, pack_sequence from tests.strategies import devices, sizes, TOKEN_SIZE, TINY_BATCH_SIZE from tests.utils import assert_grad_close @@ -15,12 +15,12 @@ num_tags=sizes(TOKEN_SIZE), ) def test_crf_catted_fit(device, token_sizes, num_tags): - decoder1 = CrfDecoder(num_tags) - decoder2 = torchcrf.CRF(num_tags, batch_first=False) + actual_decoder = CrfDecoder(num_tags) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) - decoder1.transitions.data = decoder2.transitions[None, None, :, :] - decoder1.head_transitions.data = decoder2.start_transitions[None, None, :] - decoder1.last_transitions.data = decoder2.end_transitions[None, None, :] + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] + actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] + actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] emissions = [ torch.randn((token_size, num_tags), requires_grad=True, device=device) @@ -41,8 +41,58 @@ def test_crf_catted_fit(device, token_sizes, num_tags): mask = torch.zeros(size, dtype=torch.bool, device=device) mask[ptr] = True - actual = decoder1.fit(catted_emissions, catted_targets)[:, 0] - excepted = decoder2.forward(padded_emissions, padded_targets, mask=mask, reduction='none').neg() + actual = actual_decoder.fit(emissions=catted_emissions, targets=catted_targets)[:, 0] + excepted = excepted_decoder.forward( + emissions=padded_emissions, tags=padded_targets.long(), + mask=mask.byte(), reduction='none', + ).neg() + + assert_close(actual=actual, expected=excepted) + assert_grad_close(actual=actual, expected=excepted, inputs=emissions) + + +@given( + device=devices(), + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + num_tags=sizes(TOKEN_SIZE), +) +def test_crf_packed_fit(device, token_sizes, num_tags): + actual_decoder = CrfDecoder(num_tags) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] + actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] + actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] + + emissions = [ + torch.randn((token_size, num_tags), requires_grad=True, device=device) + for token_size in token_sizes + ] + targets = [ + torch.randint(0, num_tags, (token_size,), device=device) + for token_size in token_sizes + ] + + packed_emissions = pack_sequence([x[:, None] for x in emissions]) + packed_targets = pack_sequence([x[:, None] for x in targets]) + + padded_emissions, _ = pad_sequence(emissions, batch_first=False) + padded_targets, _ = pad_sequence(targets, batch_first=False) + + size, ptr, _ = pad_packed_indices( + batch_sizes=packed_emissions.batch_sizes, + sorted_indices=packed_emissions.sorted_indices, + unsorted_indices=packed_emissions.unsorted_indices, + batch_first=False, + ) + mask = torch.zeros(size, dtype=torch.bool, device=device) + mask[ptr] = True + + actual = actual_decoder.fit(emissions=packed_emissions, targets=packed_targets)[:, 0] + excepted = excepted_decoder.forward( + emissions=padded_emissions, tags=padded_targets.long(), + mask=mask.byte(), reduction='none', + ).neg() assert_close(actual=actual, expected=excepted) assert_grad_close(actual=actual, expected=excepted, inputs=emissions) From e575cdb3de7053ce3a1bc23df19e590de4b4e0e0 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 22:01:18 +0900 Subject: [PATCH 029/102] Test: Add test_crf_catted_decode --- tests/test_crf2.py | 40 ++++++++++++++++++++++++++++++++++++++-- torchlatent/crf2.py | 14 +++++++++++--- 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/tests/test_crf2.py b/tests/test_crf2.py index 4ae8b74..63553c0 100644 --- a/tests/test_crf2.py +++ b/tests/test_crf2.py @@ -2,10 +2,11 @@ import torchcrf from hypothesis import given from torch.testing import assert_close -from torchrua import cat_sequence, pad_sequence, pad_catted_indices, pad_packed_indices, pack_sequence +from torchrua import cat_sequence, pad_sequence, pad_catted_indices, pad_packed_indices, pack_sequence, \ + pad_catted_sequence from tests.strategies import devices, sizes, TOKEN_SIZE, TINY_BATCH_SIZE -from tests.utils import assert_grad_close +from tests.utils import assert_grad_close, assert_equal from torchlatent.crf2 import CrfDecoder @@ -51,6 +52,41 @@ def test_crf_catted_fit(device, token_sizes, num_tags): assert_grad_close(actual=actual, expected=excepted, inputs=emissions) +@given( + device=devices(), + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + num_tags=sizes(TOKEN_SIZE), +) +def test_crf_catted_decode(device, token_sizes, num_tags): + actual_decoder = CrfDecoder(num_tags) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] + actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] + actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] + + emissions = [ + torch.randn((token_size, num_tags), requires_grad=True, device=device) + for token_size in token_sizes + ] + + catted_emissions = cat_sequence([x[:, None] for x in emissions]) + + padded_emissions, _ = pad_sequence(emissions, batch_first=False) + + size, ptr = pad_catted_indices(catted_emissions.token_sizes, batch_first=False) + mask = torch.zeros(size, dtype=torch.bool, device=device) + mask[ptr] = True + + actual, actual_token_sizes = actual_decoder.decode(emissions=catted_emissions) + + excepted = excepted_decoder.decode(emissions=padded_emissions, mask=mask.byte()) + excepted, excepted_token_sizes = cat_sequence([torch.tensor(x, device=device) for x in excepted]) + + assert_equal(actual=actual[:, 0], expected=excepted) + assert_equal(actual=actual_token_sizes, expected=excepted_token_sizes) + + @given( device=devices(), token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), diff --git a/torchlatent/crf2.py b/torchlatent/crf2.py index bfb801b..34ba7a7 100644 --- a/torchlatent/crf2.py +++ b/torchlatent/crf2.py @@ -1,4 +1,4 @@ -from typing import Sequence, NamedTuple +from typing import NamedTuple, Union from typing import Tuple from typing import Type @@ -10,13 +10,15 @@ from torch.nn.utils.rnn import PackedSequence from torch.types import Device from torchrua import ReductionIndices, reduce_catted_indices, reduce_packed_indices +from torchrua import cat_packed_indices, roll_catted_indices, CattedSequence from torchrua import head_catted_indices, last_catted_indices, head_packed_indices, last_packed_indices, \ accumulate_sizes -from torchrua import cat_packed_indices, roll_catted_indices, CattedSequence from torchlatent.abc import DistributionABC from torchlatent.semiring import Semiring, Log, Max +Sequence = Union[CattedSequence, PackedSequence] + class CrfIndices(NamedTuple): head: Tensor @@ -197,6 +199,10 @@ def max(self) -> Tensor: semiring=Max, ) + @lazy_property + def argmax(self) -> Tensor: + return super(CrfDistribution, self).argmax.argmax(dim=-1) + @lazy_property def entropy(self) -> Tensor: raise NotImplementedError @@ -254,4 +260,6 @@ def fit(self, emissions: Sequence, targets: Sequence, indices: CrfIndices = None dist = self.forward(emissions=emissions, indices=indices) return dist.log_partitions - dist.log_scores(targets=targets) - + def decode(self, emissions: Sequence, indices: CrfIndices = None) -> Sequence: + dist = self.forward(emissions=emissions, indices=indices) + return emissions._replace(data=dist.argmax) From 438abd5eebc9e45df482f43c133fd9bfa6b1c42b Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 22:03:49 +0900 Subject: [PATCH 030/102] Test: Add test_crf_packed_decode --- tests/test_crf2.py | 49 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/tests/test_crf2.py b/tests/test_crf2.py index 63553c0..4441203 100644 --- a/tests/test_crf2.py +++ b/tests/test_crf2.py @@ -2,8 +2,7 @@ import torchcrf from hypothesis import given from torch.testing import assert_close -from torchrua import cat_sequence, pad_sequence, pad_catted_indices, pad_packed_indices, pack_sequence, \ - pad_catted_sequence +from torchrua import cat_sequence, pad_sequence, pad_catted_indices, pad_packed_indices, pack_sequence from tests.strategies import devices, sizes, TOKEN_SIZE, TINY_BATCH_SIZE from tests.utils import assert_grad_close, assert_equal @@ -38,7 +37,7 @@ def test_crf_catted_fit(device, token_sizes, num_tags): padded_emissions, _ = pad_sequence(emissions, batch_first=False) padded_targets, _ = pad_sequence(targets, batch_first=False) - size, ptr = pad_catted_indices(catted_emissions.token_sizes, batch_first=False) + size, ptr = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) mask = torch.zeros(size, dtype=torch.bool, device=device) mask[ptr] = True @@ -74,7 +73,7 @@ def test_crf_catted_decode(device, token_sizes, num_tags): padded_emissions, _ = pad_sequence(emissions, batch_first=False) - size, ptr = pad_catted_indices(catted_emissions.token_sizes, batch_first=False) + size, ptr = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) mask = torch.zeros(size, dtype=torch.bool, device=device) mask[ptr] = True @@ -132,3 +131,45 @@ def test_crf_packed_fit(device, token_sizes, num_tags): assert_close(actual=actual, expected=excepted) assert_grad_close(actual=actual, expected=excepted, inputs=emissions) + + +@given( + device=devices(), + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + num_tags=sizes(TOKEN_SIZE), +) +def test_crf_packed_decode(device, token_sizes, num_tags): + actual_decoder = CrfDecoder(num_tags) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] + actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] + actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] + + emissions = [ + torch.randn((token_size, num_tags), requires_grad=True, device=device) + for token_size in token_sizes + ] + + packed_emissions = pack_sequence([x[:, None] for x in emissions]) + + padded_emissions, _ = pad_sequence(emissions, batch_first=False) + + size, ptr, _ = pad_packed_indices( + batch_sizes=packed_emissions.batch_sizes, + sorted_indices=packed_emissions.sorted_indices, + unsorted_indices=packed_emissions.unsorted_indices, + batch_first=False, + ) + mask = torch.zeros(size, dtype=torch.bool, device=device) + mask[ptr] = True + + actual = actual_decoder.decode(emissions=packed_emissions) + + excepted = excepted_decoder.decode(emissions=padded_emissions, mask=mask.byte()) + excepted = pack_sequence([torch.tensor(x, device=device) for x in excepted]) + + assert_equal(actual=actual.data[:, 0], expected=excepted.data) + assert_equal(actual=actual.batch_sizes, expected=excepted.batch_sizes) + assert_equal(actual=actual.sorted_indices, expected=excepted.sorted_indices) + assert_equal(actual=actual.unsorted_indices, expected=excepted.unsorted_indices) From 4f23f8e4c1711533451a8634b08d05d432009cde Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 22:16:27 +0900 Subject: [PATCH 031/102] Refactor: Add CrfDecoderABC --- torchlatent/crf2.py | 38 ++++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/torchlatent/crf2.py b/torchlatent/crf2.py index 34ba7a7..ce0befc 100644 --- a/torchlatent/crf2.py +++ b/torchlatent/crf2.py @@ -7,12 +7,11 @@ from torch import nn from torch.distributions.utils import lazy_property from torch.nn import init -from torch.nn.utils.rnn import PackedSequence from torch.types import Device -from torchrua import ReductionIndices, reduce_catted_indices, reduce_packed_indices -from torchrua import cat_packed_indices, roll_catted_indices, CattedSequence -from torchrua import head_catted_indices, last_catted_indices, head_packed_indices, last_packed_indices, \ - accumulate_sizes +from torchrua import ReductionIndices, accumulate_sizes +from torchrua import head_catted_indices, last_catted_indices, reduce_catted_indices +from torchrua import head_packed_indices, last_packed_indices, reduce_packed_indices +from torchrua import roll_catted_indices, cat_packed_indices, CattedSequence, PackedSequence from torchlatent.abc import DistributionABC from torchlatent.semiring import Semiring, Log, Max @@ -129,18 +128,18 @@ def crf_indices(emissions: Sequence) -> CrfIndices: def crf_reduce(emissions: Tensor, targets: Tensor, transitions: Tuple[Tensor, Tensor, Tensor], indices: CrfIndices, semiring: Type[Semiring]) -> Tensor: - head, last, prev, curr, sizes, unsorted_indices, _ = indices + head, last, prev, curr, token_sizes, unsorted_indices, _ = indices transitions, head_transitions, last_transitions = transitions c = torch.arange(transitions.size()[1], device=emissions.device) emissions = emissions[curr[:, None], c[None, :], targets[curr]] transitions = transitions[curr[:, None], c[None, :], targets[prev], targets[curr]] - transitions[accumulate_sizes(sizes=sizes)] = semiring.one + transitions[accumulate_sizes(sizes=token_sizes)] = semiring.one head_transitions = head_transitions[unsorted_indices[:, None], c[None, :], targets[head]] last_transitions = last_transitions[unsorted_indices[:, None], c[None, :], targets[last]] - emissions = semiring.segment_prod(semiring.mul(emissions, transitions), sizes=sizes) + emissions = semiring.segment_prod(semiring.mul(emissions, transitions), sizes=token_sizes) return semiring.mul(emissions, semiring.mul(head_transitions, last_transitions)) @@ -208,7 +207,21 @@ def entropy(self) -> Tensor: raise NotImplementedError -class CrfDecoder(nn.Module): +class CrfDecoderABC(nn.Module): + def reset_parameters(self) -> None: + raise NotImplementedError + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.extra_repr()})' + + def extra_repr(self) -> str: + return '' + + def forward_parameters(self, emissions: Sequence): + raise NotImplementedError + + +class CrfDecoder(CrfDecoderABC): def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: super(CrfDecoder, self).__init__() @@ -221,12 +234,17 @@ def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: self.reset_parameters() - @torch.no_grad() def reset_parameters(self) -> None: init.zeros_(self.transitions) init.zeros_(self.head_transitions) init.zeros_(self.last_transitions) + def extra_repr(self) -> str: + return ', '.join([ + f'num_tags={self.num_tags}', + f'num_conjugates={self.num_conjugates}', + ]) + def forward_parameters(self, emissions: Sequence): transitions = (self.transitions, self.head_transitions, self.last_transitions) From 3359ec1faa49cd77e4d903b2c9dffd861bce3f83 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 22:22:38 +0900 Subject: [PATCH 032/102] Feat: Add entropy --- torchlatent/crf2.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchlatent/crf2.py b/torchlatent/crf2.py index ce0befc..07c775c 100644 --- a/torchlatent/crf2.py +++ b/torchlatent/crf2.py @@ -204,7 +204,11 @@ def argmax(self) -> Tensor: @lazy_property def entropy(self) -> Tensor: - raise NotImplementedError + tensor = (self.marginals * self.marginals.log()).sum(dim=-1) + return -Log.segment_prod( + tensor=tensor[self.indices.curr], + sizes=self.indices.token_sizes, + ) class CrfDecoderABC(nn.Module): @@ -253,7 +257,7 @@ def forward_parameters(self, emissions: Sequence): elif isinstance(emissions, PackedSequence): t, c, h = broadcast_packed_shapes(sequence=emissions, transitions=transitions) else: - raise NotImplementedError + raise KeyError(f'type {type(emissions)} is not supported') emissions = emissions.data.expand((t, c, -1)) transitions = self.transitions.expand((t, c, -1, -1)) @@ -268,11 +272,7 @@ def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistrib emissions, transitions = self.forward_parameters(emissions=emissions) - return CrfDistribution( - emissions=emissions, - transitions=transitions, - indices=indices, - ) + return CrfDistribution(emissions=emissions, transitions=transitions, indices=indices) def fit(self, emissions: Sequence, targets: Sequence, indices: CrfIndices = None) -> Tensor: dist = self.forward(emissions=emissions, indices=indices) From 1bebae2badac1f0ffcac7a7c8e6be4483fb0f07a Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 22:40:11 +0900 Subject: [PATCH 033/102] Test: Add test_conjugated_*_fit --- tests/test_crf2.py | 86 ++++++++++++++++++++++++++++++++++++++++++++- torchlatent/crf2.py | 2 +- 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/tests/test_crf2.py b/tests/test_crf2.py index 4441203..7b7a73b 100644 --- a/tests/test_crf2.py +++ b/tests/test_crf2.py @@ -4,7 +4,7 @@ from torch.testing import assert_close from torchrua import cat_sequence, pad_sequence, pad_catted_indices, pad_packed_indices, pack_sequence -from tests.strategies import devices, sizes, TOKEN_SIZE, TINY_BATCH_SIZE +from tests.strategies import devices, sizes, TOKEN_SIZE, TINY_BATCH_SIZE, NUM_CONJUGATES, TINY_TOKEN_SIZE from tests.utils import assert_grad_close, assert_equal from torchlatent.crf2 import CrfDecoder @@ -173,3 +173,87 @@ def test_crf_packed_decode(device, token_sizes, num_tags): assert_equal(actual=actual.batch_sizes, expected=excepted.batch_sizes) assert_equal(actual=actual.sorted_indices, expected=excepted.sorted_indices) assert_equal(actual=actual.unsorted_indices, expected=excepted.unsorted_indices) + + +@given( + device=devices(), + token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), + num_conjugates=sizes(NUM_CONJUGATES), + num_tags=sizes(TINY_TOKEN_SIZE), +) +def test_conjugated_catted_fit(device, token_sizes, num_conjugates, num_tags): + decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) + decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1) for _ in range(num_conjugates)] + + for index in range(num_conjugates): + decoder.transitions.data[:, index] = decoders[index].transitions + decoder.head_transitions.data[:, index] = decoders[index].head_transitions + decoder.last_transitions.data[:, index] = decoders[index].last_transitions + + emissions = [[ + torch.randn((token_size, 1, num_tags), requires_grad=True, device=device) + for token_size in token_sizes + ] for _ in range(num_conjugates)] + + targets = [[ + torch.randint(0, num_tags, (token_size, 1), device=device) + for token_size in token_sizes + ] for _ in range(num_conjugates)] + + actual = decoder.fit( + emissions=cat_sequence([torch.cat(sequences, dim=1) for sequences in zip(*emissions)], device=device), + targets=cat_sequence([torch.cat(sequences, dim=1) for sequences in zip(*targets)], device=device), + ) + + expected = torch.cat([ + decoders[index].fit( + emissions=cat_sequence(emissions[index], device=device), + targets=cat_sequence(targets[index], device=device), + ) + for index in range(num_conjugates) + ], dim=1) + + assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=[x for xs in emissions for x in xs], check_stride=False) + + +@given( + device=devices(), + token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), + num_conjugates=sizes(NUM_CONJUGATES), + num_tags=sizes(TINY_TOKEN_SIZE), +) +def test_conjugated_packed_fit(device, token_sizes, num_conjugates, num_tags): + decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) + decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1) for _ in range(num_conjugates)] + + for index in range(num_conjugates): + decoder.transitions.data[:, index] = decoders[index].transitions + decoder.head_transitions.data[:, index] = decoders[index].head_transitions + decoder.last_transitions.data[:, index] = decoders[index].last_transitions + + emissions = [[ + torch.randn((token_size, 1, num_tags), requires_grad=True, device=device) + for token_size in token_sizes + ] for _ in range(num_conjugates)] + + targets = [[ + torch.randint(0, num_tags, (token_size, 1), device=device) + for token_size in token_sizes + ] for _ in range(num_conjugates)] + + actual = decoder.fit( + emissions=pack_sequence([torch.cat(sequences, dim=1) for sequences in zip(*emissions)], device=device), + targets=pack_sequence([torch.cat(sequences, dim=1) for sequences in zip(*targets)], device=device), + ) + + expected = torch.cat([ + decoders[index].fit( + emissions=pack_sequence(emissions[index], device=device), + targets=pack_sequence(targets[index], device=device), + ) + for index in range(num_conjugates) + ], dim=1) + + assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=[x for xs in emissions for x in xs], check_stride=False) diff --git a/torchlatent/crf2.py b/torchlatent/crf2.py index 07c775c..9ff9373 100644 --- a/torchlatent/crf2.py +++ b/torchlatent/crf2.py @@ -276,7 +276,7 @@ def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistrib def fit(self, emissions: Sequence, targets: Sequence, indices: CrfIndices = None) -> Tensor: dist = self.forward(emissions=emissions, indices=indices) - return dist.log_partitions - dist.log_scores(targets=targets) + return dist.log_prob(targets=targets).neg() def decode(self, emissions: Sequence, indices: CrfIndices = None) -> Sequence: dist = self.forward(emissions=emissions, indices=indices) From d6122fe200e9085237e1571d7cfb1f4a635f4106 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 22:56:02 +0900 Subject: [PATCH 034/102] Refactor: Remove old crf.py --- tests/test_crf.py | 276 ++++++++++++++++++++++++-------- tests/test_crf2.py | 259 ------------------------------ torchlatent/{crf2.py => crf.py} | 0 torchlatent/crf/__init__.py | 127 --------------- torchlatent/crf/catting.py | 103 ------------ torchlatent/crf/packing.py | 120 -------------- 6 files changed, 209 insertions(+), 676 deletions(-) delete mode 100644 tests/test_crf2.py rename torchlatent/{crf2.py => crf.py} (100%) delete mode 100644 torchlatent/crf/__init__.py delete mode 100644 torchlatent/crf/catting.py delete mode 100644 torchlatent/crf/packing.py diff --git a/tests/test_crf.py b/tests/test_crf.py index 9c3a0af..cc5f453 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -1,117 +1,259 @@ import torch +import torchcrf from hypothesis import given -from torchrua import pack_sequence, cat_sequence, pack_catted_sequence +from torch.testing import assert_close +from torchrua import cat_sequence, pad_sequence, pad_catted_indices, pad_packed_indices, pack_sequence -from tests.strategies import devices, sizes, BATCH_SIZE, TOKEN_SIZE, NUM_CONJUGATES, NUM_TAGS -from tests.utils import assert_close, assert_grad_close, assert_packed_sequence_equal -from third.crf import CrfDecoder as ThirdPartyCrfDecoder +from tests.strategies import devices, sizes, TOKEN_SIZE, TINY_BATCH_SIZE, NUM_CONJUGATES, TINY_TOKEN_SIZE +from tests.utils import assert_grad_close, assert_equal from torchlatent.crf import CrfDecoder @given( device=devices(), - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_conjugate=sizes(NUM_CONJUGATES), - num_tags=sizes(NUM_TAGS), + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + num_tags=sizes(TOKEN_SIZE), ) -def test_crf_packed_fit(device, token_sizes, num_conjugate, num_tags): - emissions = pack_sequence([ - torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) - for token_size in token_sizes - ], device=device) +def test_crf_catted_fit(device, token_sizes, num_tags): + actual_decoder = CrfDecoder(num_tags) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] + actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] + actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] - tags = pack_sequence([ - torch.randint(0, num_tags, (token_size, num_conjugate), device=device) + emissions = [ + torch.randn((token_size, num_tags), requires_grad=True, device=device) for token_size in token_sizes - ], device=device) + ] + targets = [ + torch.randint(0, num_tags, (token_size,), device=device) + for token_size in token_sizes + ] - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder.reset_parameters_with_(decoder=actual_decoder) + catted_emissions = cat_sequence([x[:, None] for x in emissions]) + catted_targets = cat_sequence([x[:, None] for x in targets]) - actual = actual_decoder.fit(emissions=emissions, tags=tags) - expected = expected_decoder.fit(emissions=emissions, tags=tags) + padded_emissions, _ = pad_sequence(emissions, batch_first=False) + padded_targets, _ = pad_sequence(targets, batch_first=False) - assert_close(actual=actual, expected=expected) - assert_grad_close(actual=actual, expected=expected, inputs=(emissions.data,)) + size, ptr = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) + mask = torch.zeros(size, dtype=torch.bool, device=device) + mask[ptr] = True + + actual = actual_decoder.fit(emissions=catted_emissions, targets=catted_targets)[:, 0] + excepted = excepted_decoder.forward( + emissions=padded_emissions, tags=padded_targets.long(), + mask=mask.byte(), reduction='none', + ).neg() + + assert_close(actual=actual, expected=excepted) + assert_grad_close(actual=actual, expected=excepted, inputs=emissions) @given( device=devices(), - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_conjugate=sizes(NUM_CONJUGATES), - num_tags=sizes(NUM_TAGS), + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + num_tags=sizes(TOKEN_SIZE), ) -def test_crf_packed_decode(device, token_sizes, num_conjugate, num_tags): - emissions = pack_sequence([ - torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) +def test_crf_catted_decode(device, token_sizes, num_tags): + actual_decoder = CrfDecoder(num_tags) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] + actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] + actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] + + emissions = [ + torch.randn((token_size, num_tags), requires_grad=True, device=device) for token_size in token_sizes - ], device=device) + ] - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder.reset_parameters_with_(decoder=actual_decoder) + catted_emissions = cat_sequence([x[:, None] for x in emissions]) - expected = expected_decoder.decode(emissions=emissions) - actual = actual_decoder.decode(emissions=emissions) + padded_emissions, _ = pad_sequence(emissions, batch_first=False) - assert_packed_sequence_equal(actual=actual, expected=expected) + size, ptr = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) + mask = torch.zeros(size, dtype=torch.bool, device=device) + mask[ptr] = True + + actual, actual_token_sizes = actual_decoder.decode(emissions=catted_emissions) + + excepted = excepted_decoder.decode(emissions=padded_emissions, mask=mask.byte()) + excepted, excepted_token_sizes = cat_sequence([torch.tensor(x, device=device) for x in excepted]) + + assert_equal(actual=actual[:, 0], expected=excepted) + assert_equal(actual=actual_token_sizes, expected=excepted_token_sizes) @given( device=devices(), - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_conjugate=sizes(NUM_CONJUGATES), - num_tags=sizes(NUM_TAGS), + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + num_tags=sizes(TOKEN_SIZE), ) -def test_crf_catted_fit(device, token_sizes, num_conjugate, num_tags): +def test_crf_packed_fit(device, token_sizes, num_tags): + actual_decoder = CrfDecoder(num_tags) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] + actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] + actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] + emissions = [ - torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) + torch.randn((token_size, num_tags), requires_grad=True, device=device) for token_size in token_sizes ] - tags = [ - torch.randint(0, num_tags, (token_size, num_conjugate), device=device) + targets = [ + torch.randint(0, num_tags, (token_size,), device=device) for token_size in token_sizes ] - catted_emissions = cat_sequence(emissions, device=device) - packed_emissions = pack_sequence(emissions, device=device) + packed_emissions = pack_sequence([x[:, None] for x in emissions]) + packed_targets = pack_sequence([x[:, None] for x in targets]) - catted_tags = cat_sequence(tags, device=device) - packed_tags = pack_sequence(tags, device=device) + padded_emissions, _ = pad_sequence(emissions, batch_first=False) + padded_targets, _ = pad_sequence(targets, batch_first=False) - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder.reset_parameters_with_(decoder=actual_decoder) + size, ptr, _ = pad_packed_indices( + batch_sizes=packed_emissions.batch_sizes, + sorted_indices=packed_emissions.sorted_indices, + unsorted_indices=packed_emissions.unsorted_indices, + batch_first=False, + ) + mask = torch.zeros(size, dtype=torch.bool, device=device) + mask[ptr] = True - actual = actual_decoder.fit(emissions=catted_emissions, tags=catted_tags) - expected = expected_decoder.fit(emissions=packed_emissions, tags=packed_tags) + actual = actual_decoder.fit(emissions=packed_emissions, targets=packed_targets)[:, 0] + excepted = excepted_decoder.forward( + emissions=padded_emissions, tags=padded_targets.long(), + mask=mask.byte(), reduction='none', + ).neg() - assert_close(actual=actual, expected=expected) - assert_grad_close(actual=actual, expected=expected, inputs=tuple(emissions)) + assert_close(actual=actual, expected=excepted) + assert_grad_close(actual=actual, expected=excepted, inputs=emissions) @given( device=devices(), - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_conjugate=sizes(NUM_CONJUGATES), - num_tags=sizes(NUM_TAGS), + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + num_tags=sizes(TOKEN_SIZE), ) -def test_crf_catted_decode(device, token_sizes, num_conjugate, num_tags): +def test_crf_packed_decode(device, token_sizes, num_tags): + actual_decoder = CrfDecoder(num_tags) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] + actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] + actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] + emissions = [ - torch.randn((token_size, num_conjugate, num_tags), device=device, requires_grad=True) + torch.randn((token_size, num_tags), requires_grad=True, device=device) for token_size in token_sizes ] - catted_emissions = cat_sequence(emissions, device=device) - packed_emissions = pack_sequence(emissions, device=device) + packed_emissions = pack_sequence([x[:, None] for x in emissions]) + + padded_emissions, _ = pad_sequence(emissions, batch_first=False) + + size, ptr, _ = pad_packed_indices( + batch_sizes=packed_emissions.batch_sizes, + sorted_indices=packed_emissions.sorted_indices, + unsorted_indices=packed_emissions.unsorted_indices, + batch_first=False, + ) + mask = torch.zeros(size, dtype=torch.bool, device=device) + mask[ptr] = True - actual_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugate).to(device=device) - expected_decoder.reset_parameters_with_(decoder=actual_decoder) + actual = actual_decoder.decode(emissions=packed_emissions) - expected = expected_decoder.decode(emissions=packed_emissions) - actual = actual_decoder.decode(emissions=catted_emissions) - actual = pack_catted_sequence(*actual, device=device) + excepted = excepted_decoder.decode(emissions=padded_emissions, mask=mask.byte()) + excepted = pack_sequence([torch.tensor(x, device=device) for x in excepted]) - assert_packed_sequence_equal(actual=actual, expected=expected) + assert_equal(actual=actual.data[:, 0], expected=excepted.data) + assert_equal(actual=actual.batch_sizes, expected=excepted.batch_sizes) + assert_equal(actual=actual.sorted_indices, expected=excepted.sorted_indices) + assert_equal(actual=actual.unsorted_indices, expected=excepted.unsorted_indices) + + +@given( + device=devices(), + token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), + num_conjugates=sizes(NUM_CONJUGATES), + num_tags=sizes(TINY_TOKEN_SIZE), +) +def test_conjugated_catted_fit(device, token_sizes, num_conjugates, num_tags): + decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) + decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1) for _ in range(num_conjugates)] + + for index in range(num_conjugates): + decoder.transitions.data[:, index] = decoders[index].transitions + decoder.head_transitions.data[:, index] = decoders[index].head_transitions + decoder.last_transitions.data[:, index] = decoders[index].last_transitions + + emissions = [[ + torch.randn((token_size, 1, num_tags), requires_grad=True, device=device) + for token_size in token_sizes + ] for _ in range(num_conjugates)] + + targets = [[ + torch.randint(0, num_tags, (token_size, 1), device=device) + for token_size in token_sizes + ] for _ in range(num_conjugates)] + + actual = decoder.fit( + emissions=cat_sequence([torch.cat(sequences, dim=1) for sequences in zip(*emissions)], device=device), + targets=cat_sequence([torch.cat(sequences, dim=1) for sequences in zip(*targets)], device=device), + ) + + expected = torch.cat([ + decoders[index].fit( + emissions=cat_sequence(emissions[index], device=device), + targets=cat_sequence(targets[index], device=device), + ) + for index in range(num_conjugates) + ], dim=1) + + assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=[x for xs in emissions for x in xs], check_stride=False) + + +@given( + device=devices(), + token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), + num_conjugates=sizes(NUM_CONJUGATES), + num_tags=sizes(TINY_TOKEN_SIZE), +) +def test_conjugated_packed_fit(device, token_sizes, num_conjugates, num_tags): + decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) + decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1) for _ in range(num_conjugates)] + + for index in range(num_conjugates): + decoder.transitions.data[:, index] = decoders[index].transitions + decoder.head_transitions.data[:, index] = decoders[index].head_transitions + decoder.last_transitions.data[:, index] = decoders[index].last_transitions + + emissions = [[ + torch.randn((token_size, 1, num_tags), requires_grad=True, device=device) + for token_size in token_sizes + ] for _ in range(num_conjugates)] + + targets = [[ + torch.randint(0, num_tags, (token_size, 1), device=device) + for token_size in token_sizes + ] for _ in range(num_conjugates)] + + actual = decoder.fit( + emissions=pack_sequence([torch.cat(sequences, dim=1) for sequences in zip(*emissions)], device=device), + targets=pack_sequence([torch.cat(sequences, dim=1) for sequences in zip(*targets)], device=device), + ) + + expected = torch.cat([ + decoders[index].fit( + emissions=pack_sequence(emissions[index], device=device), + targets=pack_sequence(targets[index], device=device), + ) + for index in range(num_conjugates) + ], dim=1) + + assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=[x for xs in emissions for x in xs], check_stride=False) diff --git a/tests/test_crf2.py b/tests/test_crf2.py deleted file mode 100644 index 7b7a73b..0000000 --- a/tests/test_crf2.py +++ /dev/null @@ -1,259 +0,0 @@ -import torch -import torchcrf -from hypothesis import given -from torch.testing import assert_close -from torchrua import cat_sequence, pad_sequence, pad_catted_indices, pad_packed_indices, pack_sequence - -from tests.strategies import devices, sizes, TOKEN_SIZE, TINY_BATCH_SIZE, NUM_CONJUGATES, TINY_TOKEN_SIZE -from tests.utils import assert_grad_close, assert_equal -from torchlatent.crf2 import CrfDecoder - - -@given( - device=devices(), - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - num_tags=sizes(TOKEN_SIZE), -) -def test_crf_catted_fit(device, token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) - - actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] - actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] - actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] - - emissions = [ - torch.randn((token_size, num_tags), requires_grad=True, device=device) - for token_size in token_sizes - ] - targets = [ - torch.randint(0, num_tags, (token_size,), device=device) - for token_size in token_sizes - ] - - catted_emissions = cat_sequence([x[:, None] for x in emissions]) - catted_targets = cat_sequence([x[:, None] for x in targets]) - - padded_emissions, _ = pad_sequence(emissions, batch_first=False) - padded_targets, _ = pad_sequence(targets, batch_first=False) - - size, ptr = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) - mask = torch.zeros(size, dtype=torch.bool, device=device) - mask[ptr] = True - - actual = actual_decoder.fit(emissions=catted_emissions, targets=catted_targets)[:, 0] - excepted = excepted_decoder.forward( - emissions=padded_emissions, tags=padded_targets.long(), - mask=mask.byte(), reduction='none', - ).neg() - - assert_close(actual=actual, expected=excepted) - assert_grad_close(actual=actual, expected=excepted, inputs=emissions) - - -@given( - device=devices(), - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - num_tags=sizes(TOKEN_SIZE), -) -def test_crf_catted_decode(device, token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) - - actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] - actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] - actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] - - emissions = [ - torch.randn((token_size, num_tags), requires_grad=True, device=device) - for token_size in token_sizes - ] - - catted_emissions = cat_sequence([x[:, None] for x in emissions]) - - padded_emissions, _ = pad_sequence(emissions, batch_first=False) - - size, ptr = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) - mask = torch.zeros(size, dtype=torch.bool, device=device) - mask[ptr] = True - - actual, actual_token_sizes = actual_decoder.decode(emissions=catted_emissions) - - excepted = excepted_decoder.decode(emissions=padded_emissions, mask=mask.byte()) - excepted, excepted_token_sizes = cat_sequence([torch.tensor(x, device=device) for x in excepted]) - - assert_equal(actual=actual[:, 0], expected=excepted) - assert_equal(actual=actual_token_sizes, expected=excepted_token_sizes) - - -@given( - device=devices(), - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - num_tags=sizes(TOKEN_SIZE), -) -def test_crf_packed_fit(device, token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) - - actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] - actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] - actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] - - emissions = [ - torch.randn((token_size, num_tags), requires_grad=True, device=device) - for token_size in token_sizes - ] - targets = [ - torch.randint(0, num_tags, (token_size,), device=device) - for token_size in token_sizes - ] - - packed_emissions = pack_sequence([x[:, None] for x in emissions]) - packed_targets = pack_sequence([x[:, None] for x in targets]) - - padded_emissions, _ = pad_sequence(emissions, batch_first=False) - padded_targets, _ = pad_sequence(targets, batch_first=False) - - size, ptr, _ = pad_packed_indices( - batch_sizes=packed_emissions.batch_sizes, - sorted_indices=packed_emissions.sorted_indices, - unsorted_indices=packed_emissions.unsorted_indices, - batch_first=False, - ) - mask = torch.zeros(size, dtype=torch.bool, device=device) - mask[ptr] = True - - actual = actual_decoder.fit(emissions=packed_emissions, targets=packed_targets)[:, 0] - excepted = excepted_decoder.forward( - emissions=padded_emissions, tags=padded_targets.long(), - mask=mask.byte(), reduction='none', - ).neg() - - assert_close(actual=actual, expected=excepted) - assert_grad_close(actual=actual, expected=excepted, inputs=emissions) - - -@given( - device=devices(), - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - num_tags=sizes(TOKEN_SIZE), -) -def test_crf_packed_decode(device, token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) - - actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] - actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] - actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] - - emissions = [ - torch.randn((token_size, num_tags), requires_grad=True, device=device) - for token_size in token_sizes - ] - - packed_emissions = pack_sequence([x[:, None] for x in emissions]) - - padded_emissions, _ = pad_sequence(emissions, batch_first=False) - - size, ptr, _ = pad_packed_indices( - batch_sizes=packed_emissions.batch_sizes, - sorted_indices=packed_emissions.sorted_indices, - unsorted_indices=packed_emissions.unsorted_indices, - batch_first=False, - ) - mask = torch.zeros(size, dtype=torch.bool, device=device) - mask[ptr] = True - - actual = actual_decoder.decode(emissions=packed_emissions) - - excepted = excepted_decoder.decode(emissions=padded_emissions, mask=mask.byte()) - excepted = pack_sequence([torch.tensor(x, device=device) for x in excepted]) - - assert_equal(actual=actual.data[:, 0], expected=excepted.data) - assert_equal(actual=actual.batch_sizes, expected=excepted.batch_sizes) - assert_equal(actual=actual.sorted_indices, expected=excepted.sorted_indices) - assert_equal(actual=actual.unsorted_indices, expected=excepted.unsorted_indices) - - -@given( - device=devices(), - token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), - num_conjugates=sizes(NUM_CONJUGATES), - num_tags=sizes(TINY_TOKEN_SIZE), -) -def test_conjugated_catted_fit(device, token_sizes, num_conjugates, num_tags): - decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) - decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1) for _ in range(num_conjugates)] - - for index in range(num_conjugates): - decoder.transitions.data[:, index] = decoders[index].transitions - decoder.head_transitions.data[:, index] = decoders[index].head_transitions - decoder.last_transitions.data[:, index] = decoders[index].last_transitions - - emissions = [[ - torch.randn((token_size, 1, num_tags), requires_grad=True, device=device) - for token_size in token_sizes - ] for _ in range(num_conjugates)] - - targets = [[ - torch.randint(0, num_tags, (token_size, 1), device=device) - for token_size in token_sizes - ] for _ in range(num_conjugates)] - - actual = decoder.fit( - emissions=cat_sequence([torch.cat(sequences, dim=1) for sequences in zip(*emissions)], device=device), - targets=cat_sequence([torch.cat(sequences, dim=1) for sequences in zip(*targets)], device=device), - ) - - expected = torch.cat([ - decoders[index].fit( - emissions=cat_sequence(emissions[index], device=device), - targets=cat_sequence(targets[index], device=device), - ) - for index in range(num_conjugates) - ], dim=1) - - assert_close(actual=actual, expected=expected) - assert_grad_close(actual=actual, expected=expected, inputs=[x for xs in emissions for x in xs], check_stride=False) - - -@given( - device=devices(), - token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), - num_conjugates=sizes(NUM_CONJUGATES), - num_tags=sizes(TINY_TOKEN_SIZE), -) -def test_conjugated_packed_fit(device, token_sizes, num_conjugates, num_tags): - decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) - decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1) for _ in range(num_conjugates)] - - for index in range(num_conjugates): - decoder.transitions.data[:, index] = decoders[index].transitions - decoder.head_transitions.data[:, index] = decoders[index].head_transitions - decoder.last_transitions.data[:, index] = decoders[index].last_transitions - - emissions = [[ - torch.randn((token_size, 1, num_tags), requires_grad=True, device=device) - for token_size in token_sizes - ] for _ in range(num_conjugates)] - - targets = [[ - torch.randint(0, num_tags, (token_size, 1), device=device) - for token_size in token_sizes - ] for _ in range(num_conjugates)] - - actual = decoder.fit( - emissions=pack_sequence([torch.cat(sequences, dim=1) for sequences in zip(*emissions)], device=device), - targets=pack_sequence([torch.cat(sequences, dim=1) for sequences in zip(*targets)], device=device), - ) - - expected = torch.cat([ - decoders[index].fit( - emissions=pack_sequence(emissions[index], device=device), - targets=pack_sequence(targets[index], device=device), - ) - for index in range(num_conjugates) - ], dim=1) - - assert_close(actual=actual, expected=expected) - assert_grad_close(actual=actual, expected=expected, inputs=[x for xs in emissions for x in xs], check_stride=False) diff --git a/torchlatent/crf2.py b/torchlatent/crf.py similarity index 100% rename from torchlatent/crf2.py rename to torchlatent/crf.py diff --git a/torchlatent/crf/__init__.py b/torchlatent/crf/__init__.py deleted file mode 100644 index fb5bf78..0000000 --- a/torchlatent/crf/__init__.py +++ /dev/null @@ -1,127 +0,0 @@ -from abc import ABCMeta -from typing import Optional, Tuple, Union - -import torch -from torch import Tensor -from torch import nn -from torch.nn import init -from torchrua import ReductionIndices, PackedSequence, CattedSequence -from torchrua import reduce_packed_indices, reduce_catted_indices - -from torchlatent.crf.catting import CattedCrfDistribution -from torchlatent.crf.packing import PackedCrfDistribution - -__all__ = [ - 'CrfDecoderABC', 'CrfDecoder', - 'PackedCrfDistribution', - 'CattedCrfDistribution', - 'Sequence', -] - -Sequence = Union[ - PackedSequence, - CattedSequence, -] - - -class CrfDecoderABC(nn.Module, metaclass=ABCMeta): - def __init__(self, num_tags: int, num_conjugates: int) -> None: - super(CrfDecoderABC, self).__init__() - - self.num_tags = num_tags - self.num_conjugates = num_conjugates - - def reset_parameters(self) -> None: - raise NotImplementedError - - def extra_repr(self) -> str: - return ', '.join([ - f'num_tags={self.num_tags}', - f'num_conjugates={self.num_conjugates}', - ]) - - @staticmethod - def compile_indices(emissions: Sequence, - tags: Optional[Sequence] = None, - indices: Optional[ReductionIndices] = None, **kwargs): - assert emissions.data.dim() == 3, f'{emissions.data.dim()} != {3}' - if tags is not None: - assert tags.data.dim() == 2, f'{tags.data.dim()} != {2}' - - if indices is None: - if isinstance(emissions, PackedSequence): - return reduce_packed_indices( - batch_sizes=emissions.batch_sizes, - device=emissions.data.device, - ) - - if isinstance(emissions, CattedSequence): - return reduce_catted_indices( - token_sizes=emissions.token_sizes, - device=emissions.data.device, - ) - - return indices - - def obtain_parameters(self, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: - return self.transitions, self.head_transitions, self.last_transitions - - def forward(self, emissions: Sequence, tags: Optional[Sequence] = None, - indices: Optional[ReductionIndices] = None, **kwargs): - indices = self.compile_indices(emissions=emissions, tags=tags, indices=indices) - transitions, head_transitions, last_transitions = self.obtain_parameters( - emissions=emissions, tags=tags, indices=indices, - ) - - if isinstance(emissions, PackedSequence): - dist = PackedCrfDistribution( - emissions=emissions, indices=indices, - transitions=transitions, - head_transitions=head_transitions, - last_transitions=last_transitions, - ) - return dist, tags - - if isinstance(emissions, CattedSequence): - dist = CattedCrfDistribution( - emissions=emissions, indices=indices, - transitions=transitions, - head_transitions=head_transitions, - last_transitions=last_transitions, - ) - return dist, tags - - raise TypeError(f'{type(emissions)} is not supported.') - - def fit(self, emissions: Sequence, tags: Sequence, - indices: Optional[ReductionIndices] = None, **kwargs) -> Tensor: - dist, tags = self(emissions=emissions, tags=tags, instr=indices, **kwargs) - - return dist.log_prob(tags=tags) - - def decode(self, emissions: Sequence, - indices: Optional[ReductionIndices] = None, **kwargs) -> Sequence: - dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) - return dist.argmax - - def marginals(self, emissions: Sequence, - indices: Optional[ReductionIndices] = None, **kwargs) -> Tensor: - dist, _ = self(emissions=emissions, tags=None, instr=indices, **kwargs) - return dist.marginals - - -class CrfDecoder(CrfDecoderABC): - def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: - super(CrfDecoder, self).__init__(num_tags=num_tags, num_conjugates=num_conjugates) - - self.transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags, self.num_tags))) - self.head_transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags))) - self.last_transitions = nn.Parameter(torch.empty((1, self.num_conjugates, self.num_tags))) - - self.reset_parameters() - - @torch.no_grad() - def reset_parameters(self, bound: float = 0.01) -> None: - init.uniform_(self.transitions, -bound, +bound) - init.uniform_(self.head_transitions, -bound, +bound) - init.uniform_(self.last_transitions, -bound, +bound) diff --git a/torchlatent/crf/catting.py b/torchlatent/crf/catting.py deleted file mode 100644 index 618f4ae..0000000 --- a/torchlatent/crf/catting.py +++ /dev/null @@ -1,103 +0,0 @@ -from typing import Type - -import torch -from torch import Tensor, autograd -from torch.distributions.utils import lazy_property -from torchrua import CattedSequence -from torchrua import ReductionIndices, head_catted_indices - -from torchlatent.crf2 import crf_reduce, crf_partition -from torchlatent.semiring import Semiring, Log, Max - -__all__ = [ - 'compute_catted_sequence_scores', - 'compute_catted_sequence_partitions', - 'CattedCrfDistribution', -] - - -def compute_catted_sequence_scores(semiring: Type[Semiring]): - def _compute_catted_sequence_scores( - emissions: CattedSequence, tags: CattedSequence, - transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor: - return crf_reduce( - emissions=emissions, targets=tags, - transitions=(transitions, head_transitions, last_transitions), - semiring=semiring, - ) - - return _compute_catted_sequence_scores - - -def compute_catted_sequence_partitions(semiring: Type[Semiring]): - def _compute_catted_sequence_partitions( - emissions: CattedSequence, indices: ReductionIndices, - transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor, eye: Tensor) -> Tensor: - return crf_partition( - emissions=emissions, indices=indices, - transitions=(transitions, head_transitions, last_transitions), - semiring=semiring, - ) - - return _compute_catted_sequence_partitions - - -class CattedCrfDistribution(object): - def __init__(self, emissions: CattedSequence, indices: ReductionIndices, - transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> None: - super(CattedCrfDistribution, self).__init__() - self.emissions = emissions - self.indices = indices - - self.transitions = transitions - self.head_transitions = head_transitions - self.last_transitions = last_transitions - - def semiring_scores(self, semiring: Type[Semiring], tags: CattedSequence) -> Tensor: - return compute_catted_sequence_scores(semiring=semiring)( - emissions=self.emissions, tags=tags, - transitions=self.transitions, - head_transitions=self.head_transitions, - last_transitions=self.last_transitions, - ) - - def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: - return compute_catted_sequence_partitions(semiring=semiring)( - emissions=self.emissions, indices=self.indices, - transitions=self.transitions, - head_transitions=self.head_transitions, - last_transitions=self.last_transitions, - eye=semiring.eye_like(self.transitions), - ) - - def log_prob(self, tags: CattedSequence) -> Tensor: - return self.log_scores(tags=tags) - self.log_partitions - - def log_scores(self, tags: CattedSequence) -> Tensor: - return self.semiring_scores(semiring=Log, tags=tags) - - @lazy_property - def log_partitions(self) -> Tensor: - return self.semiring_partitions(semiring=Log) - - @lazy_property - def marginals(self) -> Tensor: - log_partitions = self.log_partitions - grad, = autograd.grad( - log_partitions, self.emissions.data, torch.ones_like(log_partitions), - create_graph=True, only_inputs=True, allow_unused=False, - ) - return grad - - @lazy_property - def argmax(self) -> CattedSequence: - max_partitions = self.semiring_partitions(semiring=Max) - - grad, = torch.autograd.grad( - max_partitions, self.emissions.data, torch.ones_like(max_partitions), - retain_graph=False, create_graph=False, allow_unused=False, - ) - return CattedSequence( - data=grad.argmax(dim=-1), - token_sizes=self.emissions.token_sizes, - ) diff --git a/torchlatent/crf/packing.py b/torchlatent/crf/packing.py deleted file mode 100644 index 21beac9..0000000 --- a/torchlatent/crf/packing.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import Type - -import torch -from torch import Tensor, autograd -from torch.distributions.utils import lazy_property -from torch.nn.utils.rnn import PackedSequence -from torchrua import ReductionIndices - -from torchlatent.crf2 import crf_reduce -from torchlatent.semiring import Semiring, Log, Max - -__all__ = [ - 'compute_packed_sequence_scores', - 'compute_packed_sequence_partitions', - 'PackedCrfDistribution', -] - - -def compute_packed_sequence_scores(semiring: Type[Semiring]): - def _compute_packed_sequence_scores( - emissions: PackedSequence, tags: PackedSequence, - transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> Tensor: - return crf_reduce( - emissions=emissions, - targets=tags, - transitions=(transitions, head_transitions, last_transitions), - semiring=semiring, - ) - - return _compute_packed_sequence_scores - - -def compute_packed_sequence_partitions(semiring: Type[Semiring]): - def _compute_packed_sequence_partitions( - emissions: PackedSequence, indices: ReductionIndices, - transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor, eye: Tensor) -> Tensor: - h = emissions.batch_sizes[0].item() - t = torch.arange(transitions.size()[0], device=transitions.device) # [t] - c = torch.arange(transitions.size()[1], device=transitions.device) # [c] - - emission_scores = semiring.mul(transitions, emissions.data[..., None, :]) # [t, c, n, n] - emission_scores[:h] = eye[None, None, :, :] - emission_scores = semiring.reduce(tensor=emission_scores, indices=indices) - - emission_head_scores = emissions.data[:h, :, None, :] - transition_head_scores = head_transitions[t[:h, None], c[None, :], None, :] - transition_last_scores = last_transitions[t[:h, None], c[None, :], :, None] - - scores = semiring.mul(transition_head_scores, emission_head_scores) - scores = semiring.bmm(scores, emission_scores) - scores = semiring.bmm(scores, transition_last_scores)[..., 0, 0] - - if emissions.unsorted_indices is not None: - scores = scores[emissions.unsorted_indices] - return scores - - return _compute_packed_sequence_partitions - - -class PackedCrfDistribution(object): - def __init__(self, emissions: PackedSequence, indices: ReductionIndices, - transitions: Tensor, head_transitions: Tensor, last_transitions: Tensor) -> None: - super(PackedCrfDistribution, self).__init__() - self.emissions = emissions - self.indices = indices - - self.transitions = transitions - self.head_transitions = head_transitions - self.last_transitions = last_transitions - - def semiring_scores(self, semiring: Type[Semiring], tags: PackedSequence) -> Tensor: - return compute_packed_sequence_scores(semiring=semiring)( - emissions=self.emissions, tags=tags, - transitions=self.transitions, - head_transitions=self.head_transitions, - last_transitions=self.last_transitions, - ) - - def semiring_partitions(self, semiring: Type[Semiring]) -> Tensor: - return compute_packed_sequence_partitions(semiring=semiring)( - emissions=self.emissions, indices=self.indices, - transitions=self.transitions, - head_transitions=self.head_transitions, - last_transitions=self.last_transitions, - eye=semiring.eye_like(self.transitions), - ) - - def log_prob(self, tags: PackedSequence) -> Tensor: - return self.log_scores(tags=tags) - self.log_partitions - - def log_scores(self, tags: PackedSequence) -> Tensor: - return self.semiring_scores(semiring=Log, tags=tags) - - @lazy_property - def log_partitions(self) -> Tensor: - return self.semiring_partitions(semiring=Log) - - @lazy_property - def marginals(self) -> Tensor: - log_partitions = self.log_partitions - grad, = autograd.grad( - log_partitions, self.emissions.data, torch.ones_like(log_partitions), - create_graph=True, only_inputs=True, allow_unused=False, - ) - return grad - - @lazy_property - def argmax(self) -> PackedSequence: - max_partitions = self.semiring_partitions(semiring=Max) - - grad, = torch.autograd.grad( - max_partitions, self.emissions.data, torch.ones_like(max_partitions), - retain_graph=False, create_graph=False, allow_unused=False, - ) - return PackedSequence( - data=grad.argmax(dim=-1), - batch_sizes=self.emissions.batch_sizes, - sorted_indices=self.emissions.sorted_indices, - unsorted_indices=self.emissions.unsorted_indices, - ) From 490e232a2d60e8dc6d3eed295c4ab0b1a4ddb89d Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 23:01:17 +0900 Subject: [PATCH 035/102] Test: Init transitions w/ randn --- tests/test_crf.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_crf.py b/tests/test_crf.py index cc5f453..d6bd4fe 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -18,6 +18,10 @@ def test_crf_catted_fit(device, token_sizes, num_tags): actual_decoder = CrfDecoder(num_tags) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) + excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) + excepted_decoder.end_transitions.data = torch.randn_like(excepted_decoder.end_transitions) + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] @@ -60,6 +64,10 @@ def test_crf_catted_decode(device, token_sizes, num_tags): actual_decoder = CrfDecoder(num_tags) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) + excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) + excepted_decoder.end_transitions.data = torch.randn_like(excepted_decoder.end_transitions) + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] @@ -95,6 +103,10 @@ def test_crf_packed_fit(device, token_sizes, num_tags): actual_decoder = CrfDecoder(num_tags) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) + excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) + excepted_decoder.end_transitions.data = torch.randn_like(excepted_decoder.end_transitions) + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] @@ -142,6 +154,10 @@ def test_crf_packed_decode(device, token_sizes, num_tags): actual_decoder = CrfDecoder(num_tags) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) + excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) + excepted_decoder.end_transitions.data = torch.randn_like(excepted_decoder.end_transitions) + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] @@ -186,6 +202,10 @@ def test_conjugated_catted_fit(device, token_sizes, num_conjugates, num_tags): decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1) for _ in range(num_conjugates)] for index in range(num_conjugates): + decoders[index].transitions.data = torch.randn_like(decoders[index].transitions) + decoders[index].head_transitions.data = torch.randn_like(decoders[index].head_transitions) + decoders[index].last_transitions.data = torch.randn_like(decoders[index].last_transitions) + decoder.transitions.data[:, index] = decoders[index].transitions decoder.head_transitions.data[:, index] = decoders[index].head_transitions decoder.last_transitions.data[:, index] = decoders[index].last_transitions @@ -228,6 +248,10 @@ def test_conjugated_packed_fit(device, token_sizes, num_conjugates, num_tags): decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1) for _ in range(num_conjugates)] for index in range(num_conjugates): + decoders[index].transitions.data = torch.randn_like(decoders[index].transitions) + decoders[index].head_transitions.data = torch.randn_like(decoders[index].head_transitions) + decoders[index].last_transitions.data = torch.randn_like(decoders[index].last_transitions) + decoder.transitions.data[:, index] = decoders[index].transitions decoder.head_transitions.data[:, index] = decoders[index].head_transitions decoder.last_transitions.data[:, index] = decoders[index].last_transitions From adf78318571d9522de838575118ecfdcef8710e4 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 23:10:01 +0900 Subject: [PATCH 036/102] Test: Add test_dynamic_fit --- tests/test_crf.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++- tests/utils.py | 2 ++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/tests/test_crf.py b/tests/test_crf.py index d6bd4fe..f32bd27 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -2,7 +2,8 @@ import torchcrf from hypothesis import given from torch.testing import assert_close -from torchrua import cat_sequence, pad_sequence, pad_catted_indices, pad_packed_indices, pack_sequence +from torchrua import cat_sequence, pad_catted_indices, pack_catted_indices +from torchrua import pad_sequence, pad_packed_indices, pack_sequence from tests.strategies import devices, sizes, TOKEN_SIZE, TINY_BATCH_SIZE, NUM_CONJUGATES, TINY_TOKEN_SIZE from tests.utils import assert_grad_close, assert_equal @@ -281,3 +282,48 @@ def test_conjugated_packed_fit(device, token_sizes, num_conjugates, num_tags): assert_close(actual=actual, expected=expected) assert_grad_close(actual=actual, expected=expected, inputs=[x for xs in emissions for x in xs], check_stride=False) + + +@given( + device=devices(), + token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), + num_conjugates=sizes(NUM_CONJUGATES), + num_tags=sizes(TINY_TOKEN_SIZE), +) +def test_dynamic_fit(device, token_sizes, num_conjugates, num_tags): + packed_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) + catted_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) + + emissions = [ + torch.randn((token_size, 1, num_tags), requires_grad=True, device=device) + for token_size in token_sizes + ] + + targets = [ + torch.randint(0, num_tags, (token_size, 1), device=device) + for token_size in token_sizes + ] + + catted_decoder.transitions.data = torch.randn((sum(token_sizes), num_conjugates, num_tags, num_tags)) + catted_decoder.head_transitions.data = torch.randn((len(token_sizes), num_conjugates, num_tags)) + catted_decoder.last_transitions.data = torch.randn((len(token_sizes), num_conjugates, num_tags)) + + token_sizes = torch.tensor(token_sizes, device=device) + indices, _, sorted_indices, _ = pack_catted_indices(token_sizes=token_sizes, device=device) + + packed_decoder.transitions.data = catted_decoder.transitions[indices] + packed_decoder.head_transitions.data = catted_decoder.head_transitions[sorted_indices] + packed_decoder.last_transitions.data = catted_decoder.last_transitions[sorted_indices] + + packed_fit = packed_decoder.fit( + emissions=pack_sequence(emissions, device=device), + targets=pack_sequence(targets, device=device), + ) + + catted_fit = catted_decoder.fit( + emissions=cat_sequence(emissions, device=device), + targets=cat_sequence(targets, device=device), + ) + + assert_close(actual=packed_fit, expected=catted_fit) + assert_grad_close(actual=catted_fit, expected=catted_fit, inputs=emissions) diff --git a/tests/utils.py b/tests/utils.py index 5981653..a369ec9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -30,12 +30,14 @@ def assert_grad_close( actual, inputs, grad, create_graph=False, allow_unused=allow_unused, + retain_graph=True, ) expected_grads = torch.autograd.grad( expected, inputs, grad, create_graph=False, allow_unused=allow_unused, + retain_graph=True, ) for actual_grad, expected_grad in zip(actual_grads, expected_grads): From ed380d5d40bbeb318817255109bfefc2976a12f9 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 3 Apr 2022 23:26:07 +0900 Subject: [PATCH 037/102] Test: Update rtol and atol --- tests/test_cky.py | 132 +++++++++++++++++++++++----------------------- tests/test_crf.py | 26 +++++---- 2 files changed, 83 insertions(+), 75 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index b2fbcd7..168ac83 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -1,66 +1,66 @@ -import torch -from hypothesis import given, strategies as st -from torch.testing import assert_close -from torch_struct import TreeCRF -from torchrua import pack_sequence, cat_sequence - -from tests.strategies import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, devices, TINY_BATCH_SIZE -from tests.utils import assert_grad_close -from torchlatent.cky import CkyDistribution, cky_indices, CkyDecoder - - -@given( - device=devices(), - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - embedding_dim=sizes(EMBEDDING_DIM), - num_tags=sizes(TOKEN_SIZE), - bias=st.booleans(), -) -def test_cky_catted_max(device, token_sizes, embedding_dim, num_tags, bias): - sequence = cat_sequence([ - torch.randn((token_size, embedding_dim), requires_grad=True, device=device) - for token_size in token_sizes - ]) - - decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias) - cky = decoder.forward(sequence=sequence) - - assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) - - -@given( - device=devices(), - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - embedding_dim=sizes(EMBEDDING_DIM), - num_tags=sizes(TOKEN_SIZE), - bias=st.booleans(), -) -def test_cky_packed_max(device, token_sizes, embedding_dim, num_tags, bias): - sequence = pack_sequence([ - torch.randn((token_size, embedding_dim), requires_grad=True, device=device) - for token_size in token_sizes - ]) - - decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias) - cky = decoder.forward(sequence=sequence) - - assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) - - -@given( - device=devices(), - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_tags=sizes(TOKEN_SIZE), -) -def test_cky_log_partitions(device, token_sizes, num_tags): - scores = torch.randn( - (len(token_sizes), max(token_sizes), max(token_sizes), num_tags), - requires_grad=True, device=device, - ) - token_sizes = torch.tensor(token_sizes, device=device) - - excepted = TreeCRF(log_potentials=scores, lengths=token_sizes) - actual = CkyDistribution(log_potentials=scores, indices=cky_indices(token_sizes=token_sizes, device=device)) - - assert_close(actual=actual.log_partitions, expected=excepted.partition) - assert_grad_close(actual=actual.log_partitions, expected=excepted.partition, inputs=scores, rtol=1e-5, atol=1e-5) +# import torch +# from hypothesis import given, strategies as st +# from torch.testing import assert_close +# from torch_struct import TreeCRF +# from torchrua import pack_sequence, cat_sequence +# +# from tests.strategies import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, devices, TINY_BATCH_SIZE +# from tests.utils import assert_grad_close +# from torchlatent.cky import CkyDistribution, cky_indices, CkyDecoder +# +# +# @given( +# device=devices(), +# token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), +# embedding_dim=sizes(EMBEDDING_DIM), +# num_tags=sizes(TOKEN_SIZE), +# bias=st.booleans(), +# ) +# def test_cky_catted_max(device, token_sizes, embedding_dim, num_tags, bias): +# sequence = cat_sequence([ +# torch.randn((token_size, embedding_dim), requires_grad=True, device=device) +# for token_size in token_sizes +# ]) +# +# decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias) +# cky = decoder.forward(sequence=sequence) +# +# assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) +# +# +# @given( +# device=devices(), +# token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), +# embedding_dim=sizes(EMBEDDING_DIM), +# num_tags=sizes(TOKEN_SIZE), +# bias=st.booleans(), +# ) +# def test_cky_packed_max(device, token_sizes, embedding_dim, num_tags, bias): +# sequence = pack_sequence([ +# torch.randn((token_size, embedding_dim), requires_grad=True, device=device) +# for token_size in token_sizes +# ]) +# +# decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias) +# cky = decoder.forward(sequence=sequence) +# +# assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) +# +# +# @given( +# device=devices(), +# token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), +# num_tags=sizes(TOKEN_SIZE), +# ) +# def test_cky_log_partitions(device, token_sizes, num_tags): +# scores = torch.randn( +# (len(token_sizes), max(token_sizes), max(token_sizes), num_tags), +# requires_grad=True, device=device, +# ) +# token_sizes = torch.tensor(token_sizes, device=device) +# +# excepted = TreeCRF(log_potentials=scores, lengths=token_sizes) +# actual = CkyDistribution(log_potentials=scores, indices=cky_indices(token_sizes=token_sizes, device=device)) +# +# assert_close(actual=actual.log_partitions, expected=excepted.partition) +# assert_grad_close(actual=actual.log_partitions, expected=excepted.partition, inputs=scores, rtol=1e-5, atol=1e-5) diff --git a/tests/test_crf.py b/tests/test_crf.py index f32bd27..e8a4c9c 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -52,8 +52,8 @@ def test_crf_catted_fit(device, token_sizes, num_tags): mask=mask.byte(), reduction='none', ).neg() - assert_close(actual=actual, expected=excepted) - assert_grad_close(actual=actual, expected=excepted, inputs=emissions) + assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) + assert_grad_close(actual=actual, expected=excepted, inputs=emissions, rtol=1e-4, atol=1e-4) @given( @@ -142,8 +142,8 @@ def test_crf_packed_fit(device, token_sizes, num_tags): mask=mask.byte(), reduction='none', ).neg() - assert_close(actual=actual, expected=excepted) - assert_grad_close(actual=actual, expected=excepted, inputs=emissions) + assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) + assert_grad_close(actual=actual, expected=excepted, inputs=emissions, rtol=1e-4, atol=1e-4) @given( @@ -234,8 +234,12 @@ def test_conjugated_catted_fit(device, token_sizes, num_conjugates, num_tags): for index in range(num_conjugates) ], dim=1) - assert_close(actual=actual, expected=expected) - assert_grad_close(actual=actual, expected=expected, inputs=[x for xs in emissions for x in xs], check_stride=False) + assert_close(actual=actual, expected=expected, rtol=1e-4, atol=1e-4) + assert_grad_close( + actual=actual, expected=expected, + inputs=[x for xs in emissions for x in xs], + rtol=1e-4, atol=1e-4, check_stride=False, + ) @given( @@ -281,7 +285,11 @@ def test_conjugated_packed_fit(device, token_sizes, num_conjugates, num_tags): ], dim=1) assert_close(actual=actual, expected=expected) - assert_grad_close(actual=actual, expected=expected, inputs=[x for xs in emissions for x in xs], check_stride=False) + assert_grad_close( + actual=actual, expected=expected, + inputs=[x for xs in emissions for x in xs], + rtol=1e-4, atol=1e-4, check_stride=False, + ) @given( @@ -325,5 +333,5 @@ def test_dynamic_fit(device, token_sizes, num_conjugates, num_tags): targets=cat_sequence(targets, device=device), ) - assert_close(actual=packed_fit, expected=catted_fit) - assert_grad_close(actual=catted_fit, expected=catted_fit, inputs=emissions) + assert_close(actual=packed_fit, expected=catted_fit, rtol=1e-4, atol=1e-4) + assert_grad_close(actual=catted_fit, expected=catted_fit, inputs=emissions, rtol=1e-4, atol=1e-4) From 4b054766250877d6be14452b598b8386cd3de759 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 4 Apr 2022 09:20:02 +0900 Subject: [PATCH 038/102] Test: Update device strategy --- tests/strategies.py | 14 +++++--------- tests/test_cky.py | 11 ++++------- tests/test_crf.py | 23 ++++++++--------------- tests/test_functional.py | 8 +++----- 4 files changed, 20 insertions(+), 36 deletions(-) diff --git a/tests/strategies.py b/tests/strategies.py index 980e6da..c1f8975 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -12,15 +12,11 @@ EMBEDDING_DIM = 25 - -@st.composite -def devices(draw): - if not torch.cuda.is_available(): - device = torch.device('cpu') - else: - device = torch.device('cuda:0') - _ = torch.empty((1,), device=device) - return device +if torch.cuda.is_available(): + device = torch.device('cuda:0') +else: + device = torch.device('cpu') +_ = torch.empty((1,), device=device) @st.composite diff --git a/tests/test_cky.py b/tests/test_cky.py index 168ac83..e1c11b1 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -4,19 +4,18 @@ # from torch_struct import TreeCRF # from torchrua import pack_sequence, cat_sequence # -# from tests.strategies import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, devices, TINY_BATCH_SIZE +# from tests.strategies import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, device, TINY_BATCH_SIZE # from tests.utils import assert_grad_close # from torchlatent.cky import CkyDistribution, cky_indices, CkyDecoder # # # @given( -# device=devices(), # token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), # embedding_dim=sizes(EMBEDDING_DIM), # num_tags=sizes(TOKEN_SIZE), # bias=st.booleans(), # ) -# def test_cky_catted_max(device, token_sizes, embedding_dim, num_tags, bias): +# def test_cky_catted_max(token_sizes, embedding_dim, num_tags, bias): # sequence = cat_sequence([ # torch.randn((token_size, embedding_dim), requires_grad=True, device=device) # for token_size in token_sizes @@ -29,13 +28,12 @@ # # # @given( -# device=devices(), # token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), # embedding_dim=sizes(EMBEDDING_DIM), # num_tags=sizes(TOKEN_SIZE), # bias=st.booleans(), # ) -# def test_cky_packed_max(device, token_sizes, embedding_dim, num_tags, bias): +# def test_cky_packed_max(token_sizes, embedding_dim, num_tags, bias): # sequence = pack_sequence([ # torch.randn((token_size, embedding_dim), requires_grad=True, device=device) # for token_size in token_sizes @@ -48,11 +46,10 @@ # # # @given( -# device=devices(), # token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), # num_tags=sizes(TOKEN_SIZE), # ) -# def test_cky_log_partitions(device, token_sizes, num_tags): +# def test_cky_log_partitions(token_sizes, num_tags): # scores = torch.randn( # (len(token_sizes), max(token_sizes), max(token_sizes), num_tags), # requires_grad=True, device=device, diff --git a/tests/test_crf.py b/tests/test_crf.py index e8a4c9c..57251cf 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -5,17 +5,16 @@ from torchrua import cat_sequence, pad_catted_indices, pack_catted_indices from torchrua import pad_sequence, pad_packed_indices, pack_sequence -from tests.strategies import devices, sizes, TOKEN_SIZE, TINY_BATCH_SIZE, NUM_CONJUGATES, TINY_TOKEN_SIZE +from tests.strategies import device, sizes, TOKEN_SIZE, TINY_BATCH_SIZE, NUM_CONJUGATES, TINY_TOKEN_SIZE from tests.utils import assert_grad_close, assert_equal from torchlatent.crf import CrfDecoder @given( - device=devices(), token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), num_tags=sizes(TOKEN_SIZE), ) -def test_crf_catted_fit(device, token_sizes, num_tags): +def test_crf_catted_fit(token_sizes, num_tags): actual_decoder = CrfDecoder(num_tags) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) @@ -57,11 +56,10 @@ def test_crf_catted_fit(device, token_sizes, num_tags): @given( - device=devices(), token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), num_tags=sizes(TOKEN_SIZE), ) -def test_crf_catted_decode(device, token_sizes, num_tags): +def test_crf_catted_decode(token_sizes, num_tags): actual_decoder = CrfDecoder(num_tags) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) @@ -96,11 +94,10 @@ def test_crf_catted_decode(device, token_sizes, num_tags): @given( - device=devices(), token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), num_tags=sizes(TOKEN_SIZE), ) -def test_crf_packed_fit(device, token_sizes, num_tags): +def test_crf_packed_fit(token_sizes, num_tags): actual_decoder = CrfDecoder(num_tags) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) @@ -147,11 +144,10 @@ def test_crf_packed_fit(device, token_sizes, num_tags): @given( - device=devices(), token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), num_tags=sizes(TOKEN_SIZE), ) -def test_crf_packed_decode(device, token_sizes, num_tags): +def test_crf_packed_decode(token_sizes, num_tags): actual_decoder = CrfDecoder(num_tags) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) @@ -193,12 +189,11 @@ def test_crf_packed_decode(device, token_sizes, num_tags): @given( - device=devices(), token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), num_conjugates=sizes(NUM_CONJUGATES), num_tags=sizes(TINY_TOKEN_SIZE), ) -def test_conjugated_catted_fit(device, token_sizes, num_conjugates, num_tags): +def test_conjugated_catted_fit(token_sizes, num_conjugates, num_tags): decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1) for _ in range(num_conjugates)] @@ -243,12 +238,11 @@ def test_conjugated_catted_fit(device, token_sizes, num_conjugates, num_tags): @given( - device=devices(), token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), num_conjugates=sizes(NUM_CONJUGATES), num_tags=sizes(TINY_TOKEN_SIZE), ) -def test_conjugated_packed_fit(device, token_sizes, num_conjugates, num_tags): +def test_conjugated_packed_fit(token_sizes, num_conjugates, num_tags): decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1) for _ in range(num_conjugates)] @@ -293,12 +287,11 @@ def test_conjugated_packed_fit(device, token_sizes, num_conjugates, num_tags): @given( - device=devices(), token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), num_conjugates=sizes(NUM_CONJUGATES), num_tags=sizes(TINY_TOKEN_SIZE), ) -def test_dynamic_fit(device, token_sizes, num_conjugates, num_tags): +def test_dynamic_fit(token_sizes, num_conjugates, num_tags): packed_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) catted_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) diff --git a/tests/test_functional.py b/tests/test_functional.py index c12b74a..abad128 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,16 +1,15 @@ import torch from hypothesis import given, strategies as st -from tests.strategies import devices, sizes, TINY_TOKEN_SIZE, TINY_BATCH_SIZE +from tests.strategies import device, sizes, TINY_TOKEN_SIZE, TINY_BATCH_SIZE from tests.utils import assert_close, assert_grad_close from torchlatent.functional import logaddexp, logsumexp @given( - device=devices(), token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE) ) -def test_logaddexp(device, token_sizes): +def test_logaddexp(token_sizes): x = torch.randn(token_sizes, device=device, requires_grad=True) y = torch.randn(token_sizes, device=device, requires_grad=True) @@ -23,10 +22,9 @@ def test_logaddexp(device, token_sizes): @given( data=st.data(), - device=devices(), token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE) ) -def test_logsumexp(data, device, token_sizes): +def test_logsumexp(data, token_sizes): tensor = torch.randn(token_sizes, device=device, requires_grad=True) dim = data.draw(st.integers(min_value=-len(token_sizes), max_value=len(token_sizes) - 1)) From 2fb08427a18c8a48a0c64b49c199a6e040e7a9f4 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 4 Apr 2022 09:26:15 +0900 Subject: [PATCH 039/102] Test: Update strategies --- tests/strategies.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/tests/strategies.py b/tests/strategies.py index c1f8975..5386d4d 100644 --- a/tests/strategies.py +++ b/tests/strategies.py @@ -2,15 +2,17 @@ from hypothesis import strategies as st -TINY_BATCH_SIZE = 6 -TINY_TOKEN_SIZE = 12 - -BATCH_SIZE = 24 +BATCH_SIZE = 25 TOKEN_SIZE = 50 -NUM_TAGS = 8 NUM_CONJUGATES = 5 +NUM_TAGS = 15 +EMBEDDING_DIM = 16 -EMBEDDING_DIM = 25 +TINY_BATCH_SIZE = 5 +TINY_TOKEN_SIZE = 10 +TINY_NUM_CONJUGATES = 3 +TINY_NUM_TAGS = 3 +TINY_EMBEDDING_DIM = 4 if torch.cuda.is_available(): device = torch.device('cuda:0') @@ -20,13 +22,10 @@ @st.composite -def sizes(draw, *size: int, min_size: int = 1): - max_size, *size = size +def sizes(draw, *max_sizes: int, min_size: int = 1): + max_size, *max_sizes = max_sizes + n = draw(st.integers(min_value=min_size, max_value=max_size)) - if len(size) == 0: - return draw(st.integers(min_value=min_size, max_value=max_size)) - else: - return [ - draw(sizes(*size, min_size=min_size)) - for _ in range(draw(st.integers(min_value=min_size, max_value=max_size))) - ] + if len(max_sizes) == 0: + return n + return [draw(sizes(*max_sizes, min_size=min_size)) for _ in range(n)] From 8fe940a542d4df0c837018afdd13115c2f961ad0 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 7 Apr 2022 20:10:41 +0900 Subject: [PATCH 040/102] Refactor: Update Cky --- tests/test_cky.py | 126 ++++++++++++++++++++-------------------- torchlatent/cky.py | 90 +++++++++++++++++++++------- torchlatent/crf.py | 24 ++++---- torchlatent/semiring.py | 73 +---------------------- torchlatent/types.py | 1 + 5 files changed, 147 insertions(+), 167 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index e1c11b1..684bb26 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -1,63 +1,63 @@ -# import torch -# from hypothesis import given, strategies as st -# from torch.testing import assert_close -# from torch_struct import TreeCRF -# from torchrua import pack_sequence, cat_sequence -# -# from tests.strategies import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, device, TINY_BATCH_SIZE -# from tests.utils import assert_grad_close -# from torchlatent.cky import CkyDistribution, cky_indices, CkyDecoder -# -# -# @given( -# token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), -# embedding_dim=sizes(EMBEDDING_DIM), -# num_tags=sizes(TOKEN_SIZE), -# bias=st.booleans(), -# ) -# def test_cky_catted_max(token_sizes, embedding_dim, num_tags, bias): -# sequence = cat_sequence([ -# torch.randn((token_size, embedding_dim), requires_grad=True, device=device) -# for token_size in token_sizes -# ]) -# -# decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias) -# cky = decoder.forward(sequence=sequence) -# -# assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) -# -# -# @given( -# token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), -# embedding_dim=sizes(EMBEDDING_DIM), -# num_tags=sizes(TOKEN_SIZE), -# bias=st.booleans(), -# ) -# def test_cky_packed_max(token_sizes, embedding_dim, num_tags, bias): -# sequence = pack_sequence([ -# torch.randn((token_size, embedding_dim), requires_grad=True, device=device) -# for token_size in token_sizes -# ]) -# -# decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias) -# cky = decoder.forward(sequence=sequence) -# -# assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) -# -# -# @given( -# token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), -# num_tags=sizes(TOKEN_SIZE), -# ) -# def test_cky_log_partitions(token_sizes, num_tags): -# scores = torch.randn( -# (len(token_sizes), max(token_sizes), max(token_sizes), num_tags), -# requires_grad=True, device=device, -# ) -# token_sizes = torch.tensor(token_sizes, device=device) -# -# excepted = TreeCRF(log_potentials=scores, lengths=token_sizes) -# actual = CkyDistribution(log_potentials=scores, indices=cky_indices(token_sizes=token_sizes, device=device)) -# -# assert_close(actual=actual.log_partitions, expected=excepted.partition) -# assert_grad_close(actual=actual.log_partitions, expected=excepted.partition, inputs=scores, rtol=1e-5, atol=1e-5) +import torch +from hypothesis import given, strategies as st +from torch.testing import assert_close +from torch_struct import TreeCRF +from torchrua import pack_sequence, cat_sequence + +from tests.strategies import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, device, TINY_BATCH_SIZE +from tests.utils import assert_grad_close +from torchlatent.cky import CkyDistribution, cky_partition_indices, CkyDecoder + + +@given( + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + embedding_dim=sizes(EMBEDDING_DIM), + num_tags=sizes(TOKEN_SIZE), + bias=st.booleans(), +) +def test_cky_catted_max(token_sizes, embedding_dim, num_tags, bias): + sequence = cat_sequence([ + torch.randn((token_size, embedding_dim), requires_grad=True, device=device) + for token_size in token_sizes + ]) + + decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias) + cky = decoder.forward(sequence=sequence) + + assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) + + +@given( + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + embedding_dim=sizes(EMBEDDING_DIM), + num_tags=sizes(TOKEN_SIZE), + bias=st.booleans(), +) +def test_cky_packed_max(token_sizes, embedding_dim, num_tags, bias): + sequence = pack_sequence([ + torch.randn((token_size, embedding_dim), requires_grad=True, device=device) + for token_size in token_sizes + ]) + + decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias) + cky = decoder.forward(sequence=sequence) + + assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) + + +@given( + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_tags=sizes(TOKEN_SIZE), +) +def test_cky_log_partitions(token_sizes, num_tags): + scores = torch.randn( + (len(token_sizes), max(token_sizes), max(token_sizes), num_tags), + requires_grad=True, device=device, + ) + token_sizes = torch.tensor(token_sizes, device=device) + + excepted = TreeCRF(log_potentials=scores, lengths=token_sizes) + actual = CkyDistribution(emissions=scores, indices=cky_partition_indices(token_sizes=token_sizes, device=device)) + + assert_close(actual=actual.log_partitions, expected=excepted.partition) + assert_grad_close(actual=actual.log_partitions, expected=excepted.partition, inputs=scores, rtol=1e-5, atol=1e-5) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 714f66b..ee09f1c 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -1,4 +1,5 @@ from abc import ABCMeta +from functools import singledispatch from typing import Tuple, NamedTuple from typing import Type @@ -8,13 +9,61 @@ from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence from torch.types import Device -from torchrua import CattedSequence, transpose_sizes, pack_catted_sequence -from torchrua import major_sizes_to_ptr, accumulate_sizes -from torchrua import pad_packed_sequence, pad_catted_sequence from torchlatent.abc import DistributionABC -from torchlatent.semiring import Semiring, Log, Max, segment_indices +from torchlatent.semiring import Semiring, Log, Max from torchlatent.types import Sequence +from torchrua import CattedSequence, transpose_sizes, pack_catted_sequence, cat_packed_indices +from torchrua import major_sizes_to_ptr, accumulate_sizes +from torchrua import pad_packed_sequence, pad_catted_sequence + +__all__ = [ + 'cky_scores_indices', + 'cky_scores_catted_indices', + 'cky_scores_packed_indices', + + 'CkyIndices', + 'cky_partition_indices', + 'cky_partition', + + 'CkyDistribution', + 'CkyDecoderABC', + 'CkyDecoder', +] + + +@singledispatch +def cky_scores_indices(sequence: Sequence, device: Device = None): + raise KeyError(f'type {type(sequence)} is not supported') + + +@cky_scores_indices.register +def cky_scores_catted_indices(sequence: CattedSequence, device: Device = None): + if device is None: + device = sequence.data.device + + token_sizes = sequence.token_sizes.to(device=device) + + batch_ptr = torch.repeat_interleave(repeats=token_sizes) + return ..., batch_ptr, token_sizes + + +@cky_scores_indices.register +def cky_scores_packed_indices(sequence: PackedSequence, device: Device = None): + if device is None: + device = sequence.data.device + + batch_sizes = sequence.batch_sizes.to(device=device) + unsorted_indices = sequence.unsorted_indices.to(device=device) + + indices, token_sizes = cat_packed_indices( + batch_sizes=batch_sizes, + unsorted_indices=unsorted_indices, + device=device, + ) + + batch_ptr = torch.repeat_interleave(repeats=token_sizes) + return indices, batch_ptr, token_sizes class CkyIndices(NamedTuple): @@ -26,7 +75,7 @@ class CkyIndices(NamedTuple): @torch.no_grad() -def cky_indices(token_sizes: Tensor, device: Device = None): +def cky_partition_indices(token_sizes: Tensor, device: Device = None): if device is None: device = token_sizes.device @@ -35,8 +84,7 @@ def cky_indices(token_sizes: Tensor, device: Device = None): token_ptr, batch_ptr = major_sizes_to_ptr(sizes=token_sizes) x_ptr, z_ptr = major_sizes_to_ptr(sizes=token_ptr + 1) - batch_ptr = batch_ptr[z_ptr] - y_ptr = z_ptr - acc_token_sizes[batch_ptr] + y_ptr = token_ptr[z_ptr] token_size = token_sizes.max().item() cache_size, = token_ptr.size() @@ -44,7 +92,7 @@ def cky_indices(token_sizes: Tensor, device: Device = None): return CkyIndices( token_size=token_size, cache_size=cache_size, - src=((y_ptr - x_ptr, z_ptr), (batch_ptr, x_ptr, y_ptr)), + src=((y_ptr - x_ptr, z_ptr), (batch_ptr[z_ptr], x_ptr, y_ptr)), tgt=(token_sizes - 1, acc_token_sizes), ) @@ -53,9 +101,9 @@ def cky_partition(data: Tensor, indices: CkyIndices, semiring: Type[Semiring]) - token_size, cache_size, (src1, src2), tgt = indices size = (token_size, cache_size, *data.size()[3:]) - tensor0 = torch.full(size, fill_targets=semiring.zero, requires_grad=False) - tensor1 = torch.full(size, fill_targets=semiring.zero, requires_grad=False) - tensor2 = torch.full(size, fill_targets=semiring.zero, requires_grad=False) + tensor0 = torch.full(size, fill_value=semiring.zero, requires_grad=False) + tensor1 = torch.full(size, fill_value=semiring.zero, requires_grad=False) + tensor2 = torch.full(size, fill_value=semiring.zero, requires_grad=False) tensor0[src1] = data[src2] tensor1[0, :] = tensor2[-1, :] = tensor0[0, :] @@ -70,27 +118,27 @@ def cky_partition(data: Tensor, indices: CkyIndices, semiring: Type[Semiring]) - class CkyDistribution(DistributionABC): - def __init__(self, log_potentials: Tensor, indices: CkyIndices) -> None: + def __init__(self, emissions: Tensor, indices: CkyIndices) -> None: super(CkyDistribution, self).__init__(validate_args=False) - self.log_potentials = log_potentials + self.emissions = emissions self.indices = indices def log_scores(self, sequence: Sequence) -> Tensor: - indices, batch_ptr, sizes = segment_indices(sequence=sequence) + indices, batch_ptr, sizes = cky_scores_indices(sequence) data = sequence.data[indices] return Log.segment_prod( - tensor=self.log_potentials[batch_ptr, data[..., 0], data[..., 1], data[..., 2]], + tensor=self.emissions[batch_ptr, data[..., 0], data[..., 1], data[..., 2]], sizes=sizes, ) @lazy_property def log_partitions(self) -> Tensor: - return cky_partition(data=Log.sum(self.log_potentials, dim=-1), indices=self.indices, semiring=Log) + return cky_partition(data=Log.sum(self.emissions, dim=-1), indices=self.indices, semiring=Log) @lazy_property def max(self) -> Tensor: - return cky_partition(data=Max.sum(self.log_potentials, dim=-1), indices=self.indices, semiring=Max) + return cky_partition(data=Max.sum(self.emissions, dim=-1), indices=self.indices, semiring=Max) @lazy_property def argmax(self) -> Tensor: @@ -108,7 +156,7 @@ def argmax(self) -> Tensor: @lazy_property def marginals(self) -> Tensor: grad, = torch.autograd.grad( - self.log_partitions, self.log_potentials, torch.ones_like(self.log_partitions), + self.log_partitions, self.emissions, torch.ones_like(self.log_partitions), create_graph=True, only_inputs=True, allow_unused=False, ) return grad @@ -139,10 +187,10 @@ def forward(self, sequence: Sequence, indices: CkyIndices = None) -> CkyDistribu raise KeyError(f'type {type(sequence)} is not supported') if indices is None: - indices = cky_indices(token_sizes=token_sizes, device=features.device) + indices = cky_partition_indices(token_sizes=token_sizes, device=features.device) return CkyDistribution( - log_potentials=self.forward_scores(features=features), + emissions=self.forward_scores(features=features), indices=indices, ) @@ -158,7 +206,7 @@ def decode(self, sequence: Sequence, indices: CkyIndices = None) -> Sequence: if isinstance(sequence, PackedSequence): token_sizes = transpose_sizes(sizes=sequence.batch_sizes)[sequence.unsorted_indices] * 2 - 1 - return pack_catted_sequence(sequence=dist.argmax, token_sizes=token_sizes) + return pack_catted_sequence(sequence=CattedSequence(dist.argmax, token_sizes=token_sizes)) raise KeyError(f'type {type(sequence)} is not supported') diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 9ff9373..66ef1ae 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -8,13 +8,14 @@ from torch.distributions.utils import lazy_property from torch.nn import init from torch.types import Device + +from torchlatent.abc import DistributionABC +from torchlatent.semiring import Semiring, Log, Max from torchrua import ReductionIndices, accumulate_sizes from torchrua import head_catted_indices, last_catted_indices, reduce_catted_indices from torchrua import head_packed_indices, last_packed_indices, reduce_packed_indices from torchrua import roll_catted_indices, cat_packed_indices, CattedSequence, PackedSequence - -from torchlatent.abc import DistributionABC -from torchlatent.semiring import Semiring, Log, Max +from functools import singledispatch Sequence = Union[CattedSequence, PackedSequence] @@ -29,7 +30,12 @@ class CrfIndices(NamedTuple): indices: ReductionIndices -@torch.no_grad() +@singledispatch +def broadcast_shapes(sequence: Sequence, transitions: Tuple[Tensor, Tensor, Tensor]) -> Sequence: + raise TypeError(f'type {type(sequence)} is not supported') + + +@broadcast_shapes.register def broadcast_catted_shapes(sequence: CattedSequence, transitions: Tuple[Tensor, Tensor, Tensor]): sequence, token_sizes = sequence transitions, head_transitions, last_transitions = transitions @@ -44,7 +50,7 @@ def broadcast_catted_shapes(sequence: CattedSequence, transitions: Tuple[Tensor, return torch.broadcast_shapes((t1, c1, h1), (t2, c2, 1), (1, c3, h3), (1, c4, h4)) -@torch.no_grad() +@broadcast_shapes.register def broadcast_packed_shapes(sequence: PackedSequence, transitions: Tuple[Tensor, Tensor, Tensor]): sequence, batch_sizes, _, _ = sequence transitions, head_transitions, last_transitions = transitions @@ -251,13 +257,7 @@ def extra_repr(self) -> str: def forward_parameters(self, emissions: Sequence): transitions = (self.transitions, self.head_transitions, self.last_transitions) - - if isinstance(emissions, CattedSequence): - t, c, h = broadcast_catted_shapes(sequence=emissions, transitions=transitions) - elif isinstance(emissions, PackedSequence): - t, c, h = broadcast_packed_shapes(sequence=emissions, transitions=transitions) - else: - raise KeyError(f'type {type(emissions)} is not supported') + t, c, h = broadcast_shapes(emissions, transitions=transitions) emissions = emissions.data.expand((t, c, -1)) transitions = self.transitions.expand((t, c, -1, -1)) diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 62f7b4b..5c8c708 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,13 +1,9 @@ import torch from torch import Tensor -from torch.nn.utils.rnn import PackedSequence -from torch.types import Device -from torchrua import cat_packed_indices, cat_padded_indices, CattedSequence -from torchrua.reduction import reduce_sequence, ReductionIndices -from torchrua.scatter import scatter_add, scatter_logsumexp from torchlatent.functional import logsumexp, logaddexp -from torchlatent.types import Sequence +from torchrua.reduction import reduce_sequence, ReductionIndices +from torchrua.scatter import scatter_add, scatter_logsumexp __all__ = [ 'Semiring', @@ -15,71 +11,6 @@ ] -@torch.no_grad() -def segment_indices(sequence: Sequence, batch_first: bool = True, device: Device = None): - if isinstance(sequence, CattedSequence): - data, token_sizes = sequence - return segment_catted_indices(token_sizes=token_sizes, device=data.device) - - if isinstance(sequence, PackedSequence): - data, batch_sizes, _, unsorted_indices = sequence - return segment_packed_indices(batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=data.device) - - if isinstance(sequence, tuple) and torch.is_tensor(sequence[0]) and torch.is_tensor(sequence[1]): - data, token_sizes = sequence - return segment_padded_indices(token_sizes=token_sizes, batch_first=batch_first, device=device) - - raise KeyError(f'type {type(sequence)} is not supported') - - -@torch.no_grad() -def segment_catted_indices(token_sizes: Tensor, device: Device = None): - if device is None: - device = token_sizes.device - - token_sizes = token_sizes.to(device=device) - - batch_ptr = torch.repeat_interleave(repeats=token_sizes) - return torch.arange(batch_ptr.size()[0], device=device), batch_ptr, token_sizes - - -@torch.no_grad() -def segment_packed_indices(batch_sizes: Tensor, unsorted_indices: Tensor, device: Device = None): - if device is None: - if unsorted_indices is not None: - device = unsorted_indices.device - else: - device = batch_sizes.device - - batch_sizes = batch_sizes.to(device=device) - unsorted_indices = unsorted_indices.to(device=device) - - indices, token_sizes = cat_packed_indices( - batch_sizes=batch_sizes, - unsorted_indices=unsorted_indices, - device=device, - ) - batch_ptr = torch.repeat_interleave(repeats=token_sizes) - return indices, batch_ptr, token_sizes - - -@torch.no_grad() -def segment_padded_indices(token_sizes: Tensor, batch_first: bool, device: Device = None): - if device is None: - device = token_sizes.device - - token_sizes = token_sizes.to(device=device) - - if batch_first: - (batch_ptr, token_ptr), _ = cat_padded_indices( - token_sizes=token_sizes, batch_first=batch_first, device=device) - return (batch_ptr, token_ptr), batch_ptr, token_sizes - else: - (token_ptr, batch_ptr), _ = cat_padded_indices( - token_sizes=token_sizes, batch_first=batch_first, device=device) - return (token_ptr, batch_ptr), batch_ptr, token_sizes - - class Semiring(object): zero: float one: float diff --git a/torchlatent/types.py b/torchlatent/types.py index 117cb50..6cd724a 100644 --- a/torchlatent/types.py +++ b/torchlatent/types.py @@ -1,6 +1,7 @@ from typing import Union from torch.nn.utils.rnn import PackedSequence + from torchrua import CattedSequence, PaddedSequence Sequence = Union[CattedSequence, PackedSequence, PaddedSequence] From 245c54e4e9fff83b3f688dc5377a85b2635284b2 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 7 Apr 2022 20:19:39 +0900 Subject: [PATCH 041/102] Refactor: Add crf_scores_indices --- torchlatent/crf.py | 78 +++++++++++++++++++++------------------------- 1 file changed, 36 insertions(+), 42 deletions(-) diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 66ef1ae..6b0a5dd 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -1,3 +1,4 @@ +from functools import singledispatch from typing import NamedTuple, Union from typing import Tuple from typing import Type @@ -15,7 +16,6 @@ from torchrua import head_catted_indices, last_catted_indices, reduce_catted_indices from torchrua import head_packed_indices, last_packed_indices, reduce_packed_indices from torchrua import roll_catted_indices, cat_packed_indices, CattedSequence, PackedSequence -from functools import singledispatch Sequence = Union[CattedSequence, PackedSequence] @@ -65,12 +65,17 @@ def broadcast_packed_shapes(sequence: PackedSequence, transitions: Tuple[Tensor, return torch.broadcast_shapes((t1, c1, h1), (t2, c2, 1), (1, c3, h3), (1, c4, h4)) -@torch.no_grad() -def crf_reduce_catted_indices(token_sizes: Tensor, device: Device = None): +@singledispatch +def crf_scores_indices(sequence: Sequence, device: Device = None): + raise TypeError(f'type {type(sequence)} is not supported') + + +@crf_scores_indices.register +def crf_scores_catted_indices(sequence: CattedSequence, device: Device = None): if device is None: - device = token_sizes.device + device = sequence.data.device - token_sizes = token_sizes.to(device=device) + token_sizes = sequence.token_sizes.to(device=device) curr = torch.arange(token_sizes.sum().item(), device=device) unsorted_indices = torch.arange(token_sizes.size()[0], device=device) @@ -80,16 +85,13 @@ def crf_reduce_catted_indices(token_sizes: Tensor, device: Device = None): return head, last, prev, curr, token_sizes, unsorted_indices -@torch.no_grad() -def crf_reduce_packed_indices(batch_sizes: Tensor, unsorted_indices: Tensor, device: Device): +@crf_scores_indices.register +def crf_scores_packed_indices(sequence: PackedSequence, device: Device = None): if device is None: - if unsorted_indices is not None: - device = unsorted_indices.device - else: - device = batch_sizes.device + device = sequence.data.device - batch_sizes = batch_sizes.to(device=device) - unsorted_indices = unsorted_indices.to(device=device) + batch_sizes = sequence.batch_sizes.to(device=device) + unsorted_indices = sequence.unsorted_indices.to(device=device) curr, token_sizes = cat_packed_indices(batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=device) prev = roll_catted_indices(token_sizes=token_sizes, device=device, shifts=1) @@ -98,23 +100,33 @@ def crf_reduce_packed_indices(batch_sizes: Tensor, unsorted_indices: Tensor, dev return head, last, curr[prev], curr, token_sizes, unsorted_indices +def crf_scores(sequence: Sequence, emissions: Tensor, + transitions: Tuple[Tensor, Tensor, Tensor], semiring: Type[Semiring]) -> Tensor: + head, last, prev, curr, token_sizes, unsorted_indices = crf_scores_indices(sequence) + + sequence, *_ = sequence + transitions, head_transitions, last_transitions = transitions + c = torch.arange(transitions.size()[1], device=emissions.device) + + emissions = emissions[curr[:, None], c[None, :], sequence[curr]] + transitions = transitions[curr[:, None], c[None, :], sequence[prev], sequence[curr]] + transitions[accumulate_sizes(sizes=token_sizes)] = semiring.one + head_transitions = head_transitions[unsorted_indices[:, None], c[None, :], sequence[head]] + last_transitions = last_transitions[unsorted_indices[:, None], c[None, :], sequence[last]] + + emissions = semiring.segment_prod(semiring.mul(emissions, transitions), sizes=token_sizes) + return semiring.mul(emissions, semiring.mul(head_transitions, last_transitions)) + + @torch.no_grad() def crf_indices(emissions: Sequence) -> CrfIndices: + head, last, prev, curr, token_sizes, unsorted_indices = crf_scores_indices(emissions) if isinstance(emissions, CattedSequence): - head, last, prev, curr, token_sizes, unsorted_indices = crf_reduce_catted_indices( - token_sizes=emissions.token_sizes, - device=emissions.data.device, - ) indices = reduce_catted_indices( token_sizes=emissions.token_sizes, device=emissions.data.device, ) elif isinstance(emissions, PackedSequence): - head, last, prev, curr, token_sizes, unsorted_indices = crf_reduce_packed_indices( - batch_sizes=emissions.batch_sizes, - unsorted_indices=emissions.unsorted_indices, - device=emissions.data.device, - ) indices = reduce_packed_indices( batch_sizes=emissions.batch_sizes, unsorted_indices=emissions.unsorted_indices, @@ -132,23 +144,6 @@ def crf_indices(emissions: Sequence) -> CrfIndices: ) -def crf_reduce(emissions: Tensor, targets: Tensor, transitions: Tuple[Tensor, Tensor, Tensor], - indices: CrfIndices, semiring: Type[Semiring]) -> Tensor: - head, last, prev, curr, token_sizes, unsorted_indices, _ = indices - - transitions, head_transitions, last_transitions = transitions - c = torch.arange(transitions.size()[1], device=emissions.device) - - emissions = emissions[curr[:, None], c[None, :], targets[curr]] - transitions = transitions[curr[:, None], c[None, :], targets[prev], targets[curr]] - transitions[accumulate_sizes(sizes=token_sizes)] = semiring.one - head_transitions = head_transitions[unsorted_indices[:, None], c[None, :], targets[head]] - last_transitions = last_transitions[unsorted_indices[:, None], c[None, :], targets[last]] - - emissions = semiring.segment_prod(semiring.mul(emissions, transitions), sizes=token_sizes) - return semiring.mul(emissions, semiring.mul(head_transitions, last_transitions)) - - def crf_partition(emissions: Tensor, transitions: Tuple[Tensor, Tensor, Tensor], indices: CrfIndices, semiring: Type[Semiring]): head, _, _, _, _, unsorted_indices, indices = indices @@ -178,11 +173,10 @@ def __init__(self, emissions: Tensor, transitions: Tuple[Tensor, Tensor, Tensor] self.transitions = transitions def log_scores(self, targets: Sequence) -> Tensor: - return crf_reduce( + return crf_scores( emissions=self.emissions, - targets=targets.data, + sequence=targets, transitions=self.transitions, - indices=self.indices, semiring=Log, ) From 50b2a871244729cf52dda168ca613d113c894704 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 7 Apr 2022 20:40:43 +0900 Subject: [PATCH 042/102] Refactor: Add test_crf_catted_fit and test_crf_packed_fit --- tests/test_crf.py | 99 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 2 deletions(-) diff --git a/tests/test_crf.py b/tests/test_crf.py index 57251cf..f5c3a68 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -2,12 +2,57 @@ import torchcrf from hypothesis import given from torch.testing import assert_close -from torchrua import cat_sequence, pad_catted_indices, pack_catted_indices -from torchrua import pad_sequence, pad_packed_indices, pack_sequence from tests.strategies import device, sizes, TOKEN_SIZE, TINY_BATCH_SIZE, NUM_CONJUGATES, TINY_TOKEN_SIZE from tests.utils import assert_grad_close, assert_equal from torchlatent.crf import CrfDecoder +from torchrua import cat_sequence, pad_catted_indices, pack_catted_indices +from torchrua import pad_sequence, pad_packed_indices, pack_sequence + + +@given( + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + num_tags=sizes(TOKEN_SIZE), +) +def test_crf_catted_scores(token_sizes, num_tags): + actual_decoder = CrfDecoder(num_tags) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + + excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) + excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) + excepted_decoder.end_transitions.data = torch.randn_like(excepted_decoder.end_transitions) + + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] + actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] + actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] + + emissions = [ + torch.randn((token_size, num_tags), requires_grad=True, device=device) + for token_size in token_sizes + ] + targets = [ + torch.randint(0, num_tags, (token_size,), device=device) + for token_size in token_sizes + ] + + catted_emissions = cat_sequence([x[:, None] for x in emissions]) + catted_targets = cat_sequence([x[:, None] for x in targets]) + + padded_emissions, _ = pad_sequence(emissions, batch_first=False) + padded_targets, _ = pad_sequence(targets, batch_first=False) + + size, ptr = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) + mask = torch.zeros(size, dtype=torch.bool, device=device) + mask[ptr] = True + + actual = actual_decoder.forward(emissions=catted_emissions).log_scores(targets=catted_targets)[:, 0] + excepted = excepted_decoder._compute_score( + emissions=padded_emissions, tags=padded_targets.long(), + mask=mask.byte(), + ) + + assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) + assert_grad_close(actual=actual, expected=excepted, inputs=emissions, rtol=1e-4, atol=1e-4) @given( @@ -93,6 +138,56 @@ def test_crf_catted_decode(token_sizes, num_tags): assert_equal(actual=actual_token_sizes, expected=excepted_token_sizes) +@given( + token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), + num_tags=sizes(TOKEN_SIZE), +) +def test_crf_packed_scores(token_sizes, num_tags): + actual_decoder = CrfDecoder(num_tags) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + + excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) + excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) + excepted_decoder.end_transitions.data = torch.randn_like(excepted_decoder.end_transitions) + + actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] + actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] + actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] + + emissions = [ + torch.randn((token_size, num_tags), requires_grad=True, device=device) + for token_size in token_sizes + ] + targets = [ + torch.randint(0, num_tags, (token_size,), device=device) + for token_size in token_sizes + ] + + packed_emissions = pack_sequence([x[:, None] for x in emissions]) + packed_targets = pack_sequence([x[:, None] for x in targets]) + + padded_emissions, _ = pad_sequence(emissions, batch_first=False) + padded_targets, _ = pad_sequence(targets, batch_first=False) + + size, ptr, _ = pad_packed_indices( + batch_sizes=packed_emissions.batch_sizes, + sorted_indices=packed_emissions.sorted_indices, + unsorted_indices=packed_emissions.unsorted_indices, + batch_first=False, + ) + mask = torch.zeros(size, dtype=torch.bool, device=device) + mask[ptr] = True + + actual = actual_decoder.forward(emissions=packed_emissions).log_scores(targets=packed_targets)[:, 0] + excepted = excepted_decoder._compute_score( + emissions=padded_emissions, tags=padded_targets.long(), + mask=mask.byte(), + ) + + assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) + assert_grad_close(actual=actual, expected=excepted, inputs=emissions, rtol=1e-4, atol=1e-4) + + @given( token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), num_tags=sizes(TOKEN_SIZE), From bf5cd3646d76cc7bf3c52030ea6027a9d0cb1270 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 7 Apr 2022 20:50:00 +0900 Subject: [PATCH 043/102] Refactor: Simplify crf_scores_catted_indices --- torchlatent/crf.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 6b0a5dd..cf71255 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -8,6 +8,7 @@ from torch import nn from torch.distributions.utils import lazy_property from torch.nn import init +from torch.nn import functional as F from torch.types import Device from torchlatent.abc import DistributionABC @@ -76,13 +77,12 @@ def crf_scores_catted_indices(sequence: CattedSequence, device: Device = None): device = sequence.data.device token_sizes = sequence.token_sizes.to(device=device) - curr = torch.arange(token_sizes.sum().item(), device=device) + acc_token_sizes = token_sizes.cumsum(dim=0) + + index = torch.arange(token_sizes.sum().item(), device=device) unsorted_indices = torch.arange(token_sizes.size()[0], device=device) - prev = roll_catted_indices(token_sizes=token_sizes, device=device, shifts=1) - head = head_catted_indices(token_sizes=token_sizes, device=device) - last = last_catted_indices(token_sizes=token_sizes, device=device) - return head, last, prev, curr, token_sizes, unsorted_indices + return F.pad(acc_token_sizes, [1, -1]), acc_token_sizes - 1, index - 1, index, token_sizes, unsorted_indices @crf_scores_indices.register From 9aa744c6856b62091f3cdb553bf9ee727ca2c7b0 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 7 Apr 2022 21:05:10 +0900 Subject: [PATCH 044/102] Refactor: Simplify crf_scores_packed_indices --- torchlatent/crf.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/torchlatent/crf.py b/torchlatent/crf.py index cf71255..877b7af 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -7,16 +7,16 @@ from torch import Tensor from torch import nn from torch.distributions.utils import lazy_property -from torch.nn import init from torch.nn import functional as F +from torch.nn import init from torch.types import Device from torchlatent.abc import DistributionABC from torchlatent.semiring import Semiring, Log, Max -from torchrua import ReductionIndices, accumulate_sizes -from torchrua import head_catted_indices, last_catted_indices, reduce_catted_indices -from torchrua import head_packed_indices, last_packed_indices, reduce_packed_indices -from torchrua import roll_catted_indices, cat_packed_indices, CattedSequence, PackedSequence +from torchrua import CattedSequence, PackedSequence +from torchrua import ReductionIndices, accumulate_sizes, minor_sizes_to_ptr +from torchrua import reduce_catted_indices +from torchrua import reduce_packed_indices Sequence = Union[CattedSequence, PackedSequence] @@ -92,27 +92,30 @@ def crf_scores_packed_indices(sequence: PackedSequence, device: Device = None): batch_sizes = sequence.batch_sizes.to(device=device) unsorted_indices = sequence.unsorted_indices.to(device=device) - curr, token_sizes = cat_packed_indices(batch_sizes=batch_sizes, unsorted_indices=unsorted_indices, device=device) + acc_batch_sizes = F.pad(batch_sizes.cumsum(dim=0), [2, -1]) + + batch_ptr, token_ptr, token_sizes = minor_sizes_to_ptr( + token_sizes=batch_sizes, token_ptr=unsorted_indices, + ) + prev = acc_batch_sizes[token_ptr + 0] + batch_ptr + curr = acc_batch_sizes[token_ptr + 1] + batch_ptr + last = acc_batch_sizes[token_sizes] + unsorted_indices - prev = roll_catted_indices(token_sizes=token_sizes, device=device, shifts=1) - head = head_packed_indices(batch_sizes=batch_sizes, device=device, unsorted_indices=unsorted_indices) - last = last_packed_indices(batch_sizes=batch_sizes, device=device, unsorted_indices=unsorted_indices) - return head, last, curr[prev], curr, token_sizes, unsorted_indices + return unsorted_indices, last, prev, curr, token_sizes, unsorted_indices -def crf_scores(sequence: Sequence, emissions: Tensor, - transitions: Tuple[Tensor, Tensor, Tensor], semiring: Type[Semiring]) -> Tensor: +def crf_scores(sequence: Sequence, emissions: Tensor, transitions: Tuple[Tensor, Tensor, Tensor], + semiring: Type[Semiring]) -> Tensor: head, last, prev, curr, token_sizes, unsorted_indices = crf_scores_indices(sequence) - sequence, *_ = sequence transitions, head_transitions, last_transitions = transitions c = torch.arange(transitions.size()[1], device=emissions.device) - emissions = emissions[curr[:, None], c[None, :], sequence[curr]] - transitions = transitions[curr[:, None], c[None, :], sequence[prev], sequence[curr]] + emissions = emissions[curr[:, None], c[None, :], sequence.data[curr]] + transitions = transitions[curr[:, None], c[None, :], sequence.data[prev], sequence.data[curr]] transitions[accumulate_sizes(sizes=token_sizes)] = semiring.one - head_transitions = head_transitions[unsorted_indices[:, None], c[None, :], sequence[head]] - last_transitions = last_transitions[unsorted_indices[:, None], c[None, :], sequence[last]] + head_transitions = head_transitions[unsorted_indices[:, None], c[None, :], sequence.data[head]] + last_transitions = last_transitions[unsorted_indices[:, None], c[None, :], sequence.data[last]] emissions = semiring.segment_prod(semiring.mul(emissions, transitions), sizes=token_sizes) return semiring.mul(emissions, semiring.mul(head_transitions, last_transitions)) From 605bffd15c2047c541e6f348db1fb74c56755020 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 17 Apr 2022 15:21:22 +0900 Subject: [PATCH 045/102] Refactor: Unify tensor0 and tensor1 member allocation --- torchlatent/cky.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index ee09f1c..df60a4f 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -92,12 +92,12 @@ def cky_partition_indices(token_sizes: Tensor, device: Device = None): return CkyIndices( token_size=token_size, cache_size=cache_size, - src=((y_ptr - x_ptr, z_ptr), (batch_ptr[z_ptr], x_ptr, y_ptr)), + src=((y_ptr - x_ptr, x_ptr + z_ptr - y_ptr), (batch_ptr[z_ptr], x_ptr, y_ptr)), tgt=(token_sizes - 1, acc_token_sizes), ) -def cky_partition(data: Tensor, indices: CkyIndices, semiring: Type[Semiring]) -> Tensor: +def cky_partition(data: Tensor, indices: CkyIndices, *, semiring: Type[Semiring]) -> Tensor: token_size, cache_size, (src1, src2), tgt = indices size = (token_size, cache_size, *data.size()[3:]) @@ -111,7 +111,7 @@ def cky_partition(data: Tensor, indices: CkyIndices, semiring: Type[Semiring]) - for w in range(1, token_size): tensor1[w, :-w] = tensor2[-w - 1, w:] = semiring.mul( semiring.sum(semiring.mul(tensor1[:w, :-w], tensor2[-w:, w:]), dim=0), - tensor0[w, w:], + tensor0[w, :-w], ) return tensor1[tgt] @@ -205,8 +205,9 @@ def decode(self, sequence: Sequence, indices: CkyIndices = None) -> Sequence: return CattedSequence(data=dist.argmax, token_sizes=sequence.token_sizes * 2 - 1) if isinstance(sequence, PackedSequence): - token_sizes = transpose_sizes(sizes=sequence.batch_sizes)[sequence.unsorted_indices] * 2 - 1 - return pack_catted_sequence(sequence=CattedSequence(dist.argmax, token_sizes=token_sizes)) + token_sizes = transpose_sizes(sizes=sequence.batch_sizes) + token_sizes = token_sizes[sequence.unsorted_indices] * 2 - 1 + return pack_catted_sequence(sequence=dist.argmax, token_sizes=token_sizes) raise KeyError(f'type {type(sequence)} is not supported') From 7e9fd44074894d26fe27350807b38b0239d08316 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 17 Apr 2022 19:45:31 +0900 Subject: [PATCH 046/102] Feat: Add ExceptionSemiring --- torchlatent/semiring.py | 84 +++++++++++++++++++++++++++++++++++------ 1 file changed, 72 insertions(+), 12 deletions(-) diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 5c8c708..6f0b2fb 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,13 +1,13 @@ import torch from torch import Tensor +from torchrua.reduction import reduce_sequence, ReductionIndices from torchlatent.functional import logsumexp, logaddexp -from torchrua.reduction import reduce_sequence, ReductionIndices -from torchrua.scatter import scatter_add, scatter_logsumexp __all__ = [ - 'Semiring', - 'Std', 'Log', 'Max', + 'Semiring', 'ExceptionSemiring', + 'Std', 'Log', 'Max', 'Xen', 'Div', + ] @@ -106,14 +106,6 @@ def sum(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: return torch.sum(tensor, dim=dim, keepdim=keepdim) - @classmethod - def scatter_add(cls, tensor: Tensor, index: Tensor) -> Tensor: - return scatter_logsumexp(tensor=tensor, index=index) - - @classmethod - def scatter_mul(cls, tensor: Tensor, index: Tensor) -> Tensor: - return scatter_add(tensor=tensor, index=index) - @classmethod def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor: m = torch.segment_reduce(tensor, reduce='max', lengths=sizes, unsafe=True).detach() @@ -152,3 +144,71 @@ def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor: @classmethod def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) + + +class ExceptionSemiring(Semiring): + @classmethod + def sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, dim: int, keepdim: bool = False) -> Tensor: + raise NotImplementedError + + @classmethod + def segment_sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, sizes: Tensor) -> Tensor: + raise NotImplementedError + + +class Xen(ExceptionSemiring): + zero = 0. + one = 0. + + @classmethod + def add(cls, x: Tensor, y: Tensor) -> Tensor: + raise NotImplementedError + + @classmethod + def mul(cls, x: Tensor, y: Tensor) -> Tensor: + return x + y + + @classmethod + def sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return torch.sum((tensor - log_q) * log_p.exp(), dim=dim, keepdim=keepdim) + + @classmethod + def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return torch.sum(tensor, dim=dim, keepdim=keepdim) + + @classmethod + def segment_sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, sizes: Tensor) -> Tensor: + return torch.segment_reduce((tensor - log_q) * log_p.exp(), reduce='sum', lengths=sizes, unsafe=True) + + @classmethod + def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) + + +class Div(ExceptionSemiring): + zero = 0. + one = 0. + + @classmethod + def add(cls, x: Tensor, y: Tensor) -> Tensor: + raise NotImplementedError + + @classmethod + def mul(cls, x: Tensor, y: Tensor) -> Tensor: + return x + y + + @classmethod + def sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return torch.sum((tensor - log_q + log_p) * log_p.exp(), dim=dim, keepdim=keepdim) + + @classmethod + def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: + return torch.sum(tensor, dim=dim, keepdim=keepdim) + + @classmethod + def segment_sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, sizes: Tensor) -> Tensor: + return torch.segment_reduce((tensor - log_q + log_p) * log_p.exp(), reduce='sum', lengths=sizes, unsafe=True) + + @classmethod + def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: + return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) From c4e0d92b56052ff62b47168e59b1c788ee279f3d Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Fri, 27 May 2022 21:27:26 +0900 Subject: [PATCH 047/102] Refactor: Rename to assertion.py --- tests/assertion.py | 43 +++++++++++++++++++ tests/strategies.py | 31 -------------- tests/strategy.py | 39 +++++++++++++++++ tests/test_cky.py | 5 +-- tests/test_crf.py | 21 +++++---- tests/test_functional.py | 4 +- tests/utils.py | 92 ---------------------------------------- torchlatent/cky.py | 2 +- 8 files changed, 97 insertions(+), 140 deletions(-) create mode 100644 tests/assertion.py delete mode 100644 tests/strategies.py create mode 100644 tests/strategy.py delete mode 100644 tests/utils.py diff --git a/tests/assertion.py b/tests/assertion.py new file mode 100644 index 0000000..f5d8664 --- /dev/null +++ b/tests/assertion.py @@ -0,0 +1,43 @@ +import torch +from torch import Tensor +from torch.nn.utils.rnn import PackedSequence +from torch.testing import assert_close + +from torchrua.catting import CattedSequence + +__all__ = [ + 'assert_close', + 'assert_grad_close', + 'assert_catted_sequence_close', + 'assert_packed_sequence_close', +] + + +def assert_grad_close(actual: Tensor, expected: Tensor, inputs, **kwargs) -> None: + grad = torch.rand_like(actual) + + actual_grads = torch.autograd.grad(actual, inputs, grad, retain_graph=True, allow_unused=False) + expected_grads = torch.autograd.grad(expected, inputs, grad, retain_graph=True, allow_unused=False) + + for actual_grad, expected_grad in zip(actual_grads, expected_grads): + assert_close(actual=actual_grad, expected=expected_grad, **kwargs) + + +def assert_catted_sequence_close(actual: CattedSequence, expected: CattedSequence, **kwargs) -> None: + assert_close(actual=actual.data, expected=expected.data, **kwargs) + assert_close(actual=actual.token_sizes, expected=expected.token_sizes, **kwargs) + + +def assert_packed_sequence_close(actual: PackedSequence, expected: PackedSequence, **kwargs) -> None: + assert_close(actual=actual.data, expected=expected.data, **kwargs) + assert_close(actual=actual.batch_sizes, expected=expected.batch_sizes, **kwargs) + + if actual.sorted_indices is None: + assert expected.sorted_indices is None + else: + assert_close(actual=actual.sorted_indices, expected=expected.sorted_indices, **kwargs) + + if actual.unsorted_indices is None: + assert expected.unsorted_indices is None + else: + assert_close(actual=actual.unsorted_indices, expected=expected.unsorted_indices, **kwargs) diff --git a/tests/strategies.py b/tests/strategies.py deleted file mode 100644 index 5386d4d..0000000 --- a/tests/strategies.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch - -from hypothesis import strategies as st - -BATCH_SIZE = 25 -TOKEN_SIZE = 50 -NUM_CONJUGATES = 5 -NUM_TAGS = 15 -EMBEDDING_DIM = 16 - -TINY_BATCH_SIZE = 5 -TINY_TOKEN_SIZE = 10 -TINY_NUM_CONJUGATES = 3 -TINY_NUM_TAGS = 3 -TINY_EMBEDDING_DIM = 4 - -if torch.cuda.is_available(): - device = torch.device('cuda:0') -else: - device = torch.device('cpu') -_ = torch.empty((1,), device=device) - - -@st.composite -def sizes(draw, *max_sizes: int, min_size: int = 1): - max_size, *max_sizes = max_sizes - n = draw(st.integers(min_value=min_size, max_value=max_size)) - - if len(max_sizes) == 0: - return n - return [draw(sizes(*max_sizes, min_size=min_size)) for _ in range(n)] diff --git a/tests/strategy.py b/tests/strategy.py new file mode 100644 index 0000000..19b90fc --- /dev/null +++ b/tests/strategy.py @@ -0,0 +1,39 @@ +import torch +from hypothesis import strategies as st + +TINY_BATCH_SIZE = 5 +TINY_TOKEN_SIZE = 11 +TINY_EMBEDDING_DIM = 13 +NUM_CONJUGATES = 5 +NUM_TAGS = 7 + +if torch.cuda.is_available(): + BATCH_SIZE = 53 + TOKEN_SIZE = 83 + EMBEDDING_DIM = 107 + NUM_CONJUGATES = 5 + NUM_TAGS = 17 +else: + BATCH_SIZE = 37 + TOKEN_SIZE = 53 + EMBEDDING_DIM = 61 + NUM_CONJUGATES = 5 + NUM_TAGS = 17 + +if torch.cuda.is_available(): + device = torch.device('cuda:0') +else: + device = torch.device('cpu') + +torch.empty((1,), device=device) + + +@st.composite +def sizes(draw, *size: int, min_size: int = 1): + max_size, *size = size + n = draw(st.integers(min_value=min_size, max_value=max_size)) + + if len(size) == 0: + return n + else: + return draw(st.lists(sizes(*size), min_size=n, max_size=n)) diff --git a/tests/test_cky.py b/tests/test_cky.py index 684bb26..8210e06 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -1,11 +1,10 @@ import torch from hypothesis import given, strategies as st -from torch.testing import assert_close from torch_struct import TreeCRF from torchrua import pack_sequence, cat_sequence -from tests.strategies import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, device, TINY_BATCH_SIZE -from tests.utils import assert_grad_close +from tests.assertion import assert_close, assert_grad_close +from tests.strategy import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, device, TINY_BATCH_SIZE from torchlatent.cky import CkyDistribution, cky_partition_indices, CkyDecoder diff --git a/tests/test_crf.py b/tests/test_crf.py index f5c3a68..97b48e2 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -1,14 +1,13 @@ import torch import torchcrf from hypothesis import given -from torch.testing import assert_close - -from tests.strategies import device, sizes, TOKEN_SIZE, TINY_BATCH_SIZE, NUM_CONJUGATES, TINY_TOKEN_SIZE -from tests.utils import assert_grad_close, assert_equal -from torchlatent.crf import CrfDecoder from torchrua import cat_sequence, pad_catted_indices, pack_catted_indices from torchrua import pad_sequence, pad_packed_indices, pack_sequence +from tests.assertion import assert_close, assert_grad_close +from tests.strategy import device, sizes, TOKEN_SIZE, TINY_BATCH_SIZE, NUM_CONJUGATES, TINY_TOKEN_SIZE +from torchlatent.crf import CrfDecoder + @given( token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), @@ -134,8 +133,8 @@ def test_crf_catted_decode(token_sizes, num_tags): excepted = excepted_decoder.decode(emissions=padded_emissions, mask=mask.byte()) excepted, excepted_token_sizes = cat_sequence([torch.tensor(x, device=device) for x in excepted]) - assert_equal(actual=actual[:, 0], expected=excepted) - assert_equal(actual=actual_token_sizes, expected=excepted_token_sizes) + assert_close(actual=actual[:, 0], expected=excepted) + assert_close(actual=actual_token_sizes, expected=excepted_token_sizes) @given( @@ -277,10 +276,10 @@ def test_crf_packed_decode(token_sizes, num_tags): excepted = excepted_decoder.decode(emissions=padded_emissions, mask=mask.byte()) excepted = pack_sequence([torch.tensor(x, device=device) for x in excepted]) - assert_equal(actual=actual.data[:, 0], expected=excepted.data) - assert_equal(actual=actual.batch_sizes, expected=excepted.batch_sizes) - assert_equal(actual=actual.sorted_indices, expected=excepted.sorted_indices) - assert_equal(actual=actual.unsorted_indices, expected=excepted.unsorted_indices) + assert_close(actual=actual.data[:, 0], expected=excepted.data) + assert_close(actual=actual.batch_sizes, expected=excepted.batch_sizes) + assert_close(actual=actual.sorted_indices, expected=excepted.sorted_indices) + assert_close(actual=actual.unsorted_indices, expected=excepted.unsorted_indices) @given( diff --git a/tests/test_functional.py b/tests/test_functional.py index abad128..5d9ce3a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,8 +1,8 @@ import torch from hypothesis import given, strategies as st -from tests.strategies import device, sizes, TINY_TOKEN_SIZE, TINY_BATCH_SIZE -from tests.utils import assert_close, assert_grad_close +from tests.assertion import assert_close, assert_grad_close +from tests.strategy import device, sizes, TINY_TOKEN_SIZE, TINY_BATCH_SIZE from torchlatent.functional import logaddexp, logsumexp diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index a369ec9..0000000 --- a/tests/utils.py +++ /dev/null @@ -1,92 +0,0 @@ -from typing import List, Tuple, Union - -import torch -from torch import Tensor -from torch.nn.utils.rnn import PackedSequence -from torch.testing import assert_close -from torchrua.catting import CattedSequence - -__all__ = [ - 'assert_equal', 'assert_close', 'assert_grad_close', - 'assert_catted_sequence_equal', 'assert_catted_sequence_close', - 'assert_packed_sequence_equal', 'assert_packed_sequence_close', -] - - -def assert_equal(actual: Tensor, expected: Tensor) -> None: - assert torch.equal(actual, expected) - - -def assert_grad_close( - actual: Tensor, expected: Tensor, - inputs: Union[Tensor, List[Tensor], Tuple[Tensor, ...]], - allow_unused: bool = False, - check_device: bool = True, check_dtype: bool = True, check_stride: bool = True, **kwargs) -> None: - kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride, **kwargs) - - grad = torch.rand_like(actual) - - actual_grads = torch.autograd.grad( - actual, inputs, grad, - create_graph=False, - allow_unused=allow_unused, - retain_graph=True, - ) - - expected_grads = torch.autograd.grad( - expected, inputs, grad, - create_graph=False, - allow_unused=allow_unused, - retain_graph=True, - ) - - for actual_grad, expected_grad in zip(actual_grads, expected_grads): - assert_close(actual=actual_grad, expected=expected_grad, **kwargs) - - -def assert_catted_sequence_close( - actual: CattedSequence, expected: CattedSequence, - check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None: - kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride) - - assert_close(actual=actual.data, expected=expected.data, **kwargs) - assert_equal(actual=actual.token_sizes, expected=expected.token_sizes) - - -def assert_catted_sequence_equal(actual: CattedSequence, expected: CattedSequence) -> None: - assert_equal(actual=actual.data, expected=expected.data) - assert_equal(actual=actual.token_sizes, expected=expected.token_sizes) - - -def assert_packed_sequence_close( - actual: PackedSequence, expected: PackedSequence, - check_device: bool = True, check_dtype: bool = True, check_stride: bool = True) -> None: - kwargs = dict(check_device=check_device, check_dtype=check_dtype, check_stride=check_stride) - - assert_close(actual=actual.data, expected=expected.data, **kwargs) - assert_equal(actual=actual.batch_sizes, expected=expected.batch_sizes) - - if actual.sorted_indices is None: - assert expected.sorted_indices is None - else: - assert_equal(actual=actual.sorted_indices, expected=expected.sorted_indices) - - if actual.unsorted_indices is None: - assert expected.unsorted_indices is None - else: - assert_equal(actual=actual.unsorted_indices, expected=expected.unsorted_indices) - - -def assert_packed_sequence_equal(actual: PackedSequence, expected: PackedSequence) -> None: - assert_equal(actual=actual.data, expected=expected.data) - assert_equal(actual=actual.batch_sizes, expected=expected.batch_sizes) - - if actual.sorted_indices is None: - assert expected.sorted_indices is None - else: - assert_equal(actual=actual.sorted_indices, expected=expected.sorted_indices) - - if actual.unsorted_indices is None: - assert expected.unsorted_indices is None - else: - assert_equal(actual=actual.unsorted_indices, expected=expected.unsorted_indices) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index df60a4f..d7d37e6 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -207,7 +207,7 @@ def decode(self, sequence: Sequence, indices: CkyIndices = None) -> Sequence: if isinstance(sequence, PackedSequence): token_sizes = transpose_sizes(sizes=sequence.batch_sizes) token_sizes = token_sizes[sequence.unsorted_indices] * 2 - 1 - return pack_catted_sequence(sequence=dist.argmax, token_sizes=token_sizes) + return pack_catted_sequence(CattedSequence(data=dist.argmax, token_sizes=token_sizes)) raise KeyError(f'type {type(sequence)} is not supported') From 4c488c3f5b3df2dc3843b9f5a2582ab01661dce0 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 13 Jun 2022 19:47:00 +0900 Subject: [PATCH 048/102] Fix: resolve unit test bugs --- tests/test_cky.py | 4 ++-- tests/test_crf.py | 42 +++++++++++++++++++++--------------------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index 8210e06..821222c 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -20,7 +20,7 @@ def test_cky_catted_max(token_sizes, embedding_dim, num_tags, bias): for token_size in token_sizes ]) - decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias) + decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias).to(device=device) cky = decoder.forward(sequence=sequence) assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) @@ -38,7 +38,7 @@ def test_cky_packed_max(token_sizes, embedding_dim, num_tags, bias): for token_size in token_sizes ]) - decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias) + decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias).to(device=device) cky = decoder.forward(sequence=sequence) assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) diff --git a/tests/test_crf.py b/tests/test_crf.py index 97b48e2..63b2249 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -14,8 +14,8 @@ num_tags=sizes(TOKEN_SIZE), ) def test_crf_catted_scores(token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + actual_decoder = CrfDecoder(num_tags).to(device=device) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) @@ -59,8 +59,8 @@ def test_crf_catted_scores(token_sizes, num_tags): num_tags=sizes(TOKEN_SIZE), ) def test_crf_catted_fit(token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + actual_decoder = CrfDecoder(num_tags).to(device=device) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) @@ -104,8 +104,8 @@ def test_crf_catted_fit(token_sizes, num_tags): num_tags=sizes(TOKEN_SIZE), ) def test_crf_catted_decode(token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + actual_decoder = CrfDecoder(num_tags).to(device=device) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) @@ -142,8 +142,8 @@ def test_crf_catted_decode(token_sizes, num_tags): num_tags=sizes(TOKEN_SIZE), ) def test_crf_packed_scores(token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + actual_decoder = CrfDecoder(num_tags).to(device=device) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) @@ -192,8 +192,8 @@ def test_crf_packed_scores(token_sizes, num_tags): num_tags=sizes(TOKEN_SIZE), ) def test_crf_packed_fit(token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + actual_decoder = CrfDecoder(num_tags).to(device=device) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) @@ -242,8 +242,8 @@ def test_crf_packed_fit(token_sizes, num_tags): num_tags=sizes(TOKEN_SIZE), ) def test_crf_packed_decode(token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False) + actual_decoder = CrfDecoder(num_tags).to(device=device) + excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) @@ -288,8 +288,8 @@ def test_crf_packed_decode(token_sizes, num_tags): num_tags=sizes(TINY_TOKEN_SIZE), ) def test_conjugated_catted_fit(token_sizes, num_conjugates, num_tags): - decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) - decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1) for _ in range(num_conjugates)] + decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) + decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1).to(device=device) for _ in range(num_conjugates)] for index in range(num_conjugates): decoders[index].transitions.data = torch.randn_like(decoders[index].transitions) @@ -337,8 +337,8 @@ def test_conjugated_catted_fit(token_sizes, num_conjugates, num_tags): num_tags=sizes(TINY_TOKEN_SIZE), ) def test_conjugated_packed_fit(token_sizes, num_conjugates, num_tags): - decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) - decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1) for _ in range(num_conjugates)] + decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) + decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1).to(device=device) for _ in range(num_conjugates)] for index in range(num_conjugates): decoders[index].transitions.data = torch.randn_like(decoders[index].transitions) @@ -386,8 +386,8 @@ def test_conjugated_packed_fit(token_sizes, num_conjugates, num_tags): num_tags=sizes(TINY_TOKEN_SIZE), ) def test_dynamic_fit(token_sizes, num_conjugates, num_tags): - packed_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) - catted_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) + packed_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) + catted_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) emissions = [ torch.randn((token_size, 1, num_tags), requires_grad=True, device=device) @@ -399,9 +399,9 @@ def test_dynamic_fit(token_sizes, num_conjugates, num_tags): for token_size in token_sizes ] - catted_decoder.transitions.data = torch.randn((sum(token_sizes), num_conjugates, num_tags, num_tags)) - catted_decoder.head_transitions.data = torch.randn((len(token_sizes), num_conjugates, num_tags)) - catted_decoder.last_transitions.data = torch.randn((len(token_sizes), num_conjugates, num_tags)) + catted_decoder.transitions.data = torch.randn((sum(token_sizes), num_conjugates, num_tags, num_tags), device=device) + catted_decoder.head_transitions.data = torch.randn((len(token_sizes), num_conjugates, num_tags), device=device) + catted_decoder.last_transitions.data = torch.randn((len(token_sizes), num_conjugates, num_tags), device=device) token_sizes = torch.tensor(token_sizes, device=device) indices, _, sorted_indices, _ = pack_catted_indices(token_sizes=token_sizes, device=device) From 202d252a911bc0a5c3f9c0232d26f5e29d187407 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 28 Jan 2023 15:29:45 +0900 Subject: [PATCH 049/102] Feat: Add Classifier --- torchlatent/nn/__init__.py | 0 torchlatent/nn/classifier.py | 41 ++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 torchlatent/nn/__init__.py create mode 100644 torchlatent/nn/classifier.py diff --git a/torchlatent/nn/__init__.py b/torchlatent/nn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchlatent/nn/classifier.py b/torchlatent/nn/classifier.py new file mode 100644 index 0000000..4f4708b --- /dev/null +++ b/torchlatent/nn/classifier.py @@ -0,0 +1,41 @@ +import torch +from torch import Tensor +from torch import nn +from torch.nn import init + + +class Classifier(nn.Module): + def __init__(self, bias: bool = False, *, in_features: int, out_features: int, num_conjugates: int) -> None: + super(Classifier, self).__init__() + + self.in_features = in_features + self.out_features = out_features + self.num_conjugates = num_conjugates + + self.weight = nn.Parameter(torch.empty((num_conjugates, out_features, in_features))) + self.bias = nn.Parameter(torch.empty((num_conjugates, out_features,))) if bias else None + + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self) -> None: + bound = (6.0 / self.in_features) ** 0.5 + init.uniform_(self.weight, a=-bound, b=+bound) + + if self.bias is not None: + init.zeros_(self.bias) + + def extra_repr(self) -> str: + return ', '.join([ + f'in_features={self.in_features}', + f'out_features={self.out_features}', + f'num_conjugates={self.num_conjugates}', + f'bias={self.bias is not None}', + ]) + + def forward(self, tensor: Tensor) -> Tensor: + tensor = torch.einsum('...cx,cyx->...cy', tensor, self.weight) + if self.bias is not None: + tensor = tensor + self.bias + + return tensor From d4f107088623132c2493a764aa70048f232cff7e Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 28 Jan 2023 15:43:54 +0900 Subject: [PATCH 050/102] Feat: Add BiaffineClassifier --- torchlatent/nn/classifier.py | 59 ++++++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 10 deletions(-) diff --git a/torchlatent/nn/classifier.py b/torchlatent/nn/classifier.py index 4f4708b..d2b1888 100644 --- a/torchlatent/nn/classifier.py +++ b/torchlatent/nn/classifier.py @@ -5,7 +5,8 @@ class Classifier(nn.Module): - def __init__(self, bias: bool = False, *, in_features: int, out_features: int, num_conjugates: int) -> None: + def __init__(self, bias: bool = False, *, num_conjugates: int, + in_features: int, out_features: int) -> None: super(Classifier, self).__init__() self.in_features = in_features @@ -13,16 +14,15 @@ def __init__(self, bias: bool = False, *, in_features: int, out_features: int, n self.num_conjugates = num_conjugates self.weight = nn.Parameter(torch.empty((num_conjugates, out_features, in_features))) - self.bias = nn.Parameter(torch.empty((num_conjugates, out_features,))) if bias else None + self.bias = nn.Parameter(torch.empty((num_conjugates, out_features,))) if bias else 0 self.reset_parameters() @torch.no_grad() def reset_parameters(self) -> None: - bound = (6.0 / self.in_features) ** 0.5 - init.uniform_(self.weight, a=-bound, b=+bound) + init.zeros_(self.weight) - if self.bias is not None: + if torch.is_tensor(self.bias): init.zeros_(self.bias) def extra_repr(self) -> str: @@ -30,12 +30,51 @@ def extra_repr(self) -> str: f'in_features={self.in_features}', f'out_features={self.out_features}', f'num_conjugates={self.num_conjugates}', - f'bias={self.bias is not None}', + f'bias={torch.is_tensor(self.bias)}', ]) def forward(self, tensor: Tensor) -> Tensor: - tensor = torch.einsum('...cx,cyx->...cy', tensor, self.weight) - if self.bias is not None: - tensor = tensor + self.bias + return torch.einsum('cox,...cx->...co', self.weight, tensor) + self.bias - return tensor + +class BiaffineClassifier(nn.Module): + def __init__(self, bias: bool = False, *, num_conjugates: int, + in_features1: int, in_features2: int, out_features: int) -> None: + super(BiaffineClassifier, self).__init__() + + self.in_features1 = in_features1 + self.in_features2 = in_features2 + self.out_features = out_features + self.num_conjugates = num_conjugates + + self.weight0 = nn.Parameter(torch.empty((num_conjugates, out_features, in_features1, in_features2))) + self.weight1 = nn.Parameter(torch.empty((num_conjugates, out_features, in_features1))) + self.weight2 = nn.Parameter(torch.empty((num_conjugates, out_features, in_features2))) + self.bias = nn.Parameter(torch.empty((num_conjugates, out_features))) if bias else 0 + + self.reset_parameters() + + @torch.no_grad() + def reset_parameters(self) -> None: + init.zeros_(self.weight0) + init.zeros_(self.weight1) + init.zeros_(self.weight2) + + if torch.is_tensor(self.bias): + init.zeros_(self.bias) + + def extra_repr(self) -> str: + return ', '.join([ + f'in_features1={self.in_features1}', + f'in_features2={self.in_features2}', + f'out_features={self.out_features}', + f'num_conjugates={self.num_conjugates}', + f'bias={torch.is_tensor(self.bias)}', + ]) + + def forward(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: + tensor0 = torch.einsum('coxy,...cx,...cy->...co', self.weight0, tensor1, tensor2) + tensor1 = torch.einsum('cox,...cx->...co', self.weight1, tensor1) + tensor2 = torch.einsum('coy,...cy->...co', self.weight2, tensor2) + + return tensor0 + tensor1 + tensor2 + self.bias From 3b647e47501b5c26fa1a27161d20f99d24ac7a1b Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 28 Jan 2023 15:58:09 +0900 Subject: [PATCH 051/102] Feat: Update marginals --- torchlatent/abc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchlatent/abc.py b/torchlatent/abc.py index 7cfc891..fd711bb 100644 --- a/torchlatent/abc.py +++ b/torchlatent/abc.py @@ -34,7 +34,7 @@ def max(self) -> Tensor: def argmax(self) -> Tensor: grad, = torch.autograd.grad( self.max, self.emissions, torch.ones_like(self.max), - create_graph=False, only_inputs=True, allow_unused=False, + create_graph=False, only_inputs=True, allow_unused=True, ) return grad @@ -42,7 +42,8 @@ def argmax(self) -> Tensor: def marginals(self) -> Tensor: grad, = torch.autograd.grad( self.log_partitions, self.emissions, torch.ones_like(self.log_partitions), - create_graph=False, only_inputs=True, allow_unused=False, + create_graph=True, only_inputs=True, allow_unused=True, + ) return grad From 81854c3a91bb21207b9f0fa7af60e8eaa688ad31 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 28 Jan 2023 18:38:22 +0900 Subject: [PATCH 052/102] Refactor: Rename to CrfLayer --- README.md | 16 ++++++++-------- benchmark/crf.py | 4 ++-- tests/test_crf.py | 26 +++++++++++++------------- third/crf.py | 2 +- torchlatent/crf.py | 30 ++++++++++++------------------ 5 files changed, 36 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index e6d403b..363740d 100644 --- a/README.md +++ b/README.md @@ -26,23 +26,23 @@ Third (0.232487) => 0.103277 0.129209 0.145311 import torch from torchrua import pack_sequence -from torchlatent.crf import CrfDecoder +from torchlatent.crf import CrfLayer num_tags = 3 num_conjugates = 1 -decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates) +decoder = CrfLayer(num_targets=num_tags, num_conjugates=num_conjugates) emissions = pack_sequence([ - torch.randn((5, num_conjugates, num_tags), requires_grad=True), - torch.randn((2, num_conjugates, num_tags), requires_grad=True), - torch.randn((3, num_conjugates, num_tags), requires_grad=True), + torch.randn((5, num_conjugates, num_tags), requires_grad=True), + torch.randn((2, num_conjugates, num_tags), requires_grad=True), + torch.randn((3, num_conjugates, num_tags), requires_grad=True), ]) tags = pack_sequence([ - torch.randint(0, num_tags, (5, num_conjugates)), - torch.randint(0, num_tags, (2, num_conjugates)), - torch.randint(0, num_tags, (3, num_conjugates)), + torch.randint(0, num_tags, (5, num_conjugates)), + torch.randint(0, num_tags, (2, num_conjugates)), + torch.randint(0, num_tags, (3, num_conjugates)), ]) print(decoder.fit(emissions=emissions, tags=tags)) diff --git a/benchmark/crf.py b/benchmark/crf.py index d60bc7d..971e936 100644 --- a/benchmark/crf.py +++ b/benchmark/crf.py @@ -4,7 +4,7 @@ from benchmark.meter import TimeMeter from third.crf import CrfDecoder as ThirdPartyCrfDecoder -from torchlatent.crf import CrfDecoder +from torchlatent.crf import CrfLayer def benchmark_crf(num_tags: int = 50, num_conjugates: int = 1, num_runs: int = 100, @@ -19,7 +19,7 @@ def benchmark_crf(num_tags: int = 50, num_conjugates: int = 1, num_runs: int = 1 device = torch.device('cpu') print(f'device => {device}') - decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) + decoder = CrfLayer(num_targets=num_tags, num_conjugates=num_conjugates).to(device=device) print(f'decoder => {decoder}') third_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) diff --git a/tests/test_crf.py b/tests/test_crf.py index 63b2249..711645d 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -6,7 +6,7 @@ from tests.assertion import assert_close, assert_grad_close from tests.strategy import device, sizes, TOKEN_SIZE, TINY_BATCH_SIZE, NUM_CONJUGATES, TINY_TOKEN_SIZE -from torchlatent.crf import CrfDecoder +from torchlatent.crf import CrfLayer @given( @@ -14,7 +14,7 @@ num_tags=sizes(TOKEN_SIZE), ) def test_crf_catted_scores(token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags).to(device=device) + actual_decoder = CrfLayer(num_tags).to(device=device) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) @@ -59,7 +59,7 @@ def test_crf_catted_scores(token_sizes, num_tags): num_tags=sizes(TOKEN_SIZE), ) def test_crf_catted_fit(token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags).to(device=device) + actual_decoder = CrfLayer(num_tags).to(device=device) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) @@ -104,7 +104,7 @@ def test_crf_catted_fit(token_sizes, num_tags): num_tags=sizes(TOKEN_SIZE), ) def test_crf_catted_decode(token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags).to(device=device) + actual_decoder = CrfLayer(num_tags).to(device=device) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) @@ -142,7 +142,7 @@ def test_crf_catted_decode(token_sizes, num_tags): num_tags=sizes(TOKEN_SIZE), ) def test_crf_packed_scores(token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags).to(device=device) + actual_decoder = CrfLayer(num_tags).to(device=device) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) @@ -192,7 +192,7 @@ def test_crf_packed_scores(token_sizes, num_tags): num_tags=sizes(TOKEN_SIZE), ) def test_crf_packed_fit(token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags).to(device=device) + actual_decoder = CrfLayer(num_tags).to(device=device) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) @@ -242,7 +242,7 @@ def test_crf_packed_fit(token_sizes, num_tags): num_tags=sizes(TOKEN_SIZE), ) def test_crf_packed_decode(token_sizes, num_tags): - actual_decoder = CrfDecoder(num_tags).to(device=device) + actual_decoder = CrfLayer(num_tags).to(device=device) excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) @@ -288,8 +288,8 @@ def test_crf_packed_decode(token_sizes, num_tags): num_tags=sizes(TINY_TOKEN_SIZE), ) def test_conjugated_catted_fit(token_sizes, num_conjugates, num_tags): - decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) - decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1).to(device=device) for _ in range(num_conjugates)] + decoder = CrfLayer(num_targets=num_tags, num_conjugates=num_conjugates).to(device=device) + decoders = [CrfLayer(num_targets=num_tags, num_conjugates=1).to(device=device) for _ in range(num_conjugates)] for index in range(num_conjugates): decoders[index].transitions.data = torch.randn_like(decoders[index].transitions) @@ -337,8 +337,8 @@ def test_conjugated_catted_fit(token_sizes, num_conjugates, num_tags): num_tags=sizes(TINY_TOKEN_SIZE), ) def test_conjugated_packed_fit(token_sizes, num_conjugates, num_tags): - decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) - decoders = [CrfDecoder(num_tags=num_tags, num_conjugates=1).to(device=device) for _ in range(num_conjugates)] + decoder = CrfLayer(num_targets=num_tags, num_conjugates=num_conjugates).to(device=device) + decoders = [CrfLayer(num_targets=num_tags, num_conjugates=1).to(device=device) for _ in range(num_conjugates)] for index in range(num_conjugates): decoders[index].transitions.data = torch.randn_like(decoders[index].transitions) @@ -386,8 +386,8 @@ def test_conjugated_packed_fit(token_sizes, num_conjugates, num_tags): num_tags=sizes(TINY_TOKEN_SIZE), ) def test_dynamic_fit(token_sizes, num_conjugates, num_tags): - packed_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) - catted_decoder = CrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) + packed_decoder = CrfLayer(num_targets=num_tags, num_conjugates=num_conjugates).to(device=device) + catted_decoder = CrfLayer(num_targets=num_tags, num_conjugates=num_conjugates).to(device=device) emissions = [ torch.randn((token_size, 1, num_tags), requires_grad=True, device=device) diff --git a/third/crf.py b/third/crf.py index 5fb565f..baef860 100644 --- a/third/crf.py +++ b/third/crf.py @@ -30,7 +30,7 @@ def __init__(self, num_tags: int, num_conjugates: int) -> None: @torch.no_grad() def reset_parameters_with_(self, decoder) -> None: - assert self.num_tags == decoder.num_tags + assert self.num_tags == decoder.num_targets assert self.num_conjugates == decoder.num_conjugates for index in range(self.num_conjugates): diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 877b7af..3989712 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -214,30 +214,24 @@ def entropy(self) -> Tensor: ) -class CrfDecoderABC(nn.Module): +class CrfLayerABC(nn.Module): def reset_parameters(self) -> None: raise NotImplementedError - def __repr__(self) -> str: - return f'{self.__class__.__name__}({self.extra_repr()})' - - def extra_repr(self) -> str: - return '' - def forward_parameters(self, emissions: Sequence): raise NotImplementedError -class CrfDecoder(CrfDecoderABC): - def __init__(self, num_tags: int, num_conjugates: int = 1) -> None: - super(CrfDecoder, self).__init__() +class CrfLayer(CrfLayerABC): + def __init__(self, num_targets: int, num_conjugates: int = 1) -> None: + super(CrfLayer, self).__init__() - self.num_tags = num_tags + self.num_targets = num_targets self.num_conjugates = num_conjugates - self.transitions = nn.Parameter(torch.empty((1, num_conjugates, num_tags, num_tags))) - self.head_transitions = nn.Parameter(torch.empty((1, num_conjugates, num_tags))) - self.last_transitions = nn.Parameter(torch.empty((1, num_conjugates, num_tags))) + self.transitions = nn.Parameter(torch.empty((1, num_conjugates, num_targets, num_targets))) + self.head_transitions = nn.Parameter(torch.empty((1, num_conjugates, num_targets))) + self.last_transitions = nn.Parameter(torch.empty((1, num_conjugates, num_targets))) self.reset_parameters() @@ -248,7 +242,7 @@ def reset_parameters(self) -> None: def extra_repr(self) -> str: return ', '.join([ - f'num_tags={self.num_tags}', + f'num_targets={self.num_targets}', f'num_conjugates={self.num_conjugates}', ]) @@ -272,9 +266,9 @@ def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistrib return CrfDistribution(emissions=emissions, transitions=transitions, indices=indices) def fit(self, emissions: Sequence, targets: Sequence, indices: CrfIndices = None) -> Tensor: - dist = self.forward(emissions=emissions, indices=indices) - return dist.log_prob(targets=targets).neg() + dist: CrfDistribution = self.forward(emissions=emissions, indices=indices) + return dist.log_partitions - dist.log_scores(targets=targets) def decode(self, emissions: Sequence, indices: CrfIndices = None) -> Sequence: - dist = self.forward(emissions=emissions, indices=indices) + dist: CrfDistribution = self.forward(emissions=emissions, indices=indices) return emissions._replace(data=dist.argmax) From bc021e9864e6c6ff185d8199a40a6723dc4c24b7 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 28 Jan 2023 19:11:35 +0900 Subject: [PATCH 053/102] Feat: Add CrfDecoder --- torchlatent/crf.py | 57 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 3989712..6ee4329 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -10,14 +10,15 @@ from torch.nn import functional as F from torch.nn import init from torch.types import Device - -from torchlatent.abc import DistributionABC -from torchlatent.semiring import Semiring, Log, Max -from torchrua import CattedSequence, PackedSequence +from torchrua import CattedSequence, PackedSequence, RuaSequential from torchrua import ReductionIndices, accumulate_sizes, minor_sizes_to_ptr from torchrua import reduce_catted_indices from torchrua import reduce_packed_indices +from torchlatent.abc import DistributionABC +from torchlatent.nn.classifier import Classifier +from torchlatent.semiring import Semiring, Log, Max + Sequence = Union[CattedSequence, PackedSequence] @@ -272,3 +273,51 @@ def fit(self, emissions: Sequence, targets: Sequence, indices: CrfIndices = None def decode(self, emissions: Sequence, indices: CrfIndices = None) -> Sequence: dist: CrfDistribution = self.forward(emissions=emissions, indices=indices) return emissions._replace(data=dist.argmax) + + +class CrfDecoder(nn.Module): + def __init__(self, in_features: int, num_targets: int, num_conjugates: int, dropout: float) -> None: + super(CrfDecoder, self).__init__() + + self.in_features = in_features + self.num_targets = num_targets + self.num_conjugates = num_conjugates + num_conjugates = max(1, num_conjugates) + + self.classifier = RuaSequential( + nn.Dropout(dropout), + Classifier( + num_conjugates=num_conjugates, + in_features=in_features, + out_features=num_targets, + bias=False, + ) + ) + + self.crf = CrfLayer( + num_targets=num_targets, + num_conjugates=num_conjugates, + ) + + def forward(self, sequence: Sequence) -> CrfDistribution: + if self.num_conjugates == 0: + sequence = sequence._replace(data=sequence.data[..., None, :]) + + emissions = self.classifier(sequence) + return self.crf(emissions) + + def fit(self, sequence: Sequence, targets: Sequence) -> Tensor: + dist: CrfDistribution = self(sequence=sequence) + loss = dist.log_partitions - dist.log_scores(targets=targets) + + if self.num_conjugates == 0: + loss = loss[..., 0] + return loss + + def decode(self, sequence: Sequence) -> Sequence: + dist: CrfDistribution = self(sequence=sequence) + argmax = dist.argmax + + if self.num_conjugates == 0: + argmax = argmax[..., 0] + return sequence._replace(data=argmax) From b2fc23cbf2c0cbfb84c155f7b2ab582e0592ecec Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 28 Jan 2023 20:13:02 +0900 Subject: [PATCH 054/102] Feat: Add CkyDecoder --- tests/test_cky.py | 6 +-- torchlatent/cky.py | 87 +++++++++++++++++++++--------------- torchlatent/nn/classifier.py | 32 +++---------- 3 files changed, 59 insertions(+), 66 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index 821222c..c8ec55c 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -5,7 +5,7 @@ from tests.assertion import assert_close, assert_grad_close from tests.strategy import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, device, TINY_BATCH_SIZE -from torchlatent.cky import CkyDistribution, cky_partition_indices, CkyDecoder +from torchlatent.cky import CkyDistribution, cky_partition_indices, CkyLayer @given( @@ -20,7 +20,7 @@ def test_cky_catted_max(token_sizes, embedding_dim, num_tags, bias): for token_size in token_sizes ]) - decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias).to(device=device) + decoder = CkyLayer(in_features=embedding_dim, out_features=num_tags, bias=bias).to(device=device) cky = decoder.forward(sequence=sequence) assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) @@ -38,7 +38,7 @@ def test_cky_packed_max(token_sizes, embedding_dim, num_tags, bias): for token_size in token_sizes ]) - decoder = CkyDecoder(in_features=embedding_dim, out_features=num_tags, bias=bias).to(device=device) + decoder = CkyLayer(in_features=embedding_dim, out_features=num_tags, bias=bias).to(device=device) cky = decoder.forward(sequence=sequence) assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index d7d37e6..579a403 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -11,26 +11,13 @@ from torch.types import Device from torchlatent.abc import DistributionABC +from torchlatent.nn.classifier import BiaffineClassifier from torchlatent.semiring import Semiring, Log, Max from torchlatent.types import Sequence -from torchrua import CattedSequence, transpose_sizes, pack_catted_sequence, cat_packed_indices +from torchrua import CattedSequence, transpose_sizes, pack_catted_sequence, cat_packed_indices, RuaSequential from torchrua import major_sizes_to_ptr, accumulate_sizes from torchrua import pad_packed_sequence, pad_catted_sequence -__all__ = [ - 'cky_scores_indices', - 'cky_scores_catted_indices', - 'cky_scores_packed_indices', - - 'CkyIndices', - 'cky_partition_indices', - 'cky_partition', - - 'CkyDistribution', - 'CkyDecoderABC', - 'CkyDecoder', -] - @singledispatch def cky_scores_indices(sequence: Sequence, device: Device = None): @@ -166,13 +153,10 @@ def entropy(self) -> Tensor: raise NotImplementedError -class CkyDecoderABC(nn.Module, metaclass=ABCMeta): +class CkyLayerABC(nn.Module, metaclass=ABCMeta): def reset_parameters(self) -> None: raise NotImplementedError - def extra_repr(self) -> str: - raise NotImplementedError - def forward_scores(self, features: Tensor, *args, **kwargs) -> Tensor: raise NotImplementedError @@ -212,27 +196,56 @@ def decode(self, sequence: Sequence, indices: CkyIndices = None) -> Sequence: raise KeyError(f'type {type(sequence)} is not supported') -class CkyDecoder(CkyDecoderABC): - def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: - super(CkyDecoder, self).__init__() - - self.fc = nn.Bilinear( - in1_features=in_features, - in2_features=in_features, - out_features=out_features, - bias=bias, - ) +class CkyLayer(CkyLayerABC): + def __init__(self, num_targets: int, num_conjugates: int) -> None: + super(CkyLayer, self).__init__() - def __repr__(self) -> str: - return f'{self.__class__.__name__}({self.extra_repr()})' + self.num_targets = num_targets + self.num_conjugates = num_conjugates def extra_repr(self) -> str: return ', '.join([ - f'in_features={self.fc.in_features}', - f'in_features={self.fc.out_features}', - f'bias={self.fc1.bias is not None}', + f'num_targets={self.num_targets}', + f'num_conjugates={self.num_conjugates}', ]) - def forward_scores(self, features: Tensor, *args, **kwargs) -> Tensor: - x, y = torch.broadcast_tensors(features[..., :, None, :], features[..., None, :, :]) - return self.fc(x, y) + +class CkyDecoder(nn.Module): + def __init__(self, in_features: int, hidden_features: int, + num_targets: int, num_conjugates: int, dropout: float) -> None: + super(CkyDecoder, self).__init__() + + self.in_features = in_features + self.hidden_features = hidden_features + self.num_targets = num_targets + self.num_conjugates = num_conjugates + + self.ffn1 = RuaSequential( + nn.Linear(in_features, hidden_features, bias=True), + nn.GELU(), + nn.Dropout(dropout), + ) + self.ffn2 = RuaSequential( + nn.Linear(in_features, hidden_features, bias=True), + nn.GELU(), + nn.Dropout(dropout), + ) + self.classifier = BiaffineClassifier( + num_conjugates=num_conjugates, + in_features1=hidden_features, + in_features2=hidden_features, + out_features=num_targets, + bias=False, + ) + + self.cky = CkyLayer( + num_targets=num_targets, + num_conjugates=num_conjugates, + ) + + def forward(self, sequence: Sequence) -> CkyDistribution: + features1, _ = self.ffn1(sequence) + features2, _ = self.ffn2(sequence) + + emissions = self.classifier(features1, features2) + return self.cky(emissions) diff --git a/torchlatent/nn/classifier.py b/torchlatent/nn/classifier.py index d2b1888..5cd1f1b 100644 --- a/torchlatent/nn/classifier.py +++ b/torchlatent/nn/classifier.py @@ -13,17 +13,8 @@ def __init__(self, bias: bool = False, *, num_conjugates: int, self.out_features = out_features self.num_conjugates = num_conjugates - self.weight = nn.Parameter(torch.empty((num_conjugates, out_features, in_features))) - self.bias = nn.Parameter(torch.empty((num_conjugates, out_features,))) if bias else 0 - - self.reset_parameters() - - @torch.no_grad() - def reset_parameters(self) -> None: - init.zeros_(self.weight) - - if torch.is_tensor(self.bias): - init.zeros_(self.bias) + self.weight = nn.Parameter(torch.zeros((num_conjugates, out_features, in_features))) + self.bias = nn.Parameter(torch.zeros((num_conjugates, out_features,))) if bias else 0 def extra_repr(self) -> str: return ', '.join([ @@ -47,21 +38,10 @@ def __init__(self, bias: bool = False, *, num_conjugates: int, self.out_features = out_features self.num_conjugates = num_conjugates - self.weight0 = nn.Parameter(torch.empty((num_conjugates, out_features, in_features1, in_features2))) - self.weight1 = nn.Parameter(torch.empty((num_conjugates, out_features, in_features1))) - self.weight2 = nn.Parameter(torch.empty((num_conjugates, out_features, in_features2))) - self.bias = nn.Parameter(torch.empty((num_conjugates, out_features))) if bias else 0 - - self.reset_parameters() - - @torch.no_grad() - def reset_parameters(self) -> None: - init.zeros_(self.weight0) - init.zeros_(self.weight1) - init.zeros_(self.weight2) - - if torch.is_tensor(self.bias): - init.zeros_(self.bias) + self.weight0 = nn.Parameter(torch.zeros((num_conjugates, out_features, in_features1, in_features2))) + self.weight1 = nn.Parameter(torch.zeros((num_conjugates, out_features, in_features1))) + self.weight2 = nn.Parameter(torch.zeros((num_conjugates, out_features, in_features2))) + self.bias = nn.Parameter(torch.zeros((num_conjugates, out_features))) if bias else 0 def extra_repr(self) -> str: return ', '.join([ From 95fae8af5658853c2c46af1bd04f8592028729f3 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 18 Feb 2023 18:29:34 +0900 Subject: [PATCH 055/102] Feat: Update Classifier --- torchlatent/nn/classifier.py | 38 ++++++++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/torchlatent/nn/classifier.py b/torchlatent/nn/classifier.py index 5cd1f1b..d355ab7 100644 --- a/torchlatent/nn/classifier.py +++ b/torchlatent/nn/classifier.py @@ -13,8 +13,16 @@ def __init__(self, bias: bool = False, *, num_conjugates: int, self.out_features = out_features self.num_conjugates = num_conjugates - self.weight = nn.Parameter(torch.zeros((num_conjugates, out_features, in_features))) - self.bias = nn.Parameter(torch.zeros((num_conjugates, out_features,))) if bias else 0 + self.weight = nn.Parameter(torch.empty((num_conjugates, out_features, in_features))) + self.bias = nn.Parameter(torch.empty((num_conjugates, out_features,))) if bias else 0 + + self.reset_parameters() + + def reset_parameters(self) -> None: + init.zeros_(self.weight) + + if torch.is_tensor(self.bias): + init.zeros_(self.bias) def extra_repr(self) -> str: return ', '.join([ @@ -25,7 +33,7 @@ def extra_repr(self) -> str: ]) def forward(self, tensor: Tensor) -> Tensor: - return torch.einsum('cox,...cx->...co', self.weight, tensor) + self.bias + return torch.einsum('nzx,...nx->...nz', self.weight, tensor) + self.bias class BiaffineClassifier(nn.Module): @@ -38,10 +46,20 @@ def __init__(self, bias: bool = False, *, num_conjugates: int, self.out_features = out_features self.num_conjugates = num_conjugates - self.weight0 = nn.Parameter(torch.zeros((num_conjugates, out_features, in_features1, in_features2))) - self.weight1 = nn.Parameter(torch.zeros((num_conjugates, out_features, in_features1))) - self.weight2 = nn.Parameter(torch.zeros((num_conjugates, out_features, in_features2))) - self.bias = nn.Parameter(torch.zeros((num_conjugates, out_features))) if bias else 0 + self.weight0 = nn.Parameter(torch.empty((num_conjugates, out_features, in_features1, in_features2))) + self.weight1 = nn.Parameter(torch.empty((num_conjugates, out_features, in_features1))) + self.weight2 = nn.Parameter(torch.empty((num_conjugates, out_features, in_features2))) + self.bias = nn.Parameter(torch.empty((num_conjugates, out_features))) if bias else 0 + + self.reset_parameters() + + def reset_parameters(self) -> None: + init.zeros_(self.weight0) + init.zeros_(self.weight1) + init.zeros_(self.weight2) + + if torch.is_tensor(self.bias): + init.zeros_(self.bias) def extra_repr(self) -> str: return ', '.join([ @@ -53,8 +71,8 @@ def extra_repr(self) -> str: ]) def forward(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: - tensor0 = torch.einsum('coxy,...cx,...cy->...co', self.weight0, tensor1, tensor2) - tensor1 = torch.einsum('cox,...cx->...co', self.weight1, tensor1) - tensor2 = torch.einsum('coy,...cy->...co', self.weight2, tensor2) + tensor0 = torch.einsum('nzxy,...nx,...ny->...nz', self.weight0, tensor1, tensor2) + tensor1 = torch.einsum('nzx,...nx->...nz', self.weight1, tensor1) + tensor2 = torch.einsum('nzy,...ny->...nz', self.weight2, tensor2) return tensor0 + tensor1 + tensor2 + self.bias From e19dbb93346d59fda9eeb05352b49f2cd737117a Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 11 Mar 2023 18:39:42 +0900 Subject: [PATCH 056/102] Refactor: Remove types.py --- torchlatent/abc.py | 9 ++++----- torchlatent/cky.py | 11 ++++++----- torchlatent/types.py | 7 ------- 3 files changed, 10 insertions(+), 17 deletions(-) delete mode 100644 torchlatent/types.py diff --git a/torchlatent/abc.py b/torchlatent/abc.py index fd711bb..0f7767f 100644 --- a/torchlatent/abc.py +++ b/torchlatent/abc.py @@ -1,16 +1,15 @@ from abc import ABCMeta +from typing import Union import torch import torch.autograd from torch import Tensor from torch.distributions import Distribution from torch.distributions.utils import lazy_property +from torch.nn.utils.rnn import PackedSequence +from torchrua import CattedSequence -from torchlatent.types import Sequence - -__all__ = [ - 'DistributionABC', -] +Sequence = Union[CattedSequence, PackedSequence] class DistributionABC(Distribution, metaclass=ABCMeta): diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 579a403..0b19058 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -1,6 +1,6 @@ from abc import ABCMeta from functools import singledispatch -from typing import Tuple, NamedTuple +from typing import Tuple, NamedTuple, Union from typing import Type import torch @@ -9,14 +9,15 @@ from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence from torch.types import Device +from torchrua import CattedSequence, transpose_sizes, pack_catted_sequence, cat_packed_indices, RuaSequential +from torchrua import major_sizes_to_ptr, accumulate_sizes +from torchrua import pad_packed_sequence, pad_catted_sequence from torchlatent.abc import DistributionABC from torchlatent.nn.classifier import BiaffineClassifier from torchlatent.semiring import Semiring, Log, Max -from torchlatent.types import Sequence -from torchrua import CattedSequence, transpose_sizes, pack_catted_sequence, cat_packed_indices, RuaSequential -from torchrua import major_sizes_to_ptr, accumulate_sizes -from torchrua import pad_packed_sequence, pad_catted_sequence + +Sequence = Union[CattedSequence, PackedSequence] @singledispatch diff --git a/torchlatent/types.py b/torchlatent/types.py deleted file mode 100644 index 6cd724a..0000000 --- a/torchlatent/types.py +++ /dev/null @@ -1,7 +0,0 @@ -from typing import Union - -from torch.nn.utils.rnn import PackedSequence - -from torchrua import CattedSequence, PaddedSequence - -Sequence = Union[CattedSequence, PackedSequence, PaddedSequence] From 20499e2c35076e84c27b148f7a5d18556b78f6d7 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 12 Mar 2023 22:06:20 +0900 Subject: [PATCH 057/102] Fix: Resolve torchrua issue --- tests/test_crf.py | 6 +++--- torchlatent/crf.py | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_crf.py b/tests/test_crf.py index 711645d..ed4becf 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -40,7 +40,7 @@ def test_crf_catted_scores(token_sizes, num_tags): padded_emissions, _ = pad_sequence(emissions, batch_first=False) padded_targets, _ = pad_sequence(targets, batch_first=False) - size, ptr = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) + size, ptr, _ = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) mask = torch.zeros(size, dtype=torch.bool, device=device) mask[ptr] = True @@ -85,7 +85,7 @@ def test_crf_catted_fit(token_sizes, num_tags): padded_emissions, _ = pad_sequence(emissions, batch_first=False) padded_targets, _ = pad_sequence(targets, batch_first=False) - size, ptr = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) + size, ptr, _ = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) mask = torch.zeros(size, dtype=torch.bool, device=device) mask[ptr] = True @@ -124,7 +124,7 @@ def test_crf_catted_decode(token_sizes, num_tags): padded_emissions, _ = pad_sequence(emissions, batch_first=False) - size, ptr = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) + size, ptr, _ = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) mask = torch.zeros(size, dtype=torch.bool, device=device) mask[ptr] = True diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 6ee4329..8325d0b 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -12,8 +12,7 @@ from torch.types import Device from torchrua import CattedSequence, PackedSequence, RuaSequential from torchrua import ReductionIndices, accumulate_sizes, minor_sizes_to_ptr -from torchrua import reduce_catted_indices -from torchrua import reduce_packed_indices +from torchrua import reduce_catted_indices, reduce_packed_indices from torchlatent.abc import DistributionABC from torchlatent.nn.classifier import Classifier @@ -96,7 +95,7 @@ def crf_scores_packed_indices(sequence: PackedSequence, device: Device = None): acc_batch_sizes = F.pad(batch_sizes.cumsum(dim=0), [2, -1]) batch_ptr, token_ptr, token_sizes = minor_sizes_to_ptr( - token_sizes=batch_sizes, token_ptr=unsorted_indices, + sizes=batch_sizes, minor_ptr=unsorted_indices, ) prev = acc_batch_sizes[token_ptr + 0] + batch_ptr curr = acc_batch_sizes[token_ptr + 1] + batch_ptr From dabc6475d28fb1362efb8b1299c862d09f82d69a Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 12 Mar 2023 22:39:48 +0900 Subject: [PATCH 058/102] Refactor: Use new segment functions --- torchlatent/semiring.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 6f0b2fb..aa097a2 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,8 +1,9 @@ import torch from torch import Tensor -from torchrua.reduction import reduce_sequence, ReductionIndices from torchlatent.functional import logsumexp, logaddexp +from torchrua import segment_sum, segment_prod, segment_max, segment_logsumexp +from torchrua.reduction import reduce_sequence, ReductionIndices __all__ = [ 'Semiring', 'ExceptionSemiring', @@ -79,11 +80,11 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: @classmethod def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor: - return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) + return segment_sum(tensor, segment_sizes=sizes) @classmethod def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: - raise NotImplementedError + return segment_prod(tensor, segment_sizes=sizes) class Log(Semiring): @@ -108,13 +109,11 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: @classmethod def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor: - m = torch.segment_reduce(tensor, reduce='max', lengths=sizes, unsafe=True).detach() - z = (tensor - torch.repeat_interleave(m, repeats=sizes)).exp() - return torch.segment_reduce(z, reduce='sum', lengths=sizes, unsafe=True).log() + m + return segment_logsumexp(tensor, segment_sizes=sizes) @classmethod def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: - return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) + return segment_sum(tensor, segment_sizes=sizes) class Max(Semiring): @@ -139,11 +138,11 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: @classmethod def segment_sum(cls, tensor: Tensor, sizes: Tensor) -> Tensor: - return torch.segment_reduce(tensor, reduce='max', lengths=sizes, unsafe=True) + return segment_max(tensor, segment_sizes=sizes) @classmethod def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: - return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) + return segment_sum(tensor, segment_sizes=sizes) class ExceptionSemiring(Semiring): @@ -178,11 +177,11 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: @classmethod def segment_sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, sizes: Tensor) -> Tensor: - return torch.segment_reduce((tensor - log_q) * log_p.exp(), reduce='sum', lengths=sizes, unsafe=True) + return segment_sum((tensor - log_q) * log_p.exp(), segment_sizes=sizes) @classmethod def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: - return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) + return segment_sum(tensor, segment_sizes=sizes) class Div(ExceptionSemiring): @@ -207,8 +206,8 @@ def prod(cls, tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: @classmethod def segment_sum(cls, tensor: Tensor, log_p: Tensor, log_q: Tensor, sizes: Tensor) -> Tensor: - return torch.segment_reduce((tensor - log_q + log_p) * log_p.exp(), reduce='sum', lengths=sizes, unsafe=True) + return segment_sum((tensor - log_q + log_p) * log_p.exp(), segment_sizes=sizes) @classmethod def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: - return torch.segment_reduce(tensor, reduce='sum', lengths=sizes, unsafe=True) + return segment_sum(tensor, segment_sizes=sizes) From db70af339999d8a55cf94b3427b608625c9cc88b Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 13 Mar 2023 23:29:07 +0900 Subject: [PATCH 059/102] Refactor: Rename --- tests/test_cky.py | 4 ++-- torchlatent/cky.py | 22 +++++++++++----------- torchlatent/crf.py | 8 ++++---- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index c8ec55c..c17759e 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -5,7 +5,7 @@ from tests.assertion import assert_close, assert_grad_close from tests.strategy import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, device, TINY_BATCH_SIZE -from torchlatent.cky import CkyDistribution, cky_partition_indices, CkyLayer +from torchlatent.cky import CkyDistribution, cky_partitions_indices, CkyLayer @given( @@ -56,7 +56,7 @@ def test_cky_log_partitions(token_sizes, num_tags): token_sizes = torch.tensor(token_sizes, device=device) excepted = TreeCRF(log_potentials=scores, lengths=token_sizes) - actual = CkyDistribution(emissions=scores, indices=cky_partition_indices(token_sizes=token_sizes, device=device)) + actual = CkyDistribution(emissions=scores, indices=cky_partitions_indices(token_sizes=token_sizes, device=device)) assert_close(actual=actual.log_partitions, expected=excepted.partition) assert_grad_close(actual=actual.log_partitions, expected=excepted.partition, inputs=scores, rtol=1e-5, atol=1e-5) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 0b19058..aba6aa4 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -9,13 +9,13 @@ from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence from torch.types import Device -from torchrua import CattedSequence, transpose_sizes, pack_catted_sequence, cat_packed_indices, RuaSequential -from torchrua import major_sizes_to_ptr, accumulate_sizes -from torchrua import pad_packed_sequence, pad_catted_sequence from torchlatent.abc import DistributionABC from torchlatent.nn.classifier import BiaffineClassifier from torchlatent.semiring import Semiring, Log, Max +from torchrua import CattedSequence, transpose_sizes, pack_catted_sequence, cat_packed_indices, RuaSequential +from torchrua import major_sizes_to_ptr, accumulate_sizes +from torchrua import pad_packed_sequence, pad_catted_sequence Sequence = Union[CattedSequence, PackedSequence] @@ -63,7 +63,7 @@ class CkyIndices(NamedTuple): @torch.no_grad() -def cky_partition_indices(token_sizes: Tensor, device: Device = None): +def cky_partitions_indices(token_sizes: Tensor, device: Device = None): if device is None: device = token_sizes.device @@ -85,13 +85,13 @@ def cky_partition_indices(token_sizes: Tensor, device: Device = None): ) -def cky_partition(data: Tensor, indices: CkyIndices, *, semiring: Type[Semiring]) -> Tensor: +def cky_partitions(data: Tensor, indices: CkyIndices, *, semiring: Type[Semiring]) -> Tensor: token_size, cache_size, (src1, src2), tgt = indices size = (token_size, cache_size, *data.size()[3:]) - tensor0 = torch.full(size, fill_value=semiring.zero, requires_grad=False) - tensor1 = torch.full(size, fill_value=semiring.zero, requires_grad=False) - tensor2 = torch.full(size, fill_value=semiring.zero, requires_grad=False) + tensor0 = torch.full(size, fill_value=semiring.zero, device=data.device, requires_grad=False) + tensor1 = torch.full(size, fill_value=semiring.zero, device=data.device, requires_grad=False) + tensor2 = torch.full(size, fill_value=semiring.zero, device=data.device, requires_grad=False) tensor0[src1] = data[src2] tensor1[0, :] = tensor2[-1, :] = tensor0[0, :] @@ -122,11 +122,11 @@ def log_scores(self, sequence: Sequence) -> Tensor: @lazy_property def log_partitions(self) -> Tensor: - return cky_partition(data=Log.sum(self.emissions, dim=-1), indices=self.indices, semiring=Log) + return cky_partitions(data=Log.sum(self.emissions, dim=-1), indices=self.indices, semiring=Log) @lazy_property def max(self) -> Tensor: - return cky_partition(data=Max.sum(self.emissions, dim=-1), indices=self.indices, semiring=Max) + return cky_partitions(data=Max.sum(self.emissions, dim=-1), indices=self.indices, semiring=Max) @lazy_property def argmax(self) -> Tensor: @@ -172,7 +172,7 @@ def forward(self, sequence: Sequence, indices: CkyIndices = None) -> CkyDistribu raise KeyError(f'type {type(sequence)} is not supported') if indices is None: - indices = cky_partition_indices(token_sizes=token_sizes, device=features.device) + indices = cky_partitions_indices(token_sizes=token_sizes, device=features.device) return CkyDistribution( emissions=self.forward_scores(features=features), diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 8325d0b..f04c111 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -147,8 +147,8 @@ def crf_indices(emissions: Sequence) -> CrfIndices: ) -def crf_partition(emissions: Tensor, transitions: Tuple[Tensor, Tensor, Tensor], - indices: CrfIndices, semiring: Type[Semiring]): +def crf_partitions(emissions: Tensor, transitions: Tuple[Tensor, Tensor, Tensor], + indices: CrfIndices, semiring: Type[Semiring]): head, _, _, _, _, unsorted_indices, indices = indices transitions, head_transitions, last_transitions = transitions @@ -185,7 +185,7 @@ def log_scores(self, targets: Sequence) -> Tensor: @lazy_property def log_partitions(self) -> Tensor: - return crf_partition( + return crf_partitions( emissions=self.emissions, transitions=self.transitions, indices=self.indices, @@ -194,7 +194,7 @@ def log_partitions(self) -> Tensor: @lazy_property def max(self) -> Tensor: - return crf_partition( + return crf_partitions( emissions=self.emissions, transitions=self.transitions, indices=self.indices, From 80c4727da24f230236c23d3a045d66241ea21654 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 13 Mar 2023 23:44:10 +0900 Subject: [PATCH 060/102] Refactor: Remove num_conjugates from BiaffineClassifier --- torchlatent/nn/classifier.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/torchlatent/nn/classifier.py b/torchlatent/nn/classifier.py index d355ab7..92ee02b 100644 --- a/torchlatent/nn/classifier.py +++ b/torchlatent/nn/classifier.py @@ -37,19 +37,18 @@ def forward(self, tensor: Tensor) -> Tensor: class BiaffineClassifier(nn.Module): - def __init__(self, bias: bool = False, *, num_conjugates: int, + def __init__(self, bias: bool = False, *, in_features1: int, in_features2: int, out_features: int) -> None: super(BiaffineClassifier, self).__init__() self.in_features1 = in_features1 self.in_features2 = in_features2 self.out_features = out_features - self.num_conjugates = num_conjugates - self.weight0 = nn.Parameter(torch.empty((num_conjugates, out_features, in_features1, in_features2))) - self.weight1 = nn.Parameter(torch.empty((num_conjugates, out_features, in_features1))) - self.weight2 = nn.Parameter(torch.empty((num_conjugates, out_features, in_features2))) - self.bias = nn.Parameter(torch.empty((num_conjugates, out_features))) if bias else 0 + self.weight0 = nn.Parameter(torch.empty((out_features, in_features1, in_features2))) + self.weight1 = nn.Parameter(torch.empty((out_features, in_features1))) + self.weight2 = nn.Parameter(torch.empty((out_features, in_features2))) + self.bias = nn.Parameter(torch.empty((out_features,))) if bias else 0 self.reset_parameters() @@ -66,13 +65,12 @@ def extra_repr(self) -> str: f'in_features1={self.in_features1}', f'in_features2={self.in_features2}', f'out_features={self.out_features}', - f'num_conjugates={self.num_conjugates}', f'bias={torch.is_tensor(self.bias)}', ]) def forward(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: - tensor0 = torch.einsum('nzxy,...nx,...ny->...nz', self.weight0, tensor1, tensor2) - tensor1 = torch.einsum('nzx,...nx->...nz', self.weight1, tensor1) - tensor2 = torch.einsum('nzy,...ny->...nz', self.weight2, tensor2) + tensor0 = torch.einsum('zxy,...x,...y->...z', self.weight0, tensor1, tensor2) + tensor1 = torch.einsum('zx,...x->...z', self.weight1, tensor1) + tensor2 = torch.einsum('zy,...y->...z', self.weight2, tensor2) return tensor0 + tensor1 + tensor2 + self.bias From c69277b1654953f98648946703a2c130803ac73a Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 13 Mar 2023 23:54:34 +0900 Subject: [PATCH 061/102] Refactor: Use different data format --- torchlatent/cky.py | 72 ++++++++++++++++++++-------------------------- 1 file changed, 31 insertions(+), 41 deletions(-) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index aba6aa4..66a2d44 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -13,9 +13,9 @@ from torchlatent.abc import DistributionABC from torchlatent.nn.classifier import BiaffineClassifier from torchlatent.semiring import Semiring, Log, Max -from torchrua import CattedSequence, transpose_sizes, pack_catted_sequence, cat_packed_indices, RuaSequential +from torchrua import CattedSequence, pack_catted_sequence, cat_packed_indices, RuaSequential from torchrua import major_sizes_to_ptr, accumulate_sizes -from torchrua import pad_packed_sequence, pad_catted_sequence +from torchrua import pad_sequence, pad_indices Sequence = Union[CattedSequence, PackedSequence] @@ -93,6 +93,11 @@ def cky_partitions(data: Tensor, indices: CkyIndices, *, semiring: Type[Semiring tensor1 = torch.full(size, fill_value=semiring.zero, device=data.device, requires_grad=False) tensor2 = torch.full(size, fill_value=semiring.zero, device=data.device, requires_grad=False) + print(f'src1 => {src1}') + print(f'src2 => {src2}') + print(f'tensor0.size() => {tensor0.size()}') + print(f'tensor1.size() => {tensor1.size()}') + tensor0[src1] = data[src2] tensor1[0, :] = tensor2[-1, :] = tensor0[0, :] @@ -161,65 +166,52 @@ def reset_parameters(self) -> None: def forward_scores(self, features: Tensor, *args, **kwargs) -> Tensor: raise NotImplementedError - def forward(self, sequence: Sequence, indices: CkyIndices = None) -> CkyDistribution: - if isinstance(sequence, CattedSequence): - features, token_sizes = pad_catted_sequence(sequence, batch_first=True) - elif isinstance(sequence, PackedSequence): - features, token_sizes = pad_packed_sequence(sequence, batch_first=True) - elif isinstance(sequence, tuple) and torch.tensor(sequence[0]) and torch.is_tensor(sequence[1]): - features, token_sizes = sequence - else: - raise KeyError(f'type {type(sequence)} is not supported') + def forward(self, emissions: Sequence, indices: CkyIndices = None) -> CkyDistribution: + _, _, token_sizes = pad_indices(emissions, batch_first=True) if indices is None: - indices = cky_partitions_indices(token_sizes=token_sizes, device=features.device) + indices = cky_partitions_indices(token_sizes=token_sizes, device=emissions.data.device) - return CkyDistribution( - emissions=self.forward_scores(features=features), - indices=indices, - ) + return CkyDistribution(emissions=emissions.data, indices=indices) - def fit(self, sequence: Sequence, targets: Sequence, indices: CkyIndices = None) -> Tensor: - dist = self.forward(sequence=sequence, indices=indices) + def fit(self, emissions: Sequence, targets: Sequence, indices: CkyIndices = None) -> Tensor: + dist = self.forward(emissions=emissions, indices=indices) return dist.log_partitions - dist.log_scores(sequence=targets) - def decode(self, sequence: Sequence, indices: CkyIndices = None) -> Sequence: - dist = self.forward(sequence=sequence, indices=indices) + def decode(self, emissions: Sequence, indices: CkyIndices = None) -> Sequence: + dist = self.forward(emissions=emissions, indices=indices) + _, _, token_sizes = pad_indices(emissions, batch_first=True) - if isinstance(sequence, CattedSequence): - return CattedSequence(data=dist.argmax, token_sizes=sequence.token_sizes * 2 - 1) + if isinstance(emissions, CattedSequence): + sequence = CattedSequence(data=dist.argmax, token_sizes=token_sizes * 2 - 1) + return sequence - if isinstance(sequence, PackedSequence): - token_sizes = transpose_sizes(sizes=sequence.batch_sizes) - token_sizes = token_sizes[sequence.unsorted_indices] * 2 - 1 - return pack_catted_sequence(CattedSequence(data=dist.argmax, token_sizes=token_sizes)) + if isinstance(emissions, PackedSequence): + sequence = CattedSequence(data=dist.argmax, token_sizes=token_sizes * 2 - 1) + return pack_catted_sequence(sequence) - raise KeyError(f'type {type(sequence)} is not supported') + raise KeyError(f'type {type(emissions)} is not supported') class CkyLayer(CkyLayerABC): - def __init__(self, num_targets: int, num_conjugates: int) -> None: + def __init__(self, num_targets: int) -> None: super(CkyLayer, self).__init__() self.num_targets = num_targets - self.num_conjugates = num_conjugates def extra_repr(self) -> str: return ', '.join([ f'num_targets={self.num_targets}', - f'num_conjugates={self.num_conjugates}', ]) class CkyDecoder(nn.Module): - def __init__(self, in_features: int, hidden_features: int, - num_targets: int, num_conjugates: int, dropout: float) -> None: + def __init__(self, in_features: int, hidden_features: int, num_targets: int, dropout: float) -> None: super(CkyDecoder, self).__init__() self.in_features = in_features self.hidden_features = hidden_features self.num_targets = num_targets - self.num_conjugates = num_conjugates self.ffn1 = RuaSequential( nn.Linear(in_features, hidden_features, bias=True), @@ -232,21 +224,19 @@ def __init__(self, in_features: int, hidden_features: int, nn.Dropout(dropout), ) self.classifier = BiaffineClassifier( - num_conjugates=num_conjugates, in_features1=hidden_features, in_features2=hidden_features, out_features=num_targets, bias=False, ) - self.cky = CkyLayer( - num_targets=num_targets, - num_conjugates=num_conjugates, - ) + self.cky = CkyLayer(num_targets=num_targets) def forward(self, sequence: Sequence) -> CkyDistribution: - features1, _ = self.ffn1(sequence) - features2, _ = self.ffn2(sequence) + features, _ = pad_sequence(sequence, batch_first=True) + + features1 = self.ffn1(features)[:, :, None, :] + features2 = self.ffn2(features)[:, None, :, :] emissions = self.classifier(features1, features2) - return self.cky(emissions) + return self.cky(sequence._replace(data=emissions)) From 9302f94e02ce81e4529da2f0798151d9451e2fb7 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 14 Mar 2023 00:18:09 +0900 Subject: [PATCH 062/102] Fix: Resolve test_cky_catted_max --- tests/test_cky.py | 50 +++++++++++++++++++++++++++------------------- torchlatent/cky.py | 8 ++++---- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index c17759e..0a8a4fc 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -5,43 +5,51 @@ from tests.assertion import assert_close, assert_grad_close from tests.strategy import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, device, TINY_BATCH_SIZE -from torchlatent.cky import CkyDistribution, cky_partitions_indices, CkyLayer +from torchlatent.cky import CkyDistribution, cky_partitions_indices, CkyLayer, CkyDecoder @given( token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), embedding_dim=sizes(EMBEDDING_DIM), num_tags=sizes(TOKEN_SIZE), - bias=st.booleans(), + dropout=st.floats(0, 1), ) -def test_cky_catted_max(token_sizes, embedding_dim, num_tags, bias): +def test_cky_catted_max(token_sizes, embedding_dim, num_tags, dropout): sequence = cat_sequence([ torch.randn((token_size, embedding_dim), requires_grad=True, device=device) for token_size in token_sizes ]) - decoder = CkyLayer(in_features=embedding_dim, out_features=num_tags, bias=bias).to(device=device) - cky = decoder.forward(sequence=sequence) - - assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) - - -@given( - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - embedding_dim=sizes(EMBEDDING_DIM), - num_tags=sizes(TOKEN_SIZE), - bias=st.booleans(), -) -def test_cky_packed_max(token_sizes, embedding_dim, num_tags, bias): - sequence = pack_sequence([ - torch.randn((token_size, embedding_dim), requires_grad=True, device=device) + targets = cat_sequence([ + torch.empty((token_size * 2 - 1,), dtype=torch.long, device=device) for token_size in token_sizes ]) - decoder = CkyLayer(in_features=embedding_dim, out_features=num_tags, bias=bias).to(device=device) - cky = decoder.forward(sequence=sequence) + decoder = CkyDecoder( + in_features=embedding_dim, hidden_features=embedding_dim, + num_targets=num_tags, dropout=dropout, + ).to(device=device) + dist = decoder(sequence) + + assert_close(actual=dist.max, expected=dist.log_scores(targets=targets._replace(data=dist.argmax))) + - assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) +# @given( +# token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), +# embedding_dim=sizes(EMBEDDING_DIM), +# num_tags=sizes(TOKEN_SIZE), +# bias=st.booleans(), +# ) +# def test_cky_packed_max(token_sizes, embedding_dim, num_tags, bias): +# sequence = pack_sequence([ +# torch.randn((token_size, embedding_dim), requires_grad=True, device=device) +# for token_size in token_sizes +# ]) +# +# decoder = CkyLayer(in_features=embedding_dim, out_features=num_tags, bias=bias).to(device=device) +# cky = decoder.forward(sequence=sequence) +# +# assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) @given( diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 66a2d44..b47f21d 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -117,9 +117,9 @@ def __init__(self, emissions: Tensor, indices: CkyIndices) -> None: self.emissions = emissions self.indices = indices - def log_scores(self, sequence: Sequence) -> Tensor: - indices, batch_ptr, sizes = cky_scores_indices(sequence) - data = sequence.data[indices] + def log_scores(self, targets: Sequence) -> Tensor: + indices, batch_ptr, sizes = cky_scores_indices(targets) + data = targets.data[indices] return Log.segment_prod( tensor=self.emissions[batch_ptr, data[..., 0], data[..., 1], data[..., 2]], sizes=sizes, @@ -176,7 +176,7 @@ def forward(self, emissions: Sequence, indices: CkyIndices = None) -> CkyDistrib def fit(self, emissions: Sequence, targets: Sequence, indices: CkyIndices = None) -> Tensor: dist = self.forward(emissions=emissions, indices=indices) - return dist.log_partitions - dist.log_scores(sequence=targets) + return dist.log_partitions - dist.log_scores(targets=targets) def decode(self, emissions: Sequence, indices: CkyIndices = None) -> Sequence: dist = self.forward(emissions=emissions, indices=indices) From f7659805b26997da7429f656346a71eff74cfe02 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 28 May 2023 16:35:50 +0900 Subject: [PATCH 063/102] Refactor: Remove benchmark --- benchmark/__init__.py | 0 benchmark/__main__.py | 9 ---- benchmark/crf.py | 84 ------------------------------------ benchmark/meter.py | 23 ---------- tests/test_cky.py | 6 +-- tests/test_crf.py | 6 +-- tests/test_functional.py | 2 +- third/crf.py | 5 ++- torchlatent/abc.py | 1 + torchlatent/cky.py | 18 +++----- torchlatent/crf.py | 17 +++----- torchlatent/nn/classifier.py | 3 +- torchlatent/semiring.py | 4 +- 13 files changed, 25 insertions(+), 153 deletions(-) delete mode 100644 benchmark/__init__.py delete mode 100644 benchmark/__main__.py delete mode 100644 benchmark/crf.py delete mode 100644 benchmark/meter.py diff --git a/benchmark/__init__.py b/benchmark/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/benchmark/__main__.py b/benchmark/__main__.py deleted file mode 100644 index 351d087..0000000 --- a/benchmark/__main__.py +++ /dev/null @@ -1,9 +0,0 @@ -from aku import Aku - -from benchmark.crf import benchmark_crf - -aku = Aku() - -aku.option(benchmark_crf) - -aku.run() diff --git a/benchmark/crf.py b/benchmark/crf.py deleted file mode 100644 index 971e936..0000000 --- a/benchmark/crf.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -from torchrua import pack_sequence, cat_sequence -from tqdm import tqdm - -from benchmark.meter import TimeMeter -from third.crf import CrfDecoder as ThirdPartyCrfDecoder -from torchlatent.crf import CrfLayer - - -def benchmark_crf(num_tags: int = 50, num_conjugates: int = 1, num_runs: int = 100, - batch_size: int = 32, max_token_size: int = 512): - jit1, fwd1, bwd1, dec1, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() - jit2, fwd2, bwd2, dec2, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() - jit3, fwd3, bwd3, dec3, = TimeMeter(), TimeMeter(), TimeMeter(), TimeMeter() - - if torch.cuda.is_available(): - device = torch.device('cuda:0') - else: - device = torch.device('cpu') - print(f'device => {device}') - - decoder = CrfLayer(num_targets=num_tags, num_conjugates=num_conjugates).to(device=device) - print(f'decoder => {decoder}') - - third_decoder = ThirdPartyCrfDecoder(num_tags=num_tags, num_conjugates=num_conjugates).to(device=device) - print(f'third_decoder => {third_decoder}') - - for _ in tqdm(range(num_runs)): - token_sizes = torch.randint(1, max_token_size + 1, (batch_size,), device=device).detach().cpu().tolist() - - catted_emissions = cat_sequence([ - torch.randn((token_size, num_conjugates, num_tags), device=device, requires_grad=True) - for token_size in token_sizes - ]) - catted_tags = cat_sequence([ - torch.randint(0, num_tags, (token_size, num_conjugates), device=device) - for token_size in token_sizes - ]) - - packed_emissions = pack_sequence([ - torch.randn((token_size, num_conjugates, num_tags), device=device, requires_grad=True) - for token_size in token_sizes - ]) - packed_tags = pack_sequence([ - torch.randint(0, num_tags, (token_size, num_conjugates), device=device) - for token_size in token_sizes - ]) - - with jit1: - indices = decoder.compile_indices(emissions=packed_emissions, tags=packed_tags) - - with fwd1: - loss = decoder.fit(emissions=packed_emissions, tags=packed_tags, indices=indices).neg().mean() - - with bwd1: - _, torch.autograd.grad(loss, packed_emissions.data, torch.randn_like(loss)) - - with dec1: - _ = decoder.decode(emissions=packed_emissions, indices=indices) - - with jit2: - indices = decoder.compile_indices(emissions=catted_emissions, tags=catted_tags) - - with fwd2: - loss = decoder.fit(emissions=catted_emissions, tags=catted_tags, indices=indices).neg().mean() - - with bwd2: - _, torch.autograd.grad(loss, catted_emissions.data, torch.randn_like(loss)) - - with dec2: - _ = decoder.decode(emissions=catted_emissions, indices=indices) - - with fwd3: - loss = third_decoder.fit(emissions=packed_emissions, tags=packed_tags).neg().mean() - - with bwd3: - _, torch.autograd.grad(loss, packed_emissions.data, torch.randn_like(loss)) - - with dec3: - _ = third_decoder.decode(emissions=packed_emissions) - - print(f'PackedLatent ({jit1.merit + fwd1.merit + bwd1.merit:.6f}) => {jit1} {fwd1} {bwd1} {dec1}') - print(f'CattedLatent ({jit2.merit + fwd2.merit + bwd2.merit:.6f}) => {jit2} {fwd2} {bwd2} {dec2}') - print(f'Third ({jit3.merit + fwd3.merit + bwd3.merit:.6f}) => {jit3} {fwd3} {bwd3} {dec3}') diff --git a/benchmark/meter.py b/benchmark/meter.py deleted file mode 100644 index 64e1c6e..0000000 --- a/benchmark/meter.py +++ /dev/null @@ -1,23 +0,0 @@ -from datetime import datetime - - -class TimeMeter(object): - def __init__(self) -> None: - super(TimeMeter, self).__init__() - - self.seconds = 0 - self.counts = 0 - - def __enter__(self): - self.start_tm = datetime.now() - - def __exit__(self, exc_type, exc_val, exc_tb): - self.seconds += (datetime.now() - self.start_tm).total_seconds() - self.counts += 1 - - @property - def merit(self) -> float: - return self.seconds / max(1, self.counts) - - def __repr__(self) -> str: - return f'{self.merit :.6f}' diff --git a/tests/test_cky.py b/tests/test_cky.py index 0a8a4fc..e545901 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -1,11 +1,11 @@ import torch from hypothesis import given, strategies as st from torch_struct import TreeCRF -from torchrua import pack_sequence, cat_sequence from tests.assertion import assert_close, assert_grad_close -from tests.strategy import sizes, BATCH_SIZE, TOKEN_SIZE, EMBEDDING_DIM, device, TINY_BATCH_SIZE -from torchlatent.cky import CkyDistribution, cky_partitions_indices, CkyLayer, CkyDecoder +from tests.strategy import BATCH_SIZE, device, EMBEDDING_DIM, sizes, TINY_BATCH_SIZE, TOKEN_SIZE +from torchlatent.cky import cky_partitions_indices, CkyDecoder, CkyDistribution +from torchrua import cat_sequence @given( diff --git a/tests/test_crf.py b/tests/test_crf.py index ed4becf..2d72dab 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -1,12 +1,12 @@ import torch import torchcrf from hypothesis import given -from torchrua import cat_sequence, pad_catted_indices, pack_catted_indices -from torchrua import pad_sequence, pad_packed_indices, pack_sequence from tests.assertion import assert_close, assert_grad_close -from tests.strategy import device, sizes, TOKEN_SIZE, TINY_BATCH_SIZE, NUM_CONJUGATES, TINY_TOKEN_SIZE +from tests.strategy import device, NUM_CONJUGATES, sizes, TINY_BATCH_SIZE, TINY_TOKEN_SIZE, TOKEN_SIZE from torchlatent.crf import CrfLayer +from torchrua import cat_sequence, pack_catted_indices, pack_sequence, pad_catted_indices, pad_packed_indices, \ + pad_sequence @given( diff --git a/tests/test_functional.py b/tests/test_functional.py index 5d9ce3a..f6610d6 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2,7 +2,7 @@ from hypothesis import given, strategies as st from tests.assertion import assert_close, assert_grad_close -from tests.strategy import device, sizes, TINY_TOKEN_SIZE, TINY_BATCH_SIZE +from tests.strategy import device, sizes, TINY_BATCH_SIZE, TINY_TOKEN_SIZE from torchlatent.functional import logaddexp, logsumexp diff --git a/third/crf.py b/third/crf.py index baef860..282f7a9 100644 --- a/third/crf.py +++ b/third/crf.py @@ -1,9 +1,10 @@ import torch import torchcrf -from torch import Tensor, nn +from torch import nn, Tensor from torch.nn.utils.rnn import PackedSequence from torch.types import Device -from torchrua import pad_catted_indices, pad_packed_sequence, pack_sequence + +from torchrua import pack_sequence, pad_catted_indices, pad_packed_sequence @torch.no_grad() diff --git a/torchlatent/abc.py b/torchlatent/abc.py index 0f7767f..169b9d1 100644 --- a/torchlatent/abc.py +++ b/torchlatent/abc.py @@ -7,6 +7,7 @@ from torch.distributions import Distribution from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence + from torchrua import CattedSequence Sequence = Union[CattedSequence, PackedSequence] diff --git a/torchlatent/cky.py b/torchlatent/cky.py index b47f21d..3d199e3 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -1,21 +1,18 @@ from abc import ABCMeta from functools import singledispatch -from typing import Tuple, NamedTuple, Union -from typing import Type +from typing import NamedTuple, Tuple, Type, Union import torch -from torch import Tensor -from torch import nn +from torch import nn, Tensor from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence from torch.types import Device from torchlatent.abc import DistributionABC from torchlatent.nn.classifier import BiaffineClassifier -from torchlatent.semiring import Semiring, Log, Max -from torchrua import CattedSequence, pack_catted_sequence, cat_packed_indices, RuaSequential -from torchrua import major_sizes_to_ptr, accumulate_sizes -from torchrua import pad_sequence, pad_indices +from torchlatent.semiring import Log, Max, Semiring +from torchrua import accumulate_sizes, cat_packed_indices, CattedSequence, major_sizes_to_ptr, pack_catted_sequence, \ + pad_indices, pad_sequence, RuaSequential Sequence = Union[CattedSequence, PackedSequence] @@ -93,11 +90,6 @@ def cky_partitions(data: Tensor, indices: CkyIndices, *, semiring: Type[Semiring tensor1 = torch.full(size, fill_value=semiring.zero, device=data.device, requires_grad=False) tensor2 = torch.full(size, fill_value=semiring.zero, device=data.device, requires_grad=False) - print(f'src1 => {src1}') - print(f'src2 => {src2}') - print(f'tensor0.size() => {tensor0.size()}') - print(f'tensor1.size() => {tensor1.size()}') - tensor0[src1] = data[src2] tensor1[0, :] = tensor2[-1, :] = tensor0[0, :] diff --git a/torchlatent/crf.py b/torchlatent/crf.py index f04c111..c26f9f2 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -1,22 +1,17 @@ from functools import singledispatch -from typing import NamedTuple, Union -from typing import Tuple -from typing import Type +from typing import NamedTuple, Tuple, Type, Union import torch -from torch import Tensor -from torch import nn +from torch import nn, Tensor from torch.distributions.utils import lazy_property -from torch.nn import functional as F -from torch.nn import init +from torch.nn import functional as F, init from torch.types import Device -from torchrua import CattedSequence, PackedSequence, RuaSequential -from torchrua import ReductionIndices, accumulate_sizes, minor_sizes_to_ptr -from torchrua import reduce_catted_indices, reduce_packed_indices from torchlatent.abc import DistributionABC from torchlatent.nn.classifier import Classifier -from torchlatent.semiring import Semiring, Log, Max +from torchlatent.semiring import Log, Max, Semiring +from torchrua import accumulate_sizes, CattedSequence, minor_sizes_to_ptr, PackedSequence, reduce_catted_indices, \ + reduce_packed_indices, ReductionIndices, RuaSequential Sequence = Union[CattedSequence, PackedSequence] diff --git a/torchlatent/nn/classifier.py b/torchlatent/nn/classifier.py index 92ee02b..4efe3b0 100644 --- a/torchlatent/nn/classifier.py +++ b/torchlatent/nn/classifier.py @@ -1,6 +1,5 @@ import torch -from torch import Tensor -from torch import nn +from torch import nn, Tensor from torch.nn import init diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index aa097a2..630b61b 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,8 +1,8 @@ import torch from torch import Tensor -from torchlatent.functional import logsumexp, logaddexp -from torchrua import segment_sum, segment_prod, segment_max, segment_logsumexp +from torchlatent.functional import logaddexp, logsumexp +from torchrua import segment_logsumexp, segment_max, segment_prod, segment_sum from torchrua.reduction import reduce_sequence, ReductionIndices __all__ = [ From 78481a1f366355867f5b8fbb55e87baeaff229ec Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 11 Jul 2023 14:07:19 +0900 Subject: [PATCH 064/102] Style: PEP8 them all --- .github/workflows/publish-package.yml | 28 ++++++++++++++++++++++++ .github/workflows/python-publish.yml | 31 --------------------------- .github/workflows/unit-tests.yml | 11 +++++----- README.md | 14 ++++++------ setup.py | 3 ++- tests/test_cky.py | 19 +++++++++++----- tests/test_crf.py | 20 ++++++++++++----- tests/test_functional.py | 16 +++++++++----- third/crf.py | 8 ++++--- torchlatent/abc.py | 1 - torchlatent/cky.py | 22 ++++++++++++++----- torchlatent/crf.py | 25 +++++++++++++++------ torchlatent/nn/classifier.py | 3 ++- torchlatent/semiring.py | 13 +++++++---- 14 files changed, 135 insertions(+), 79 deletions(-) create mode 100644 .github/workflows/publish-package.yml delete mode 100644 .github/workflows/python-publish.yml diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml new file mode 100644 index 0000000..c4e0338 --- /dev/null +++ b/.github/workflows/publish-package.yml @@ -0,0 +1,28 @@ +name: publish package + +on: + release: + types: [ created ] + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.8' + - name: Install dependencies + run: | + python -m pip install pip setuptools wheel --upgrade + python -m pip install twine + - name: Build and publish + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + python setup.py sdist bdist_wheel + twine upload dist/* diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml deleted file mode 100644 index 9993084..0000000 --- a/.github/workflows/python-publish.yml +++ /dev/null @@ -1,31 +0,0 @@ -# This workflows will upload a Python Package using Twine when a release is created -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries - -name: Upload Python Package - -on: - release: - types: [created] - -jobs: - deploy: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.8' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install setuptools wheel twine - - name: Build and publish - env: - TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} - TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} - run: | - python setup.py sdist bdist_wheel - twine upload dist/* diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index db1e5bc..1085139 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,6 +1,6 @@ -name: Unit Tests +name: unit tests -on: [push] +on: [ push ] jobs: build: @@ -8,16 +8,17 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: '3.8' - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install pip setuptools wheel --upgrade python -m pip install torch python -m pip install -e '.[dev]' + python -m pip install git+https://github.com/speedcell4/torchrua.git@develop --force-reinstall --no-deps - name: Test with pytest run: | python -m pytest tests \ No newline at end of file diff --git a/README.md b/README.md index 363740d..21edae0 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ ## Requirements - Python 3.8 -- PyTorch 1.10.2 +- PyTorch 2.0 ## Installation @@ -34,15 +34,15 @@ num_conjugates = 1 decoder = CrfLayer(num_targets=num_tags, num_conjugates=num_conjugates) emissions = pack_sequence([ - torch.randn((5, num_conjugates, num_tags), requires_grad=True), - torch.randn((2, num_conjugates, num_tags), requires_grad=True), - torch.randn((3, num_conjugates, num_tags), requires_grad=True), + torch.randn((5, num_conjugates, num_tags), requires_grad=True), + torch.randn((2, num_conjugates, num_tags), requires_grad=True), + torch.randn((3, num_conjugates, num_tags), requires_grad=True), ]) tags = pack_sequence([ - torch.randint(0, num_tags, (5, num_conjugates)), - torch.randint(0, num_tags, (2, num_conjugates)), - torch.randint(0, num_tags, (3, num_conjugates)), + torch.randint(0, num_tags, (5, num_conjugates)), + torch.randint(0, num_tags, (2, num_conjugates)), + torch.randint(0, num_tags, (3, num_conjugates)), ]) print(decoder.fit(emissions=emissions, tags=tags)) diff --git a/setup.py b/setup.py index 6f8a71b..5dd691b 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ -from setuptools import setup, find_packages +from setuptools import find_packages +from setuptools import setup name = 'torchlatent' diff --git a/tests/test_cky.py b/tests/test_cky.py index e545901..b223291 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -1,12 +1,21 @@ import torch -from hypothesis import given, strategies as st +from hypothesis import given +from hypothesis import strategies as st from torch_struct import TreeCRF - -from tests.assertion import assert_close, assert_grad_close -from tests.strategy import BATCH_SIZE, device, EMBEDDING_DIM, sizes, TINY_BATCH_SIZE, TOKEN_SIZE -from torchlatent.cky import cky_partitions_indices, CkyDecoder, CkyDistribution from torchrua import cat_sequence +from tests.assertion import assert_close +from tests.assertion import assert_grad_close +from tests.strategy import BATCH_SIZE +from tests.strategy import device +from tests.strategy import EMBEDDING_DIM +from tests.strategy import sizes +from tests.strategy import TINY_BATCH_SIZE +from tests.strategy import TOKEN_SIZE +from torchlatent.cky import cky_partitions_indices +from torchlatent.cky import CkyDecoder +from torchlatent.cky import CkyDistribution + @given( token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), diff --git a/tests/test_crf.py b/tests/test_crf.py index 2d72dab..0053328 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -1,12 +1,22 @@ import torch import torchcrf from hypothesis import given - -from tests.assertion import assert_close, assert_grad_close -from tests.strategy import device, NUM_CONJUGATES, sizes, TINY_BATCH_SIZE, TINY_TOKEN_SIZE, TOKEN_SIZE +from torchrua import cat_sequence +from torchrua import pack_catted_indices +from torchrua import pack_sequence +from torchrua import pad_catted_indices +from torchrua import pad_packed_indices +from torchrua import pad_sequence + +from tests.assertion import assert_close +from tests.assertion import assert_grad_close +from tests.strategy import device +from tests.strategy import NUM_CONJUGATES +from tests.strategy import sizes +from tests.strategy import TINY_BATCH_SIZE +from tests.strategy import TINY_TOKEN_SIZE +from tests.strategy import TOKEN_SIZE from torchlatent.crf import CrfLayer -from torchrua import cat_sequence, pack_catted_indices, pack_sequence, pad_catted_indices, pad_packed_indices, \ - pad_sequence @given( diff --git a/tests/test_functional.py b/tests/test_functional.py index f6610d6..5a8e05c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,9 +1,15 @@ import torch -from hypothesis import given, strategies as st - -from tests.assertion import assert_close, assert_grad_close -from tests.strategy import device, sizes, TINY_BATCH_SIZE, TINY_TOKEN_SIZE -from torchlatent.functional import logaddexp, logsumexp +from hypothesis import given +from hypothesis import strategies as st + +from tests.assertion import assert_close +from tests.assertion import assert_grad_close +from tests.strategy import device +from tests.strategy import sizes +from tests.strategy import TINY_BATCH_SIZE +from tests.strategy import TINY_TOKEN_SIZE +from torchlatent.functional import logaddexp +from torchlatent.functional import logsumexp @given( diff --git a/third/crf.py b/third/crf.py index 282f7a9..e3fa4e6 100644 --- a/third/crf.py +++ b/third/crf.py @@ -1,10 +1,12 @@ import torch import torchcrf -from torch import nn, Tensor +from torch import nn +from torch import Tensor from torch.nn.utils.rnn import PackedSequence from torch.types import Device - -from torchrua import pack_sequence, pad_catted_indices, pad_packed_sequence +from torchrua import pack_sequence +from torchrua import pad_catted_indices +from torchrua import pad_packed_sequence @torch.no_grad() diff --git a/torchlatent/abc.py b/torchlatent/abc.py index 169b9d1..0f7767f 100644 --- a/torchlatent/abc.py +++ b/torchlatent/abc.py @@ -7,7 +7,6 @@ from torch.distributions import Distribution from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence - from torchrua import CattedSequence Sequence = Union[CattedSequence, PackedSequence] diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 3d199e3..e81eab7 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -1,18 +1,30 @@ from abc import ABCMeta from functools import singledispatch -from typing import NamedTuple, Tuple, Type, Union +from typing import NamedTuple +from typing import Tuple +from typing import Type +from typing import Union import torch -from torch import nn, Tensor +from torch import nn +from torch import Tensor from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence from torch.types import Device +from torchrua import accumulate_sizes +from torchrua import cat_packed_indices +from torchrua import CattedSequence +from torchrua import major_sizes_to_ptr +from torchrua import pack_catted_sequence +from torchrua import pad_indices +from torchrua import pad_sequence +from torchrua import RuaSequential from torchlatent.abc import DistributionABC from torchlatent.nn.classifier import BiaffineClassifier -from torchlatent.semiring import Log, Max, Semiring -from torchrua import accumulate_sizes, cat_packed_indices, CattedSequence, major_sizes_to_ptr, pack_catted_sequence, \ - pad_indices, pad_sequence, RuaSequential +from torchlatent.semiring import Log +from torchlatent.semiring import Max +from torchlatent.semiring import Semiring Sequence = Union[CattedSequence, PackedSequence] diff --git a/torchlatent/crf.py b/torchlatent/crf.py index c26f9f2..46b125e 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -1,17 +1,30 @@ from functools import singledispatch -from typing import NamedTuple, Tuple, Type, Union +from typing import NamedTuple +from typing import Tuple +from typing import Type +from typing import Union import torch -from torch import nn, Tensor +from torch import nn +from torch import Tensor from torch.distributions.utils import lazy_property -from torch.nn import functional as F, init +from torch.nn import functional as F +from torch.nn import init from torch.types import Device +from torchrua import accumulate_sizes +from torchrua import CattedSequence +from torchrua import minor_sizes_to_ptr +from torchrua import PackedSequence +from torchrua import reduce_catted_indices +from torchrua import reduce_packed_indices +from torchrua import ReductionIndices +from torchrua import RuaSequential from torchlatent.abc import DistributionABC from torchlatent.nn.classifier import Classifier -from torchlatent.semiring import Log, Max, Semiring -from torchrua import accumulate_sizes, CattedSequence, minor_sizes_to_ptr, PackedSequence, reduce_catted_indices, \ - reduce_packed_indices, ReductionIndices, RuaSequential +from torchlatent.semiring import Log +from torchlatent.semiring import Max +from torchlatent.semiring import Semiring Sequence = Union[CattedSequence, PackedSequence] diff --git a/torchlatent/nn/classifier.py b/torchlatent/nn/classifier.py index 4efe3b0..a1a26ff 100644 --- a/torchlatent/nn/classifier.py +++ b/torchlatent/nn/classifier.py @@ -1,5 +1,6 @@ import torch -from torch import nn, Tensor +from torch import nn +from torch import Tensor from torch.nn import init diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 630b61b..be859f7 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,9 +1,14 @@ import torch from torch import Tensor - -from torchlatent.functional import logaddexp, logsumexp -from torchrua import segment_logsumexp, segment_max, segment_prod, segment_sum -from torchrua.reduction import reduce_sequence, ReductionIndices +from torchrua import segment_logsumexp +from torchrua import segment_max +from torchrua import segment_prod +from torchrua import segment_sum +from torchrua.reduction import reduce_sequence +from torchrua.reduction import ReductionIndices + +from torchlatent.functional import logaddexp +from torchlatent.functional import logsumexp __all__ = [ 'Semiring', 'ExceptionSemiring', From f5d3f83b78b2ba4e31cc78c8a0891bd526893af7 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 13 Aug 2023 19:31:56 +0900 Subject: [PATCH 065/102] Feat: Add abc2.py --- torchlatent/abc2.py | 63 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 torchlatent/abc2.py diff --git a/torchlatent/abc2.py b/torchlatent/abc2.py new file mode 100644 index 0000000..5c636fc --- /dev/null +++ b/torchlatent/abc2.py @@ -0,0 +1,63 @@ +from abc import ABCMeta +from typing import Tuple +from typing import Union + +import torch +import torch.autograd +from torch import Tensor +from torch.distributions.utils import lazy_property +from torch.nn.utils.rnn import PackedSequence +from torchrua import CattedSequence + +Sequence = Union[CattedSequence, PackedSequence, Tuple[Tensor, Tensor]] + + +class StructuredDistribution(object, metaclass=ABCMeta): + @property + def emissions(self) -> Sequence: + raise NotImplementedError + + @lazy_property + def log_emissions(self) -> Sequence: + return self.emissions + + @lazy_property + def max_emissions(self) -> Sequence: + return self.emissions + + def log_scores(self, targets: Sequence) -> Tensor: + raise NotImplementedError + + def log_prob(self, targets: Sequence) -> Tensor: + return self.log_scores(targets=targets) - self.log_partitions + + @lazy_property + def log_partitions(self) -> Tensor: + raise NotImplementedError + + @lazy_property + def marginals(self) -> Tensor: + emissions, *_ = self.log_emissions + grad, = torch.autograd.grad( + self.log_partitions, emissions, torch.ones_like(self.log_partitions), + create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True, + + ) + return grad + + @lazy_property + def max(self) -> Tensor: + raise NotImplementedError + + @lazy_property + def argmax(self) -> Tensor: + emissions, *_ = self.max_emissions + grad, = torch.autograd.grad( + self.max, emissions, torch.ones_like(self.max), + create_graph=False, retain_graph=False, only_inputs=True, allow_unused=True, + ) + return grad + + @lazy_property + def entropy(self) -> Tensor: + raise NotImplementedError From 1d361cff483be69e46226e2c76cd3e838d6a897c Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 13 Aug 2023 19:36:28 +0900 Subject: [PATCH 066/102] Feat: Add broadcast_devices [WIP] --- torchlatent/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 torchlatent/utils.py diff --git a/torchlatent/utils.py b/torchlatent/utils.py new file mode 100644 index 0000000..4337306 --- /dev/null +++ b/torchlatent/utils.py @@ -0,0 +1,10 @@ +from torch import Tensor +from torch.types import Device + + +def get_device(*tensors: Tensor, device: Device = None) -> Device: + raise NotImplementedError + + +def broadcast_devices(*tensors: Tensor, device: Device = None) -> Device: + raise NotImplementedError From 2fac68e4129084cac3d186436fe5d5f9c14fa7df Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 13 Aug 2023 20:29:40 +0900 Subject: [PATCH 067/102] Feat: Add crf_packed_partitions --- torchlatent/linear_crf.py | 47 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 torchlatent/linear_crf.py diff --git a/torchlatent/linear_crf.py b/torchlatent/linear_crf.py new file mode 100644 index 0000000..c171acc --- /dev/null +++ b/torchlatent/linear_crf.py @@ -0,0 +1,47 @@ +from typing import Tuple +from typing import Type +from typing import Union + +import torch +from torch import Tensor +from torch.nn.utils.rnn import PackedSequence +from torchrua import CattedSequence +from torchrua import last_packed_indices +from torchrua import pack_catted_sequence +from torchrua import pack_padded_sequence + +from torchlatent.semiring import Semiring + +Sequence = Union[CattedSequence, PackedSequence, Tuple[Tensor, Tensor]] +Transitions = Tuple[Tensor, Tensor, Tensor] + + +def crf_partitions(emissions: Sequence, transitions: Transitions, semiring: Type[Semiring]) -> Tensor: + if isinstance(emissions, CattedSequence): + emissions = pack_catted_sequence(emissions) + elif isinstance(emissions, tuple) and len(emissions) == 2: + emissions = pack_padded_sequence(*emissions, batch_first=True) + + assert isinstance(emissions, PackedSequence) + + transitions, head_transitions, last_transitions = transitions + emissions, batch_sizes, _, unsorted_indices = emissions + + last_indices = last_packed_indices( + batch_sizes=batch_sizes, + unsorted_indices=unsorted_indices, + device=emissions.device, + ) + + _, *batch_sizes = sections = batch_sizes.detach().cpu().tolist() + emission, *emissions = torch.split(emissions, sections, dim=0) + + chunks = [semiring.mul(head_transitions, emission)] + for emission, batch_size in zip(emissions, batch_sizes): + chunks.append(semiring.mul( + semiring.bmm(chunks[-1][:batch_size], transitions), + emission, + )) + + emission = torch.cat(chunks, dim=0)[last_indices] + return semiring.sum(semiring.mul(emission, last_transitions), dim=-1) From 4495276b2d0ba94892509fc1259b7307b8524f4c Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 13 Aug 2023 21:39:32 +0900 Subject: [PATCH 068/102] Test: Add unit test for crf_partitions --- tests/test_linear_crf.py | 96 ++++++++++++++++++++++++++++++++++++++++ torchlatent/utils.py | 10 ----- 2 files changed, 96 insertions(+), 10 deletions(-) create mode 100644 tests/test_linear_crf.py delete mode 100644 torchlatent/utils.py diff --git a/tests/test_linear_crf.py b/tests/test_linear_crf.py new file mode 100644 index 0000000..a576e66 --- /dev/null +++ b/tests/test_linear_crf.py @@ -0,0 +1,96 @@ +import torch +from hypothesis import given +from torch.testing import assert_close +from torchcrf import CRF +from torchnyan import BATCH_SIZE +from torchnyan import TOKEN_SIZE +from torchnyan import assert_grad_close +from torchnyan import device +from torchnyan import sizes +from torchrua import cat_sequence +from torchrua import pack_sequence +from torchrua import pad_sequence + +from torchlatent.linear_crf import crf_partitions +from torchlatent.semiring import Log + + +@given( + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_targets=sizes(TOKEN_SIZE), +) +def test_crf_catted_partitions(token_sizes, num_targets): + inputs = [ + torch.randn((token_size, num_targets), device=device, requires_grad=True) + for token_size in token_sizes + ] + + excepted_emissions, token_sizes = pad_sequence(inputs, batch_first=True) + index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) + mask = index[None, :] < token_sizes[:, None] + + excepted_crf = CRF(num_tags=num_targets, batch_first=False) + excepted = excepted_crf._compute_normalizer(excepted_emissions.transpose(0, 1), mask.t()) + + actual = crf_partitions( + emissions=cat_sequence(inputs), + transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), + semiring=Log, + ) + + assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) + assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) + + +@given( + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_targets=sizes(TOKEN_SIZE), +) +def test_crf_packed_partitions(token_sizes, num_targets): + inputs = [ + torch.randn((token_size, num_targets), device=device, requires_grad=True) + for token_size in token_sizes + ] + + excepted_emissions, token_sizes = pad_sequence(inputs, batch_first=True) + index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) + mask = index[None, :] < token_sizes[:, None] + + excepted_crf = CRF(num_tags=num_targets, batch_first=False) + excepted = excepted_crf._compute_normalizer(excepted_emissions.transpose(0, 1), mask.t()) + + actual = crf_partitions( + emissions=pack_sequence(inputs), + transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), + semiring=Log, + ) + + assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) + assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) + + +@given( + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_targets=sizes(TOKEN_SIZE), +) +def test_crf_padded_partitions(token_sizes, num_targets): + inputs = [ + torch.randn((token_size, num_targets), device=device, requires_grad=True) + for token_size in token_sizes + ] + + excepted_emissions, token_sizes = pad_sequence(inputs, batch_first=True) + index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) + mask = index[None, :] < token_sizes[:, None] + + excepted_crf = CRF(num_tags=num_targets, batch_first=False) + excepted = excepted_crf._compute_normalizer(excepted_emissions.transpose(0, 1), mask.t()) + + actual = crf_partitions( + emissions=pad_sequence(inputs, batch_first=True), + transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), + semiring=Log, + ) + + assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) + assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) diff --git a/torchlatent/utils.py b/torchlatent/utils.py deleted file mode 100644 index 4337306..0000000 --- a/torchlatent/utils.py +++ /dev/null @@ -1,10 +0,0 @@ -from torch import Tensor -from torch.types import Device - - -def get_device(*tensors: Tensor, device: Device = None) -> Device: - raise NotImplementedError - - -def broadcast_devices(*tensors: Tensor, device: Device = None) -> Device: - raise NotImplementedError From 54db3b8e21889136fa85a2f8acb9b0ccc2d6487f Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 13 Aug 2023 21:47:09 +0900 Subject: [PATCH 069/102] Feat: Add CrfDecoder --- torchlatent/abc2.py | 10 ++----- torchlatent/linear_crf.py | 63 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/torchlatent/abc2.py b/torchlatent/abc2.py index 5c636fc..72603a3 100644 --- a/torchlatent/abc2.py +++ b/torchlatent/abc2.py @@ -13,10 +13,6 @@ class StructuredDistribution(object, metaclass=ABCMeta): - @property - def emissions(self) -> Sequence: - raise NotImplementedError - @lazy_property def log_emissions(self) -> Sequence: return self.emissions @@ -37,9 +33,8 @@ def log_partitions(self) -> Tensor: @lazy_property def marginals(self) -> Tensor: - emissions, *_ = self.log_emissions grad, = torch.autograd.grad( - self.log_partitions, emissions, torch.ones_like(self.log_partitions), + self.log_partitions, self.emissions, torch.ones_like(self.log_partitions), create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True, ) @@ -51,9 +46,8 @@ def max(self) -> Tensor: @lazy_property def argmax(self) -> Tensor: - emissions, *_ = self.max_emissions grad, = torch.autograd.grad( - self.max, emissions, torch.ones_like(self.max), + self.max, self.emissions, torch.ones_like(self.max), create_graph=False, retain_graph=False, only_inputs=True, allow_unused=True, ) return grad diff --git a/torchlatent/linear_crf.py b/torchlatent/linear_crf.py index c171acc..ec19d7f 100644 --- a/torchlatent/linear_crf.py +++ b/torchlatent/linear_crf.py @@ -1,18 +1,22 @@ from typing import Tuple from typing import Type -from typing import Union import torch from torch import Tensor +from torch import nn +from torch.nn import init from torch.nn.utils.rnn import PackedSequence from torchrua import CattedSequence from torchrua import last_packed_indices from torchrua import pack_catted_sequence from torchrua import pack_padded_sequence +from torchlatent.abc2 import Sequence +from torchlatent.abc2 import StructuredDistribution +from torchlatent.semiring import Log +from torchlatent.semiring import Max from torchlatent.semiring import Semiring -Sequence = Union[CattedSequence, PackedSequence, Tuple[Tensor, Tensor]] Transitions = Tuple[Tensor, Tensor, Tensor] @@ -45,3 +49,58 @@ def crf_partitions(emissions: Sequence, transitions: Transitions, semiring: Type emission = torch.cat(chunks, dim=0)[last_indices] return semiring.sum(semiring.mul(emission, last_transitions), dim=-1) + + +class CrfDistribution(StructuredDistribution): + def __init__(self, emissions: Sequence, transitions: Transitions) -> None: + super(CrfDistribution, self).__init__() + self.emissions = emissions + self.transitions = transitions + + def log_scores(self, targets: Sequence) -> Tensor: + raise NotImplementedError + + def log_partitions(self) -> Tensor: + return crf_partitions( + emissions=self.emissions, + transitions=self.transitions, + semiring=Log, + ) + + def max(self) -> Tensor: + return crf_partitions( + emissions=self.emissions, + transitions=self.transitions, + semiring=Max, + ) + + +class CrfDecoder(nn.Module): + def __init__(self, num_targets: int) -> None: + super(CrfDecoder, self).__init__() + + self.num_targets = num_targets + + self.transitions = nn.Parameter(torch.empty((num_targets, num_targets))) + self.head_transitions = nn.Parameter(torch.empty((num_targets,))) + self.last_transitions = nn.Parameter(torch.empty((num_targets,))) + + self.reset_parameters() + + def reset_parameters(self) -> None: + init.zeros_(self.transitions) + init.zeros_(self.head_transitions) + init.zeros_(self.last_transitions) + + def extra_repr(self) -> str: + return f'num_targets={self.num_targets}' + + def forward(self, emissions: Sequence) -> CrfDistribution: + return CrfDistribution( + emissions=emissions, + transitions=( + self.transitions, + self.head_transitions, + self.last_transitions, + ), + ) From 4306060aef7b3bdebc84b036621d51e911a01581 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 14 Aug 2023 20:44:58 +0900 Subject: [PATCH 070/102] Feat: Use torchnyan --- .github/workflows/unit-tests.yml | 2 +- setup.py | 10 +------- tests/assertion.py | 43 -------------------------------- tests/strategy.py | 39 ----------------------------- tests/test_cky.py | 22 ++++++++-------- tests/test_crf.py | 21 ++++++++-------- tests/test_functional.py | 12 ++++----- tests/test_linear_crf.py | 2 +- torchlatent/linear_crf.py | 10 ++++---- 9 files changed, 35 insertions(+), 126 deletions(-) delete mode 100644 tests/assertion.py delete mode 100644 tests/strategy.py diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 1085139..b6df8b1 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -17,7 +17,7 @@ jobs: run: | python -m pip install pip setuptools wheel --upgrade python -m pip install torch - python -m pip install -e '.[dev]' + python -m pip install pytest hypothesis torchnyan pytorch-crf python -m pip install git+https://github.com/speedcell4/torchrua.git@develop --force-reinstall --no-deps - name: Test with pytest run: | diff --git a/setup.py b/setup.py index 5dd691b..841e603 100644 --- a/setup.py +++ b/setup.py @@ -15,14 +15,6 @@ python_requires='>=3.8', install_requires=[ 'numpy', - 'torchrua>=0.4.0', + 'torchrua', ], - extras_require={ - 'dev': [ - 'einops', - 'pytest', - 'hypothesis', - 'pytorch-crf', - ], - } ) diff --git a/tests/assertion.py b/tests/assertion.py deleted file mode 100644 index f5d8664..0000000 --- a/tests/assertion.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch -from torch import Tensor -from torch.nn.utils.rnn import PackedSequence -from torch.testing import assert_close - -from torchrua.catting import CattedSequence - -__all__ = [ - 'assert_close', - 'assert_grad_close', - 'assert_catted_sequence_close', - 'assert_packed_sequence_close', -] - - -def assert_grad_close(actual: Tensor, expected: Tensor, inputs, **kwargs) -> None: - grad = torch.rand_like(actual) - - actual_grads = torch.autograd.grad(actual, inputs, grad, retain_graph=True, allow_unused=False) - expected_grads = torch.autograd.grad(expected, inputs, grad, retain_graph=True, allow_unused=False) - - for actual_grad, expected_grad in zip(actual_grads, expected_grads): - assert_close(actual=actual_grad, expected=expected_grad, **kwargs) - - -def assert_catted_sequence_close(actual: CattedSequence, expected: CattedSequence, **kwargs) -> None: - assert_close(actual=actual.data, expected=expected.data, **kwargs) - assert_close(actual=actual.token_sizes, expected=expected.token_sizes, **kwargs) - - -def assert_packed_sequence_close(actual: PackedSequence, expected: PackedSequence, **kwargs) -> None: - assert_close(actual=actual.data, expected=expected.data, **kwargs) - assert_close(actual=actual.batch_sizes, expected=expected.batch_sizes, **kwargs) - - if actual.sorted_indices is None: - assert expected.sorted_indices is None - else: - assert_close(actual=actual.sorted_indices, expected=expected.sorted_indices, **kwargs) - - if actual.unsorted_indices is None: - assert expected.unsorted_indices is None - else: - assert_close(actual=actual.unsorted_indices, expected=expected.unsorted_indices, **kwargs) diff --git a/tests/strategy.py b/tests/strategy.py deleted file mode 100644 index 19b90fc..0000000 --- a/tests/strategy.py +++ /dev/null @@ -1,39 +0,0 @@ -import torch -from hypothesis import strategies as st - -TINY_BATCH_SIZE = 5 -TINY_TOKEN_SIZE = 11 -TINY_EMBEDDING_DIM = 13 -NUM_CONJUGATES = 5 -NUM_TAGS = 7 - -if torch.cuda.is_available(): - BATCH_SIZE = 53 - TOKEN_SIZE = 83 - EMBEDDING_DIM = 107 - NUM_CONJUGATES = 5 - NUM_TAGS = 17 -else: - BATCH_SIZE = 37 - TOKEN_SIZE = 53 - EMBEDDING_DIM = 61 - NUM_CONJUGATES = 5 - NUM_TAGS = 17 - -if torch.cuda.is_available(): - device = torch.device('cuda:0') -else: - device = torch.device('cpu') - -torch.empty((1,), device=device) - - -@st.composite -def sizes(draw, *size: int, min_size: int = 1): - max_size, *size = size - n = draw(st.integers(min_value=min_size, max_value=max_size)) - - if len(size) == 0: - return n - else: - return draw(st.lists(sizes(*size), min_size=n, max_size=n)) diff --git a/tests/test_cky.py b/tests/test_cky.py index b223291..6e78125 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -2,24 +2,24 @@ from hypothesis import given from hypothesis import strategies as st from torch_struct import TreeCRF +from torchnyan.assertion import assert_close +from torchnyan.assertion import assert_grad_close +from torchnyan.strategy import BATCH_SIZE +from torchnyan.strategy import FEATURE_DIM +from torchnyan.strategy import TINY_BATCH_SIZE +from torchnyan.strategy import TOKEN_SIZE +from torchnyan.strategy import device +from torchnyan.strategy import sizes from torchrua import cat_sequence -from tests.assertion import assert_close -from tests.assertion import assert_grad_close -from tests.strategy import BATCH_SIZE -from tests.strategy import device -from tests.strategy import EMBEDDING_DIM -from tests.strategy import sizes -from tests.strategy import TINY_BATCH_SIZE -from tests.strategy import TOKEN_SIZE -from torchlatent.cky import cky_partitions_indices from torchlatent.cky import CkyDecoder from torchlatent.cky import CkyDistribution +from torchlatent.cky import cky_partitions_indices @given( token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - embedding_dim=sizes(EMBEDDING_DIM), + embedding_dim=sizes(FEATURE_DIM), num_tags=sizes(TOKEN_SIZE), dropout=st.floats(0, 1), ) @@ -45,7 +45,7 @@ def test_cky_catted_max(token_sizes, embedding_dim, num_tags, dropout): # @given( # token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), -# embedding_dim=sizes(EMBEDDING_DIM), +# embedding_dim=sizes(FEATURE_DIM), # num_tags=sizes(TOKEN_SIZE), # bias=st.booleans(), # ) diff --git a/tests/test_crf.py b/tests/test_crf.py index 0053328..b5dbab9 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -1,6 +1,13 @@ import torch import torchcrf from hypothesis import given +from torchnyan.assertion import assert_close +from torchnyan.assertion import assert_grad_close +from torchnyan.strategy import TINY_BATCH_SIZE +from torchnyan.strategy import TINY_TOKEN_SIZE +from torchnyan.strategy import TOKEN_SIZE +from torchnyan.strategy import device +from torchnyan.strategy import sizes from torchrua import cat_sequence from torchrua import pack_catted_indices from torchrua import pack_sequence @@ -8,14 +15,6 @@ from torchrua import pad_packed_indices from torchrua import pad_sequence -from tests.assertion import assert_close -from tests.assertion import assert_grad_close -from tests.strategy import device -from tests.strategy import NUM_CONJUGATES -from tests.strategy import sizes -from tests.strategy import TINY_BATCH_SIZE -from tests.strategy import TINY_TOKEN_SIZE -from tests.strategy import TOKEN_SIZE from torchlatent.crf import CrfLayer @@ -294,7 +293,7 @@ def test_crf_packed_decode(token_sizes, num_tags): @given( token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), - num_conjugates=sizes(NUM_CONJUGATES), + num_conjugates=sizes(TOKEN_SIZE), num_tags=sizes(TINY_TOKEN_SIZE), ) def test_conjugated_catted_fit(token_sizes, num_conjugates, num_tags): @@ -343,7 +342,7 @@ def test_conjugated_catted_fit(token_sizes, num_conjugates, num_tags): @given( token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), - num_conjugates=sizes(NUM_CONJUGATES), + num_conjugates=sizes(TOKEN_SIZE), num_tags=sizes(TINY_TOKEN_SIZE), ) def test_conjugated_packed_fit(token_sizes, num_conjugates, num_tags): @@ -392,7 +391,7 @@ def test_conjugated_packed_fit(token_sizes, num_conjugates, num_tags): @given( token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), - num_conjugates=sizes(NUM_CONJUGATES), + num_conjugates=sizes(TOKEN_SIZE), num_tags=sizes(TINY_TOKEN_SIZE), ) def test_dynamic_fit(token_sizes, num_conjugates, num_tags): diff --git a/tests/test_functional.py b/tests/test_functional.py index 5a8e05c..1af9370 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,13 +1,13 @@ import torch from hypothesis import given from hypothesis import strategies as st +from torchnyan.assertion import assert_close +from torchnyan.assertion import assert_grad_close +from torchnyan.strategy import TINY_BATCH_SIZE +from torchnyan.strategy import TINY_TOKEN_SIZE +from torchnyan.strategy import device +from torchnyan.strategy import sizes -from tests.assertion import assert_close -from tests.assertion import assert_grad_close -from tests.strategy import device -from tests.strategy import sizes -from tests.strategy import TINY_BATCH_SIZE -from tests.strategy import TINY_TOKEN_SIZE from torchlatent.functional import logaddexp from torchlatent.functional import logsumexp diff --git a/tests/test_linear_crf.py b/tests/test_linear_crf.py index a576e66..1934e77 100644 --- a/tests/test_linear_crf.py +++ b/tests/test_linear_crf.py @@ -1,9 +1,9 @@ import torch from hypothesis import given -from torch.testing import assert_close from torchcrf import CRF from torchnyan import BATCH_SIZE from torchnyan import TOKEN_SIZE +from torchnyan import assert_close from torchnyan import assert_grad_close from torchnyan import device from torchnyan import sizes diff --git a/torchlatent/linear_crf.py b/torchlatent/linear_crf.py index ec19d7f..8573577 100644 --- a/torchlatent/linear_crf.py +++ b/torchlatent/linear_crf.py @@ -28,8 +28,8 @@ def crf_partitions(emissions: Sequence, transitions: Transitions, semiring: Type assert isinstance(emissions, PackedSequence) - transitions, head_transitions, last_transitions = transitions emissions, batch_sizes, _, unsorted_indices = emissions + transitions, head_transitions, last_transitions = transitions last_indices = last_packed_indices( batch_sizes=batch_sizes, @@ -40,14 +40,14 @@ def crf_partitions(emissions: Sequence, transitions: Transitions, semiring: Type _, *batch_sizes = sections = batch_sizes.detach().cpu().tolist() emission, *emissions = torch.split(emissions, sections, dim=0) - chunks = [semiring.mul(head_transitions, emission)] + charts = [semiring.mul(head_transitions, emission)] for emission, batch_size in zip(emissions, batch_sizes): - chunks.append(semiring.mul( - semiring.bmm(chunks[-1][:batch_size], transitions), + charts.append(semiring.mul( + semiring.bmm(charts[-1][:batch_size], transitions), emission, )) - emission = torch.cat(chunks, dim=0)[last_indices] + emission = torch.cat(charts, dim=0)[last_indices] return semiring.sum(semiring.mul(emission, last_transitions), dim=-1) From 349eef43eb82a525e95e12ec69e8c758d9e1f4b1 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 21 Aug 2023 20:07:22 +0900 Subject: [PATCH 071/102] Feat: Add crf_scores --- tests/test_linear_crf.py | 39 +++++++++++++++++++++++++++++++++ torchlatent/linear_crf.py | 45 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/tests/test_linear_crf.py b/tests/test_linear_crf.py index 1934e77..14d5c04 100644 --- a/tests/test_linear_crf.py +++ b/tests/test_linear_crf.py @@ -12,9 +12,48 @@ from torchrua import pad_sequence from torchlatent.linear_crf import crf_partitions +from torchlatent.linear_crf import crf_scores from torchlatent.semiring import Log +@given( + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_targets=sizes(TOKEN_SIZE), +) +def test_crf_catted_partitions(token_sizes, num_targets): + emissions = [ + torch.randn((token_size, num_targets), device=device, requires_grad=True) + for token_size in token_sizes + ] + + tags = [ + torch.randint(0, num_targets, (token_size,), device=device) + for token_size in token_sizes + ] + + excepted_emissions, token_sizes = pad_sequence(emissions, batch_first=True) + index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) + mask = index[None, :] < token_sizes[:, None] + + excepted_tags, _ = pad_sequence(tags, batch_first=True) + + excepted_crf = CRF(num_tags=num_targets, batch_first=False) + excepted = excepted_crf._compute_score( + excepted_emissions.transpose(0, 1), + excepted_tags.transpose(0, 1), mask.t(), + ) + + actual = crf_scores( + emissions=pad_sequence(emissions, batch_first=True), + targets=cat_sequence(tags), + transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), + semiring=Log, + ) + + assert_close(actual=actual, expected=excepted) + assert_grad_close(actual=actual, expected=excepted, inputs=emissions) + + @given( token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), num_targets=sizes(TOKEN_SIZE), diff --git a/torchlatent/linear_crf.py b/torchlatent/linear_crf.py index 8573577..5da4b57 100644 --- a/torchlatent/linear_crf.py +++ b/torchlatent/linear_crf.py @@ -4,12 +4,18 @@ import torch from torch import Tensor from torch import nn +from torch.distributions.utils import lazy_property from torch.nn import init from torch.nn.utils.rnn import PackedSequence from torchrua import CattedSequence +from torchrua import head_catted_sequence +from torchrua import last_catted_sequence from torchrua import last_packed_indices from torchrua import pack_catted_sequence from torchrua import pack_padded_sequence +from torchrua import pad_catted_sequence +from torchrua import pad_indices +from torchrua import pad_packed_sequence from torchlatent.abc2 import Sequence from torchlatent.abc2 import StructuredDistribution @@ -20,6 +26,34 @@ Transitions = Tuple[Tensor, Tensor, Tensor] +def crf_scores(emissions: Sequence, targets: Sequence, + transitions: Transitions, semiring: Type[Semiring]) -> Tensor: + if isinstance(emissions, CattedSequence): + emissions, _ = pad_catted_sequence(emissions, batch_first=True) + elif isinstance(emissions, PackedSequence): + emissions, _ = pad_packed_sequence(emissions, batch_first=True) + else: + emissions, _ = emissions + + transitions, head_transitions, last_transitions = transitions + + head_transitions = head_transitions[head_catted_sequence(targets)] + last_transitions = last_transitions[last_catted_sequence(targets)] + transitions = transitions[targets.data.roll(1, dims=[0]), targets.data] + + _, (batch_ptr, token_ptr), _ = pad_indices(targets, batch_first=True) + emissions = emissions[batch_ptr, token_ptr, targets.data] + emissions = semiring.segment_prod(emissions, sizes=targets.token_sizes) + + token_sizes = torch.stack([torch.ones_like(targets.token_sizes), targets.token_sizes - 1], dim=-1) + transitions = semiring.segment_prod(transitions, sizes=token_sizes.view(-1))[1::2] + + return semiring.mul( + semiring.mul(head_transitions, last_transitions), + semiring.mul(emissions, transitions), + ) + + def crf_partitions(emissions: Sequence, transitions: Transitions, semiring: Type[Semiring]) -> Tensor: if isinstance(emissions, CattedSequence): emissions = pack_catted_sequence(emissions) @@ -58,8 +92,14 @@ def __init__(self, emissions: Sequence, transitions: Transitions) -> None: self.transitions = transitions def log_scores(self, targets: Sequence) -> Tensor: - raise NotImplementedError + return crf_scores( + emissions=self.emissions, + targets=targets, + transitions=self.transitions, + semiring=Log, + ) + @lazy_property def log_partitions(self) -> Tensor: return crf_partitions( emissions=self.emissions, @@ -67,6 +107,7 @@ def log_partitions(self) -> Tensor: semiring=Log, ) + @lazy_property def max(self) -> Tensor: return crf_partitions( emissions=self.emissions, @@ -76,7 +117,7 @@ def max(self) -> Tensor: class CrfDecoder(nn.Module): - def __init__(self, num_targets: int) -> None: + def __init__(self, *, num_targets: int) -> None: super(CrfDecoder, self).__init__() self.num_targets = num_targets From d3998df6311b0ed30b7af1bfde6997073b18eb6f Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 00:48:42 +0900 Subject: [PATCH 072/102] Feat: Update crf_scores --- tests/test_linear_crf.py | 196 ++++++++++++++++----------------- torchlatent/abc2.py | 33 ++---- torchlatent/linear_crf.py | 226 +++++++++++++++++--------------------- torchlatent/semiring.py | 12 +- 4 files changed, 215 insertions(+), 252 deletions(-) diff --git a/tests/test_linear_crf.py b/tests/test_linear_crf.py index 14d5c04..9d92d11 100644 --- a/tests/test_linear_crf.py +++ b/tests/test_linear_crf.py @@ -1,5 +1,6 @@ import torch from hypothesis import given +from hypothesis import strategies as st from torchcrf import CRF from torchnyan import BATCH_SIZE from torchnyan import TOKEN_SIZE @@ -7,129 +8,128 @@ from torchnyan import assert_grad_close from torchnyan import device from torchnyan import sizes -from torchrua import cat_sequence -from torchrua import pack_sequence -from torchrua import pad_sequence -from torchlatent.linear_crf import crf_partitions from torchlatent.linear_crf import crf_scores from torchlatent.semiring import Log +from torchrua import cat_sequence +from torchrua import pack_sequence +from torchrua import pad_sequence @given( token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), num_targets=sizes(TOKEN_SIZE), + rua_emissions=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), + rua_targets=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), ) -def test_crf_catted_partitions(token_sizes, num_targets): - emissions = [ +def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): + inputs = [ torch.randn((token_size, num_targets), device=device, requires_grad=True) for token_size in token_sizes ] - tags = [ + targets = [ torch.randint(0, num_targets, (token_size,), device=device) for token_size in token_sizes ] - excepted_emissions, token_sizes = pad_sequence(emissions, batch_first=True) - index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) - mask = index[None, :] < token_sizes[:, None] + excepted_crf = CRF(num_tags=num_targets, batch_first=False) - excepted_tags, _ = pad_sequence(tags, batch_first=True) + excepted_emissions = pad_sequence(inputs) + excepted_tags = pad_sequence(targets) - excepted_crf = CRF(num_tags=num_targets, batch_first=False) excepted = excepted_crf._compute_score( - excepted_emissions.transpose(0, 1), - excepted_tags.transpose(0, 1), mask.t(), + excepted_emissions.data.transpose(0, 1), + excepted_tags.data.transpose(0, 1), + excepted_emissions.mask().transpose(0, 1), ) actual = crf_scores( - emissions=pad_sequence(emissions, batch_first=True), - targets=cat_sequence(tags), + emissions=rua_emissions(inputs), + targets=rua_targets(targets), transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), semiring=Log, ) assert_close(actual=actual, expected=excepted) - assert_grad_close(actual=actual, expected=excepted, inputs=emissions) - - -@given( - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_targets=sizes(TOKEN_SIZE), -) -def test_crf_catted_partitions(token_sizes, num_targets): - inputs = [ - torch.randn((token_size, num_targets), device=device, requires_grad=True) - for token_size in token_sizes - ] - - excepted_emissions, token_sizes = pad_sequence(inputs, batch_first=True) - index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) - mask = index[None, :] < token_sizes[:, None] - - excepted_crf = CRF(num_tags=num_targets, batch_first=False) - excepted = excepted_crf._compute_normalizer(excepted_emissions.transpose(0, 1), mask.t()) - - actual = crf_partitions( - emissions=cat_sequence(inputs), - transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), - semiring=Log, - ) - - assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) - assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) - - -@given( - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_targets=sizes(TOKEN_SIZE), -) -def test_crf_packed_partitions(token_sizes, num_targets): - inputs = [ - torch.randn((token_size, num_targets), device=device, requires_grad=True) - for token_size in token_sizes - ] - - excepted_emissions, token_sizes = pad_sequence(inputs, batch_first=True) - index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) - mask = index[None, :] < token_sizes[:, None] - - excepted_crf = CRF(num_tags=num_targets, batch_first=False) - excepted = excepted_crf._compute_normalizer(excepted_emissions.transpose(0, 1), mask.t()) - - actual = crf_partitions( - emissions=pack_sequence(inputs), - transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), - semiring=Log, - ) - - assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) - assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) - - -@given( - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_targets=sizes(TOKEN_SIZE), -) -def test_crf_padded_partitions(token_sizes, num_targets): - inputs = [ - torch.randn((token_size, num_targets), device=device, requires_grad=True) - for token_size in token_sizes - ] - - excepted_emissions, token_sizes = pad_sequence(inputs, batch_first=True) - index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) - mask = index[None, :] < token_sizes[:, None] - - excepted_crf = CRF(num_tags=num_targets, batch_first=False) - excepted = excepted_crf._compute_normalizer(excepted_emissions.transpose(0, 1), mask.t()) - - actual = crf_partitions( - emissions=pad_sequence(inputs, batch_first=True), - transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), - semiring=Log, - ) - - assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) - assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) + assert_grad_close(actual=actual, expected=excepted, inputs=inputs) + +# @given( +# token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), +# num_targets=sizes(TOKEN_SIZE), +# ) +# def test_crf_catted_partitions(token_sizes, num_targets): +# inputs = [ +# torch.randn((token_size, num_targets), device=device, requires_grad=True) +# for token_size in token_sizes +# ] +# +# excepted_emissions, token_sizes = pad_sequence(inputs) +# index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) +# mask = index[None, :] < token_sizes[:, None] +# +# excepted_crf = CRF(num_tags=num_targets, batch_first=False) +# excepted = excepted_crf._compute_normalizer(excepted_emissions.transpose(0, 1), mask.t()) +# +# actual = crf_partitions( +# emissions=cat_sequence(inputs), +# transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), +# semiring=Log, +# ) +# +# assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) +# assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) +# +# +# @given( +# token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), +# num_targets=sizes(TOKEN_SIZE), +# ) +# def test_crf_packed_partitions(token_sizes, num_targets): +# inputs = [ +# torch.randn((token_size, num_targets), device=device, requires_grad=True) +# for token_size in token_sizes +# ] +# +# excepted_emissions, token_sizes = pad_sequence(inputs) +# index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) +# mask = index[None, :] < token_sizes[:, None] +# +# excepted_crf = CRF(num_tags=num_targets, batch_first=False) +# excepted = excepted_crf._compute_normalizer(excepted_emissions.transpose(0, 1), mask.t()) +# +# actual = crf_partitions( +# emissions=pack_sequence(inputs), +# transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), +# semiring=Log, +# ) +# +# assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) +# assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) +# +# +# @given( +# token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), +# num_targets=sizes(TOKEN_SIZE), +# ) +# def test_crf_padded_partitions(token_sizes, num_targets): +# inputs = [ +# torch.randn((token_size, num_targets), device=device, requires_grad=True) +# for token_size in token_sizes +# ] +# +# excepted_emissions, token_sizes = pad_sequence(inputs) +# index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) +# mask = index[None, :] < token_sizes[:, None] +# +# excepted_crf = CRF(num_tags=num_targets, batch_first=False) +# excepted = excepted_crf._compute_normalizer(excepted_emissions.transpose(0, 1), mask.t()) +# +# actual = crf_partitions( +# emissions=pad_sequence(inputs), +# transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), +# semiring=Log, +# ) +# +# assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) +# assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) diff --git a/torchlatent/abc2.py b/torchlatent/abc2.py index 72603a3..70ed998 100644 --- a/torchlatent/abc2.py +++ b/torchlatent/abc2.py @@ -1,30 +1,25 @@ from abc import ABCMeta -from typing import Tuple from typing import Union import torch import torch.autograd from torch import Tensor from torch.distributions.utils import lazy_property -from torch.nn.utils.rnn import PackedSequence -from torchrua import CattedSequence -Sequence = Union[CattedSequence, PackedSequence, Tuple[Tensor, Tensor]] +from torchrua import C +from torchrua import D +from torchrua import P class StructuredDistribution(object, metaclass=ABCMeta): - @lazy_property - def log_emissions(self) -> Sequence: - return self.emissions - - @lazy_property - def max_emissions(self) -> Sequence: - return self.emissions + def __init__(self, emissions: Union[C, D, P]) -> None: + super(StructuredDistribution, self).__init__() + self.emissions = emissions - def log_scores(self, targets: Sequence) -> Tensor: + def log_scores(self, targets: Union[C, D, P]) -> Tensor: raise NotImplementedError - def log_prob(self, targets: Sequence) -> Tensor: + def log_probs(self, targets: Union[C, D, P]) -> Tensor: return self.log_scores(targets=targets) - self.log_partitions @lazy_property @@ -34,7 +29,7 @@ def log_partitions(self) -> Tensor: @lazy_property def marginals(self) -> Tensor: grad, = torch.autograd.grad( - self.log_partitions, self.emissions, torch.ones_like(self.log_partitions), + self.log_partitions, self.emissions.data, torch.ones_like(self.log_partitions), create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True, ) @@ -45,13 +40,9 @@ def max(self) -> Tensor: raise NotImplementedError @lazy_property - def argmax(self) -> Tensor: + def argmax(self) -> Union[C, D, P]: grad, = torch.autograd.grad( - self.max, self.emissions, torch.ones_like(self.max), + self.max, self.emissions.data, torch.ones_like(self.max), create_graph=False, retain_graph=False, only_inputs=True, allow_unused=True, ) - return grad - - @lazy_property - def entropy(self) -> Tensor: - raise NotImplementedError + return self.emissions._replace(data=grad.argmax(dim=-1)) diff --git a/torchlatent/linear_crf.py b/torchlatent/linear_crf.py index 5da4b57..b7ae90a 100644 --- a/torchlatent/linear_crf.py +++ b/torchlatent/linear_crf.py @@ -1,48 +1,27 @@ from typing import Tuple from typing import Type +from typing import Union import torch from torch import Tensor -from torch import nn -from torch.distributions.utils import lazy_property -from torch.nn import init -from torch.nn.utils.rnn import PackedSequence -from torchrua import CattedSequence -from torchrua import head_catted_sequence -from torchrua import last_catted_sequence -from torchrua import last_packed_indices -from torchrua import pack_catted_sequence -from torchrua import pack_padded_sequence -from torchrua import pad_catted_sequence -from torchrua import pad_indices -from torchrua import pad_packed_sequence -from torchlatent.abc2 import Sequence -from torchlatent.abc2 import StructuredDistribution -from torchlatent.semiring import Log -from torchlatent.semiring import Max from torchlatent.semiring import Semiring +from torchrua import C +from torchrua import D +from torchrua import P -Transitions = Tuple[Tensor, Tensor, Tensor] +T = Tuple[Tensor, Tensor, Tensor] -def crf_scores(emissions: Sequence, targets: Sequence, - transitions: Transitions, semiring: Type[Semiring]) -> Tensor: - if isinstance(emissions, CattedSequence): - emissions, _ = pad_catted_sequence(emissions, batch_first=True) - elif isinstance(emissions, PackedSequence): - emissions, _ = pad_packed_sequence(emissions, batch_first=True) - else: - emissions, _ = emissions - +def crf_scores(emissions: Union[C, D, P], targets: Union[C, D, P], transitions: T, semiring: Type[Semiring]) -> Tensor: transitions, head_transitions, last_transitions = transitions - head_transitions = head_transitions[head_catted_sequence(targets)] - last_transitions = last_transitions[last_catted_sequence(targets)] - transitions = transitions[targets.data.roll(1, dims=[0]), targets.data] + targets = targets.cat() + head_transitions = head_transitions[targets.head().data] + last_transitions = last_transitions[targets.last().data] + transitions = transitions[targets.roll(1).data, targets.data] - _, (batch_ptr, token_ptr), _ = pad_indices(targets, batch_first=True) - emissions = emissions[batch_ptr, token_ptr, targets.data] + emissions, _ = emissions.idx().cat().rua(emissions, targets) emissions = semiring.segment_prod(emissions, sizes=targets.token_sizes) token_sizes = torch.stack([torch.ones_like(targets.token_sizes), targets.token_sizes - 1], dim=-1) @@ -53,95 +32,94 @@ def crf_scores(emissions: Sequence, targets: Sequence, semiring.mul(emissions, transitions), ) - -def crf_partitions(emissions: Sequence, transitions: Transitions, semiring: Type[Semiring]) -> Tensor: - if isinstance(emissions, CattedSequence): - emissions = pack_catted_sequence(emissions) - elif isinstance(emissions, tuple) and len(emissions) == 2: - emissions = pack_padded_sequence(*emissions, batch_first=True) - - assert isinstance(emissions, PackedSequence) - - emissions, batch_sizes, _, unsorted_indices = emissions - transitions, head_transitions, last_transitions = transitions - - last_indices = last_packed_indices( - batch_sizes=batch_sizes, - unsorted_indices=unsorted_indices, - device=emissions.device, - ) - - _, *batch_sizes = sections = batch_sizes.detach().cpu().tolist() - emission, *emissions = torch.split(emissions, sections, dim=0) - - charts = [semiring.mul(head_transitions, emission)] - for emission, batch_size in zip(emissions, batch_sizes): - charts.append(semiring.mul( - semiring.bmm(charts[-1][:batch_size], transitions), - emission, - )) - - emission = torch.cat(charts, dim=0)[last_indices] - return semiring.sum(semiring.mul(emission, last_transitions), dim=-1) - - -class CrfDistribution(StructuredDistribution): - def __init__(self, emissions: Sequence, transitions: Transitions) -> None: - super(CrfDistribution, self).__init__() - self.emissions = emissions - self.transitions = transitions - - def log_scores(self, targets: Sequence) -> Tensor: - return crf_scores( - emissions=self.emissions, - targets=targets, - transitions=self.transitions, - semiring=Log, - ) - - @lazy_property - def log_partitions(self) -> Tensor: - return crf_partitions( - emissions=self.emissions, - transitions=self.transitions, - semiring=Log, - ) - - @lazy_property - def max(self) -> Tensor: - return crf_partitions( - emissions=self.emissions, - transitions=self.transitions, - semiring=Max, - ) - - -class CrfDecoder(nn.Module): - def __init__(self, *, num_targets: int) -> None: - super(CrfDecoder, self).__init__() - - self.num_targets = num_targets - - self.transitions = nn.Parameter(torch.empty((num_targets, num_targets))) - self.head_transitions = nn.Parameter(torch.empty((num_targets,))) - self.last_transitions = nn.Parameter(torch.empty((num_targets,))) - - self.reset_parameters() - - def reset_parameters(self) -> None: - init.zeros_(self.transitions) - init.zeros_(self.head_transitions) - init.zeros_(self.last_transitions) - - def extra_repr(self) -> str: - return f'num_targets={self.num_targets}' - - def forward(self, emissions: Sequence) -> CrfDistribution: - return CrfDistribution( - emissions=emissions, - transitions=( - self.transitions, - self.head_transitions, - self.last_transitions, - ), - ) +# def crf_partitions(emissions: Sequence, transitions: Transitions, semiring: Type[Semiring]) -> Tensor: +# if isinstance(emissions, CattedSequence): +# emissions = pack_catted_sequence(emissions) +# elif isinstance(emissions, tuple) and len(emissions) == 2: +# emissions = pack_padded_sequence(*emissions, batch_first=True) +# +# assert isinstance(emissions, PackedSequence) +# +# emissions, batch_sizes, _, unsorted_indices = emissions +# transitions, head_transitions, last_transitions = transitions +# +# last_indices = last_packed_indices( +# batch_sizes=batch_sizes, +# unsorted_indices=unsorted_indices, +# device=emissions.device, +# ) +# +# _, *batch_sizes = sections = batch_sizes.detach().cpu().tolist() +# emission, *emissions = torch.split(emissions, sections, dim=0) +# +# charts = [semiring.mul(head_transitions, emission)] +# for emission, batch_size in zip(emissions, batch_sizes): +# charts.append(semiring.mul( +# semiring.bmm(charts[-1][:batch_size], transitions), +# emission, +# )) +# +# emission = torch.cat(charts, dim=0)[last_indices] +# return semiring.sum(semiring.mul(emission, last_transitions), dim=-1) +# +# +# class CrfDistribution(StructuredDistribution): +# def __init__(self, emissions: Sequence, transitions: Transitions) -> None: +# super(CrfDistribution, self).__init__() +# self.emissions = emissions +# self.transitions = transitions +# +# def log_scores(self, targets: Sequence) -> Tensor: +# return crf_scores( +# emissions=self.emissions, +# targets=targets, +# transitions=self.transitions, +# semiring=Log, +# ) +# +# @lazy_property +# def log_partitions(self) -> Tensor: +# return crf_partitions( +# emissions=self.emissions, +# transitions=self.transitions, +# semiring=Log, +# ) +# +# @lazy_property +# def max(self) -> Tensor: +# return crf_partitions( +# emissions=self.emissions, +# transitions=self.transitions, +# semiring=Max, +# ) +# +# +# class CrfDecoder(nn.Module): +# def __init__(self, *, num_targets: int) -> None: +# super(CrfDecoder, self).__init__() +# +# self.num_targets = num_targets +# +# self.transitions = nn.Parameter(torch.empty((num_targets, num_targets))) +# self.head_transitions = nn.Parameter(torch.empty((num_targets,))) +# self.last_transitions = nn.Parameter(torch.empty((num_targets,))) +# +# self.reset_parameters() +# +# def reset_parameters(self) -> None: +# init.zeros_(self.transitions) +# init.zeros_(self.head_transitions) +# init.zeros_(self.last_transitions) +# +# def extra_repr(self) -> str: +# return f'num_targets={self.num_targets}' +# +# def forward(self, emissions: Sequence) -> CrfDistribution: +# return CrfDistribution( +# emissions=emissions, +# transitions=( +# self.transitions, +# self.head_transitions, +# self.last_transitions, +# ), +# ) diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index be859f7..b6f0b64 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,14 +1,12 @@ import torch from torch import Tensor + +from torchlatent.functional import logaddexp +from torchlatent.functional import logsumexp from torchrua import segment_logsumexp from torchrua import segment_max from torchrua import segment_prod from torchrua import segment_sum -from torchrua.reduction import reduce_sequence -from torchrua.reduction import ReductionIndices - -from torchlatent.functional import logaddexp -from torchlatent.functional import logsumexp __all__ = [ 'Semiring', 'ExceptionSemiring', @@ -58,10 +56,6 @@ def segment_prod(cls, tensor: Tensor, sizes: Tensor) -> Tensor: def bmm(cls, x: Tensor, y: Tensor) -> Tensor: return cls.sum(cls.mul(x[..., :, :, None], y[..., None, :, :]), dim=-2, keepdim=False) - @classmethod - def reduce(cls, tensor: Tensor, indices: ReductionIndices) -> Tensor: - return reduce_sequence(data=tensor, indices=indices, op=cls.bmm) - class Std(Semiring): zero = 0. From 5a6355c7183a52e04a41f2a5cb349398c9b48399 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 00:54:20 +0900 Subject: [PATCH 073/102] Feat: Update crf_partitions --- tests/test_linear_crf.py | 109 +++++++++++--------------------------- torchlatent/linear_crf.py | 53 ++++++++---------- 2 files changed, 52 insertions(+), 110 deletions(-) diff --git a/tests/test_linear_crf.py b/tests/test_linear_crf.py index 9d92d11..4ecd7c5 100644 --- a/tests/test_linear_crf.py +++ b/tests/test_linear_crf.py @@ -9,6 +9,7 @@ from torchnyan import device from torchnyan import sizes +from torchlatent.linear_crf import crf_partitions from torchlatent.linear_crf import crf_scores from torchlatent.semiring import Log from torchrua import cat_sequence @@ -54,82 +55,32 @@ def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): assert_close(actual=actual, expected=excepted) assert_grad_close(actual=actual, expected=excepted, inputs=inputs) -# @given( -# token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), -# num_targets=sizes(TOKEN_SIZE), -# ) -# def test_crf_catted_partitions(token_sizes, num_targets): -# inputs = [ -# torch.randn((token_size, num_targets), device=device, requires_grad=True) -# for token_size in token_sizes -# ] -# -# excepted_emissions, token_sizes = pad_sequence(inputs) -# index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) -# mask = index[None, :] < token_sizes[:, None] -# -# excepted_crf = CRF(num_tags=num_targets, batch_first=False) -# excepted = excepted_crf._compute_normalizer(excepted_emissions.transpose(0, 1), mask.t()) -# -# actual = crf_partitions( -# emissions=cat_sequence(inputs), -# transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), -# semiring=Log, -# ) -# -# assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) -# assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) -# -# -# @given( -# token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), -# num_targets=sizes(TOKEN_SIZE), -# ) -# def test_crf_packed_partitions(token_sizes, num_targets): -# inputs = [ -# torch.randn((token_size, num_targets), device=device, requires_grad=True) -# for token_size in token_sizes -# ] -# -# excepted_emissions, token_sizes = pad_sequence(inputs) -# index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) -# mask = index[None, :] < token_sizes[:, None] -# -# excepted_crf = CRF(num_tags=num_targets, batch_first=False) -# excepted = excepted_crf._compute_normalizer(excepted_emissions.transpose(0, 1), mask.t()) -# -# actual = crf_partitions( -# emissions=pack_sequence(inputs), -# transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), -# semiring=Log, -# ) -# -# assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) -# assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) -# -# -# @given( -# token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), -# num_targets=sizes(TOKEN_SIZE), -# ) -# def test_crf_padded_partitions(token_sizes, num_targets): -# inputs = [ -# torch.randn((token_size, num_targets), device=device, requires_grad=True) -# for token_size in token_sizes -# ] -# -# excepted_emissions, token_sizes = pad_sequence(inputs) -# index = torch.arange(token_sizes.max().detach().cpu().item(), device=token_sizes.device) -# mask = index[None, :] < token_sizes[:, None] -# -# excepted_crf = CRF(num_tags=num_targets, batch_first=False) -# excepted = excepted_crf._compute_normalizer(excepted_emissions.transpose(0, 1), mask.t()) -# -# actual = crf_partitions( -# emissions=pad_sequence(inputs), -# transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), -# semiring=Log, -# ) -# -# assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) -# assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) + +@given( + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_targets=sizes(TOKEN_SIZE), + rua_emissions=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), +) +def test_crf_partitions(token_sizes, num_targets, rua_emissions): + inputs = [ + torch.randn((token_size, num_targets), device=device, requires_grad=True) + for token_size in token_sizes + ] + + excepted_crf = CRF(num_tags=num_targets, batch_first=False) + + excepted_emissions = pad_sequence(inputs) + + excepted = excepted_crf._compute_normalizer( + excepted_emissions.data.transpose(0, 1), + excepted_emissions.mask().t(), + ) + + actual = crf_partitions( + emissions=rua_emissions(inputs), + transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), + semiring=Log, + ) + + assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) + assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) diff --git a/torchlatent/linear_crf.py b/torchlatent/linear_crf.py index b7ae90a..d28d081 100644 --- a/torchlatent/linear_crf.py +++ b/torchlatent/linear_crf.py @@ -32,37 +32,28 @@ def crf_scores(emissions: Union[C, D, P], targets: Union[C, D, P], transitions: semiring.mul(emissions, transitions), ) -# def crf_partitions(emissions: Sequence, transitions: Transitions, semiring: Type[Semiring]) -> Tensor: -# if isinstance(emissions, CattedSequence): -# emissions = pack_catted_sequence(emissions) -# elif isinstance(emissions, tuple) and len(emissions) == 2: -# emissions = pack_padded_sequence(*emissions, batch_first=True) -# -# assert isinstance(emissions, PackedSequence) -# -# emissions, batch_sizes, _, unsorted_indices = emissions -# transitions, head_transitions, last_transitions = transitions -# -# last_indices = last_packed_indices( -# batch_sizes=batch_sizes, -# unsorted_indices=unsorted_indices, -# device=emissions.device, -# ) -# -# _, *batch_sizes = sections = batch_sizes.detach().cpu().tolist() -# emission, *emissions = torch.split(emissions, sections, dim=0) -# -# charts = [semiring.mul(head_transitions, emission)] -# for emission, batch_size in zip(emissions, batch_sizes): -# charts.append(semiring.mul( -# semiring.bmm(charts[-1][:batch_size], transitions), -# emission, -# )) -# -# emission = torch.cat(charts, dim=0)[last_indices] -# return semiring.sum(semiring.mul(emission, last_transitions), dim=-1) -# -# + +def crf_partitions(emissions: Union[C, D, P], transitions: T, semiring: Type[Semiring]) -> Tensor: + transitions, head_transitions, last_transitions = transitions + + emissions = emissions.pack() + last_indices = emissions.idx().last() + emissions, batch_sizes, _, unsorted_indices = emissions + + _, *batch_sizes = sections = batch_sizes.detach().cpu().tolist() + emission, *emissions = torch.split(emissions, sections, dim=0) + + charts = [semiring.mul(head_transitions, emission)] + for emission, batch_size in zip(emissions, batch_sizes): + charts.append(semiring.mul( + semiring.bmm(charts[-1][:batch_size], transitions), + emission, + )) + + emission = torch.cat(charts, dim=0)[last_indices] + return semiring.sum(semiring.mul(emission, last_transitions), dim=-1) + + # class CrfDistribution(StructuredDistribution): # def __init__(self, emissions: Sequence, transitions: Transitions) -> None: # super(CrfDistribution, self).__init__() From c092b9c1d6a650504737fcb063b52679ccaff546 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 01:06:15 +0900 Subject: [PATCH 074/102] Test: Add CrfDecoder --- tests/test_linear_crf.py | 41 +++++++++++-- torchlatent/linear_crf.py | 124 ++++++++++++++++++++------------------ 2 files changed, 101 insertions(+), 64 deletions(-) diff --git a/tests/test_linear_crf.py b/tests/test_linear_crf.py index 4ecd7c5..86b35bd 100644 --- a/tests/test_linear_crf.py +++ b/tests/test_linear_crf.py @@ -2,16 +2,18 @@ from hypothesis import given from hypothesis import strategies as st from torchcrf import CRF + +from torchlatent.linear_crf import CrfDecoder +from torchlatent.linear_crf import crf_partitions +from torchlatent.linear_crf import crf_scores +from torchlatent.semiring import Log from torchnyan import BATCH_SIZE from torchnyan import TOKEN_SIZE from torchnyan import assert_close from torchnyan import assert_grad_close +from torchnyan import assert_sequence_close from torchnyan import device from torchnyan import sizes - -from torchlatent.linear_crf import crf_partitions -from torchlatent.linear_crf import crf_scores -from torchlatent.semiring import Log from torchrua import cat_sequence from torchrua import pack_sequence from torchrua import pad_sequence @@ -84,3 +86,34 @@ def test_crf_partitions(token_sizes, num_targets, rua_emissions): assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) + + +@given( + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_targets=sizes(TOKEN_SIZE), + rua_emissions=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), +) +def test_crf_argmax(token_sizes, num_targets, rua_emissions): + inputs = [ + torch.randn((token_size, num_targets), device=device, requires_grad=True) + for token_size in token_sizes + ] + + excepted_crf = CRF(num_tags=num_targets, batch_first=False) + + excepted_emissions = pad_sequence(inputs) + + excepted = excepted_crf.decode( + excepted_emissions.data.transpose(0, 1), + excepted_emissions.mask().t(), + ) + excepted = cat_sequence([torch.tensor(tensor, device=device) for tensor in excepted]) + + actual_crf = CrfDecoder(num_targets=num_targets) + actual_crf.transitions = excepted_crf.transitions + actual_crf.head_transitions = excepted_crf.start_transitions + actual_crf.last_transitions = excepted_crf.end_transitions + + actual = actual_crf(rua_emissions(inputs)).argmax.cat() + + assert_sequence_close(actual=actual, expected=excepted) diff --git a/torchlatent/linear_crf.py b/torchlatent/linear_crf.py index d28d081..e0d9257 100644 --- a/torchlatent/linear_crf.py +++ b/torchlatent/linear_crf.py @@ -4,7 +4,13 @@ import torch from torch import Tensor +from torch import nn +from torch.distributions.utils import lazy_property +from torch.nn import init +from torchlatent.abc2 import StructuredDistribution +from torchlatent.semiring import Log +from torchlatent.semiring import Max from torchlatent.semiring import Semiring from torchrua import C from torchrua import D @@ -54,63 +60,61 @@ def crf_partitions(emissions: Union[C, D, P], transitions: T, semiring: Type[Sem return semiring.sum(semiring.mul(emission, last_transitions), dim=-1) -# class CrfDistribution(StructuredDistribution): -# def __init__(self, emissions: Sequence, transitions: Transitions) -> None: -# super(CrfDistribution, self).__init__() -# self.emissions = emissions -# self.transitions = transitions -# -# def log_scores(self, targets: Sequence) -> Tensor: -# return crf_scores( -# emissions=self.emissions, -# targets=targets, -# transitions=self.transitions, -# semiring=Log, -# ) -# -# @lazy_property -# def log_partitions(self) -> Tensor: -# return crf_partitions( -# emissions=self.emissions, -# transitions=self.transitions, -# semiring=Log, -# ) -# -# @lazy_property -# def max(self) -> Tensor: -# return crf_partitions( -# emissions=self.emissions, -# transitions=self.transitions, -# semiring=Max, -# ) -# -# -# class CrfDecoder(nn.Module): -# def __init__(self, *, num_targets: int) -> None: -# super(CrfDecoder, self).__init__() -# -# self.num_targets = num_targets -# -# self.transitions = nn.Parameter(torch.empty((num_targets, num_targets))) -# self.head_transitions = nn.Parameter(torch.empty((num_targets,))) -# self.last_transitions = nn.Parameter(torch.empty((num_targets,))) -# -# self.reset_parameters() -# -# def reset_parameters(self) -> None: -# init.zeros_(self.transitions) -# init.zeros_(self.head_transitions) -# init.zeros_(self.last_transitions) -# -# def extra_repr(self) -> str: -# return f'num_targets={self.num_targets}' -# -# def forward(self, emissions: Sequence) -> CrfDistribution: -# return CrfDistribution( -# emissions=emissions, -# transitions=( -# self.transitions, -# self.head_transitions, -# self.last_transitions, -# ), -# ) +class CrfDistribution(StructuredDistribution): + def __init__(self, emissions: Union[C, D, P], transitions: T) -> None: + super(CrfDistribution, self).__init__(emissions=emissions) + self.transitions = transitions + + def log_scores(self, targets: Union[C, D, P]) -> Tensor: + return crf_scores( + emissions=self.emissions, targets=targets, + transitions=self.transitions, + semiring=Log, + ) + + @lazy_property + def log_partitions(self) -> Tensor: + return crf_partitions( + emissions=self.emissions, + transitions=self.transitions, + semiring=Log, + ) + + @lazy_property + def max(self) -> Tensor: + return crf_partitions( + emissions=self.emissions, + transitions=self.transitions, + semiring=Max, + ) + + +class CrfDecoder(nn.Module): + def __init__(self, *, num_targets: int) -> None: + super(CrfDecoder, self).__init__() + + self.num_targets = num_targets + + self.transitions = nn.Parameter(torch.empty((num_targets, num_targets))) + self.head_transitions = nn.Parameter(torch.empty((num_targets,))) + self.last_transitions = nn.Parameter(torch.empty((num_targets,))) + + self.reset_parameters() + + def reset_parameters(self) -> None: + init.zeros_(self.transitions) + init.zeros_(self.head_transitions) + init.zeros_(self.last_transitions) + + def extra_repr(self) -> str: + return f'num_targets={self.num_targets}' + + def forward(self, emissions: Union[C, D, P]) -> CrfDistribution: + return CrfDistribution( + emissions=emissions, + transitions=( + self.transitions, + self.head_transitions, + self.last_transitions, + ), + ) From 976164b6c6668915f5f822f8bb1790a2b01f550c Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 01:11:21 +0900 Subject: [PATCH 075/102] Refactor: Renaming --- README.md | 82 +------ tests/test_cky.py | 8 +- tests/test_crf.py | 456 ++++++----------------------------- tests/test_functional.py | 6 +- tests/test_linear_crf.py | 119 --------- torchlatent/__init__.py | 2 + torchlatent/abc.py | 45 ++-- torchlatent/abc2.py | 48 ---- torchlatent/cky.py | 18 +- torchlatent/crf.py | 316 ++++-------------------- torchlatent/linear_crf.py | 120 --------- torchlatent/nn/__init__.py | 0 torchlatent/nn/classifier.py | 76 ------ 13 files changed, 161 insertions(+), 1135 deletions(-) delete mode 100644 tests/test_linear_crf.py delete mode 100644 torchlatent/abc2.py delete mode 100644 torchlatent/linear_crf.py delete mode 100644 torchlatent/nn/__init__.py delete mode 100644 torchlatent/nn/classifier.py diff --git a/README.md b/README.md index 21edae0..4ec95ef 100644 --- a/README.md +++ b/README.md @@ -13,88 +13,10 @@ `python3 -m pip torchlatent` -## Performance - -``` -TorchLatent (0.109244) => 0.003781 0.017763 0.087700 0.063497 -Third (0.232487) => 0.103277 0.129209 0.145311 -``` - -## Usage - -```python -import torch -from torchrua import pack_sequence - -from torchlatent.crf import CrfLayer - -num_tags = 3 -num_conjugates = 1 - -decoder = CrfLayer(num_targets=num_tags, num_conjugates=num_conjugates) - -emissions = pack_sequence([ - torch.randn((5, num_conjugates, num_tags), requires_grad=True), - torch.randn((2, num_conjugates, num_tags), requires_grad=True), - torch.randn((3, num_conjugates, num_tags), requires_grad=True), -]) - -tags = pack_sequence([ - torch.randint(0, num_tags, (5, num_conjugates)), - torch.randint(0, num_tags, (2, num_conjugates)), - torch.randint(0, num_tags, (3, num_conjugates)), -]) - -print(decoder.fit(emissions=emissions, tags=tags)) -# tensor([[-6.7424], -# [-5.1288], -# [-2.7283]], grad_fn=) - -print(decoder.decode(emissions=emissions)) -# PackedSequence(data=tensor([[2], -# [0], -# [1], -# [0], -# [2], -# [0], -# [2], -# [0], -# [1], -# [2]]), -# batch_sizes=tensor([3, 3, 2, 1, 1]), -# sorted_indices=tensor([0, 2, 1]), -# unsorted_indices=tensor([0, 2, 1])) - -print(decoder.marginals(emissions=emissions)) -# tensor([[[0.1040, 0.1001, 0.7958]], -# -# [[0.5736, 0.0784, 0.3479]], -# -# [[0.0932, 0.8797, 0.0271]], -# -# [[0.6558, 0.0472, 0.2971]], -# -# [[0.2740, 0.1109, 0.6152]], -# -# [[0.4811, 0.2163, 0.3026]], -# -# [[0.2321, 0.3478, 0.4201]], -# -# [[0.4987, 0.1986, 0.3027]], -# -# [[0.2029, 0.5888, 0.2083]], -# -# [[0.2802, 0.2358, 0.4840]]], grad_fn=) -``` - ## Latent Structures -- [ ] Conditional Random Fields (CRF) - - [x] Conjugated - - [ ] Dynamic Transition Matrix - - [ ] Second-order - - [ ] Variant-order -- [ ] Tree CRF +- [x] Conditional Random Fields (CRF) +- [x] Tree CRF - [ ] Non-Projective Dependency Tree (Matrix-tree Theorem) - [ ] Probabilistic Context-free Grammars (PCFG) - [ ] Dependency Model with Valence (DMV) \ No newline at end of file diff --git a/tests/test_cky.py b/tests/test_cky.py index 6e78125..b9c3d31 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -2,6 +2,10 @@ from hypothesis import given from hypothesis import strategies as st from torch_struct import TreeCRF + +from torchlatent.cky import CkyDecoder +from torchlatent.cky import CkyDistribution +from torchlatent.cky import cky_partitions_indices from torchnyan.assertion import assert_close from torchnyan.assertion import assert_grad_close from torchnyan.strategy import BATCH_SIZE @@ -12,10 +16,6 @@ from torchnyan.strategy import sizes from torchrua import cat_sequence -from torchlatent.cky import CkyDecoder -from torchlatent.cky import CkyDistribution -from torchlatent.cky import cky_partitions_indices - @given( token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), diff --git a/tests/test_crf.py b/tests/test_crf.py index b5dbab9..af9f30c 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -1,433 +1,119 @@ import torch -import torchcrf from hypothesis import given -from torchnyan.assertion import assert_close -from torchnyan.assertion import assert_grad_close -from torchnyan.strategy import TINY_BATCH_SIZE -from torchnyan.strategy import TINY_TOKEN_SIZE -from torchnyan.strategy import TOKEN_SIZE -from torchnyan.strategy import device -from torchnyan.strategy import sizes +from hypothesis import strategies as st +from torchcrf import CRF + +from torchlatent.crf import CrfDecoder +from torchlatent.crf import crf_partitions +from torchlatent.crf import crf_scores +from torchlatent.semiring import Log +from torchnyan import BATCH_SIZE +from torchnyan import TOKEN_SIZE +from torchnyan import assert_close +from torchnyan import assert_grad_close +from torchnyan import assert_sequence_close +from torchnyan import device +from torchnyan import sizes from torchrua import cat_sequence -from torchrua import pack_catted_indices from torchrua import pack_sequence -from torchrua import pad_catted_indices -from torchrua import pad_packed_indices from torchrua import pad_sequence -from torchlatent.crf import CrfLayer - - -@given( - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - num_tags=sizes(TOKEN_SIZE), -) -def test_crf_catted_scores(token_sizes, num_tags): - actual_decoder = CrfLayer(num_tags).to(device=device) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) - - excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) - excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) - excepted_decoder.end_transitions.data = torch.randn_like(excepted_decoder.end_transitions) - - actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] - actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] - actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] - - emissions = [ - torch.randn((token_size, num_tags), requires_grad=True, device=device) - for token_size in token_sizes - ] - targets = [ - torch.randint(0, num_tags, (token_size,), device=device) - for token_size in token_sizes - ] - - catted_emissions = cat_sequence([x[:, None] for x in emissions]) - catted_targets = cat_sequence([x[:, None] for x in targets]) - - padded_emissions, _ = pad_sequence(emissions, batch_first=False) - padded_targets, _ = pad_sequence(targets, batch_first=False) - - size, ptr, _ = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) - mask = torch.zeros(size, dtype=torch.bool, device=device) - mask[ptr] = True - - actual = actual_decoder.forward(emissions=catted_emissions).log_scores(targets=catted_targets)[:, 0] - excepted = excepted_decoder._compute_score( - emissions=padded_emissions, tags=padded_targets.long(), - mask=mask.byte(), - ) - - assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) - assert_grad_close(actual=actual, expected=excepted, inputs=emissions, rtol=1e-4, atol=1e-4) - - -@given( - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - num_tags=sizes(TOKEN_SIZE), -) -def test_crf_catted_fit(token_sizes, num_tags): - actual_decoder = CrfLayer(num_tags).to(device=device) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) - - excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) - excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) - excepted_decoder.end_transitions.data = torch.randn_like(excepted_decoder.end_transitions) - - actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] - actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] - actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] - - emissions = [ - torch.randn((token_size, num_tags), requires_grad=True, device=device) - for token_size in token_sizes - ] - targets = [ - torch.randint(0, num_tags, (token_size,), device=device) - for token_size in token_sizes - ] - - catted_emissions = cat_sequence([x[:, None] for x in emissions]) - catted_targets = cat_sequence([x[:, None] for x in targets]) - - padded_emissions, _ = pad_sequence(emissions, batch_first=False) - padded_targets, _ = pad_sequence(targets, batch_first=False) - - size, ptr, _ = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) - mask = torch.zeros(size, dtype=torch.bool, device=device) - mask[ptr] = True - - actual = actual_decoder.fit(emissions=catted_emissions, targets=catted_targets)[:, 0] - excepted = excepted_decoder.forward( - emissions=padded_emissions, tags=padded_targets.long(), - mask=mask.byte(), reduction='none', - ).neg() - - assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) - assert_grad_close(actual=actual, expected=excepted, inputs=emissions, rtol=1e-4, atol=1e-4) - @given( - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - num_tags=sizes(TOKEN_SIZE), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_targets=sizes(TOKEN_SIZE), + rua_emissions=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), + rua_targets=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), ) -def test_crf_catted_decode(token_sizes, num_tags): - actual_decoder = CrfLayer(num_tags).to(device=device) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) - - excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) - excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) - excepted_decoder.end_transitions.data = torch.randn_like(excepted_decoder.end_transitions) - - actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] - actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] - actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] - - emissions = [ - torch.randn((token_size, num_tags), requires_grad=True, device=device) +def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): + inputs = [ + torch.randn((token_size, num_targets), device=device, requires_grad=True) for token_size in token_sizes ] - catted_emissions = cat_sequence([x[:, None] for x in emissions]) - - padded_emissions, _ = pad_sequence(emissions, batch_first=False) - - size, ptr, _ = pad_catted_indices(token_sizes=catted_emissions.token_sizes, batch_first=False) - mask = torch.zeros(size, dtype=torch.bool, device=device) - mask[ptr] = True - - actual, actual_token_sizes = actual_decoder.decode(emissions=catted_emissions) - - excepted = excepted_decoder.decode(emissions=padded_emissions, mask=mask.byte()) - excepted, excepted_token_sizes = cat_sequence([torch.tensor(x, device=device) for x in excepted]) - - assert_close(actual=actual[:, 0], expected=excepted) - assert_close(actual=actual_token_sizes, expected=excepted_token_sizes) - - -@given( - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - num_tags=sizes(TOKEN_SIZE), -) -def test_crf_packed_scores(token_sizes, num_tags): - actual_decoder = CrfLayer(num_tags).to(device=device) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) - - excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) - excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) - excepted_decoder.end_transitions.data = torch.randn_like(excepted_decoder.end_transitions) - - actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] - actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] - actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] - - emissions = [ - torch.randn((token_size, num_tags), requires_grad=True, device=device) - for token_size in token_sizes - ] targets = [ - torch.randint(0, num_tags, (token_size,), device=device) + torch.randint(0, num_targets, (token_size,), device=device) for token_size in token_sizes ] - packed_emissions = pack_sequence([x[:, None] for x in emissions]) - packed_targets = pack_sequence([x[:, None] for x in targets]) + excepted_crf = CRF(num_tags=num_targets, batch_first=False) - padded_emissions, _ = pad_sequence(emissions, batch_first=False) - padded_targets, _ = pad_sequence(targets, batch_first=False) + excepted_emissions = pad_sequence(inputs) + excepted_tags = pad_sequence(targets) - size, ptr, _ = pad_packed_indices( - batch_sizes=packed_emissions.batch_sizes, - sorted_indices=packed_emissions.sorted_indices, - unsorted_indices=packed_emissions.unsorted_indices, - batch_first=False, + excepted = excepted_crf._compute_score( + excepted_emissions.data.transpose(0, 1), + excepted_tags.data.transpose(0, 1), + excepted_emissions.mask().transpose(0, 1), ) - mask = torch.zeros(size, dtype=torch.bool, device=device) - mask[ptr] = True - actual = actual_decoder.forward(emissions=packed_emissions).log_scores(targets=packed_targets)[:, 0] - excepted = excepted_decoder._compute_score( - emissions=padded_emissions, tags=padded_targets.long(), - mask=mask.byte(), + actual = crf_scores( + emissions=rua_emissions(inputs), + targets=rua_targets(targets), + transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), + semiring=Log, ) - assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) - assert_grad_close(actual=actual, expected=excepted, inputs=emissions, rtol=1e-4, atol=1e-4) + assert_close(actual=actual, expected=excepted) + assert_grad_close(actual=actual, expected=excepted, inputs=inputs) @given( - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - num_tags=sizes(TOKEN_SIZE), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_targets=sizes(TOKEN_SIZE), + rua_emissions=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), ) -def test_crf_packed_fit(token_sizes, num_tags): - actual_decoder = CrfLayer(num_tags).to(device=device) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) - - excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) - excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) - excepted_decoder.end_transitions.data = torch.randn_like(excepted_decoder.end_transitions) - - actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] - actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] - actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] - - emissions = [ - torch.randn((token_size, num_tags), requires_grad=True, device=device) - for token_size in token_sizes - ] - targets = [ - torch.randint(0, num_tags, (token_size,), device=device) +def test_crf_partitions(token_sizes, num_targets, rua_emissions): + inputs = [ + torch.randn((token_size, num_targets), device=device, requires_grad=True) for token_size in token_sizes ] - packed_emissions = pack_sequence([x[:, None] for x in emissions]) - packed_targets = pack_sequence([x[:, None] for x in targets]) + excepted_crf = CRF(num_tags=num_targets, batch_first=False) - padded_emissions, _ = pad_sequence(emissions, batch_first=False) - padded_targets, _ = pad_sequence(targets, batch_first=False) + excepted_emissions = pad_sequence(inputs) - size, ptr, _ = pad_packed_indices( - batch_sizes=packed_emissions.batch_sizes, - sorted_indices=packed_emissions.sorted_indices, - unsorted_indices=packed_emissions.unsorted_indices, - batch_first=False, + excepted = excepted_crf._compute_normalizer( + excepted_emissions.data.transpose(0, 1), + excepted_emissions.mask().t(), ) - mask = torch.zeros(size, dtype=torch.bool, device=device) - mask[ptr] = True - actual = actual_decoder.fit(emissions=packed_emissions, targets=packed_targets)[:, 0] - excepted = excepted_decoder.forward( - emissions=padded_emissions, tags=padded_targets.long(), - mask=mask.byte(), reduction='none', - ).neg() + actual = crf_partitions( + emissions=rua_emissions(inputs), + transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), + semiring=Log, + ) assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) - assert_grad_close(actual=actual, expected=excepted, inputs=emissions, rtol=1e-4, atol=1e-4) + assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) @given( - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - num_tags=sizes(TOKEN_SIZE), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_targets=sizes(TOKEN_SIZE), + rua_emissions=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), ) -def test_crf_packed_decode(token_sizes, num_tags): - actual_decoder = CrfLayer(num_tags).to(device=device) - excepted_decoder = torchcrf.CRF(num_tags, batch_first=False).to(device=device) - - excepted_decoder.transitions.data = torch.randn_like(excepted_decoder.transitions) - excepted_decoder.start_transitions.data = torch.randn_like(excepted_decoder.start_transitions) - excepted_decoder.end_transitions.data = torch.randn_like(excepted_decoder.end_transitions) - - actual_decoder.transitions.data = excepted_decoder.transitions[None, None, :, :] - actual_decoder.head_transitions.data = excepted_decoder.start_transitions[None, None, :] - actual_decoder.last_transitions.data = excepted_decoder.end_transitions[None, None, :] - - emissions = [ - torch.randn((token_size, num_tags), requires_grad=True, device=device) +def test_crf_argmax(token_sizes, num_targets, rua_emissions): + inputs = [ + torch.randn((token_size, num_targets), device=device, requires_grad=True) for token_size in token_sizes ] - packed_emissions = pack_sequence([x[:, None] for x in emissions]) + excepted_crf = CRF(num_tags=num_targets, batch_first=False) - padded_emissions, _ = pad_sequence(emissions, batch_first=False) + excepted_emissions = pad_sequence(inputs) - size, ptr, _ = pad_packed_indices( - batch_sizes=packed_emissions.batch_sizes, - sorted_indices=packed_emissions.sorted_indices, - unsorted_indices=packed_emissions.unsorted_indices, - batch_first=False, + excepted = excepted_crf.decode( + excepted_emissions.data.transpose(0, 1), + excepted_emissions.mask().t(), ) - mask = torch.zeros(size, dtype=torch.bool, device=device) - mask[ptr] = True - - actual = actual_decoder.decode(emissions=packed_emissions) - - excepted = excepted_decoder.decode(emissions=padded_emissions, mask=mask.byte()) - excepted = pack_sequence([torch.tensor(x, device=device) for x in excepted]) - - assert_close(actual=actual.data[:, 0], expected=excepted.data) - assert_close(actual=actual.batch_sizes, expected=excepted.batch_sizes) - assert_close(actual=actual.sorted_indices, expected=excepted.sorted_indices) - assert_close(actual=actual.unsorted_indices, expected=excepted.unsorted_indices) - - -@given( - token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), - num_conjugates=sizes(TOKEN_SIZE), - num_tags=sizes(TINY_TOKEN_SIZE), -) -def test_conjugated_catted_fit(token_sizes, num_conjugates, num_tags): - decoder = CrfLayer(num_targets=num_tags, num_conjugates=num_conjugates).to(device=device) - decoders = [CrfLayer(num_targets=num_tags, num_conjugates=1).to(device=device) for _ in range(num_conjugates)] + excepted = cat_sequence([torch.tensor(tensor, device=device) for tensor in excepted]) - for index in range(num_conjugates): - decoders[index].transitions.data = torch.randn_like(decoders[index].transitions) - decoders[index].head_transitions.data = torch.randn_like(decoders[index].head_transitions) - decoders[index].last_transitions.data = torch.randn_like(decoders[index].last_transitions) - - decoder.transitions.data[:, index] = decoders[index].transitions - decoder.head_transitions.data[:, index] = decoders[index].head_transitions - decoder.last_transitions.data[:, index] = decoders[index].last_transitions - - emissions = [[ - torch.randn((token_size, 1, num_tags), requires_grad=True, device=device) - for token_size in token_sizes - ] for _ in range(num_conjugates)] - - targets = [[ - torch.randint(0, num_tags, (token_size, 1), device=device) - for token_size in token_sizes - ] for _ in range(num_conjugates)] - - actual = decoder.fit( - emissions=cat_sequence([torch.cat(sequences, dim=1) for sequences in zip(*emissions)], device=device), - targets=cat_sequence([torch.cat(sequences, dim=1) for sequences in zip(*targets)], device=device), - ) - - expected = torch.cat([ - decoders[index].fit( - emissions=cat_sequence(emissions[index], device=device), - targets=cat_sequence(targets[index], device=device), - ) - for index in range(num_conjugates) - ], dim=1) - - assert_close(actual=actual, expected=expected, rtol=1e-4, atol=1e-4) - assert_grad_close( - actual=actual, expected=expected, - inputs=[x for xs in emissions for x in xs], - rtol=1e-4, atol=1e-4, check_stride=False, - ) - - -@given( - token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), - num_conjugates=sizes(TOKEN_SIZE), - num_tags=sizes(TINY_TOKEN_SIZE), -) -def test_conjugated_packed_fit(token_sizes, num_conjugates, num_tags): - decoder = CrfLayer(num_targets=num_tags, num_conjugates=num_conjugates).to(device=device) - decoders = [CrfLayer(num_targets=num_tags, num_conjugates=1).to(device=device) for _ in range(num_conjugates)] + actual_crf = CrfDecoder(num_targets=num_targets) + actual_crf.transitions = excepted_crf.transitions + actual_crf.head_transitions = excepted_crf.start_transitions + actual_crf.last_transitions = excepted_crf.end_transitions - for index in range(num_conjugates): - decoders[index].transitions.data = torch.randn_like(decoders[index].transitions) - decoders[index].head_transitions.data = torch.randn_like(decoders[index].head_transitions) - decoders[index].last_transitions.data = torch.randn_like(decoders[index].last_transitions) - - decoder.transitions.data[:, index] = decoders[index].transitions - decoder.head_transitions.data[:, index] = decoders[index].head_transitions - decoder.last_transitions.data[:, index] = decoders[index].last_transitions - - emissions = [[ - torch.randn((token_size, 1, num_tags), requires_grad=True, device=device) - for token_size in token_sizes - ] for _ in range(num_conjugates)] - - targets = [[ - torch.randint(0, num_tags, (token_size, 1), device=device) - for token_size in token_sizes - ] for _ in range(num_conjugates)] - - actual = decoder.fit( - emissions=pack_sequence([torch.cat(sequences, dim=1) for sequences in zip(*emissions)], device=device), - targets=pack_sequence([torch.cat(sequences, dim=1) for sequences in zip(*targets)], device=device), - ) - - expected = torch.cat([ - decoders[index].fit( - emissions=pack_sequence(emissions[index], device=device), - targets=pack_sequence(targets[index], device=device), - ) - for index in range(num_conjugates) - ], dim=1) - - assert_close(actual=actual, expected=expected) - assert_grad_close( - actual=actual, expected=expected, - inputs=[x for xs in emissions for x in xs], - rtol=1e-4, atol=1e-4, check_stride=False, - ) - - -@given( - token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE), - num_conjugates=sizes(TOKEN_SIZE), - num_tags=sizes(TINY_TOKEN_SIZE), -) -def test_dynamic_fit(token_sizes, num_conjugates, num_tags): - packed_decoder = CrfLayer(num_targets=num_tags, num_conjugates=num_conjugates).to(device=device) - catted_decoder = CrfLayer(num_targets=num_tags, num_conjugates=num_conjugates).to(device=device) - - emissions = [ - torch.randn((token_size, 1, num_tags), requires_grad=True, device=device) - for token_size in token_sizes - ] - - targets = [ - torch.randint(0, num_tags, (token_size, 1), device=device) - for token_size in token_sizes - ] - - catted_decoder.transitions.data = torch.randn((sum(token_sizes), num_conjugates, num_tags, num_tags), device=device) - catted_decoder.head_transitions.data = torch.randn((len(token_sizes), num_conjugates, num_tags), device=device) - catted_decoder.last_transitions.data = torch.randn((len(token_sizes), num_conjugates, num_tags), device=device) - - token_sizes = torch.tensor(token_sizes, device=device) - indices, _, sorted_indices, _ = pack_catted_indices(token_sizes=token_sizes, device=device) - - packed_decoder.transitions.data = catted_decoder.transitions[indices] - packed_decoder.head_transitions.data = catted_decoder.head_transitions[sorted_indices] - packed_decoder.last_transitions.data = catted_decoder.last_transitions[sorted_indices] - - packed_fit = packed_decoder.fit( - emissions=pack_sequence(emissions, device=device), - targets=pack_sequence(targets, device=device), - ) - - catted_fit = catted_decoder.fit( - emissions=cat_sequence(emissions, device=device), - targets=cat_sequence(targets, device=device), - ) + actual = actual_crf(rua_emissions(inputs)).argmax.cat() - assert_close(actual=packed_fit, expected=catted_fit, rtol=1e-4, atol=1e-4) - assert_grad_close(actual=catted_fit, expected=catted_fit, inputs=emissions, rtol=1e-4, atol=1e-4) + assert_sequence_close(actual=actual, expected=excepted) diff --git a/tests/test_functional.py b/tests/test_functional.py index 1af9370..b06221f 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,6 +1,9 @@ import torch from hypothesis import given from hypothesis import strategies as st + +from torchlatent.functional import logaddexp +from torchlatent.functional import logsumexp from torchnyan.assertion import assert_close from torchnyan.assertion import assert_grad_close from torchnyan.strategy import TINY_BATCH_SIZE @@ -8,9 +11,6 @@ from torchnyan.strategy import device from torchnyan.strategy import sizes -from torchlatent.functional import logaddexp -from torchlatent.functional import logsumexp - @given( token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE) diff --git a/tests/test_linear_crf.py b/tests/test_linear_crf.py deleted file mode 100644 index 86b35bd..0000000 --- a/tests/test_linear_crf.py +++ /dev/null @@ -1,119 +0,0 @@ -import torch -from hypothesis import given -from hypothesis import strategies as st -from torchcrf import CRF - -from torchlatent.linear_crf import CrfDecoder -from torchlatent.linear_crf import crf_partitions -from torchlatent.linear_crf import crf_scores -from torchlatent.semiring import Log -from torchnyan import BATCH_SIZE -from torchnyan import TOKEN_SIZE -from torchnyan import assert_close -from torchnyan import assert_grad_close -from torchnyan import assert_sequence_close -from torchnyan import device -from torchnyan import sizes -from torchrua import cat_sequence -from torchrua import pack_sequence -from torchrua import pad_sequence - - -@given( - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_targets=sizes(TOKEN_SIZE), - rua_emissions=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), - rua_targets=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), -) -def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): - inputs = [ - torch.randn((token_size, num_targets), device=device, requires_grad=True) - for token_size in token_sizes - ] - - targets = [ - torch.randint(0, num_targets, (token_size,), device=device) - for token_size in token_sizes - ] - - excepted_crf = CRF(num_tags=num_targets, batch_first=False) - - excepted_emissions = pad_sequence(inputs) - excepted_tags = pad_sequence(targets) - - excepted = excepted_crf._compute_score( - excepted_emissions.data.transpose(0, 1), - excepted_tags.data.transpose(0, 1), - excepted_emissions.mask().transpose(0, 1), - ) - - actual = crf_scores( - emissions=rua_emissions(inputs), - targets=rua_targets(targets), - transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), - semiring=Log, - ) - - assert_close(actual=actual, expected=excepted) - assert_grad_close(actual=actual, expected=excepted, inputs=inputs) - - -@given( - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_targets=sizes(TOKEN_SIZE), - rua_emissions=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), -) -def test_crf_partitions(token_sizes, num_targets, rua_emissions): - inputs = [ - torch.randn((token_size, num_targets), device=device, requires_grad=True) - for token_size in token_sizes - ] - - excepted_crf = CRF(num_tags=num_targets, batch_first=False) - - excepted_emissions = pad_sequence(inputs) - - excepted = excepted_crf._compute_normalizer( - excepted_emissions.data.transpose(0, 1), - excepted_emissions.mask().t(), - ) - - actual = crf_partitions( - emissions=rua_emissions(inputs), - transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), - semiring=Log, - ) - - assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) - assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) - - -@given( - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_targets=sizes(TOKEN_SIZE), - rua_emissions=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), -) -def test_crf_argmax(token_sizes, num_targets, rua_emissions): - inputs = [ - torch.randn((token_size, num_targets), device=device, requires_grad=True) - for token_size in token_sizes - ] - - excepted_crf = CRF(num_tags=num_targets, batch_first=False) - - excepted_emissions = pad_sequence(inputs) - - excepted = excepted_crf.decode( - excepted_emissions.data.transpose(0, 1), - excepted_emissions.mask().t(), - ) - excepted = cat_sequence([torch.tensor(tensor, device=device) for tensor in excepted]) - - actual_crf = CrfDecoder(num_targets=num_targets) - actual_crf.transitions = excepted_crf.transitions - actual_crf.head_transitions = excepted_crf.start_transitions - actual_crf.last_transitions = excepted_crf.end_transitions - - actual = actual_crf(rua_emissions(inputs)).argmax.cat() - - assert_sequence_close(actual=actual, expected=excepted) diff --git a/torchlatent/__init__.py b/torchlatent/__init__.py index e69de29..99084d9 100644 --- a/torchlatent/__init__.py +++ b/torchlatent/__init__.py @@ -0,0 +1,2 @@ +from torchlatent.crf import CrfDecoder +from torchlatent.crf import CrfDistribution diff --git a/torchlatent/abc.py b/torchlatent/abc.py index 0f7767f..70ed998 100644 --- a/torchlatent/abc.py +++ b/torchlatent/abc.py @@ -4,48 +4,45 @@ import torch import torch.autograd from torch import Tensor -from torch.distributions import Distribution from torch.distributions.utils import lazy_property -from torch.nn.utils.rnn import PackedSequence -from torchrua import CattedSequence -Sequence = Union[CattedSequence, PackedSequence] +from torchrua import C +from torchrua import D +from torchrua import P -class DistributionABC(Distribution, metaclass=ABCMeta): - emissions: Tensor +class StructuredDistribution(object, metaclass=ABCMeta): + def __init__(self, emissions: Union[C, D, P]) -> None: + super(StructuredDistribution, self).__init__() + self.emissions = emissions - def log_scores(self, targets: Sequence) -> Tensor: + def log_scores(self, targets: Union[C, D, P]) -> Tensor: raise NotImplementedError - @lazy_property - def log_partitions(self) -> Tensor: - raise NotImplementedError - - def log_prob(self, targets: Sequence) -> Tensor: + def log_probs(self, targets: Union[C, D, P]) -> Tensor: return self.log_scores(targets=targets) - self.log_partitions @lazy_property - def max(self) -> Tensor: + def log_partitions(self) -> Tensor: raise NotImplementedError - @lazy_property - def argmax(self) -> Tensor: - grad, = torch.autograd.grad( - self.max, self.emissions, torch.ones_like(self.max), - create_graph=False, only_inputs=True, allow_unused=True, - ) - return grad - @lazy_property def marginals(self) -> Tensor: grad, = torch.autograd.grad( - self.log_partitions, self.emissions, torch.ones_like(self.log_partitions), - create_graph=True, only_inputs=True, allow_unused=True, + self.log_partitions, self.emissions.data, torch.ones_like(self.log_partitions), + create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True, ) return grad @lazy_property - def entropy(self) -> Tensor: + def max(self) -> Tensor: raise NotImplementedError + + @lazy_property + def argmax(self) -> Union[C, D, P]: + grad, = torch.autograd.grad( + self.max, self.emissions.data, torch.ones_like(self.max), + create_graph=False, retain_graph=False, only_inputs=True, allow_unused=True, + ) + return self.emissions._replace(data=grad.argmax(dim=-1)) diff --git a/torchlatent/abc2.py b/torchlatent/abc2.py deleted file mode 100644 index 70ed998..0000000 --- a/torchlatent/abc2.py +++ /dev/null @@ -1,48 +0,0 @@ -from abc import ABCMeta -from typing import Union - -import torch -import torch.autograd -from torch import Tensor -from torch.distributions.utils import lazy_property - -from torchrua import C -from torchrua import D -from torchrua import P - - -class StructuredDistribution(object, metaclass=ABCMeta): - def __init__(self, emissions: Union[C, D, P]) -> None: - super(StructuredDistribution, self).__init__() - self.emissions = emissions - - def log_scores(self, targets: Union[C, D, P]) -> Tensor: - raise NotImplementedError - - def log_probs(self, targets: Union[C, D, P]) -> Tensor: - return self.log_scores(targets=targets) - self.log_partitions - - @lazy_property - def log_partitions(self) -> Tensor: - raise NotImplementedError - - @lazy_property - def marginals(self) -> Tensor: - grad, = torch.autograd.grad( - self.log_partitions, self.emissions.data, torch.ones_like(self.log_partitions), - create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True, - - ) - return grad - - @lazy_property - def max(self) -> Tensor: - raise NotImplementedError - - @lazy_property - def argmax(self) -> Union[C, D, P]: - grad, = torch.autograd.grad( - self.max, self.emissions.data, torch.ones_like(self.max), - create_graph=False, retain_graph=False, only_inputs=True, allow_unused=True, - ) - return self.emissions._replace(data=grad.argmax(dim=-1)) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index e81eab7..45de7ee 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -6,25 +6,17 @@ from typing import Union import torch -from torch import nn from torch import Tensor +from torch import nn from torch.distributions.utils import lazy_property from torch.nn.utils.rnn import PackedSequence from torch.types import Device -from torchrua import accumulate_sizes -from torchrua import cat_packed_indices -from torchrua import CattedSequence -from torchrua import major_sizes_to_ptr -from torchrua import pack_catted_sequence -from torchrua import pad_indices -from torchrua import pad_sequence -from torchrua import RuaSequential - -from torchlatent.abc import DistributionABC -from torchlatent.nn.classifier import BiaffineClassifier +from torchlatent.abc2 import StructuredDistribution + from torchlatent.semiring import Log from torchlatent.semiring import Max from torchlatent.semiring import Semiring +from torchrua import CattedSequence Sequence = Union[CattedSequence, PackedSequence] @@ -114,7 +106,7 @@ def cky_partitions(data: Tensor, indices: CkyIndices, *, semiring: Type[Semiring return tensor1[tgt] -class CkyDistribution(DistributionABC): +class CkyDistribution(StructuredDistribution): def __init__(self, emissions: Tensor, indices: CkyIndices) -> None: super(CkyDistribution, self).__init__(validate_args=False) diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 46b125e..d396678 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -1,192 +1,73 @@ -from functools import singledispatch -from typing import NamedTuple from typing import Tuple from typing import Type from typing import Union import torch -from torch import nn from torch import Tensor +from torch import nn from torch.distributions.utils import lazy_property -from torch.nn import functional as F from torch.nn import init -from torch.types import Device -from torchrua import accumulate_sizes -from torchrua import CattedSequence -from torchrua import minor_sizes_to_ptr -from torchrua import PackedSequence -from torchrua import reduce_catted_indices -from torchrua import reduce_packed_indices -from torchrua import ReductionIndices -from torchrua import RuaSequential -from torchlatent.abc import DistributionABC -from torchlatent.nn.classifier import Classifier +from torchlatent.abc import StructuredDistribution from torchlatent.semiring import Log from torchlatent.semiring import Max from torchlatent.semiring import Semiring +from torchrua import C +from torchrua import D +from torchrua import P -Sequence = Union[CattedSequence, PackedSequence] - - -class CrfIndices(NamedTuple): - head: Tensor - last: Tensor - prev: Tensor - curr: Tensor - token_sizes: Tensor - unsorted_indices: Tensor - indices: ReductionIndices - - -@singledispatch -def broadcast_shapes(sequence: Sequence, transitions: Tuple[Tensor, Tensor, Tensor]) -> Sequence: - raise TypeError(f'type {type(sequence)} is not supported') - - -@broadcast_shapes.register -def broadcast_catted_shapes(sequence: CattedSequence, transitions: Tuple[Tensor, Tensor, Tensor]): - sequence, token_sizes = sequence - transitions, head_transitions, last_transitions = transitions - - t1, c1, *_ = sequence.size() - h1, = token_sizes.size() - - t2, c2, _, _ = transitions.size() - h3, c3, _ = head_transitions.size() - h4, c4, _ = last_transitions.size() - - return torch.broadcast_shapes((t1, c1, h1), (t2, c2, 1), (1, c3, h3), (1, c4, h4)) - - -@broadcast_shapes.register -def broadcast_packed_shapes(sequence: PackedSequence, transitions: Tuple[Tensor, Tensor, Tensor]): - sequence, batch_sizes, _, _ = sequence - transitions, head_transitions, last_transitions = transitions - - t1, c1, *_ = sequence.size() - h1 = batch_sizes[0].item() - - t2, c2, _, _ = transitions.size() - h3, c3, _ = head_transitions.size() - h4, c4, _ = last_transitions.size() - - return torch.broadcast_shapes((t1, c1, h1), (t2, c2, 1), (1, c3, h3), (1, c4, h4)) +T = Tuple[Tensor, Tensor, Tensor] -@singledispatch -def crf_scores_indices(sequence: Sequence, device: Device = None): - raise TypeError(f'type {type(sequence)} is not supported') - - -@crf_scores_indices.register -def crf_scores_catted_indices(sequence: CattedSequence, device: Device = None): - if device is None: - device = sequence.data.device - - token_sizes = sequence.token_sizes.to(device=device) - acc_token_sizes = token_sizes.cumsum(dim=0) - - index = torch.arange(token_sizes.sum().item(), device=device) - unsorted_indices = torch.arange(token_sizes.size()[0], device=device) - - return F.pad(acc_token_sizes, [1, -1]), acc_token_sizes - 1, index - 1, index, token_sizes, unsorted_indices - - -@crf_scores_indices.register -def crf_scores_packed_indices(sequence: PackedSequence, device: Device = None): - if device is None: - device = sequence.data.device - - batch_sizes = sequence.batch_sizes.to(device=device) - unsorted_indices = sequence.unsorted_indices.to(device=device) - acc_batch_sizes = F.pad(batch_sizes.cumsum(dim=0), [2, -1]) - - batch_ptr, token_ptr, token_sizes = minor_sizes_to_ptr( - sizes=batch_sizes, minor_ptr=unsorted_indices, - ) - prev = acc_batch_sizes[token_ptr + 0] + batch_ptr - curr = acc_batch_sizes[token_ptr + 1] + batch_ptr - last = acc_batch_sizes[token_sizes] + unsorted_indices - - return unsorted_indices, last, prev, curr, token_sizes, unsorted_indices - - -def crf_scores(sequence: Sequence, emissions: Tensor, transitions: Tuple[Tensor, Tensor, Tensor], - semiring: Type[Semiring]) -> Tensor: - head, last, prev, curr, token_sizes, unsorted_indices = crf_scores_indices(sequence) - +def crf_scores(emissions: Union[C, D, P], targets: Union[C, D, P], transitions: T, semiring: Type[Semiring]) -> Tensor: transitions, head_transitions, last_transitions = transitions - c = torch.arange(transitions.size()[1], device=emissions.device) - emissions = emissions[curr[:, None], c[None, :], sequence.data[curr]] - transitions = transitions[curr[:, None], c[None, :], sequence.data[prev], sequence.data[curr]] - transitions[accumulate_sizes(sizes=token_sizes)] = semiring.one - head_transitions = head_transitions[unsorted_indices[:, None], c[None, :], sequence.data[head]] - last_transitions = last_transitions[unsorted_indices[:, None], c[None, :], sequence.data[last]] + targets = targets.cat() + head_transitions = head_transitions[targets.head().data] + last_transitions = last_transitions[targets.last().data] + transitions = transitions[targets.roll(1).data, targets.data] - emissions = semiring.segment_prod(semiring.mul(emissions, transitions), sizes=token_sizes) - return semiring.mul(emissions, semiring.mul(head_transitions, last_transitions)) + emissions, _ = emissions.idx().cat().rua(emissions, targets) + emissions = semiring.segment_prod(emissions, sizes=targets.token_sizes) + token_sizes = torch.stack([torch.ones_like(targets.token_sizes), targets.token_sizes - 1], dim=-1) + transitions = semiring.segment_prod(transitions, sizes=token_sizes.view(-1))[1::2] -@torch.no_grad() -def crf_indices(emissions: Sequence) -> CrfIndices: - head, last, prev, curr, token_sizes, unsorted_indices = crf_scores_indices(emissions) - if isinstance(emissions, CattedSequence): - indices = reduce_catted_indices( - token_sizes=emissions.token_sizes, - device=emissions.data.device, - ) - elif isinstance(emissions, PackedSequence): - indices = reduce_packed_indices( - batch_sizes=emissions.batch_sizes, - unsorted_indices=emissions.unsorted_indices, - device=emissions.data.device, - ) - else: - raise KeyError(f'type {type(emissions)} is not supported') - - return CrfIndices( - head=head, last=last, - prev=prev, curr=curr, - token_sizes=token_sizes, - unsorted_indices=unsorted_indices, - indices=indices, + return semiring.mul( + semiring.mul(head_transitions, last_transitions), + semiring.mul(emissions, transitions), ) -def crf_partitions(emissions: Tensor, transitions: Tuple[Tensor, Tensor, Tensor], - indices: CrfIndices, semiring: Type[Semiring]): - head, _, _, _, _, unsorted_indices, indices = indices - +def crf_partitions(emissions: Union[C, D, P], transitions: T, semiring: Type[Semiring]) -> Tensor: transitions, head_transitions, last_transitions = transitions - c = torch.arange(transitions.size()[1], device=emissions.device) - transitions = semiring.mul(emissions[:, :, None, :], transitions) - transitions[head] = semiring.eye_like(transitions)[None, None, :, :] + emissions = emissions.pack() + last_indices = emissions.idx().last() + emissions, batch_sizes, _, unsorted_indices = emissions - head_transitions = head_transitions[unsorted_indices[:, None], c[None, :], None, :] - last_transitions = last_transitions[unsorted_indices[:, None], c[None, :], :, None] + _, *batch_sizes = sections = batch_sizes.detach().cpu().tolist() + emission, *emissions = torch.split(emissions, sections, dim=0) - scores = semiring.mul(emissions[head[:, None], c[None, :], None, :], head_transitions) - scores = semiring.bmm(scores, semiring.reduce(transitions, indices=indices)) - scores = semiring.bmm(scores, last_transitions) + charts = [semiring.mul(head_transitions, emission)] + for emission, batch_size in zip(emissions, batch_sizes): + charts.append(semiring.mul( + semiring.bmm(charts[-1][:batch_size], transitions), + emission, + )) - return scores[..., 0, 0] + emission = torch.cat(charts, dim=0)[last_indices] + return semiring.sum(semiring.mul(emission, last_transitions), dim=-1) -class CrfDistribution(DistributionABC): - def __init__(self, emissions: Tensor, transitions: Tuple[Tensor, Tensor, Tensor], indices: CrfIndices) -> None: - super(CrfDistribution, self).__init__(validate_args=False) - - self.emissions = emissions - self.indices = indices +class CrfDistribution(StructuredDistribution): + def __init__(self, emissions: Union[C, D, P], transitions: T) -> None: + super(CrfDistribution, self).__init__(emissions=emissions) self.transitions = transitions - def log_scores(self, targets: Sequence) -> Tensor: + def log_scores(self, targets: Union[C, D, P]) -> Tensor: return crf_scores( - emissions=self.emissions, - sequence=targets, + emissions=self.emissions, targets=targets, transitions=self.transitions, semiring=Log, ) @@ -196,7 +77,6 @@ def log_partitions(self) -> Tensor: return crf_partitions( emissions=self.emissions, transitions=self.transitions, - indices=self.indices, semiring=Log, ) @@ -205,41 +85,19 @@ def max(self) -> Tensor: return crf_partitions( emissions=self.emissions, transitions=self.transitions, - indices=self.indices, semiring=Max, ) - @lazy_property - def argmax(self) -> Tensor: - return super(CrfDistribution, self).argmax.argmax(dim=-1) - - @lazy_property - def entropy(self) -> Tensor: - tensor = (self.marginals * self.marginals.log()).sum(dim=-1) - return -Log.segment_prod( - tensor=tensor[self.indices.curr], - sizes=self.indices.token_sizes, - ) - - -class CrfLayerABC(nn.Module): - def reset_parameters(self) -> None: - raise NotImplementedError - - def forward_parameters(self, emissions: Sequence): - raise NotImplementedError - -class CrfLayer(CrfLayerABC): - def __init__(self, num_targets: int, num_conjugates: int = 1) -> None: - super(CrfLayer, self).__init__() +class CrfDecoder(nn.Module): + def __init__(self, *, num_targets: int) -> None: + super(CrfDecoder, self).__init__() self.num_targets = num_targets - self.num_conjugates = num_conjugates - self.transitions = nn.Parameter(torch.empty((1, num_conjugates, num_targets, num_targets))) - self.head_transitions = nn.Parameter(torch.empty((1, num_conjugates, num_targets))) - self.last_transitions = nn.Parameter(torch.empty((1, num_conjugates, num_targets))) + self.transitions = nn.Parameter(torch.empty((num_targets, num_targets))) + self.head_transitions = nn.Parameter(torch.empty((num_targets,))) + self.last_transitions = nn.Parameter(torch.empty((num_targets,))) self.reset_parameters() @@ -249,82 +107,14 @@ def reset_parameters(self) -> None: init.zeros_(self.last_transitions) def extra_repr(self) -> str: - return ', '.join([ - f'num_targets={self.num_targets}', - f'num_conjugates={self.num_conjugates}', - ]) - - def forward_parameters(self, emissions: Sequence): - transitions = (self.transitions, self.head_transitions, self.last_transitions) - t, c, h = broadcast_shapes(emissions, transitions=transitions) - - emissions = emissions.data.expand((t, c, -1)) - transitions = self.transitions.expand((t, c, -1, -1)) - head_transitions = self.head_transitions.expand((h, c, -1)) - last_transitions = self.last_transitions.expand((h, c, -1)) - - return emissions, (transitions, head_transitions, last_transitions) - - def forward(self, emissions: Sequence, indices: CrfIndices = None) -> CrfDistribution: - if indices is None: - indices = crf_indices(emissions=emissions) - - emissions, transitions = self.forward_parameters(emissions=emissions) - - return CrfDistribution(emissions=emissions, transitions=transitions, indices=indices) - - def fit(self, emissions: Sequence, targets: Sequence, indices: CrfIndices = None) -> Tensor: - dist: CrfDistribution = self.forward(emissions=emissions, indices=indices) - return dist.log_partitions - dist.log_scores(targets=targets) - - def decode(self, emissions: Sequence, indices: CrfIndices = None) -> Sequence: - dist: CrfDistribution = self.forward(emissions=emissions, indices=indices) - return emissions._replace(data=dist.argmax) - - -class CrfDecoder(nn.Module): - def __init__(self, in_features: int, num_targets: int, num_conjugates: int, dropout: float) -> None: - super(CrfDecoder, self).__init__() - - self.in_features = in_features - self.num_targets = num_targets - self.num_conjugates = num_conjugates - num_conjugates = max(1, num_conjugates) - - self.classifier = RuaSequential( - nn.Dropout(dropout), - Classifier( - num_conjugates=num_conjugates, - in_features=in_features, - out_features=num_targets, - bias=False, - ) - ) - - self.crf = CrfLayer( - num_targets=num_targets, - num_conjugates=num_conjugates, + return f'num_targets={self.num_targets}' + + def forward(self, emissions: Union[C, D, P]) -> CrfDistribution: + return CrfDistribution( + emissions=emissions, + transitions=( + self.transitions, + self.head_transitions, + self.last_transitions, + ), ) - - def forward(self, sequence: Sequence) -> CrfDistribution: - if self.num_conjugates == 0: - sequence = sequence._replace(data=sequence.data[..., None, :]) - - emissions = self.classifier(sequence) - return self.crf(emissions) - - def fit(self, sequence: Sequence, targets: Sequence) -> Tensor: - dist: CrfDistribution = self(sequence=sequence) - loss = dist.log_partitions - dist.log_scores(targets=targets) - - if self.num_conjugates == 0: - loss = loss[..., 0] - return loss - - def decode(self, sequence: Sequence) -> Sequence: - dist: CrfDistribution = self(sequence=sequence) - argmax = dist.argmax - - if self.num_conjugates == 0: - argmax = argmax[..., 0] - return sequence._replace(data=argmax) diff --git a/torchlatent/linear_crf.py b/torchlatent/linear_crf.py deleted file mode 100644 index e0d9257..0000000 --- a/torchlatent/linear_crf.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import Tuple -from typing import Type -from typing import Union - -import torch -from torch import Tensor -from torch import nn -from torch.distributions.utils import lazy_property -from torch.nn import init - -from torchlatent.abc2 import StructuredDistribution -from torchlatent.semiring import Log -from torchlatent.semiring import Max -from torchlatent.semiring import Semiring -from torchrua import C -from torchrua import D -from torchrua import P - -T = Tuple[Tensor, Tensor, Tensor] - - -def crf_scores(emissions: Union[C, D, P], targets: Union[C, D, P], transitions: T, semiring: Type[Semiring]) -> Tensor: - transitions, head_transitions, last_transitions = transitions - - targets = targets.cat() - head_transitions = head_transitions[targets.head().data] - last_transitions = last_transitions[targets.last().data] - transitions = transitions[targets.roll(1).data, targets.data] - - emissions, _ = emissions.idx().cat().rua(emissions, targets) - emissions = semiring.segment_prod(emissions, sizes=targets.token_sizes) - - token_sizes = torch.stack([torch.ones_like(targets.token_sizes), targets.token_sizes - 1], dim=-1) - transitions = semiring.segment_prod(transitions, sizes=token_sizes.view(-1))[1::2] - - return semiring.mul( - semiring.mul(head_transitions, last_transitions), - semiring.mul(emissions, transitions), - ) - - -def crf_partitions(emissions: Union[C, D, P], transitions: T, semiring: Type[Semiring]) -> Tensor: - transitions, head_transitions, last_transitions = transitions - - emissions = emissions.pack() - last_indices = emissions.idx().last() - emissions, batch_sizes, _, unsorted_indices = emissions - - _, *batch_sizes = sections = batch_sizes.detach().cpu().tolist() - emission, *emissions = torch.split(emissions, sections, dim=0) - - charts = [semiring.mul(head_transitions, emission)] - for emission, batch_size in zip(emissions, batch_sizes): - charts.append(semiring.mul( - semiring.bmm(charts[-1][:batch_size], transitions), - emission, - )) - - emission = torch.cat(charts, dim=0)[last_indices] - return semiring.sum(semiring.mul(emission, last_transitions), dim=-1) - - -class CrfDistribution(StructuredDistribution): - def __init__(self, emissions: Union[C, D, P], transitions: T) -> None: - super(CrfDistribution, self).__init__(emissions=emissions) - self.transitions = transitions - - def log_scores(self, targets: Union[C, D, P]) -> Tensor: - return crf_scores( - emissions=self.emissions, targets=targets, - transitions=self.transitions, - semiring=Log, - ) - - @lazy_property - def log_partitions(self) -> Tensor: - return crf_partitions( - emissions=self.emissions, - transitions=self.transitions, - semiring=Log, - ) - - @lazy_property - def max(self) -> Tensor: - return crf_partitions( - emissions=self.emissions, - transitions=self.transitions, - semiring=Max, - ) - - -class CrfDecoder(nn.Module): - def __init__(self, *, num_targets: int) -> None: - super(CrfDecoder, self).__init__() - - self.num_targets = num_targets - - self.transitions = nn.Parameter(torch.empty((num_targets, num_targets))) - self.head_transitions = nn.Parameter(torch.empty((num_targets,))) - self.last_transitions = nn.Parameter(torch.empty((num_targets,))) - - self.reset_parameters() - - def reset_parameters(self) -> None: - init.zeros_(self.transitions) - init.zeros_(self.head_transitions) - init.zeros_(self.last_transitions) - - def extra_repr(self) -> str: - return f'num_targets={self.num_targets}' - - def forward(self, emissions: Union[C, D, P]) -> CrfDistribution: - return CrfDistribution( - emissions=emissions, - transitions=( - self.transitions, - self.head_transitions, - self.last_transitions, - ), - ) diff --git a/torchlatent/nn/__init__.py b/torchlatent/nn/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/torchlatent/nn/classifier.py b/torchlatent/nn/classifier.py deleted file mode 100644 index a1a26ff..0000000 --- a/torchlatent/nn/classifier.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -from torch import nn -from torch import Tensor -from torch.nn import init - - -class Classifier(nn.Module): - def __init__(self, bias: bool = False, *, num_conjugates: int, - in_features: int, out_features: int) -> None: - super(Classifier, self).__init__() - - self.in_features = in_features - self.out_features = out_features - self.num_conjugates = num_conjugates - - self.weight = nn.Parameter(torch.empty((num_conjugates, out_features, in_features))) - self.bias = nn.Parameter(torch.empty((num_conjugates, out_features,))) if bias else 0 - - self.reset_parameters() - - def reset_parameters(self) -> None: - init.zeros_(self.weight) - - if torch.is_tensor(self.bias): - init.zeros_(self.bias) - - def extra_repr(self) -> str: - return ', '.join([ - f'in_features={self.in_features}', - f'out_features={self.out_features}', - f'num_conjugates={self.num_conjugates}', - f'bias={torch.is_tensor(self.bias)}', - ]) - - def forward(self, tensor: Tensor) -> Tensor: - return torch.einsum('nzx,...nx->...nz', self.weight, tensor) + self.bias - - -class BiaffineClassifier(nn.Module): - def __init__(self, bias: bool = False, *, - in_features1: int, in_features2: int, out_features: int) -> None: - super(BiaffineClassifier, self).__init__() - - self.in_features1 = in_features1 - self.in_features2 = in_features2 - self.out_features = out_features - - self.weight0 = nn.Parameter(torch.empty((out_features, in_features1, in_features2))) - self.weight1 = nn.Parameter(torch.empty((out_features, in_features1))) - self.weight2 = nn.Parameter(torch.empty((out_features, in_features2))) - self.bias = nn.Parameter(torch.empty((out_features,))) if bias else 0 - - self.reset_parameters() - - def reset_parameters(self) -> None: - init.zeros_(self.weight0) - init.zeros_(self.weight1) - init.zeros_(self.weight2) - - if torch.is_tensor(self.bias): - init.zeros_(self.bias) - - def extra_repr(self) -> str: - return ', '.join([ - f'in_features1={self.in_features1}', - f'in_features2={self.in_features2}', - f'out_features={self.out_features}', - f'bias={torch.is_tensor(self.bias)}', - ]) - - def forward(self, tensor1: Tensor, tensor2: Tensor) -> Tensor: - tensor0 = torch.einsum('zxy,...x,...y->...z', self.weight0, tensor1, tensor2) - tensor1 = torch.einsum('zx,...x->...z', self.weight1, tensor1) - tensor2 = torch.einsum('zy,...y->...z', self.weight2, tensor2) - - return tensor0 + tensor1 + tensor2 + self.bias From 96e4380aaf7cabd1e43b606da20d6f9451d1c9b9 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 01:24:45 +0900 Subject: [PATCH 076/102] Feat: Update CrfDecoder --- torchlatent/crf.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torchlatent/crf.py b/torchlatent/crf.py index d396678..3c3a847 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -22,15 +22,15 @@ def crf_scores(emissions: Union[C, D, P], targets: Union[C, D, P], transitions: T, semiring: Type[Semiring]) -> Tensor: transitions, head_transitions, last_transitions = transitions - targets = targets.cat() - head_transitions = head_transitions[targets.head().data] - last_transitions = last_transitions[targets.last().data] - transitions = transitions[targets.roll(1).data, targets.data] + targets = _, token_sizes = targets.cat() + head_transitions = targets.head().rua(head_transitions) + last_transitions = targets.last().rua(last_transitions) + transitions = targets.data.roll(1).rua(transitions, targets) emissions, _ = emissions.idx().cat().rua(emissions, targets) - emissions = semiring.segment_prod(emissions, sizes=targets.token_sizes) + emissions = semiring.segment_prod(emissions, sizes=token_sizes) - token_sizes = torch.stack([torch.ones_like(targets.token_sizes), targets.token_sizes - 1], dim=-1) + token_sizes = torch.stack([torch.ones_like(token_sizes), token_sizes - 1], dim=-1) transitions = semiring.segment_prod(transitions, sizes=token_sizes.view(-1))[1::2] return semiring.mul( @@ -44,7 +44,7 @@ def crf_partitions(emissions: Union[C, D, P], transitions: T, semiring: Type[Sem emissions = emissions.pack() last_indices = emissions.idx().last() - emissions, batch_sizes, _, unsorted_indices = emissions + emissions, batch_sizes, _, _ = emissions _, *batch_sizes = sections = batch_sizes.detach().cpu().tolist() emission, *emissions = torch.split(emissions, sections, dim=0) From 3bfba281c613c202113d605f2ec1320a6a68e492 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 02:00:50 +0900 Subject: [PATCH 077/102] Feat: Add cky2.py --- tests/test_cky.py | 6 ++-- tests/test_cky2.py | 67 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_crf.py | 56 ++++++++++++++++++------------------- torchlatent/cky2.py | 47 +++++++++++++++++++++++++++++++ 4 files changed, 145 insertions(+), 31 deletions(-) create mode 100644 tests/test_cky2.py create mode 100644 torchlatent/cky2.py diff --git a/tests/test_cky.py b/tests/test_cky.py index b9c3d31..d89f717 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -72,8 +72,8 @@ def test_cky_log_partitions(token_sizes, num_tags): ) token_sizes = torch.tensor(token_sizes, device=device) - excepted = TreeCRF(log_potentials=scores, lengths=token_sizes) + expected = TreeCRF(log_potentials=scores, lengths=token_sizes) actual = CkyDistribution(emissions=scores, indices=cky_partitions_indices(token_sizes=token_sizes, device=device)) - assert_close(actual=actual.log_partitions, expected=excepted.partition) - assert_grad_close(actual=actual.log_partitions, expected=excepted.partition, inputs=scores, rtol=1e-5, atol=1e-5) + assert_close(actual=actual.log_partitions, expected=expected.partition) + assert_grad_close(actual=actual.log_partitions, expected=expected.partition, inputs=scores, rtol=1e-5, atol=1e-5) diff --git a/tests/test_cky2.py b/tests/test_cky2.py new file mode 100644 index 0000000..0f158a3 --- /dev/null +++ b/tests/test_cky2.py @@ -0,0 +1,67 @@ +import torch +from hypothesis import given +from hypothesis import strategies as st +from torch.testing import assert_close +from torch_struct import TreeCRF + +from torchlatent.cky2 import cky_partitions +from torchlatent.cky2 import cky_scores +from torchlatent.semiring import Log +from torchnyan import BATCH_SIZE +from torchnyan import TOKEN_SIZE +from torchnyan import device +from torchnyan import sizes +from torchrua import C +from torchrua import CattedSequence + + +@given( + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_targets=sizes(TOKEN_SIZE), + rua_targets=st.sampled_from([C.cat, C.pad, C.pack]), +) +def test_cky_scores(token_sizes, num_targets, rua_targets): + emissions = torch.randn((len(token_sizes), max(token_sizes), max(token_sizes), num_targets), requires_grad=True) + token_sizes = torch.tensor(token_sizes, device=device) + + expected_cky = TreeCRF(emissions, lengths=token_sizes) + + mask = expected_cky.argmax > 0 + _, t, _, n = mask.size() + + index = torch.arange(t, device=mask.device) + x = torch.masked_select(index[None, :, None, None], mask=mask) + y = torch.masked_select(index[None, None, :, None], mask=mask) + + index = torch.arange(n, device=mask.device) + z = torch.masked_select(index[None, None, None, :], mask=mask) + + expected = expected_cky.max + + targets = CattedSequence(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) + actual = cky_scores( + emissions=CattedSequence(emissions, token_sizes), + targets=rua_targets(targets), + semiring=Log, + ) + + assert_close(actual=actual, expected=expected) + + +@given( + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_targets=sizes(TOKEN_SIZE), +) +def test_cky_partitions(token_sizes, num_targets): + emissions = torch.randn((len(token_sizes), max(token_sizes), max(token_sizes), num_targets), requires_grad=True) + token_sizes = torch.tensor(token_sizes, device=device) + + expected = TreeCRF(emissions, lengths=token_sizes).partition + + actual_emissions = CattedSequence( + data=emissions.logsumexp(dim=-1), + token_sizes=token_sizes, + ) + actual = cky_partitions(actual_emissions, Log) + + assert_close(actual=actual, expected=expected) diff --git a/tests/test_crf.py b/tests/test_crf.py index af9f30c..2827645 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -36,26 +36,26 @@ def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): for token_size in token_sizes ] - excepted_crf = CRF(num_tags=num_targets, batch_first=False) + expected_crf = CRF(num_tags=num_targets, batch_first=False) - excepted_emissions = pad_sequence(inputs) - excepted_tags = pad_sequence(targets) + expected_emissions = pad_sequence(inputs) + expected_tags = pad_sequence(targets) - excepted = excepted_crf._compute_score( - excepted_emissions.data.transpose(0, 1), - excepted_tags.data.transpose(0, 1), - excepted_emissions.mask().transpose(0, 1), + expected = expected_crf._compute_score( + expected_emissions.data.transpose(0, 1), + expected_tags.data.transpose(0, 1), + expected_emissions.mask().transpose(0, 1), ) actual = crf_scores( emissions=rua_emissions(inputs), targets=rua_targets(targets), - transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), + transitions=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions), semiring=Log, ) - assert_close(actual=actual, expected=excepted) - assert_grad_close(actual=actual, expected=excepted, inputs=inputs) + assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=inputs) @given( @@ -69,23 +69,23 @@ def test_crf_partitions(token_sizes, num_targets, rua_emissions): for token_size in token_sizes ] - excepted_crf = CRF(num_tags=num_targets, batch_first=False) + expected_crf = CRF(num_tags=num_targets, batch_first=False) - excepted_emissions = pad_sequence(inputs) + expected_emissions = pad_sequence(inputs) - excepted = excepted_crf._compute_normalizer( - excepted_emissions.data.transpose(0, 1), - excepted_emissions.mask().t(), + expected = expected_crf._compute_normalizer( + expected_emissions.data.transpose(0, 1), + expected_emissions.mask().t(), ) actual = crf_partitions( emissions=rua_emissions(inputs), - transitions=(excepted_crf.transitions, excepted_crf.start_transitions, excepted_crf.end_transitions), + transitions=(expected_crf.transitions, expected_crf.start_transitions, expected_crf.end_transitions), semiring=Log, ) - assert_close(actual=actual, expected=excepted, rtol=1e-4, atol=1e-4) - assert_grad_close(actual=actual, expected=excepted, inputs=inputs, rtol=1e-4, atol=1e-4) + assert_close(actual=actual, expected=expected, rtol=1e-4, atol=1e-4) + assert_grad_close(actual=actual, expected=expected, inputs=inputs, rtol=1e-4, atol=1e-4) @given( @@ -99,21 +99,21 @@ def test_crf_argmax(token_sizes, num_targets, rua_emissions): for token_size in token_sizes ] - excepted_crf = CRF(num_tags=num_targets, batch_first=False) + expected_crf = CRF(num_tags=num_targets, batch_first=False) - excepted_emissions = pad_sequence(inputs) + expected_emissions = pad_sequence(inputs) - excepted = excepted_crf.decode( - excepted_emissions.data.transpose(0, 1), - excepted_emissions.mask().t(), + expected = expected_crf.decode( + expected_emissions.data.transpose(0, 1), + expected_emissions.mask().t(), ) - excepted = cat_sequence([torch.tensor(tensor, device=device) for tensor in excepted]) + expected = cat_sequence([torch.tensor(tensor, device=device) for tensor in expected]) actual_crf = CrfDecoder(num_targets=num_targets) - actual_crf.transitions = excepted_crf.transitions - actual_crf.head_transitions = excepted_crf.start_transitions - actual_crf.last_transitions = excepted_crf.end_transitions + actual_crf.transitions = expected_crf.transitions + actual_crf.head_transitions = expected_crf.start_transitions + actual_crf.last_transitions = expected_crf.end_transitions actual = actual_crf(rua_emissions(inputs)).argmax.cat() - assert_sequence_close(actual=actual, expected=excepted) + assert_sequence_close(actual=actual, expected=expected) diff --git a/torchlatent/cky2.py b/torchlatent/cky2.py new file mode 100644 index 0000000..325adbc --- /dev/null +++ b/torchlatent/cky2.py @@ -0,0 +1,47 @@ +from typing import Type +from typing import Union + +import torch +from torch import Tensor + +from torchlatent.semiring import Semiring +from torchrua import C +from torchrua import D +from torchrua import P + + +def cky_scores(emissions: C, targets: Union[C, D, P], semiring: Type[Semiring]) -> Tensor: + xyz, token_sizes = targets = targets.cat() + batch_ptr, _ = targets.ptr() + + emissions = emissions.data[batch_ptr, xyz[..., 0], xyz[..., 1], xyz[..., 2]] + return semiring.segment_prod(emissions, token_sizes) + + +def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: + batch_ptr, token_ptr = emissions.ptr() + z_ptr, x_ptr = emissions._replace(token_sizes=token_ptr + 1).ptr() + y_ptr = token_ptr[z_ptr] + + _, token_size, *_ = emissions.size() + cache_size, = y_ptr.size() + + src1 = y_ptr - x_ptr, x_ptr + z_ptr - y_ptr + src2 = batch_ptr[z_ptr], x_ptr, y_ptr + tgt = emissions.token_sizes - 1, emissions.offsets() + + size = (token_size, cache_size, *emissions.data.size()[3:]) + tensor0 = torch.full(size, fill_value=semiring.zero, device=emissions.data.device, requires_grad=False) + tensor1 = torch.full(size, fill_value=semiring.zero, device=emissions.data.device, requires_grad=False) + tensor2 = torch.full(size, fill_value=semiring.zero, device=emissions.data.device, requires_grad=False) + + tensor0[src1] = emissions.data[src2] + tensor1[0, :] = tensor2[-1, :] = tensor0[0, :] + + for w in range(1, token_size): + tensor1[w, :-w] = tensor2[-w - 1, w:] = semiring.mul( + semiring.sum(semiring.mul(tensor1[:w, :-w], tensor2[-w:, w:]), dim=0), + tensor0[w, :-w], + ) + + return tensor1[tgt] From a929b9d38c4c4630d23bbe76b29368fe8fa89407 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 02:27:24 +0900 Subject: [PATCH 078/102] Feat: Add CkyDecoder --- tests/test_cky.py | 139 +++++++++++++++----------- tests/test_cky2.py | 67 ------------- third/__init__.py | 1 - third/crf.py | 84 ---------------- torchlatent/abc.py | 20 +++- torchlatent/cky.py | 235 ++++++++++---------------------------------- torchlatent/cky2.py | 47 --------- torchlatent/crf.py | 15 +-- 8 files changed, 155 insertions(+), 453 deletions(-) delete mode 100644 tests/test_cky2.py delete mode 100644 third/__init__.py delete mode 100644 third/crf.py delete mode 100644 torchlatent/cky2.py diff --git a/tests/test_cky.py b/tests/test_cky.py index d89f717..4fbb99e 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -4,76 +4,95 @@ from torch_struct import TreeCRF from torchlatent.cky import CkyDecoder -from torchlatent.cky import CkyDistribution -from torchlatent.cky import cky_partitions_indices -from torchnyan.assertion import assert_close -from torchnyan.assertion import assert_grad_close -from torchnyan.strategy import BATCH_SIZE -from torchnyan.strategy import FEATURE_DIM -from torchnyan.strategy import TINY_BATCH_SIZE -from torchnyan.strategy import TOKEN_SIZE -from torchnyan.strategy import device -from torchnyan.strategy import sizes -from torchrua import cat_sequence +from torchlatent.cky import cky_partitions +from torchlatent.cky import cky_scores +from torchlatent.semiring import Log +from torchnyan import BATCH_SIZE +from torchnyan import TINY_TOKEN_SIZE +from torchnyan import TOKEN_SIZE +from torchnyan import assert_close +from torchnyan import device +from torchnyan import sizes +from torchrua import C +from torchrua import CattedSequence @given( - token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), - embedding_dim=sizes(FEATURE_DIM), - num_tags=sizes(TOKEN_SIZE), - dropout=st.floats(0, 1), + token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), + num_targets=sizes(TOKEN_SIZE), + rua_targets=st.sampled_from([C.cat, C.pad, C.pack]), ) -def test_cky_catted_max(token_sizes, embedding_dim, num_tags, dropout): - sequence = cat_sequence([ - torch.randn((token_size, embedding_dim), requires_grad=True, device=device) - for token_size in token_sizes - ]) - - targets = cat_sequence([ - torch.empty((token_size * 2 - 1,), dtype=torch.long, device=device) - for token_size in token_sizes - ]) - - decoder = CkyDecoder( - in_features=embedding_dim, hidden_features=embedding_dim, - num_targets=num_tags, dropout=dropout, - ).to(device=device) - dist = decoder(sequence) - - assert_close(actual=dist.max, expected=dist.log_scores(targets=targets._replace(data=dist.argmax))) - - -# @given( -# token_sizes=sizes(TINY_BATCH_SIZE, TOKEN_SIZE), -# embedding_dim=sizes(FEATURE_DIM), -# num_tags=sizes(TOKEN_SIZE), -# bias=st.booleans(), -# ) -# def test_cky_packed_max(token_sizes, embedding_dim, num_tags, bias): -# sequence = pack_sequence([ -# torch.randn((token_size, embedding_dim), requires_grad=True, device=device) -# for token_size in token_sizes -# ]) -# -# decoder = CkyLayer(in_features=embedding_dim, out_features=num_tags, bias=bias).to(device=device) -# cky = decoder.forward(sequence=sequence) -# -# assert_close(actual=cky.max, expected=cky.log_scores(decoder.decode(sequence=sequence))) +def test_cky_scores(token_sizes, num_targets, rua_targets): + emissions = torch.randn((len(token_sizes), max(token_sizes), max(token_sizes), num_targets), requires_grad=True) + token_sizes = torch.tensor(token_sizes, device=device) + + expected_cky = TreeCRF(emissions, lengths=token_sizes) + + mask = expected_cky.argmax > 0 + _, t, _, n = mask.size() + + index = torch.arange(t, device=mask.device) + x = torch.masked_select(index[None, :, None, None], mask=mask) + y = torch.masked_select(index[None, None, :, None], mask=mask) + + index = torch.arange(n, device=mask.device) + z = torch.masked_select(index[None, None, None, :], mask=mask) + + expected = expected_cky.max + + targets = CattedSequence(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) + actual = cky_scores( + emissions=CattedSequence(emissions, token_sizes), + targets=rua_targets(targets), + semiring=Log, + ) + + assert_close(actual=actual, expected=expected) @given( token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_tags=sizes(TOKEN_SIZE), + num_targets=sizes(TOKEN_SIZE), ) -def test_cky_log_partitions(token_sizes, num_tags): - scores = torch.randn( - (len(token_sizes), max(token_sizes), max(token_sizes), num_tags), - requires_grad=True, device=device, +def test_cky_partitions(token_sizes, num_targets): + emissions = torch.randn((len(token_sizes), max(token_sizes), max(token_sizes), num_targets), requires_grad=True) + token_sizes = torch.tensor(token_sizes, device=device) + + expected = TreeCRF(emissions, lengths=token_sizes).partition + + actual_emissions = CattedSequence( + data=emissions.logsumexp(dim=-1), + token_sizes=token_sizes, ) + actual = cky_partitions(actual_emissions, Log) + + assert_close(actual=actual, expected=expected) + + +@given( + token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), + num_targets=sizes(TINY_TOKEN_SIZE), +) +def test_cky_argmax(token_sizes, num_targets): + emissions = torch.randn((len(token_sizes), max(token_sizes), max(token_sizes), num_targets), requires_grad=True) token_sizes = torch.tensor(token_sizes, device=device) - expected = TreeCRF(log_potentials=scores, lengths=token_sizes) - actual = CkyDistribution(emissions=scores, indices=cky_partitions_indices(token_sizes=token_sizes, device=device)) + expected_cky = TreeCRF(emissions, lengths=token_sizes) + + mask = expected_cky.argmax > 0 + _, t, _, n = mask.size() + + index = torch.arange(t, device=mask.device) + x = torch.masked_select(index[None, :, None, None], mask=mask) + y = torch.masked_select(index[None, None, :, None], mask=mask) + + index = torch.arange(n, device=mask.device) + z = torch.masked_select(index[None, None, None, :], mask=mask) + + expected = CattedSequence(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) + + actual_cky = CkyDecoder(num_targets=num_targets) + actual = actual_cky(emissions=CattedSequence(emissions, token_sizes)).argmax - assert_close(actual=actual.log_partitions, expected=expected.partition) - assert_grad_close(actual=actual.log_partitions, expected=expected.partition, inputs=scores, rtol=1e-5, atol=1e-5) + for actual, expected in zip(actual.tolist(), expected.tolist()): + assert set(map(tuple, actual)) == set(map(tuple, expected)) diff --git a/tests/test_cky2.py b/tests/test_cky2.py deleted file mode 100644 index 0f158a3..0000000 --- a/tests/test_cky2.py +++ /dev/null @@ -1,67 +0,0 @@ -import torch -from hypothesis import given -from hypothesis import strategies as st -from torch.testing import assert_close -from torch_struct import TreeCRF - -from torchlatent.cky2 import cky_partitions -from torchlatent.cky2 import cky_scores -from torchlatent.semiring import Log -from torchnyan import BATCH_SIZE -from torchnyan import TOKEN_SIZE -from torchnyan import device -from torchnyan import sizes -from torchrua import C -from torchrua import CattedSequence - - -@given( - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_targets=sizes(TOKEN_SIZE), - rua_targets=st.sampled_from([C.cat, C.pad, C.pack]), -) -def test_cky_scores(token_sizes, num_targets, rua_targets): - emissions = torch.randn((len(token_sizes), max(token_sizes), max(token_sizes), num_targets), requires_grad=True) - token_sizes = torch.tensor(token_sizes, device=device) - - expected_cky = TreeCRF(emissions, lengths=token_sizes) - - mask = expected_cky.argmax > 0 - _, t, _, n = mask.size() - - index = torch.arange(t, device=mask.device) - x = torch.masked_select(index[None, :, None, None], mask=mask) - y = torch.masked_select(index[None, None, :, None], mask=mask) - - index = torch.arange(n, device=mask.device) - z = torch.masked_select(index[None, None, None, :], mask=mask) - - expected = expected_cky.max - - targets = CattedSequence(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) - actual = cky_scores( - emissions=CattedSequence(emissions, token_sizes), - targets=rua_targets(targets), - semiring=Log, - ) - - assert_close(actual=actual, expected=expected) - - -@given( - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_targets=sizes(TOKEN_SIZE), -) -def test_cky_partitions(token_sizes, num_targets): - emissions = torch.randn((len(token_sizes), max(token_sizes), max(token_sizes), num_targets), requires_grad=True) - token_sizes = torch.tensor(token_sizes, device=device) - - expected = TreeCRF(emissions, lengths=token_sizes).partition - - actual_emissions = CattedSequence( - data=emissions.logsumexp(dim=-1), - token_sizes=token_sizes, - ) - actual = cky_partitions(actual_emissions, Log) - - assert_close(actual=actual, expected=expected) diff --git a/third/__init__.py b/third/__init__.py deleted file mode 100644 index 1b8004e..0000000 --- a/third/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from third.crf import CrfDecoder diff --git a/third/crf.py b/third/crf.py deleted file mode 100644 index e3fa4e6..0000000 --- a/third/crf.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -import torchcrf -from torch import nn -from torch import Tensor -from torch.nn.utils.rnn import PackedSequence -from torch.types import Device -from torchrua import pack_sequence -from torchrua import pad_catted_indices -from torchrua import pad_packed_sequence - - -@torch.no_grad() -def token_sizes_to_mask(sizes: Tensor, batch_first: bool, device: Device = None) -> Tensor: - if device is None: - device = sizes.device - - size, ptr = pad_catted_indices(sizes, batch_first=batch_first, device=device) - mask = torch.zeros(size, device=device, dtype=torch.bool) - mask[ptr] = True - return mask - - -class CrfDecoder(nn.Module): - def __init__(self, num_tags: int, num_conjugates: int) -> None: - super(CrfDecoder, self).__init__() - self.num_tags = num_tags - self.num_conjugates = num_conjugates - - self.decoders = nn.ModuleList([ - torchcrf.CRF(num_tags=num_tags, batch_first=False) - for _ in range(num_conjugates) - ]) - - @torch.no_grad() - def reset_parameters_with_(self, decoder) -> None: - assert self.num_tags == decoder.num_targets - assert self.num_conjugates == decoder.num_conjugates - - for index in range(self.num_conjugates): - self.decoders[index].transitions.data[::] = decoder.transitions[:, index, :, :] - self.decoders[index].start_transitions.data[::] = decoder.head_transitions[:, index, :] - self.decoders[index].end_transitions.data[::] = decoder.last_transitions[:, index, :] - - def fit(self, emissions: PackedSequence, tags: PackedSequence, **kwargs) -> Tensor: - num_emissions_conjugates = emissions.data.size()[1] - num_decoders_conjugates = self.num_conjugates - num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) - - emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) - tags, _ = pad_packed_sequence(tags, batch_first=False) - mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False) - - log_probs = [] - for index in range(num_conjugates): - decoder = self.decoders[index % num_decoders_conjugates] - emission = emissions[:, :, index % num_emissions_conjugates] - tag = tags[:, :, index % num_emissions_conjugates] - - log_probs.append(decoder(emissions=emission, tags=tag, mask=mask, reduction='none')) - - return torch.stack(log_probs, dim=-1) - - def decode(self, emissions: PackedSequence, **kwargs) -> PackedSequence: - num_emissions_conjugates = emissions.data.size()[1] - num_decoders_conjugates = self.num_conjugates - num_conjugates = max(num_emissions_conjugates, num_decoders_conjugates) - - emissions, token_sizes = pad_packed_sequence(emissions, batch_first=False) - mask = token_sizes_to_mask(sizes=token_sizes, batch_first=False) - - predictions = [] - for index in range(num_conjugates): - decoder = self.decoders[index % num_decoders_conjugates] - emission = emissions[:, :, index % num_emissions_conjugates] - - prediction = decoder.decode(emissions=emission, mask=mask) - predictions.append(pack_sequence([torch.tensor(p) for p in prediction], device=emissions.device)) - - return PackedSequence( - torch.stack([prediction.data for prediction in predictions], dim=1), - batch_sizes=predictions[0].batch_sizes, - sorted_indices=predictions[0].sorted_indices, - unsorted_indices=predictions[0].unsorted_indices, - ) diff --git a/torchlatent/abc.py b/torchlatent/abc.py index 70ed998..f963d27 100644 --- a/torchlatent/abc.py +++ b/torchlatent/abc.py @@ -4,6 +4,7 @@ import torch import torch.autograd from torch import Tensor +from torch import nn from torch.distributions.utils import lazy_property from torchrua import C @@ -40,9 +41,24 @@ def max(self) -> Tensor: raise NotImplementedError @lazy_property - def argmax(self) -> Union[C, D, P]: + def argmax(self) -> Tensor: grad, = torch.autograd.grad( self.max, self.emissions.data, torch.ones_like(self.max), create_graph=False, retain_graph=False, only_inputs=True, allow_unused=True, ) - return self.emissions._replace(data=grad.argmax(dim=-1)) + return grad + + +class StructuredDecoder(nn.Module): + def __init__(self, *, num_targets: int) -> None: + super(StructuredDecoder, self).__init__() + self.num_targets = num_targets + + def reset_parameters(self) -> None: + pass + + def extra_repr(self) -> str: + return f'num_targets={self.num_targets}' + + def forward(self, emissions: Union[C, D, P]) -> StructuredDistribution: + raise NotImplementedError diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 45de7ee..622f0c5 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -1,100 +1,47 @@ -from abc import ABCMeta -from functools import singledispatch -from typing import NamedTuple -from typing import Tuple from typing import Type from typing import Union import torch from torch import Tensor -from torch import nn from torch.distributions.utils import lazy_property -from torch.nn.utils.rnn import PackedSequence -from torch.types import Device -from torchlatent.abc2 import StructuredDistribution +from torchlatent.abc import StructuredDecoder +from torchlatent.abc import StructuredDistribution from torchlatent.semiring import Log from torchlatent.semiring import Max from torchlatent.semiring import Semiring +from torchrua import C from torchrua import CattedSequence +from torchrua import D +from torchrua import P -Sequence = Union[CattedSequence, PackedSequence] +def cky_scores(emissions: C, targets: Union[C, D, P], semiring: Type[Semiring]) -> Tensor: + xyz, token_sizes = targets = targets.cat() + batch_ptr, _ = targets.ptr() -@singledispatch -def cky_scores_indices(sequence: Sequence, device: Device = None): - raise KeyError(f'type {type(sequence)} is not supported') + emissions = emissions.data[batch_ptr, xyz[..., 0], xyz[..., 1], xyz[..., 2]] + return semiring.segment_prod(emissions, token_sizes) -@cky_scores_indices.register -def cky_scores_catted_indices(sequence: CattedSequence, device: Device = None): - if device is None: - device = sequence.data.device - - token_sizes = sequence.token_sizes.to(device=device) - - batch_ptr = torch.repeat_interleave(repeats=token_sizes) - return ..., batch_ptr, token_sizes - - -@cky_scores_indices.register -def cky_scores_packed_indices(sequence: PackedSequence, device: Device = None): - if device is None: - device = sequence.data.device - - batch_sizes = sequence.batch_sizes.to(device=device) - unsorted_indices = sequence.unsorted_indices.to(device=device) - - indices, token_sizes = cat_packed_indices( - batch_sizes=batch_sizes, - unsorted_indices=unsorted_indices, - device=device, - ) - - batch_ptr = torch.repeat_interleave(repeats=token_sizes) - return indices, batch_ptr, token_sizes - - -class CkyIndices(NamedTuple): - token_size: int - cache_size: int - - src: Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, Tensor]] - tgt: Tuple[Tensor, Tensor] - - -@torch.no_grad() -def cky_partitions_indices(token_sizes: Tensor, device: Device = None): - if device is None: - device = token_sizes.device - - token_sizes = token_sizes.to(device=device) - acc_token_sizes = accumulate_sizes(sizes=token_sizes) - - token_ptr, batch_ptr = major_sizes_to_ptr(sizes=token_sizes) - x_ptr, z_ptr = major_sizes_to_ptr(sizes=token_ptr + 1) +def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: + batch_ptr, token_ptr = emissions.ptr() + z_ptr, x_ptr = emissions._replace(token_sizes=token_ptr + 1).ptr() y_ptr = token_ptr[z_ptr] - token_size = token_sizes.max().item() - cache_size, = token_ptr.size() - - return CkyIndices( - token_size=token_size, - cache_size=cache_size, - src=((y_ptr - x_ptr, x_ptr + z_ptr - y_ptr), (batch_ptr[z_ptr], x_ptr, y_ptr)), - tgt=(token_sizes - 1, acc_token_sizes), - ) + _, token_size, *_ = emissions.size() + cache_size, = y_ptr.size() + src1 = y_ptr - x_ptr, x_ptr + z_ptr - y_ptr + src2 = batch_ptr[z_ptr], x_ptr, y_ptr + tgt = emissions.token_sizes - 1, emissions.offsets() -def cky_partitions(data: Tensor, indices: CkyIndices, *, semiring: Type[Semiring]) -> Tensor: - token_size, cache_size, (src1, src2), tgt = indices + size = (token_size, cache_size, *emissions.data.size()[3:]) + tensor0 = torch.full(size, fill_value=semiring.zero, device=emissions.data.device, requires_grad=False) + tensor1 = torch.full(size, fill_value=semiring.zero, device=emissions.data.device, requires_grad=False) + tensor2 = torch.full(size, fill_value=semiring.zero, device=emissions.data.device, requires_grad=False) - size = (token_size, cache_size, *data.size()[3:]) - tensor0 = torch.full(size, fill_value=semiring.zero, device=data.device, requires_grad=False) - tensor1 = torch.full(size, fill_value=semiring.zero, device=data.device, requires_grad=False) - tensor2 = torch.full(size, fill_value=semiring.zero, device=data.device, requires_grad=False) - - tensor0[src1] = data[src2] + tensor0[src1] = emissions.data[src2] tensor1[0, :] = tensor2[-1, :] = tensor0[0, :] for w in range(1, token_size): @@ -107,132 +54,50 @@ def cky_partitions(data: Tensor, indices: CkyIndices, *, semiring: Type[Semiring class CkyDistribution(StructuredDistribution): - def __init__(self, emissions: Tensor, indices: CkyIndices) -> None: - super(CkyDistribution, self).__init__(validate_args=False) - - self.emissions = emissions - self.indices = indices - - def log_scores(self, targets: Sequence) -> Tensor: - indices, batch_ptr, sizes = cky_scores_indices(targets) - data = targets.data[indices] - return Log.segment_prod( - tensor=self.emissions[batch_ptr, data[..., 0], data[..., 1], data[..., 2]], - sizes=sizes, + def __init__(self, emissions: C) -> None: + super(CkyDistribution, self).__init__(emissions=emissions) + + def log_scores(self, targets: Union[C, D, P]) -> Tensor: + return cky_scores( + emissions=self.emissions, targets=targets, + semiring=Log, ) @lazy_property def log_partitions(self) -> Tensor: - return cky_partitions(data=Log.sum(self.emissions, dim=-1), indices=self.indices, semiring=Log) + return cky_partitions( + emissions=self.emissions._replace(data=Log.sum(self.emissions.data, dim=-1)), + semiring=Log, + ) @lazy_property def max(self) -> Tensor: - return cky_partitions(data=Max.sum(self.emissions, dim=-1), indices=self.indices, semiring=Max) + return cky_partitions( + emissions=self.emissions._replace(data=Max.sum(self.emissions.data, dim=-1)), + semiring=Max, + ) @lazy_property - def argmax(self) -> Tensor: + def argmax(self) -> C: mask = super(CkyDistribution, self).argmax > 0 - b, n, _, m = mask.size() + _, t, _, n = mask.size() - index = torch.arange(n, device=mask.device) + index = torch.arange(t, device=mask.device) x = torch.masked_select(index[None, :, None, None], mask=mask) y = torch.masked_select(index[None, None, :, None], mask=mask) - index = torch.arange(m, device=mask.device) + index = torch.arange(n, device=mask.device) z = torch.masked_select(index[None, None, None, :], mask=mask) - return torch.stack([x, y, z], dim=-1) - @lazy_property - def marginals(self) -> Tensor: - grad, = torch.autograd.grad( - self.log_partitions, self.emissions, torch.ones_like(self.log_partitions), - create_graph=True, only_inputs=True, allow_unused=False, + return CattedSequence( + data=torch.stack([x, y, z], dim=-1), + token_sizes=self.emissions.token_sizes * 2 - 1, ) - return grad - - @lazy_property - def entropy(self) -> Tensor: - raise NotImplementedError - - -class CkyLayerABC(nn.Module, metaclass=ABCMeta): - def reset_parameters(self) -> None: - raise NotImplementedError - - def forward_scores(self, features: Tensor, *args, **kwargs) -> Tensor: - raise NotImplementedError - - def forward(self, emissions: Sequence, indices: CkyIndices = None) -> CkyDistribution: - _, _, token_sizes = pad_indices(emissions, batch_first=True) - - if indices is None: - indices = cky_partitions_indices(token_sizes=token_sizes, device=emissions.data.device) - - return CkyDistribution(emissions=emissions.data, indices=indices) - - def fit(self, emissions: Sequence, targets: Sequence, indices: CkyIndices = None) -> Tensor: - dist = self.forward(emissions=emissions, indices=indices) - return dist.log_partitions - dist.log_scores(targets=targets) - - def decode(self, emissions: Sequence, indices: CkyIndices = None) -> Sequence: - dist = self.forward(emissions=emissions, indices=indices) - _, _, token_sizes = pad_indices(emissions, batch_first=True) - - if isinstance(emissions, CattedSequence): - sequence = CattedSequence(data=dist.argmax, token_sizes=token_sizes * 2 - 1) - return sequence - - if isinstance(emissions, PackedSequence): - sequence = CattedSequence(data=dist.argmax, token_sizes=token_sizes * 2 - 1) - return pack_catted_sequence(sequence) - - raise KeyError(f'type {type(emissions)} is not supported') - - -class CkyLayer(CkyLayerABC): - def __init__(self, num_targets: int) -> None: - super(CkyLayer, self).__init__() - - self.num_targets = num_targets - - def extra_repr(self) -> str: - return ', '.join([ - f'num_targets={self.num_targets}', - ]) - - -class CkyDecoder(nn.Module): - def __init__(self, in_features: int, hidden_features: int, num_targets: int, dropout: float) -> None: - super(CkyDecoder, self).__init__() - - self.in_features = in_features - self.hidden_features = hidden_features - self.num_targets = num_targets - - self.ffn1 = RuaSequential( - nn.Linear(in_features, hidden_features, bias=True), - nn.GELU(), - nn.Dropout(dropout), - ) - self.ffn2 = RuaSequential( - nn.Linear(in_features, hidden_features, bias=True), - nn.GELU(), - nn.Dropout(dropout), - ) - self.classifier = BiaffineClassifier( - in_features1=hidden_features, - in_features2=hidden_features, - out_features=num_targets, - bias=False, - ) - - self.cky = CkyLayer(num_targets=num_targets) - def forward(self, sequence: Sequence) -> CkyDistribution: - features, _ = pad_sequence(sequence, batch_first=True) - features1 = self.ffn1(features)[:, :, None, :] - features2 = self.ffn2(features)[:, None, :, :] +class CkyDecoder(StructuredDecoder): + def __init__(self, *, num_targets: int) -> None: + super(CkyDecoder, self).__init__(num_targets=num_targets) - emissions = self.classifier(features1, features2) - return self.cky(sequence._replace(data=emissions)) + def forward(self, emissions: C) -> CkyDistribution: + return CkyDistribution(emissions=emissions) diff --git a/torchlatent/cky2.py b/torchlatent/cky2.py deleted file mode 100644 index 325adbc..0000000 --- a/torchlatent/cky2.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Type -from typing import Union - -import torch -from torch import Tensor - -from torchlatent.semiring import Semiring -from torchrua import C -from torchrua import D -from torchrua import P - - -def cky_scores(emissions: C, targets: Union[C, D, P], semiring: Type[Semiring]) -> Tensor: - xyz, token_sizes = targets = targets.cat() - batch_ptr, _ = targets.ptr() - - emissions = emissions.data[batch_ptr, xyz[..., 0], xyz[..., 1], xyz[..., 2]] - return semiring.segment_prod(emissions, token_sizes) - - -def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: - batch_ptr, token_ptr = emissions.ptr() - z_ptr, x_ptr = emissions._replace(token_sizes=token_ptr + 1).ptr() - y_ptr = token_ptr[z_ptr] - - _, token_size, *_ = emissions.size() - cache_size, = y_ptr.size() - - src1 = y_ptr - x_ptr, x_ptr + z_ptr - y_ptr - src2 = batch_ptr[z_ptr], x_ptr, y_ptr - tgt = emissions.token_sizes - 1, emissions.offsets() - - size = (token_size, cache_size, *emissions.data.size()[3:]) - tensor0 = torch.full(size, fill_value=semiring.zero, device=emissions.data.device, requires_grad=False) - tensor1 = torch.full(size, fill_value=semiring.zero, device=emissions.data.device, requires_grad=False) - tensor2 = torch.full(size, fill_value=semiring.zero, device=emissions.data.device, requires_grad=False) - - tensor0[src1] = emissions.data[src2] - tensor1[0, :] = tensor2[-1, :] = tensor0[0, :] - - for w in range(1, token_size): - tensor1[w, :-w] = tensor2[-w - 1, w:] = semiring.mul( - semiring.sum(semiring.mul(tensor1[:w, :-w], tensor2[-w:, w:]), dim=0), - tensor0[w, :-w], - ) - - return tensor1[tgt] diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 3c3a847..e0e261f 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -8,6 +8,7 @@ from torch.distributions.utils import lazy_property from torch.nn import init +from torchlatent.abc import StructuredDecoder from torchlatent.abc import StructuredDistribution from torchlatent.semiring import Log from torchlatent.semiring import Max @@ -88,12 +89,15 @@ def max(self) -> Tensor: semiring=Max, ) + @lazy_property + def argmax(self) -> Union[C, D, P]: + argmax = super(CrfDistribution, self).argmax.argmax(dim=-1) + return self.emissions._replace(data=argmax) -class CrfDecoder(nn.Module): - def __init__(self, *, num_targets: int) -> None: - super(CrfDecoder, self).__init__() - self.num_targets = num_targets +class CrfDecoder(StructuredDecoder): + def __init__(self, *, num_targets: int) -> None: + super(CrfDecoder, self).__init__(num_targets=num_targets) self.transitions = nn.Parameter(torch.empty((num_targets, num_targets))) self.head_transitions = nn.Parameter(torch.empty((num_targets,))) @@ -106,9 +110,6 @@ def reset_parameters(self) -> None: init.zeros_(self.head_transitions) init.zeros_(self.last_transitions) - def extra_repr(self) -> str: - return f'num_targets={self.num_targets}' - def forward(self, emissions: Union[C, D, P]) -> CrfDistribution: return CrfDistribution( emissions=emissions, From c0c0c21b2d00e78994fb0337c8f303a5bd41c70e Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 02:29:48 +0900 Subject: [PATCH 079/102] Chore: Update config --- .github/workflows/publish-package.yml | 2 +- .github/workflows/unit-tests.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index c4e0338..2100175 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -14,7 +14,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: '3.9' - name: Install dependencies run: | python -m pip install pip setuptools wheel --upgrade diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index b6df8b1..168a8e2 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -12,12 +12,12 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.8' + python-version: '3.9' - name: Install dependencies run: | python -m pip install pip setuptools wheel --upgrade python -m pip install torch - python -m pip install pytest hypothesis torchnyan pytorch-crf + python -m pip install pytest hypothesis torchnyan pytorch-crf torch-struct python -m pip install git+https://github.com/speedcell4/torchrua.git@develop --force-reinstall --no-deps - name: Test with pytest run: | From cbf4824e75a0a5c0e5e337eeaa7ae5d2110fe296 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 02:34:18 +0900 Subject: [PATCH 080/102] Doc: Update README.md --- README.md | 7 +------ setup.py | 2 +- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 4ec95ef..2887e89 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,9 @@ # TorchLatent -![Unit Tests](https://github.com/speedcell4/torchlatent/workflows/Unit%20Tests/badge.svg) +[![unit tests](https://github.com/speedcell4/torchlatent/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/speedcell4/torchlatent/actions/workflows/unit-tests.yml) [![PyPI version](https://badge.fury.io/py/torchlatent.svg)](https://badge.fury.io/py/torchlatent) [![Downloads](https://pepy.tech/badge/torchrua)](https://pepy.tech/project/torchrua) -## Requirements - -- Python 3.8 -- PyTorch 2.0 - ## Installation `python3 -m pip torchlatent` diff --git a/setup.py b/setup.py index 841e603..da1a9f9 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ author='speedcell4', author_email='speedcell4@gmail.com', description='High Performance Structured Prediction in PyTorch', - python_requires='>=3.8', + python_requires='>=3.9', install_requires=[ 'numpy', 'torchrua', From 632a1b4de7b5bfbae9906f5aed1345d3cc1b1bdd Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 18:25:04 +0900 Subject: [PATCH 081/102] Doc: Update README.md --- README.md | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 2887e89..a319a97 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,20 @@ +
# TorchLatent -[![unit tests](https://github.com/speedcell4/torchlatent/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/speedcell4/torchlatent/actions/workflows/unit-tests.yml) -[![PyPI version](https://badge.fury.io/py/torchlatent.svg)](https://badge.fury.io/py/torchlatent) -[![Downloads](https://pepy.tech/badge/torchrua)](https://pepy.tech/project/torchrua) +![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/speedcell4/torchlatent/unit-tests.yml?cacheSeconds=0) +![PyPI - Version](https://img.shields.io/pypi/v/torchlatent?label=pypi%20version&cacheSeconds=0) +![PyPI - Downloads](https://img.shields.io/pypi/dm/torchlatent?cacheSeconds=0) + +
## Installation -`python3 -m pip torchlatent` +`python -m pip torchlatent` ## Latent Structures - [x] Conditional Random Fields (CRF) -- [x] Tree CRF -- [ ] Non-Projective Dependency Tree (Matrix-tree Theorem) +- [x] Cocke–Kasami-Younger algorithm (CKY) - [ ] Probabilistic Context-free Grammars (PCFG) +- [ ] Non-Projective Dependency Tree (Matrix-tree Theorem) - [ ] Dependency Model with Valence (DMV) \ No newline at end of file From a4e3c1167da6bd4e7cf9ac29e438e0f929db6f5c Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 18:27:52 +0900 Subject: [PATCH 082/102] Fix: Resolve device bug --- tests/test_cky.py | 15 ++++++++++++--- tests/test_crf.py | 6 +++--- torchlatent/cky.py | 6 +++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index 4fbb99e..75a7c8b 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -23,7 +23,10 @@ rua_targets=st.sampled_from([C.cat, C.pad, C.pack]), ) def test_cky_scores(token_sizes, num_targets, rua_targets): - emissions = torch.randn((len(token_sizes), max(token_sizes), max(token_sizes), num_targets), requires_grad=True) + emissions = torch.randn( + (len(token_sizes), max(token_sizes), max(token_sizes), num_targets), + device=device, requires_grad=True, + ) token_sizes = torch.tensor(token_sizes, device=device) expected_cky = TreeCRF(emissions, lengths=token_sizes) @@ -55,7 +58,10 @@ def test_cky_scores(token_sizes, num_targets, rua_targets): num_targets=sizes(TOKEN_SIZE), ) def test_cky_partitions(token_sizes, num_targets): - emissions = torch.randn((len(token_sizes), max(token_sizes), max(token_sizes), num_targets), requires_grad=True) + emissions = torch.randn( + (len(token_sizes), max(token_sizes), max(token_sizes), num_targets), + device=device, requires_grad=True, + ) token_sizes = torch.tensor(token_sizes, device=device) expected = TreeCRF(emissions, lengths=token_sizes).partition @@ -74,7 +80,10 @@ def test_cky_partitions(token_sizes, num_targets): num_targets=sizes(TINY_TOKEN_SIZE), ) def test_cky_argmax(token_sizes, num_targets): - emissions = torch.randn((len(token_sizes), max(token_sizes), max(token_sizes), num_targets), requires_grad=True) + emissions = torch.randn( + (len(token_sizes), max(token_sizes), max(token_sizes), num_targets), + device=device, requires_grad=True, + ) token_sizes = torch.tensor(token_sizes, device=device) expected_cky = TreeCRF(emissions, lengths=token_sizes) diff --git a/tests/test_crf.py b/tests/test_crf.py index 2827645..07f68d1 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -36,7 +36,7 @@ def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): for token_size in token_sizes ] - expected_crf = CRF(num_tags=num_targets, batch_first=False) + expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device) expected_emissions = pad_sequence(inputs) expected_tags = pad_sequence(targets) @@ -69,7 +69,7 @@ def test_crf_partitions(token_sizes, num_targets, rua_emissions): for token_size in token_sizes ] - expected_crf = CRF(num_tags=num_targets, batch_first=False) + expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device) expected_emissions = pad_sequence(inputs) @@ -99,7 +99,7 @@ def test_crf_argmax(token_sizes, num_targets, rua_emissions): for token_size in token_sizes ] - expected_crf = CRF(num_tags=num_targets, batch_first=False) + expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device) expected_emissions = pad_sequence(inputs) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 622f0c5..796d494 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -37,9 +37,9 @@ def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: tgt = emissions.token_sizes - 1, emissions.offsets() size = (token_size, cache_size, *emissions.data.size()[3:]) - tensor0 = torch.full(size, fill_value=semiring.zero, device=emissions.data.device, requires_grad=False) - tensor1 = torch.full(size, fill_value=semiring.zero, device=emissions.data.device, requires_grad=False) - tensor2 = torch.full(size, fill_value=semiring.zero, device=emissions.data.device, requires_grad=False) + tensor0 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) + tensor1 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) + tensor2 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) tensor0[src1] = emissions.data[src2] tensor1[0, :] = tensor2[-1, :] = tensor0[0, :] From b9321e0382cf8172e821cc323b0513342e80099e Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 18:28:56 +0900 Subject: [PATCH 083/102] Doc: Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index a319a97..d4b7449 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@
+ # TorchLatent ![GitHub Workflow Status (with event)](https://img.shields.io/github/actions/workflow/status/speedcell4/torchlatent/unit-tests.yml?cacheSeconds=0) From 66a86294a1e374e82f3f967ed30e00116f7bf87a Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 19:55:26 +0900 Subject: [PATCH 084/102] Chore: Upgrade setup.py --- .github/workflows/publish-package.yml | 2 +- .github/workflows/unit-tests.yml | 9 +++++---- requirements.txt | 2 ++ setup.py | 15 +++++++++------ 4 files changed, 17 insertions(+), 11 deletions(-) create mode 100644 requirements.txt diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 2100175..c4e0338 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -14,7 +14,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.9' + python-version: '3.8' - name: Install dependencies run: | python -m pip install pip setuptools wheel --upgrade diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 168a8e2..d18d6cd 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -12,13 +12,14 @@ jobs: - name: Set up Python uses: actions/setup-python@v4 with: - python-version: '3.9' + python-version: '3.8' - name: Install dependencies run: | - python -m pip install pip setuptools wheel --upgrade - python -m pip install torch - python -m pip install pytest hypothesis torchnyan pytorch-crf torch-struct + python -m pip install pip --upgrade + python -m pip install -r requirements.txt python -m pip install git+https://github.com/speedcell4/torchrua.git@develop --force-reinstall --no-deps + python -m pip install pytest hypothesis torchnyan + python -m pip install pytorch-crf torch-struct - name: Test with pytest run: | python -m pytest tests \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9443809 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +torch +torchrua diff --git a/setup.py b/setup.py index da1a9f9..136fab2 100644 --- a/setup.py +++ b/setup.py @@ -1,20 +1,23 @@ +from pathlib import Path + from setuptools import find_packages from setuptools import setup name = 'torchlatent' +root_dir = Path(__file__).parent.resolve() +with (root_dir / 'requirements.txt').open(mode='r', encoding='utf-8') as fp: + install_requires = [install_require.strip() for install_require in fp] + setup( name=name, - version='0.4.2', + version='0.5.0', packages=[package for package in find_packages() if package.startswith(name)], url='https://github.com/speedcell4/torchlatent', license='MIT', author='speedcell4', author_email='speedcell4@gmail.com', description='High Performance Structured Prediction in PyTorch', - python_requires='>=3.9', - install_requires=[ - 'numpy', - 'torchrua', - ], + python_requires='>=3.8', + install_requires=install_requires, ) From 75132dd4bd0e7bc4f40162c93fc01cce728b4e6c Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 20:01:49 +0900 Subject: [PATCH 085/102] Fix: Resolve CKY unit tests --- tests/test_cky.py | 19 +++++++++---------- tests/test_crf.py | 10 +++++----- tests/test_functional.py | 6 +++--- torchlatent/abc.py | 1 - torchlatent/cky.py | 8 ++++---- torchlatent/crf.py | 6 +++--- torchlatent/semiring.py | 6 +++--- 7 files changed, 27 insertions(+), 29 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index 75a7c8b..0d5ec94 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -2,24 +2,23 @@ from hypothesis import given from hypothesis import strategies as st from torch_struct import TreeCRF - -from torchlatent.cky import CkyDecoder -from torchlatent.cky import cky_partitions -from torchlatent.cky import cky_scores -from torchlatent.semiring import Log from torchnyan import BATCH_SIZE from torchnyan import TINY_TOKEN_SIZE -from torchnyan import TOKEN_SIZE from torchnyan import assert_close from torchnyan import device from torchnyan import sizes from torchrua import C from torchrua import CattedSequence +from torchlatent.cky import CkyDecoder +from torchlatent.cky import cky_partitions +from torchlatent.cky import cky_scores +from torchlatent.semiring import Log + @given( - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_targets=sizes(TOKEN_SIZE), + token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), + num_targets=sizes(TINY_TOKEN_SIZE), rua_targets=st.sampled_from([C.cat, C.pad, C.pack]), ) def test_cky_scores(token_sizes, num_targets, rua_targets): @@ -54,8 +53,8 @@ def test_cky_scores(token_sizes, num_targets, rua_targets): @given( - token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), - num_targets=sizes(TOKEN_SIZE), + token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), + num_targets=sizes(TINY_TOKEN_SIZE), ) def test_cky_partitions(token_sizes, num_targets): emissions = torch.randn( diff --git a/tests/test_crf.py b/tests/test_crf.py index 07f68d1..849212d 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -2,11 +2,6 @@ from hypothesis import given from hypothesis import strategies as st from torchcrf import CRF - -from torchlatent.crf import CrfDecoder -from torchlatent.crf import crf_partitions -from torchlatent.crf import crf_scores -from torchlatent.semiring import Log from torchnyan import BATCH_SIZE from torchnyan import TOKEN_SIZE from torchnyan import assert_close @@ -18,6 +13,11 @@ from torchrua import pack_sequence from torchrua import pad_sequence +from torchlatent.crf import CrfDecoder +from torchlatent.crf import crf_partitions +from torchlatent.crf import crf_scores +from torchlatent.semiring import Log + @given( token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), diff --git a/tests/test_functional.py b/tests/test_functional.py index b06221f..1af9370 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,9 +1,6 @@ import torch from hypothesis import given from hypothesis import strategies as st - -from torchlatent.functional import logaddexp -from torchlatent.functional import logsumexp from torchnyan.assertion import assert_close from torchnyan.assertion import assert_grad_close from torchnyan.strategy import TINY_BATCH_SIZE @@ -11,6 +8,9 @@ from torchnyan.strategy import device from torchnyan.strategy import sizes +from torchlatent.functional import logaddexp +from torchlatent.functional import logsumexp + @given( token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE) diff --git a/torchlatent/abc.py b/torchlatent/abc.py index f963d27..cbc874a 100644 --- a/torchlatent/abc.py +++ b/torchlatent/abc.py @@ -6,7 +6,6 @@ from torch import Tensor from torch import nn from torch.distributions.utils import lazy_property - from torchrua import C from torchrua import D from torchrua import P diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 796d494..f9ec06f 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -4,16 +4,16 @@ import torch from torch import Tensor from torch.distributions.utils import lazy_property +from torchrua import C +from torchrua import CattedSequence +from torchrua import D +from torchrua import P from torchlatent.abc import StructuredDecoder from torchlatent.abc import StructuredDistribution from torchlatent.semiring import Log from torchlatent.semiring import Max from torchlatent.semiring import Semiring -from torchrua import C -from torchrua import CattedSequence -from torchrua import D -from torchrua import P def cky_scores(emissions: C, targets: Union[C, D, P], semiring: Type[Semiring]) -> Tensor: diff --git a/torchlatent/crf.py b/torchlatent/crf.py index e0e261f..88d8bb8 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -7,15 +7,15 @@ from torch import nn from torch.distributions.utils import lazy_property from torch.nn import init +from torchrua import C +from torchrua import D +from torchrua import P from torchlatent.abc import StructuredDecoder from torchlatent.abc import StructuredDistribution from torchlatent.semiring import Log from torchlatent.semiring import Max from torchlatent.semiring import Semiring -from torchrua import C -from torchrua import D -from torchrua import P T = Tuple[Tensor, Tensor, Tensor] diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index b6f0b64..24bfaf0 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,13 +1,13 @@ import torch from torch import Tensor - -from torchlatent.functional import logaddexp -from torchlatent.functional import logsumexp from torchrua import segment_logsumexp from torchrua import segment_max from torchrua import segment_prod from torchrua import segment_sum +from torchlatent.functional import logaddexp +from torchlatent.functional import logsumexp + __all__ = [ 'Semiring', 'ExceptionSemiring', 'Std', 'Log', 'Max', 'Xen', 'Div', From 9f40b3690a9d665c2a3495bad9aca17860d78a71 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Thu, 24 Aug 2023 22:00:16 +0900 Subject: [PATCH 086/102] Doc: Update README.md --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d4b7449..3c63b53 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,10 @@ ## Latent Structures - [x] Conditional Random Fields (CRF) -- [x] Cocke–Kasami-Younger algorithm (CKY) -- [ ] Probabilistic Context-free Grammars (PCFG) +- [x] Cocke–Kasami-Younger Algorithm (CKY) +- [ ] Probabilistic Context-free Grammars (CFG) +- [ ] Connectionist Temporal Classification (CTC) +- [ ] Recurrent Neural Network Grammars (RNNG) - [ ] Non-Projective Dependency Tree (Matrix-tree Theorem) -- [ ] Dependency Model with Valence (DMV) \ No newline at end of file +- [ ] Dependency Model with Valence (DMV) +- [ ] Autoregressive Decoding (Beam Search) From b762953ff3c4b22f7a3e3a4c93dff192ec3a721d Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 28 Aug 2023 03:32:52 +0900 Subject: [PATCH 087/102] Feat: Update torchrua --- tests/test_cky.py | 11 +++++------ tests/test_crf.py | 24 ++++++++++++------------ torchlatent/cky.py | 3 +-- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index 0d5ec94..526b66d 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -8,7 +8,6 @@ from torchnyan import device from torchnyan import sizes from torchrua import C -from torchrua import CattedSequence from torchlatent.cky import CkyDecoder from torchlatent.cky import cky_partitions @@ -42,9 +41,9 @@ def test_cky_scores(token_sizes, num_targets, rua_targets): expected = expected_cky.max - targets = CattedSequence(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) + targets = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) actual = cky_scores( - emissions=CattedSequence(emissions, token_sizes), + emissions=C(emissions, token_sizes), targets=rua_targets(targets), semiring=Log, ) @@ -65,7 +64,7 @@ def test_cky_partitions(token_sizes, num_targets): expected = TreeCRF(emissions, lengths=token_sizes).partition - actual_emissions = CattedSequence( + actual_emissions = C( data=emissions.logsumexp(dim=-1), token_sizes=token_sizes, ) @@ -97,10 +96,10 @@ def test_cky_argmax(token_sizes, num_targets): index = torch.arange(n, device=mask.device) z = torch.masked_select(index[None, None, None, :], mask=mask) - expected = CattedSequence(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) + expected = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) actual_cky = CkyDecoder(num_targets=num_targets) - actual = actual_cky(emissions=CattedSequence(emissions, token_sizes)).argmax + actual = actual_cky(emissions=C(emissions, token_sizes)).argmax for actual, expected in zip(actual.tolist(), expected.tolist()): assert set(map(tuple, actual)) == set(map(tuple, expected)) diff --git a/tests/test_crf.py b/tests/test_crf.py index 849212d..14969ac 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -9,9 +9,9 @@ from torchnyan import assert_sequence_close from torchnyan import device from torchnyan import sizes -from torchrua import cat_sequence -from torchrua import pack_sequence -from torchrua import pad_sequence +from torchrua import C +from torchrua import D +from torchrua import P from torchlatent.crf import CrfDecoder from torchlatent.crf import crf_partitions @@ -22,8 +22,8 @@ @given( token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), num_targets=sizes(TOKEN_SIZE), - rua_emissions=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), - rua_targets=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), + rua_emissions=st.sampled_from([C.new, D.new, P.new]), + rua_targets=st.sampled_from([C.new, D.new, P.new]), ) def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): inputs = [ @@ -38,8 +38,8 @@ def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device) - expected_emissions = pad_sequence(inputs) - expected_tags = pad_sequence(targets) + expected_emissions = D.new(inputs) + expected_tags = D.new(targets) expected = expected_crf._compute_score( expected_emissions.data.transpose(0, 1), @@ -61,7 +61,7 @@ def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): @given( token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), num_targets=sizes(TOKEN_SIZE), - rua_emissions=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), + rua_emissions=st.sampled_from([C.new, D.new, P.new]), ) def test_crf_partitions(token_sizes, num_targets, rua_emissions): inputs = [ @@ -71,7 +71,7 @@ def test_crf_partitions(token_sizes, num_targets, rua_emissions): expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device) - expected_emissions = pad_sequence(inputs) + expected_emissions = D.new(inputs) expected = expected_crf._compute_normalizer( expected_emissions.data.transpose(0, 1), @@ -91,7 +91,7 @@ def test_crf_partitions(token_sizes, num_targets, rua_emissions): @given( token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), num_targets=sizes(TOKEN_SIZE), - rua_emissions=st.sampled_from([cat_sequence, pad_sequence, pack_sequence]), + rua_emissions=st.sampled_from([C.new, D.new, P.new]), ) def test_crf_argmax(token_sizes, num_targets, rua_emissions): inputs = [ @@ -101,13 +101,13 @@ def test_crf_argmax(token_sizes, num_targets, rua_emissions): expected_crf = CRF(num_tags=num_targets, batch_first=False).to(device=device) - expected_emissions = pad_sequence(inputs) + expected_emissions = D.new(inputs) expected = expected_crf.decode( expected_emissions.data.transpose(0, 1), expected_emissions.mask().t(), ) - expected = cat_sequence([torch.tensor(tensor, device=device) for tensor in expected]) + expected = C.new([torch.tensor(tensor, device=device) for tensor in expected]) actual_crf = CrfDecoder(num_targets=num_targets) actual_crf.transitions = expected_crf.transitions diff --git a/torchlatent/cky.py b/torchlatent/cky.py index f9ec06f..796b078 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -5,7 +5,6 @@ from torch import Tensor from torch.distributions.utils import lazy_property from torchrua import C -from torchrua import CattedSequence from torchrua import D from torchrua import P @@ -89,7 +88,7 @@ def argmax(self) -> C: index = torch.arange(n, device=mask.device) z = torch.masked_select(index[None, None, None, :], mask=mask) - return CattedSequence( + return C( data=torch.stack([x, y, z], dim=-1), token_sizes=self.emissions.token_sizes * 2 - 1, ) From a831d19de31934c95537f8442ab8127e7e8626aa Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 2 Sep 2023 22:55:35 +0900 Subject: [PATCH 088/102] Chore: Add weekly job --- .github/workflows/unit-tests.yml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index d18d6cd..99ff915 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,6 +1,9 @@ name: unit tests -on: [ push ] +on: + push: + schedule: + - cron: "0 21 * * 6" jobs: build: From ab19583603eff2ba419c08841f1ddf84b68a5e39 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 10 Sep 2023 16:26:00 +0900 Subject: [PATCH 089/102] Chore: Add workflow_dispatch --- .github/workflows/unit-tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 99ff915..a48185b 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,6 +1,7 @@ name: unit tests on: + workflow_dispatch: push: schedule: - cron: "0 21 * * 6" From e2c76ddf8058a946a01844b5c5faaff0e6b28be1 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 11 Sep 2023 18:33:29 +0900 Subject: [PATCH 090/102] Test: Add assert_grad_close --- tests/test_cky.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_cky.py b/tests/test_cky.py index 526b66d..38ea27b 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -5,6 +5,7 @@ from torchnyan import BATCH_SIZE from torchnyan import TINY_TOKEN_SIZE from torchnyan import assert_close +from torchnyan import assert_grad_close from torchnyan import device from torchnyan import sizes from torchrua import C @@ -49,6 +50,7 @@ def test_cky_scores(token_sizes, num_targets, rua_targets): ) assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=(emissions,)) @given( @@ -71,6 +73,7 @@ def test_cky_partitions(token_sizes, num_targets): actual = cky_partitions(actual_emissions, Log) assert_close(actual=actual, expected=expected) + assert_grad_close(actual=actual, expected=expected, inputs=(emissions,)) @given( From 21ef28545717a73c2796b27cd6a9707e7e1d1f1c Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 11 Sep 2023 18:44:49 +0900 Subject: [PATCH 091/102] Test: Add get_argmax --- tests/test_cky.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index 38ea27b..bdec363 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -16,6 +16,21 @@ from torchlatent.semiring import Log +def get_argmax(cky): + argmax = cky.argmax + mask = argmax > 0 + + _, t, _, n = mask.size() + index = torch.arange(t, device=mask.device) + x = torch.masked_select(index[None, :, None, None], mask=mask) + y = torch.masked_select(index[None, None, :, None], mask=mask) + + index = torch.arange(n, device=mask.device) + z = torch.masked_select(index[None, None, None, :], mask=mask) + + return argmax, x, y, z + + @given( token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), num_targets=sizes(TINY_TOKEN_SIZE), @@ -27,20 +42,14 @@ def test_cky_scores(token_sizes, num_targets, rua_targets): device=device, requires_grad=True, ) token_sizes = torch.tensor(token_sizes, device=device) - expected_cky = TreeCRF(emissions, lengths=token_sizes) - mask = expected_cky.argmax > 0 - _, t, _, n = mask.size() - - index = torch.arange(t, device=mask.device) - x = torch.masked_select(index[None, :, None, None], mask=mask) - y = torch.masked_select(index[None, None, :, None], mask=mask) + argmax, x, y, z = get_argmax(expected_cky) - index = torch.arange(n, device=mask.device) - z = torch.masked_select(index[None, None, None, :], mask=mask) + emissions = torch.randn_like(emissions, requires_grad=True) - expected = expected_cky.max + expected_cky = TreeCRF(emissions, lengths=token_sizes) + expected = expected_cky.log_prob(argmax) + expected_cky.partition targets = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) actual = cky_scores( @@ -89,15 +98,7 @@ def test_cky_argmax(token_sizes, num_targets): expected_cky = TreeCRF(emissions, lengths=token_sizes) - mask = expected_cky.argmax > 0 - _, t, _, n = mask.size() - - index = torch.arange(t, device=mask.device) - x = torch.masked_select(index[None, :, None, None], mask=mask) - y = torch.masked_select(index[None, None, :, None], mask=mask) - - index = torch.arange(n, device=mask.device) - z = torch.masked_select(index[None, None, None, :], mask=mask) + _, x, y, z = get_argmax(expected_cky) expected = C(data=torch.stack([x, y, z], dim=-1), token_sizes=token_sizes * 2 - 1) From 31bcfe50564790d64ae45c87e9806bb6ba1250f2 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Wed, 15 Nov 2023 21:17:44 +0900 Subject: [PATCH 092/102] Feat: Update cky_partitions --- torchlatent/cky.py | 43 +++++++++++++++++++++---------------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 796b078..798063a 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -1,18 +1,12 @@ -from typing import Type -from typing import Union +from typing import Type, Union import torch from torch import Tensor from torch.distributions.utils import lazy_property -from torchrua import C -from torchrua import D -from torchrua import P +from torchrua import C, D, P -from torchlatent.abc import StructuredDecoder -from torchlatent.abc import StructuredDistribution -from torchlatent.semiring import Log -from torchlatent.semiring import Max -from torchlatent.semiring import Semiring +from torchlatent.abc import StructuredDecoder, StructuredDistribution +from torchlatent.semiring import Log, Max, Semiring def cky_scores(emissions: C, targets: Union[C, D, P], semiring: Type[Semiring]) -> Tensor: @@ -29,27 +23,32 @@ def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: y_ptr = token_ptr[z_ptr] _, token_size, *_ = emissions.size() - cache_size, = y_ptr.size() + cache_size, = batch_ptr.size() - src1 = y_ptr - x_ptr, x_ptr + z_ptr - y_ptr - src2 = batch_ptr[z_ptr], x_ptr, y_ptr + w_ptr = y_ptr - x_ptr + src1 = w_ptr, z_ptr - w_ptr + # src2 = -w_ptr - 1, z_ptr + + src = batch_ptr[z_ptr], x_ptr, y_ptr tgt = emissions.token_sizes - 1, emissions.offsets() size = (token_size, cache_size, *emissions.data.size()[3:]) - tensor0 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) - tensor1 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) - tensor2 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) + score1 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) + # score2 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) + chart1 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) + chart2 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) - tensor0[src1] = emissions.data[src2] - tensor1[0, :] = tensor2[-1, :] = tensor0[0, :] + score1[src1] = emissions.data[src] + # score2[src2] = emissions.data[src] + chart1[0, :] = chart2[-1, :] = score1[0, :] for w in range(1, token_size): - tensor1[w, :-w] = tensor2[-w - 1, w:] = semiring.mul( - semiring.sum(semiring.mul(tensor1[:w, :-w], tensor2[-w:, w:]), dim=0), - tensor0[w, :-w], + chart1[w, :-w] = chart2[-w - 1, w:] = semiring.mul( + semiring.sum(semiring.mul(chart1[:w, :-w], chart2[-w:, w:]), dim=0), + score1[w, :-w], ) - return tensor1[tgt] + return chart1[tgt] class CkyDistribution(StructuredDistribution): From 27198a03c62dcabe966fbd33563c03a7a59c6aa8 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Wed, 15 Nov 2023 21:18:29 +0900 Subject: [PATCH 093/102] Refactor: PEP8 them all --- setup.py | 3 +-- tests/test_cky.py | 14 +++----------- tests/test_crf.py | 21 +++++---------------- tests/test_functional.py | 16 +++++----------- torchlatent/__init__.py | 3 +-- torchlatent/abc.py | 7 ++----- torchlatent/crf.py | 20 ++++++-------------- torchlatent/semiring.py | 8 ++------ 8 files changed, 25 insertions(+), 67 deletions(-) diff --git a/setup.py b/setup.py index 136fab2..49c9c6a 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,6 @@ from pathlib import Path -from setuptools import find_packages -from setuptools import setup +from setuptools import find_packages, setup name = 'torchlatent' diff --git a/tests/test_cky.py b/tests/test_cky.py index bdec363..63ff3c3 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -1,18 +1,10 @@ import torch -from hypothesis import given -from hypothesis import strategies as st +from hypothesis import given, strategies as st from torch_struct import TreeCRF -from torchnyan import BATCH_SIZE -from torchnyan import TINY_TOKEN_SIZE -from torchnyan import assert_close -from torchnyan import assert_grad_close -from torchnyan import device -from torchnyan import sizes +from torchnyan import assert_close, assert_grad_close, BATCH_SIZE, device, sizes, TINY_TOKEN_SIZE from torchrua import C -from torchlatent.cky import CkyDecoder -from torchlatent.cky import cky_partitions -from torchlatent.cky import cky_scores +from torchlatent.cky import cky_partitions, cky_scores, CkyDecoder from torchlatent.semiring import Log diff --git a/tests/test_crf.py b/tests/test_crf.py index 14969ac..4b6626a 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -1,21 +1,10 @@ import torch -from hypothesis import given -from hypothesis import strategies as st +from hypothesis import given, strategies as st from torchcrf import CRF -from torchnyan import BATCH_SIZE -from torchnyan import TOKEN_SIZE -from torchnyan import assert_close -from torchnyan import assert_grad_close -from torchnyan import assert_sequence_close -from torchnyan import device -from torchnyan import sizes -from torchrua import C -from torchrua import D -from torchrua import P - -from torchlatent.crf import CrfDecoder -from torchlatent.crf import crf_partitions -from torchlatent.crf import crf_scores +from torchnyan import assert_close, assert_grad_close, assert_sequence_close, BATCH_SIZE, device, sizes, TOKEN_SIZE +from torchrua import C, D, P + +from torchlatent.crf import crf_partitions, crf_scores, CrfDecoder from torchlatent.semiring import Log diff --git a/tests/test_functional.py b/tests/test_functional.py index 1af9370..3b8b4e1 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,15 +1,9 @@ import torch -from hypothesis import given -from hypothesis import strategies as st -from torchnyan.assertion import assert_close -from torchnyan.assertion import assert_grad_close -from torchnyan.strategy import TINY_BATCH_SIZE -from torchnyan.strategy import TINY_TOKEN_SIZE -from torchnyan.strategy import device -from torchnyan.strategy import sizes - -from torchlatent.functional import logaddexp -from torchlatent.functional import logsumexp +from hypothesis import given, strategies as st +from torchnyan.assertion import assert_close, assert_grad_close +from torchnyan.strategy import device, sizes, TINY_BATCH_SIZE, TINY_TOKEN_SIZE + +from torchlatent.functional import logaddexp, logsumexp @given( diff --git a/torchlatent/__init__.py b/torchlatent/__init__.py index 99084d9..8203d58 100644 --- a/torchlatent/__init__.py +++ b/torchlatent/__init__.py @@ -1,2 +1 @@ -from torchlatent.crf import CrfDecoder -from torchlatent.crf import CrfDistribution +from torchlatent.crf import CrfDecoder, CrfDistribution diff --git a/torchlatent/abc.py b/torchlatent/abc.py index cbc874a..2c0e65c 100644 --- a/torchlatent/abc.py +++ b/torchlatent/abc.py @@ -3,12 +3,9 @@ import torch import torch.autograd -from torch import Tensor -from torch import nn +from torch import nn, Tensor from torch.distributions.utils import lazy_property -from torchrua import C -from torchrua import D -from torchrua import P +from torchrua import C, D, P class StructuredDistribution(object, metaclass=ABCMeta): diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 88d8bb8..5a78e79 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -1,21 +1,13 @@ -from typing import Tuple -from typing import Type -from typing import Union +from typing import Tuple, Type, Union import torch -from torch import Tensor -from torch import nn +from torch import nn, Tensor from torch.distributions.utils import lazy_property from torch.nn import init -from torchrua import C -from torchrua import D -from torchrua import P - -from torchlatent.abc import StructuredDecoder -from torchlatent.abc import StructuredDistribution -from torchlatent.semiring import Log -from torchlatent.semiring import Max -from torchlatent.semiring import Semiring +from torchrua import C, D, P + +from torchlatent.abc import StructuredDecoder, StructuredDistribution +from torchlatent.semiring import Log, Max, Semiring T = Tuple[Tensor, Tensor, Tensor] diff --git a/torchlatent/semiring.py b/torchlatent/semiring.py index 24bfaf0..3f11e6d 100644 --- a/torchlatent/semiring.py +++ b/torchlatent/semiring.py @@ -1,12 +1,8 @@ import torch from torch import Tensor -from torchrua import segment_logsumexp -from torchrua import segment_max -from torchrua import segment_prod -from torchrua import segment_sum +from torchrua import segment_logsumexp, segment_max, segment_prod, segment_sum -from torchlatent.functional import logaddexp -from torchlatent.functional import logsumexp +from torchlatent.functional import logaddexp, logsumexp __all__ = [ 'Semiring', 'ExceptionSemiring', From 0228ce413e17bcd48cd50a93cd20c7293ffad4eb Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Tue, 21 Nov 2023 19:09:01 +0900 Subject: [PATCH 094/102] Feat: Update cky_partitions --- torchlatent/cky.py | 54 +++++++++++++++++++++------------------------- 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 798063a..b17cbf7 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -18,37 +18,31 @@ def cky_scores(emissions: C, targets: Union[C, D, P], semiring: Type[Semiring]) def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: - batch_ptr, token_ptr = emissions.ptr() - z_ptr, x_ptr = emissions._replace(token_sizes=token_ptr + 1).ptr() - y_ptr = token_ptr[z_ptr] - - _, token_size, *_ = emissions.size() - cache_size, = batch_ptr.size() - - w_ptr = y_ptr - x_ptr - src1 = w_ptr, z_ptr - w_ptr - # src2 = -w_ptr - 1, z_ptr - - src = batch_ptr[z_ptr], x_ptr, y_ptr - tgt = emissions.token_sizes - 1, emissions.offsets() - - size = (token_size, cache_size, *emissions.data.size()[3:]) - score1 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) - # score2 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) - chart1 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) - chart2 = emissions.data.new_full(size, fill_value=semiring.zero, requires_grad=False) - - score1[src1] = emissions.data[src] - # score2[src2] = emissions.data[src] - chart1[0, :] = chart2[-1, :] = score1[0, :] - - for w in range(1, token_size): - chart1[w, :-w] = chart2[-w - 1, w:] = semiring.mul( - semiring.sum(semiring.mul(chart1[:w, :-w], chart2[-w:, w:]), dim=0), - score1[w, :-w], - ) + b, t, _, *size = emissions.data.size() + c, n, m, *stride = emissions.data.stride() + + chart = torch.full_like(emissions.data, fill_value=Log.zero, requires_grad=False) + + def diag() -> Tensor: + return emissions.data.diagonal(offset=w, dim1=1, dim2=2) + + def diag_scatter(tensor: Tensor) -> None: + chart.diagonal(offset=w, dim1=1, dim2=2)[::] = tensor + + def left() -> Tensor: + return chart.as_strided(size=(b, t - w, w, *size), stride=(c, n + m, m, *stride)) + + def right() -> Tensor: + return chart[:, 1:, w:].as_strided(size=(b, t - w, w, *size), stride=(c, n + m, n, *stride)) + + w = 0 + diag_scatter(diag()) + + for w in range(1, t): + diag_scatter(semiring.mul(semiring.sum(semiring.mul(left(), right()), dim=2), diag())) - return chart1[tgt] + index = torch.arange(b, dtype=torch.long, device=chart.device) + return chart[index, 0, emissions.token_sizes - 1] class CkyDistribution(StructuredDistribution): From ab027fc6ad36f3ec6f9b0f17d6dfd74623bdb34d Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Mon, 27 Nov 2023 21:39:13 +0900 Subject: [PATCH 095/102] Style: PEP8 them all --- tests/test_cky.py | 4 ++-- tests/test_crf.py | 4 ++-- tests/test_functional.py | 2 +- torchlatent/abc.py | 2 +- torchlatent/crf.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index 63ff3c3..5352218 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -1,10 +1,10 @@ import torch from hypothesis import given, strategies as st from torch_struct import TreeCRF -from torchnyan import assert_close, assert_grad_close, BATCH_SIZE, device, sizes, TINY_TOKEN_SIZE +from torchnyan import BATCH_SIZE, TINY_TOKEN_SIZE, assert_close, assert_grad_close, device, sizes from torchrua import C -from torchlatent.cky import cky_partitions, cky_scores, CkyDecoder +from torchlatent.cky import CkyDecoder, cky_partitions, cky_scores from torchlatent.semiring import Log diff --git a/tests/test_crf.py b/tests/test_crf.py index 4b6626a..2a34b02 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -1,10 +1,10 @@ import torch from hypothesis import given, strategies as st from torchcrf import CRF -from torchnyan import assert_close, assert_grad_close, assert_sequence_close, BATCH_SIZE, device, sizes, TOKEN_SIZE +from torchnyan import BATCH_SIZE, TOKEN_SIZE, assert_close, assert_grad_close, assert_sequence_close, device, sizes from torchrua import C, D, P -from torchlatent.crf import crf_partitions, crf_scores, CrfDecoder +from torchlatent.crf import CrfDecoder, crf_partitions, crf_scores from torchlatent.semiring import Log diff --git a/tests/test_functional.py b/tests/test_functional.py index 3b8b4e1..c4b663e 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,7 +1,7 @@ import torch from hypothesis import given, strategies as st from torchnyan.assertion import assert_close, assert_grad_close -from torchnyan.strategy import device, sizes, TINY_BATCH_SIZE, TINY_TOKEN_SIZE +from torchnyan.strategy import TINY_BATCH_SIZE, TINY_TOKEN_SIZE, device, sizes from torchlatent.functional import logaddexp, logsumexp diff --git a/torchlatent/abc.py b/torchlatent/abc.py index 2c0e65c..d5e088d 100644 --- a/torchlatent/abc.py +++ b/torchlatent/abc.py @@ -3,7 +3,7 @@ import torch import torch.autograd -from torch import nn, Tensor +from torch import Tensor, nn from torch.distributions.utils import lazy_property from torchrua import C, D, P diff --git a/torchlatent/crf.py b/torchlatent/crf.py index 5a78e79..83d1b09 100644 --- a/torchlatent/crf.py +++ b/torchlatent/crf.py @@ -1,7 +1,7 @@ from typing import Tuple, Type, Union import torch -from torch import nn, Tensor +from torch import Tensor, nn from torch.distributions.utils import lazy_property from torch.nn import init from torchrua import C, D, P From 7c2cc49be71786909cb539c82ddce358ce42a3bb Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 28 Jan 2024 19:24:24 +0900 Subject: [PATCH 096/102] Refactor: Update cky --- torchlatent/cky.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index b17cbf7..a4195ec 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -18,28 +18,28 @@ def cky_scores(emissions: C, targets: Union[C, D, P], semiring: Type[Semiring]) def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: - b, t, _, *size = emissions.data.size() - c, n, m, *stride = emissions.data.stride() + chart = torch.full_like(emissions.data, fill_value=semiring.zero, requires_grad=False) + b, t, _, *size = chart.size() + c, n, m, *stride = chart.stride() - chart = torch.full_like(emissions.data, fill_value=Log.zero, requires_grad=False) + def diag(offset: int) -> Tensor: + return emissions.data.diagonal(offset=offset, dim1=1, dim2=2) - def diag() -> Tensor: - return emissions.data.diagonal(offset=w, dim1=1, dim2=2) + def diag_scatter(tensor: Tensor, offset: int) -> None: + chart.diagonal(offset=offset, dim1=1, dim2=2)[::] = tensor - def diag_scatter(tensor: Tensor) -> None: - chart.diagonal(offset=w, dim1=1, dim2=2)[::] = tensor + def left(offset: int) -> Tensor: + return chart.as_strided(size=(b, t - offset, offset, *size), stride=(c, n + m, m, *stride)) - def left() -> Tensor: - return chart.as_strided(size=(b, t - w, w, *size), stride=(c, n + m, m, *stride)) + def right(offset: int) -> Tensor: + return chart[:, 1:, offset:].as_strided(size=(b, t - offset, offset, *size), stride=(c, n + m, n, *stride)) - def right() -> Tensor: - return chart[:, 1:, w:].as_strided(size=(b, t - w, w, *size), stride=(c, n + m, n, *stride)) - - w = 0 - diag_scatter(diag()) + score = diag(offset=0) + diag_scatter(score, offset=0) for w in range(1, t): - diag_scatter(semiring.mul(semiring.sum(semiring.mul(left(), right()), dim=2), diag())) + score = semiring.sum(semiring.mul(left(offset=w), right(offset=w)), dim=2) + diag_scatter(semiring.mul(score, diag(offset=w)), offset=w) index = torch.arange(b, dtype=torch.long, device=chart.device) return chart[index, 0, emissions.token_sizes - 1] From 3ffcce342e6c72c04214eacc136832a960d1fce1 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 24 Feb 2024 21:07:56 +0900 Subject: [PATCH 097/102] Refactor: Update cky --- torchlatent/cky.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index a4195ec..0154336 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -18,6 +18,10 @@ def cky_scores(emissions: C, targets: Union[C, D, P], semiring: Type[Semiring]) def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: + if emissions.data.dim() == 4: + data = semiring.sum(emissions.data, dim=-1) + emissions = emissions._replace(data=data) + chart = torch.full_like(emissions.data, fill_value=semiring.zero, requires_grad=False) b, t, _, *size = chart.size() c, n, m, *stride = chart.stride() @@ -58,14 +62,14 @@ def log_scores(self, targets: Union[C, D, P]) -> Tensor: @lazy_property def log_partitions(self) -> Tensor: return cky_partitions( - emissions=self.emissions._replace(data=Log.sum(self.emissions.data, dim=-1)), + emissions=self.emissions, semiring=Log, ) @lazy_property def max(self) -> Tensor: return cky_partitions( - emissions=self.emissions._replace(data=Max.sum(self.emissions.data, dim=-1)), + emissions=self.emissions, semiring=Max, ) From fe1390346cf57a0ac1549f963f6699e7bc4f4be3 Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sat, 24 Feb 2024 22:06:50 +0900 Subject: [PATCH 098/102] Refactor: Update cky --- torchlatent/cky.py | 49 ++++++++++++++++++++++++++++------------------ 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index 0154336..faff08e 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -17,35 +17,46 @@ def cky_scores(emissions: C, targets: Union[C, D, P], semiring: Type[Semiring]) return semiring.segment_prod(emissions, token_sizes) -def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: - if emissions.data.dim() == 4: - data = semiring.sum(emissions.data, dim=-1) - emissions = emissions._replace(data=data) +def diag(tensor: Tensor, offset: int) -> Tensor: + return tensor.diagonal(offset=offset, dim1=1, dim2=2) - chart = torch.full_like(emissions.data, fill_value=semiring.zero, requires_grad=False) + +def diag_scatter(chart: Tensor, score: Tensor, offset: int) -> None: + chart.diagonal(offset=offset, dim1=1, dim2=2)[::] = score + + +def left(chart: Tensor, offset: int) -> Tensor: b, t, _, *size = chart.size() c, n, m, *stride = chart.stride() + return chart.as_strided( + size=(b, t - offset, offset, *size), + stride=(c, n + m, m, *stride), + ) - def diag(offset: int) -> Tensor: - return emissions.data.diagonal(offset=offset, dim1=1, dim2=2) - def diag_scatter(tensor: Tensor, offset: int) -> None: - chart.diagonal(offset=offset, dim1=1, dim2=2)[::] = tensor +def right(chart: Tensor, offset: int) -> Tensor: + b, t, _, *size = chart.size() + c, n, m, *stride = chart.stride() + return chart[:, 1:, offset:].as_strided( + size=(b, t - offset, offset, *size), + stride=(c, n + m, n, *stride), + ) - def left(offset: int) -> Tensor: - return chart.as_strided(size=(b, t - offset, offset, *size), stride=(c, n + m, m, *stride)) - def right(offset: int) -> Tensor: - return chart[:, 1:, offset:].as_strided(size=(b, t - offset, offset, *size), stride=(c, n + m, n, *stride)) +def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: + if emissions.data.dim() == 4: + data = semiring.sum(emissions.data, dim=-1) + emissions = emissions._replace(data=data) + + chart = torch.full_like(emissions.data, fill_value=semiring.zero, requires_grad=False) - score = diag(offset=0) - diag_scatter(score, offset=0) + diag_scatter(chart, diag(emissions.data, offset=0), offset=0) - for w in range(1, t): - score = semiring.sum(semiring.mul(left(offset=w), right(offset=w)), dim=2) - diag_scatter(semiring.mul(score, diag(offset=w)), offset=w) + for w in range(1, chart.size()[1]): + score = semiring.sum(semiring.mul(left(chart, offset=w), right(chart, offset=w)), dim=2) + diag_scatter(chart, semiring.mul(score, diag(emissions.data, offset=w)), offset=w) - index = torch.arange(b, dtype=torch.long, device=chart.device) + index = torch.arange(chart.size()[0], dtype=torch.long, device=chart.device) return chart[index, 0, emissions.token_sizes - 1] From 149e68211feec1c0c6a61f4222628ce69e9d6e7a Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 25 Feb 2024 13:42:41 +0900 Subject: [PATCH 099/102] Test: Add @settings(deadline=None) --- tests/test_cky.py | 5 ++++- tests/test_crf.py | 5 ++++- tests/test_functional.py | 4 +++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/test_cky.py b/tests/test_cky.py index 5352218..e627e8e 100644 --- a/tests/test_cky.py +++ b/tests/test_cky.py @@ -1,5 +1,5 @@ import torch -from hypothesis import given, strategies as st +from hypothesis import given, settings, strategies as st from torch_struct import TreeCRF from torchnyan import BATCH_SIZE, TINY_TOKEN_SIZE, assert_close, assert_grad_close, device, sizes from torchrua import C @@ -23,6 +23,7 @@ def get_argmax(cky): return argmax, x, y, z +@settings(deadline=None) @given( token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), num_targets=sizes(TINY_TOKEN_SIZE), @@ -54,6 +55,7 @@ def test_cky_scores(token_sizes, num_targets, rua_targets): assert_grad_close(actual=actual, expected=expected, inputs=(emissions,)) +@settings(deadline=None) @given( token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), num_targets=sizes(TINY_TOKEN_SIZE), @@ -77,6 +79,7 @@ def test_cky_partitions(token_sizes, num_targets): assert_grad_close(actual=actual, expected=expected, inputs=(emissions,)) +@settings(deadline=None) @given( token_sizes=sizes(BATCH_SIZE, TINY_TOKEN_SIZE), num_targets=sizes(TINY_TOKEN_SIZE), diff --git a/tests/test_crf.py b/tests/test_crf.py index 2a34b02..f148245 100644 --- a/tests/test_crf.py +++ b/tests/test_crf.py @@ -1,5 +1,5 @@ import torch -from hypothesis import given, strategies as st +from hypothesis import given, settings, strategies as st from torchcrf import CRF from torchnyan import BATCH_SIZE, TOKEN_SIZE, assert_close, assert_grad_close, assert_sequence_close, device, sizes from torchrua import C, D, P @@ -8,6 +8,7 @@ from torchlatent.semiring import Log +@settings(deadline=None) @given( token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), num_targets=sizes(TOKEN_SIZE), @@ -47,6 +48,7 @@ def test_crf_scores(token_sizes, num_targets, rua_emissions, rua_targets): assert_grad_close(actual=actual, expected=expected, inputs=inputs) +@settings(deadline=None) @given( token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), num_targets=sizes(TOKEN_SIZE), @@ -77,6 +79,7 @@ def test_crf_partitions(token_sizes, num_targets, rua_emissions): assert_grad_close(actual=actual, expected=expected, inputs=inputs, rtol=1e-4, atol=1e-4) +@settings(deadline=None) @given( token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE), num_targets=sizes(TOKEN_SIZE), diff --git a/tests/test_functional.py b/tests/test_functional.py index c4b663e..2686b68 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,11 +1,12 @@ import torch -from hypothesis import given, strategies as st +from hypothesis import given, settings, strategies as st from torchnyan.assertion import assert_close, assert_grad_close from torchnyan.strategy import TINY_BATCH_SIZE, TINY_TOKEN_SIZE, device, sizes from torchlatent.functional import logaddexp, logsumexp +@settings(deadline=None) @given( token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE) ) @@ -20,6 +21,7 @@ def test_logaddexp(token_sizes): assert_grad_close(actual=actual, expected=expected, inputs=(x, y)) +@settings(deadline=None) @given( data=st.data(), token_sizes=sizes(TINY_BATCH_SIZE, TINY_TOKEN_SIZE) From 90c4043f5a2f51deb002d795b276ce678162641f Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 25 Feb 2024 13:49:35 +0900 Subject: [PATCH 100/102] Test: Add masked_select --- torchlatent/cky.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/torchlatent/cky.py b/torchlatent/cky.py index faff08e..b8c863d 100644 --- a/torchlatent/cky.py +++ b/torchlatent/cky.py @@ -1,4 +1,4 @@ -from typing import Type, Union +from typing import Tuple, Type, Union import torch from torch import Tensor @@ -44,10 +44,6 @@ def right(chart: Tensor, offset: int) -> Tensor: def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: - if emissions.data.dim() == 4: - data = semiring.sum(emissions.data, dim=-1) - emissions = emissions._replace(data=data) - chart = torch.full_like(emissions.data, fill_value=semiring.zero, requires_grad=False) diag_scatter(chart, diag(emissions.data, offset=0), offset=0) @@ -60,6 +56,19 @@ def cky_partitions(emissions: C, semiring: Type[Semiring]) -> Tensor: return chart[index, 0, emissions.token_sizes - 1] +def masked_select(mask: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + _, t, _, n = mask.size() + + index = torch.arange(t, device=mask.device) + x = torch.masked_select(index[None, :, None, None], mask=mask) + y = torch.masked_select(index[None, None, :, None], mask=mask) + + index = torch.arange(n, device=mask.device) + z = torch.masked_select(index[None, None, None, :], mask=mask) + + return x, y, z + + class CkyDistribution(StructuredDistribution): def __init__(self, emissions: C) -> None: super(CkyDistribution, self).__init__(emissions=emissions) @@ -73,28 +82,21 @@ def log_scores(self, targets: Union[C, D, P]) -> Tensor: @lazy_property def log_partitions(self) -> Tensor: return cky_partitions( - emissions=self.emissions, + emissions=self.emissions._replace(data=Log.sum(self.emissions.data, dim=-1)), semiring=Log, ) @lazy_property def max(self) -> Tensor: return cky_partitions( - emissions=self.emissions, + emissions=self.emissions._replace(data=Max.sum(self.emissions.data, dim=-1)), semiring=Max, ) @lazy_property def argmax(self) -> C: - mask = super(CkyDistribution, self).argmax > 0 - _, t, _, n = mask.size() - - index = torch.arange(t, device=mask.device) - x = torch.masked_select(index[None, :, None, None], mask=mask) - y = torch.masked_select(index[None, None, :, None], mask=mask) - - index = torch.arange(n, device=mask.device) - z = torch.masked_select(index[None, None, None, :], mask=mask) + argmax = super(CkyDistribution, self).argmax + x, y, z = masked_select(argmax > 0) return C( data=torch.stack([x, y, z], dim=-1), From ba23b8b64f5082c813486581659b76d770b2ee8f Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 25 Feb 2024 13:58:45 +0900 Subject: [PATCH 101/102] Chore: Update github workflows --- .github/workflows/publish-package.yml | 10 ++++------ .github/workflows/unit-tests.yml | 8 +++----- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index c4e0338..e7f3d7f 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -5,19 +5,17 @@ on: types: [ created ] jobs: - deploy: - + build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: '3.8' - name: Install dependencies run: | - python -m pip install pip setuptools wheel --upgrade + python -m pip install --upgrade pip setuptools wheel python -m pip install twine - name: Build and publish env: diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index a48185b..0717269 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -8,21 +8,19 @@ on: jobs: build: - runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: '3.8' - name: Install dependencies run: | python -m pip install pip --upgrade python -m pip install -r requirements.txt - python -m pip install git+https://github.com/speedcell4/torchrua.git@develop --force-reinstall --no-deps python -m pip install pytest hypothesis torchnyan + python -m pip install git+https://github.com/speedcell4/torchrua.git@develop --force-reinstall --no-deps python -m pip install pytorch-crf torch-struct - name: Test with pytest run: | From c248730e3e5e11ab0e2984ab4b022847cb0e866d Mon Sep 17 00:00:00 2001 From: speedcell4 Date: Sun, 25 Feb 2024 14:12:55 +0900 Subject: [PATCH 102/102] Chore: Update version number --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 49c9c6a..ba04b02 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( name=name, - version='0.5.0', + version='0.4.3', packages=[package for package in find_packages() if package.startswith(name)], url='https://github.com/speedcell4/torchlatent', license='MIT',