Skip to content

Commit

Permalink
【Paddle Tensor No.8、9、14、15】为Tensor新增__rshift__,__lshift__,`__rls…
Browse files Browse the repository at this point in the history
…hift__`,`__rrshift__` (#69348)

* update

* fix:update expected output

* fix:`_rrshift__` -> `__rrshift__`

* feat: use `paddle.to_tensor` to support int input,and raise TypeError for float

* 📝 docs:remove docstring of magic method

* test:add test for `__lshift__`,`__rshift__`,`__rlshift__`,`__rrshift__`,`__rshift__`cannot pass ref_right_shift_logical

`__rshift__` cannot pass TestBitwiseRightShiftAPI.test_static_api_logical and test_dygraph_api_arithmetic

* test: use `x.__rrshift__(y,False)` to send is_arithmetic

* test: keep it as the same.

使用__rshift__而不是__rrshift__,使用__lshift__,而不是__rlshift__

* chore: remove `out` and `name`

* fix: 直接继承TestCase来进行不同类型测试

* test: add more test in uint8,int8,int16 for `__rlshift__` and `__rrshift__`

* typos: raises TypeError between `'float' and 'Tensor'` and `'Tensor' and 'float'`

* test: add test for `TypeError` in float

* typo: `bool` -> `float`
  • Loading branch information
MrXnneHang authored Nov 15, 2024
1 parent 3a8618e commit 138ac11
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@
vstack,
)
from .math import ( # noqa: F401
__lshift__,
__rlshift__,
__rrshift__,
__rshift__,
abs,
abs_,
acos,
Expand Down Expand Up @@ -862,4 +866,8 @@
('__xor__', 'bitwise_xor'),
('__invert__', 'bitwise_not'),
('__pos__', 'positive'),
('__lshift__', '__lshift__'),
('__rshift__', '__rshift__'),
('__rlshift__', '__rlshift__'),
('__rrshift__', '__rrshift__'),
]
57 changes: 57 additions & 0 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -7769,6 +7769,63 @@ def bitwise_right_shift_(
return _C_ops.bitwise_right_shift_(x, y, is_arithmetic)


def __lshift__(
x: Tensor,
y: Tensor | int,
is_arithmetic: bool = True,
) -> Tensor:
if isinstance(y, int):
y = paddle.to_tensor(y, dtype=x.dtype)
elif isinstance(y, float):
raise TypeError(
"unsupported operand type(s) for <<: 'Tensor' and 'float'"
)
return bitwise_left_shift(x, y, is_arithmetic, None, None)


def __rshift__(
x: Tensor,
y: Tensor | int,
is_arithmetic: bool = True,
) -> Tensor:

if isinstance(y, int):
y = paddle.to_tensor(y, dtype=x.dtype)
elif isinstance(y, float):
raise TypeError(
"unsupported operand type(s) for <<: 'Tensor' and 'float'"
)
return bitwise_right_shift(x, y, is_arithmetic, None, None)


def __rlshift__(
x: Tensor,
y: Tensor | int,
is_arithmetic: bool = True,
):
if isinstance(y, int):
y = paddle.to_tensor(y, dtype=x.dtype)
elif isinstance(y, float):
raise TypeError(
"unsupported operand type(s) for <<: 'float' and 'Tensor'"
)
return bitwise_left_shift(y, x, is_arithmetic, None, None)


def __rrshift__(
x: Tensor,
y: Tensor | int,
is_arithmetic: bool = True,
):
if isinstance(y, int):
y = paddle.to_tensor(y, dtype=x.dtype)
elif isinstance(y, float):
raise TypeError(
"unsupported operand type(s) for <<: 'float' and 'Tensor'"
)
return bitwise_right_shift(y, x, is_arithmetic, None, None)


def copysign(x: Tensor, y: Tensor | float, name: str | None = None) -> Tensor:
r"""
Create a new floating-point tensor with the magnitude of input ``x`` and the sign of ``y``, elementwise.
Expand Down
176 changes: 176 additions & 0 deletions test/legacy_test/test_bitwise_shift_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,13 @@ def test_static_api_arithmetic(self):
x,
y,
)
out_ = x << y
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out])
res_ = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out_])
out_ref = ref_left_shift_arithmetic(self.x, self.y)
np.testing.assert_allclose(out_ref, res[0])
np.testing.assert_allclose(out_ref, res_[0])

def test_dygraph_api_arithmetic(self):
paddle.disable_static()
Expand All @@ -96,8 +99,10 @@ def test_dygraph_api_arithmetic(self):
x,
y,
)
out_ = x << y
out_ref = ref_left_shift_arithmetic(self.x, self.y)
np.testing.assert_allclose(out_ref, out.numpy())
np.testing.assert_allclose(out_ref, out_.numpy())
paddle.enable_static()

def test_static_api_logical(self):
Expand All @@ -106,18 +111,23 @@ def test_static_api_logical(self):
x = paddle.static.data('x', self.x.shape, dtype=self.x.dtype)
y = paddle.static.data('y', self.y.shape, dtype=self.y.dtype)
out = paddle.bitwise_left_shift(x, y, False)
out_ = x.__lshift__(y, False)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out])
res_ = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out_])
out_ref = ref_left_shift_logical(self.x, self.y)
np.testing.assert_allclose(out_ref, res[0])
np.testing.assert_allclose(out_ref, res_[0])

def test_dygraph_api_logical(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
y = paddle.to_tensor(self.y)
out = paddle.bitwise_left_shift(x, y, False)
out_ = x.__lshift__(y, False)
out_ref = ref_left_shift_logical(self.x, self.y)
np.testing.assert_allclose(out_ref, out.numpy())
np.testing.assert_allclose(out_ref, out_.numpy())
paddle.enable_static()


Expand Down Expand Up @@ -235,6 +245,67 @@ def init_input(self):
self.y = np.array([10], dtype='uint8')


class TestTensorRlshiftAPI(unittest.TestCase):
def setUp(self):
self.init_input()
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)

def init_input(self):
self.x = np.random.randint(-255, 256)
self.y = np.random.randint(0, 256, [200, 300]).astype('int32')

def test_dygraph_tensor_rlshift(self):
paddle.disable_static()
x = self.x
y = paddle.to_tensor(self.y, dtype=self.y.dtype)
out = x << y
expected_out = x << y.numpy()
np.testing.assert_allclose(out.numpy(), expected_out)
paddle.enable_static()

def test_static_rlshift(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = self.x
y = paddle.static.data('y', self.y.shape, dtype=self.y.dtype)
out = x << y
exe = paddle.static.Executor(self.place)
res = exe.run(
feed={'x': self.x, 'y': self.y},
fetch_list=[out],
)
out_ref = ref_left_shift_arithmetic(self.x, self.y)
np.testing.assert_allclose(out_ref, res[0])


class TestTensorRlshiftAPI_UINT8(TestTensorRlshiftAPI):
def init_input(self):
self.x = np.random.randint(0, 64)
self.y = np.random.randint(0, 64, [200, 300]).astype('uint8')


class TestTensorRlshiftAPI_INT8(TestTensorRlshiftAPI):
def init_input(self):
self.x = np.random.randint(-64, 64)
self.y = np.random.randint(0, 64, [200, 300]).astype('int8')


class TestTensorRlshiftAPI_INT16(TestTensorRlshiftAPI):
def init_input(self):
self.x = np.random.randint(-256, 256)
self.y = np.random.randint(0, 256, [200, 300]).astype('int16')


class TestTensorRlshiftAPI_INT64(TestTensorRlshiftAPI):
def init_input(self):
self.x = np.random.randint(-255, 256)
self.y = np.random.randint(0, 256, [200, 300]).astype('int64')


class TestBitwiseRightShiftAPI(unittest.TestCase):
def setUp(self):
self.init_input()
Expand All @@ -257,10 +328,13 @@ def test_static_api_arithmetic(self):
x,
y,
)
out_ = x >> y
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out])
res_ = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out_])
out_ref = ref_right_shift_arithmetic(self.x, self.y)
np.testing.assert_allclose(out_ref, res[0])
np.testing.assert_allclose(out_ref, res_[0])

def test_dygraph_api_arithmetic(self):
paddle.disable_static()
Expand All @@ -270,8 +344,10 @@ def test_dygraph_api_arithmetic(self):
x,
y,
)
out_ = x >> y
out_ref = ref_right_shift_arithmetic(self.x, self.y)
np.testing.assert_allclose(out_ref, out.numpy())
np.testing.assert_allclose(out_ref, out_.numpy())
paddle.enable_static()

def test_static_api_logical(self):
Expand All @@ -280,18 +356,23 @@ def test_static_api_logical(self):
x = paddle.static.data('x', self.x.shape, dtype=self.x.dtype)
y = paddle.static.data('y', self.y.shape, dtype=self.y.dtype)
out = paddle.bitwise_right_shift(x, y, False)
out_ = x.__rshift__(y, False)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out])
res_ = exe.run(feed={'x': self.x, 'y': self.y}, fetch_list=[out_])
out_ref = ref_right_shift_logical(self.x, self.y)
np.testing.assert_allclose(out_ref, res[0])
np.testing.assert_allclose(out_ref, res_[0])

def test_dygraph_api_logical(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
y = paddle.to_tensor(self.y)
out = paddle.bitwise_right_shift(x, y, False)
out_ = x.__rshift__(y, False)
out_ref = ref_right_shift_logical(self.x, self.y)
np.testing.assert_allclose(out_ref, out.numpy())
np.testing.assert_allclose(out_ref, out_.numpy())
paddle.enable_static()


Expand Down Expand Up @@ -409,6 +490,101 @@ def init_input(self):
self.y = np.array([10], dtype='uint8')


class TestTensorRrshiftAPI(unittest.TestCase):
def setUp(self):
self.init_input()
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)

def init_input(self):
self.x = np.random.randint(-255, 256)
self.y = np.random.randint(0, 256, [200, 300]).astype('int32')

def test_dygraph_tensor_rrshift(self):
paddle.disable_static()
x = self.x
y = paddle.to_tensor(self.y, dtype=self.y.dtype)
out = x >> y
expected_out = x >> y.numpy()
np.testing.assert_allclose(out.numpy(), expected_out)
paddle.enable_static()

def test_static_rrshift(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = self.x
y = paddle.static.data('y', self.y.shape, dtype=self.y.dtype)
out = x >> y
exe = paddle.static.Executor(self.place)
res = exe.run(
feed={'x': self.x, 'y': self.y},
fetch_list=[out],
)
out_ref = ref_right_shift_arithmetic(self.x, self.y)
np.testing.assert_allclose(out_ref, res[0])


class TestTensorRrshiftAPI_UINT8(TestTensorRrshiftAPI):
def init_input(self):
self.x = np.random.randint(0, 64)
self.y = np.random.randint(0, 64, [200, 300]).astype('uint8')


class TestTensorRrshiftAPI_INT8(TestTensorRrshiftAPI):
def init_input(self):
self.x = np.random.randint(-64, 64)
self.y = np.random.randint(0, 64, [200, 300]).astype('int8')


class TestTensorRrshiftAPI_INT16(TestTensorRrshiftAPI):
def init_input(self):
self.x = np.random.randint(-256, 256)
self.y = np.random.randint(0, 256, [200, 300]).astype('int16')


class TestTensorRrshiftAPI_INT64(TestTensorRrshiftAPI):
def init_input(self):
self.x = np.random.randint(-255, 256)
self.y = np.random.randint(0, 256, [200, 300]).astype('int64')


class TestTensorShiftAPI_FLOAT(unittest.TestCase):
def setup(self):
paddle.disable_static()
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)

def test_lshift_float(self):
x = paddle.to_tensor(np.random.randint(-255, 256, [200, 300]))
y = np.random.uniform(0, 256)
with self.assertRaises(TypeError):
x.__lshift__(y)

def test_rshift_float(self):
x = paddle.to_tensor(np.random.randint(-255, 256, [200, 300]))
y = np.random.uniform(0, 256)
with self.assertRaises(TypeError):
x.__rshift__(y)

def test_rlshift_float(self):
x = np.random.uniform(0, 256)
y = paddle.to_tensor(np.random.randint(-255, 256, [200, 300]))
with self.assertRaises(TypeError):
y.__rlshift__(x)

def test_rrshift_float(self):
x = np.random.uniform(0, 256)
y = paddle.to_tensor(np.random.randint(-255, 256, [200, 300]))
with self.assertRaises(TypeError):
y.__rrshift__(x)


if __name__ == '__main__':
paddle.enable_static()
unittest.main()

0 comments on commit 138ac11

Please sign in to comment.