Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Paddle Tensor 第二期 常用API复数类型支持 NO.6】 添加 full 函数复数类型支持 #70277

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/tensorrt_op.h"
#include "paddle/phi/common/complex.h"
#include "paddle/pir/include/core/builtin_op.h"
#include "paddle/pir/include/core/parameter.h"

namespace paddle::dialect {

pir::Value builtin_combine(const std::vector<pir::Value>& x) {
Expand Down Expand Up @@ -54,6 +56,22 @@ std::vector<pir::Value> add_n_grad(const std::vector<pir::Value>& inputs,
return inputs_grad;
}

pir::Value full(const std::vector<int64_t>& shape,
double real,
double imag,
phi::DataType dtype,
const phi::Place& place) {
CheckDataType(dtype, "dtype", "full");
if (dtype == phi::DataType::COMPLEX64) {
dtype = phi::DataType::FLOAT32;
} else {
dtype = phi::DataType::FLOAT64;
}
pir::Value real_tmp = full(shape, real, dtype, place);
pir::Value imag_tmp = full(shape, imag, dtype, place);
return paddle::dialect::complex(real_tmp, imag_tmp);
}

pir::Value zeros_like(const pir::Value& x,
const phi::DataType dtype,
const Place& place) {
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ pir::Value ones(const std::vector<int64_t>& shape,
phi::DataType dtype = phi::DataType::FLOAT32,
const Place& place = phi::CPUPlace());

pir::Value full(const std::vector<int64_t>& shape,
double real,
double imag,
phi::DataType dtype = phi::DataType::FLOAT32,
const Place& place = phi::CPUPlace());

pir::Value ones_like(pir::Value x_,
phi::DataType dtype = phi::DataType::UNDEFINED,
const Place& place = {});
Expand Down
43 changes: 32 additions & 11 deletions paddle/fluid/pybind/manual_static_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,23 @@ PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs) {
!PyObject_CheckIRVectorOfValue(shape_obj) &&
!PyObject_CheckIRValue(value_obj)) {
std::vector<int64_t> shape = CastPyArg2Longs(shape_obj, "full", 0);
double value = CastPyArg2Double(value_obj, "full", 1);
CallStackRecorder callstack_recoder("full");
callstack_recoder.Record();
auto static_api_out = paddle::dialect::full(shape, value, dtype, place);
callstack_recoder.AttachToOps();
return ToPyObject(static_api_out);
if (PyComplex_Check(value_obj)) {
phi::dtype::complex<float> complex_value =
CastPyArg2Complex(value_obj, "full", 1);
CallStackRecorder callstack_recoder("full");
callstack_recoder.Record();
auto static_api_out = paddle::dialect::full(
shape, complex_value.real, complex_value.imag, dtype, place);
callstack_recoder.AttachToOps();
return ToPyObject(static_api_out);
} else {
double value = CastPyArg2Double(value_obj, "full", 1);
CallStackRecorder callstack_recoder("full");
callstack_recoder.Record();
auto static_api_out = paddle::dialect::full(shape, value, dtype, place);
callstack_recoder.AttachToOps();
return ToPyObject(static_api_out);
}
} else {
pir::Value shape, value;

Expand All @@ -180,11 +191,21 @@ PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs) {
if (PyObject_CheckIRValue(value_obj)) {
value = CastPyArg2Value(value_obj, "full", 1, false);
} else {
double value_tmp = CastPyArg2Double(value_obj, "full", 1);
value = paddle::dialect::full(std::vector<int64_t>{1},
value_tmp,
phi::DataType::FLOAT32,
phi::CPUPlace());
if (PyComplex_Check(value_obj)) {
phi::dtype::complex<float> complex_value_tmp =
CastPyArg2Complex(value_obj, "full", 1);
value = paddle::dialect::full(std::vector<int64_t>{1},
complex_value_tmp.real,
complex_value_tmp.imag,
dtype,
place);
} else {
double value_tmp = CastPyArg2Double(value_obj, "full", 1);
value = paddle::dialect::full(std::vector<int64_t>{1},
value_tmp,
phi::DataType::FLOAT32,
phi::CPUPlace());
}
}

CallStackRecorder callstack_recoder("full_with_tensor");
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import builtins
import math
import re
import warnings
Expand Down Expand Up @@ -1489,7 +1490,10 @@ def full(
"""

if dtype is None:
dtype = paddle.get_default_dtype()
if isinstance(fill_value, (builtins.complex)):
dtype = "complex128"
else:
dtype = paddle.get_default_dtype()
Copy link
Contributor Author

@MrXnneHang MrXnneHang Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照规范化测试,full默认 dtype 应该是float64int64complex128
但是似乎这个会让很多旧有的引用到full测试失败,大多数是调用时没有指定dtype最终导致与要求输入类型不匹配引起的。
image
特别是原本似乎存在这样的Bug。不指定bool的话True会被默认转成float,这就有点离谱了。

我先把我自己的代码测一遍,之后再考虑default dtype的问题。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果需要修改默认dtype,特别是bool的,我可以再添加。
我应该可以修复遗留的test问题。应该只要给没有指定 dtype 的调用加上一个dtype="float32"就行。

Copy link
Contributor

@HydrogenSulfate HydrogenSulfate Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照规范化测试,full默认 dtype 应该是float64int64complex128。 但是似乎这个会让很多旧有的引用到full测试失败,大多数是调用时没有指定dtype最终导致与要求输入类型不匹配引起的。 image 特别是原本似乎存在这样的Bug。不指定bool的话True会被默认转成float,这就有点离谱了。

我先把我自己的代码测一遍,之后再考虑default dtype的问题。

这个默认的dtype建议不要改动,否则会引起已有模型的兼容性问题,如果存在兼容性问题,我这边可以选择跳过这个单测。bool转float的错误情况除外

Copy link
Contributor Author

@MrXnneHang MrXnneHang Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那么默认 dtype 我就不动了。我去把bool加上去。


return fill_constant(shape=shape, dtype=dtype, value=fill_value, name=name)

Expand Down
163 changes: 162 additions & 1 deletion test/legacy_test/test_full_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,60 @@ def test_api(self):
)
out_8 = paddle.full(shape=10, dtype=np.float32, fill_value=val)

out_9 = paddle.full(
shape=10, dtype="complex64", fill_value=1.1 + 1.1j
)

out_10 = paddle.full(
shape=10, dtype="complex128", fill_value=1.1 + 1.1j
)

out_11 = paddle.full(
shape=10, dtype="complex64", fill_value=1.1 + np.inf * 1j
)

out_12 = paddle.full(
shape=10, dtype="complex128", fill_value=1.1 + np.inf * 1j
)

out_13 = paddle.full(
shape=10, dtype="complex64", fill_value=1.1 - np.inf * 1j
)

out_14 = paddle.full(
shape=10, dtype="complex128", fill_value=1.1 - np.inf * 1j
)

out_15 = paddle.full(
shape=10, dtype="complex64", fill_value=1.1 + np.nan * 1j
)

out_16 = paddle.full(
shape=10, dtype="complex128", fill_value=1.1 + np.nan * 1j
)

out_17 = paddle.full(shape=10, fill_value=1.1 + 1.1j)

exe = base.Executor(place=base.CPUPlace())
res_1, res_2, res_3, res_4, res_5, res_6, res_7, res_8 = exe.run(
(
res_1,
res_2,
res_3,
res_4,
res_5,
res_6,
res_7,
res_8,
res_9,
res_10,
res_11,
res_12,
res_13,
res_14,
res_15,
res_16,
res_17,
) = exe.run(
paddle.static.default_main_program(),
feed={
"shape_tensor_int32": np.array([1, 2]).astype("int32"),
Expand All @@ -83,6 +135,15 @@ def test_api(self):
out_6,
out_7,
out_8,
out_9,
out_10,
out_11,
out_12,
out_13,
out_14,
out_15,
out_16,
out_17,
],
)

Expand Down Expand Up @@ -110,6 +171,37 @@ def test_api(self):
np.testing.assert_array_equal(
res_8, np.full([10], 1.1, dtype="float32")
)
np.testing.assert_allclose(
res_9, np.full([10], 1.1 + 1.1j, dtype="complex64")
)
np.testing.assert_allclose(
res_10, np.full([10], 1.1 + 1.1j, dtype="complex128")
)
np.testing.assert_allclose(
res_9, np.full([10], 1.1 + 1.1j, dtype="complex64")
)
np.testing.assert_allclose(
res_10, np.full([10], 1.1 + 1.1j, dtype="complex128")
)
np.testing.assert_allclose(
res_11, np.full([10], 1.1 + np.inf * 1j, dtype="complex64")
)
np.testing.assert_allclose(
res_12, np.full([10], 1.1 + np.inf * 1j, dtype="complex128")
)
np.testing.assert_allclose(
res_13, np.full([10], 1.1 - np.inf * 1j, dtype="complex64")
)
np.testing.assert_allclose(
res_14, np.full([10], 1.1 - np.inf * 1j, dtype="complex128")
)
np.testing.assert_allclose(
res_15, np.full([10], 1.1 + np.nan * 1j, dtype="complex64")
)
np.testing.assert_allclose(
res_16, np.full([10], 1.1 + np.nan * 1j, dtype="complex128")
)
np.testing.assert_allclose(res_17, np.full([10], 1.1 + 1.1j))
paddle.disable_static()

def test_api_eager(self):
Expand Down Expand Up @@ -166,6 +258,47 @@ def test_api_eager(self):

out_11 = paddle.full(shape=10, dtype="float32", fill_value=1.1)

out_12 = paddle.full(
shape=[1, 2, 3], dtype="complex64", fill_value=1.1 + 1.1j
)

out_13 = paddle.full(
shape=[1, 2, 3], dtype="complex128", fill_value=1.1 + 1.1j
)

out_14 = paddle.full(
shape=[1, 2, 3], dtype="complex64", fill_value=1.1 + np.inf * 1j
)

out_15 = paddle.full(
shape=[1, 2, 3],
dtype="complex128",
fill_value=1.1 + np.inf * 1j,
)

out_16 = paddle.full(
shape=[1, 2, 3], dtype="complex64", fill_value=1.1 - np.inf * 1j
)

out_17 = paddle.full(
shape=[1, 2, 3],
dtype="complex128",
fill_value=1.1 - np.inf * 1j,
)

out_18 = paddle.full(
shape=[1, 2, 3], dtype="complex64", fill_value=1.1 + np.nan * 1j
)

out_19 = paddle.full(
shape=[1, 2, 3],
dtype="complex128",
fill_value=1.1 + np.nan * 1j,
)

# test without dtype input for complex
out_20 = paddle.full(shape=[1, 2, 3], fill_value=1.1 + 1.1j)

np.testing.assert_array_equal(
out_1, np.full([1, 2], 1.1, dtype="float32")
)
Expand Down Expand Up @@ -199,6 +332,34 @@ def test_api_eager(self):
np.testing.assert_array_equal(
out_11, np.full([10], 1.1, dtype="float32")
)
np.testing.assert_allclose(
out_12, np.full([1, 2, 3], 1.1 + 1.1j, dtype="complex64")
)
np.testing.assert_allclose(
out_13, np.full([1, 2, 3], 1.1 + 1.1j, dtype="complex128")
)
np.testing.assert_allclose(
out_14, np.full([1, 2, 3], 1.1 + np.inf * 1j, dtype="complex64")
)
np.testing.assert_allclose(
out_15,
np.full([1, 2, 3], 1.1 + np.inf * 1j, dtype="complex128"),
)
np.testing.assert_allclose(
out_16, np.full([1, 2, 3], 1.1 - np.inf * 1j, dtype="complex64")
)
np.testing.assert_allclose(
out_17,
np.full([1, 2, 3], 1.1 - np.inf * 1j, dtype="complex128"),
)
np.testing.assert_allclose(
out_18, np.full([1, 2, 3], 1.1 + np.nan * 1j, dtype="complex64")
)
np.testing.assert_allclose(
out_19,
np.full([1, 2, 3], 1.1 + np.nan * 1j, dtype="complex128"),
)
np.testing.assert_allclose(out_20, np.full([1, 2, 3], 1.1 + 1.1j))


class TestFullOpError(unittest.TestCase):
Expand Down