Skip to content

Latest commit

 

History

History
30 lines (22 loc) · 1.36 KB

File metadata and controls

30 lines (22 loc) · 1.36 KB

2.6 与NumPy的互通性

尽管我们认为有使用NumPy的经验不是阅读本书的先决条件,但由于NumPy在Python数据科学生态系统中无处不在,我们强烈建议你熟悉NumPy。PyTorch张量可以方便地转换为NumPy数组,反之亦然。这样,你就可以使用围绕NumPy数组类型构建的更广泛的Python生态系统中的大量功能。

与NumPy数组的这种零拷贝互通性是由于(PyTorch的)存储是遵守Python缓冲协议的。

要从points张量创建NumPy数组,请调用

points = torch.ones(3, 4)
points_np = points.numpy()
points_np

输出:

array([[1., 1., 1., 1.],
       [1., 1., 1., 1.],
       [1., 1., 1., 1.]], dtype=float32)

它返回尺寸、形状和数值类型正确的NumPy多维数组。有趣的是,返回的数组与张量存储共享一个基础缓冲区。因此,只要数据位于CPU RAM中,numpy方法就可以几乎零花费地高效执行,并且修改得到的NumPy数组会导致原始张量发生变化。

如果在GPU上分配了张量,(调用numpy方法时)PyTorch会将张量的内容复制到在CPU上分配的NumPy数组中。

相反,你可以通过以下方式从NumPy数组创建PyTorch张量:

points = torch.from_numpy(points_np)

from_numpy使用相同的缓冲共享策略。