Skip to content

Commit

Permalink
Feat: Update select
Browse files Browse the repository at this point in the history
  • Loading branch information
speedcell4 committed Jul 20, 2024
1 parent bd732ef commit bfa2f81
Show file tree
Hide file tree
Showing 39 changed files with 523 additions and 485 deletions.
10 changes: 5 additions & 5 deletions tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.nn.utils.rnn import pack_sequence
from torchnyan import FEATURE_DIM, TINY_BATCH_SIZE, TINY_TOKEN_SIZE, assert_close, assert_grad_close, device, sizes

from torchrua import C, L, P, compose
from torchrua import Z, compose


@settings(deadline=None)
Expand All @@ -14,7 +14,7 @@
input_size=sizes(FEATURE_DIM),
hidden_size=sizes(FEATURE_DIM),
)
def test_compose_sequences(data, token_sizes_batch, input_size, hidden_size):
def test_compose(data, token_sizes_batch, input_size, hidden_size):
sequences = [
[
torch.randn((token_size, input_size), requires_grad=True, device=device)
Expand All @@ -29,11 +29,11 @@ def test_compose_sequences(data, token_sizes_batch, input_size, hidden_size):
bidirectional=True, bias=True,
).to(device=device)

actual_sequences = [
data.draw(st.sampled_from([C, L, P])).new(sequence).to(device=device)
actual = [
data.draw(st.sampled_from(Z.__args__)).new(sequence).to(device=device)
for sequence in sequences
]
_, (actual, _) = rnn(compose(actual_sequences))
_, (actual, _) = rnn(compose(actual))
actual = actual.transpose(-3, -2).flatten(start_dim=-2)

expected = []
Expand Down
24 changes: 24 additions & 0 deletions tests/test_detach.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
from hypothesis import given, settings, strategies as st
from torchnyan import BATCH_SIZE, FEATURE_DIM, TOKEN_SIZE, device, sizes
from torchnyan.assertion import assert_close

from torchrua import Z


@settings(deadline=None)
@given(
token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
dim=sizes(FEATURE_DIM),
rua=st.sampled_from(Z.__args__),
)
def test_split(token_sizes, dim, rua):
inputs = expected = [
torch.randn((token_size, dim), device=device, requires_grad=True)
for token_size in token_sizes
]

actual = rua.new(inputs).split()

for a, e in zip(actual, expected):
assert_close(actual=a, expected=e)
24 changes: 0 additions & 24 deletions tests/test_head.py

This file was deleted.

24 changes: 0 additions & 24 deletions tests/test_last.py

This file was deleted.

8 changes: 4 additions & 4 deletions tests/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from torch.nn.utils.rnn import pad_sequence
from torchnyan import BATCH_SIZE, TOKEN_SIZE, assert_close, device, sizes

from torchrua import C, L, P
from torchrua import Z


@settings(deadline=None)
@given(
token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
rua_sequence=st.sampled_from([C.new, L.new, P.new]),
rua=st.sampled_from(Z.__args__),
zero_one_dtype=st.sampled_from([
(False, True, torch.bool),
(-1, +2, torch.long),
Expand All @@ -18,15 +18,15 @@
(torch.finfo(torch.float64).min, torch.finfo(torch.float64).max, torch.float64),
])
)
def test_mask_sequence(token_sizes, rua_sequence, zero_one_dtype):
def test_mask(token_sizes, rua, zero_one_dtype):
inputs = [
torch.randn((token_size,), device=device, requires_grad=True)
for token_size in token_sizes
]

zero, one, dtype = zero_one_dtype

actual = rua_sequence(inputs).mask(zero=zero, one=one, dtype=dtype)
actual = rua.new(inputs).mask(zero=zero, one=one, dtype=dtype)
expected = pad_sequence([
torch.full((token_size,), fill_value=one, device=device, dtype=dtype)
for token_size in token_sizes
Expand Down
24 changes: 0 additions & 24 deletions tests/test_reverse.py

This file was deleted.

27 changes: 0 additions & 27 deletions tests/test_roll.py

This file was deleted.

11 changes: 6 additions & 5 deletions tests/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from torch import Tensor
from torchnyan import BATCH_SIZE, FEATURE_DIM, TOKEN_SIZE, assert_grad_close, assert_sequence_close, device, sizes

from torchrua import C, L, P, segment_head, segment_last, segment_logsumexp, segment_max, segment_mean, segment_min, \
from torchrua import C, L, Z
from torchrua.reduce import segment_head, segment_last, segment_logsumexp, segment_max, segment_mean, segment_min, \
segment_prod, segment_sum


Expand Down Expand Up @@ -68,10 +69,10 @@ def raw_segment(sequence, duration, fn):
(segment_head, reduce_head),
(segment_last, reduce_last),
]),
rua_sequence=st.sampled_from([C.new, L.new, P.new]),
rua_duration=st.sampled_from([C.new, L.new, P.new]),
rua_sequence=st.sampled_from(Z.__args__),
rua_duration=st.sampled_from(Z.__args__),
)
def test_segment_sequence(token_sizes, dim, fns, rua_sequence, rua_duration):
def test_seg(token_sizes, dim, fns, rua_sequence, rua_duration):
inputs = [
torch.randn((token_size, dim), device=device, requires_grad=True)
for token_size in token_sizes
Expand All @@ -84,7 +85,7 @@ def test_segment_sequence(token_sizes, dim, fns, rua_sequence, rua_duration):

fn1, fn2 = fns

actual = rua_sequence(inputs).seg(rua_duration(durations), fn1).cat()
actual = rua_sequence.new(inputs).seg(rua_duration.new(durations), fn1).cat()
expected = C.new(raw_segment(L.new(inputs).data, durations, fn2))

assert_sequence_close(actual=actual, expected=expected)
Expand Down
109 changes: 109 additions & 0 deletions tests/test_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
from hypothesis import given, settings, strategies as st
from torchnyan import BATCH_SIZE, FEATURE_DIM, TOKEN_SIZE, device, sizes
from torchnyan.assertion import assert_close, assert_grad_close, assert_sequence_close

from torchrua import C, Z


@settings(deadline=None)
@given(
token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
dim=sizes(FEATURE_DIM),
rua=st.sampled_from(Z.__args__),
)
def test_head(token_sizes, dim, rua):
inputs = [
torch.randn((token_size, dim), device=device, requires_grad=True)
for token_size in token_sizes
]

actual = rua.new(inputs).head()
expected = torch.stack([tensor[0] for tensor in inputs], dim=0)

assert_close(actual=actual, expected=expected)
assert_grad_close(actual=actual, expected=expected, inputs=inputs)


@settings(deadline=None)
@given(
token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
dim=sizes(FEATURE_DIM),
rua=st.sampled_from(Z.__args__),
)
def test_last(token_sizes, dim, rua):
inputs = [
torch.randn((token_size, dim), device=device, requires_grad=True)
for token_size in token_sizes
]

actual = rua.new(inputs).last()
expected = torch.stack([tensor[-1] for tensor in inputs], dim=0)

assert_close(actual=actual, expected=expected)
assert_grad_close(actual=actual, expected=expected, inputs=inputs)


@settings(deadline=None)
@given(
token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
dim=sizes(FEATURE_DIM),
rua=st.sampled_from(Z.__args__),
)
def test_rev(token_sizes, dim, rua):
inputs = [
torch.randn((token_size, dim), device=device, requires_grad=True)
for token_size in token_sizes
]

actual = rua.new(inputs).rev().cat()
expected = C.new([tensor.flip(dims=[0]) for tensor in inputs])

assert_sequence_close(actual=actual, expected=expected)
assert_grad_close(actual=actual.data, expected=expected.data, inputs=inputs)


@settings(deadline=None)
@given(
data=st.data(),
token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
dim=sizes(FEATURE_DIM),
rua=st.sampled_from(Z.__args__),
)
def test_roll(data, token_sizes, dim, rua):
shifts = data.draw(st.integers(min_value=-max(token_sizes), max_value=+max(token_sizes)))

inputs = [
torch.randn((token_size, dim), device=device, requires_grad=True)
for token_size in token_sizes
]

actual = rua.new(inputs).roll(shifts=shifts).cat()
expected = C.new([tensor.roll(shifts, dims=[0]) for tensor in inputs])

assert_sequence_close(actual, expected)
assert_grad_close(actual.data, expected.data, inputs=inputs)


@settings(deadline=None)
@given(
data=st.data(),
token_sizes=sizes(BATCH_SIZE, TOKEN_SIZE),
dim=sizes(FEATURE_DIM),
rua=st.sampled_from(Z.__args__),
)
def test_trunc(data, token_sizes, dim, rua):
inputs = [
torch.randn((token_size, dim), device=device, requires_grad=True)
for token_size in token_sizes
]

s = min(token_sizes) - 1
a = data.draw(st.integers(0, max_value=s))
b = data.draw(st.integers(0, max_value=s - a))

actual = rua.new(inputs).trunc((a, b)).cat()
expected = C.new([tensor[a:tensor.size()[0] - b] for tensor in inputs])

assert_sequence_close(actual=actual, expected=expected)
assert_grad_close(actual=actual.data, expected=expected.data, inputs=inputs)
29 changes: 0 additions & 29 deletions tests/test_trunc.py

This file was deleted.

13 changes: 2 additions & 11 deletions torchrua/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
from torchrua.compose import *
from torchrua.core import *
from torchrua.decode import *
from torchrua.get import *
from torchrua.head import *
from torchrua.last import *
from torchrua.detach import *
from torchrua.layout import *
from torchrua.mask import *
from torchrua.new import *
from torchrua.reduce import *
from torchrua.reverse import *
from torchrua.roll import *
from torchrua.segment import *
from torchrua.set import *
from torchrua.transform import *
from torchrua.trunc import *
from torchrua.view import *
from torchrua.select import *
2 changes: 1 addition & 1 deletion torchrua/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import torch

from torchrua.core import invert_permutation
from torchrua.layout import C, P, Z
from torchrua.utils import invert_permutation


def compose(sequences: List[Z]) -> P:
Expand Down
Loading

0 comments on commit bfa2f81

Please sign in to comment.