尽管我们认为有使用NumPy的经验不是阅读本书的先决条件,但由于NumPy在Python数据科学生态系统中无处不在,我们强烈建议你熟悉NumPy。PyTorch张量可以方便地转换为NumPy数组,反之亦然。这样,你就可以使用围绕NumPy数组类型构建的更广泛的Python生态系统中的大量功能。
与NumPy数组的这种零拷贝互通性是由于(PyTorch的)存储是遵守Python缓冲协议的。
要从points
张量创建NumPy数组,请调用
points = torch.ones(3, 4)
points_np = points.numpy()
points_np
输出:
array([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float32)
它返回尺寸、形状和数值类型正确的NumPy多维数组。有趣的是,返回的数组与张量存储共享一个基础缓冲区。因此,只要数据位于CPU RAM中,numpy
方法就可以几乎零花费地高效执行,并且修改得到的NumPy数组会导致原始张量发生变化。
如果在GPU上分配了张量,(调用numpy
方法时)PyTorch会将张量的内容复制到在CPU上分配的NumPy数组中。
相反,你可以通过以下方式从NumPy数组创建PyTorch张量:
points = torch.from_numpy(points_np)
from_numpy
使用相同的缓冲共享策略。