Skip to content

Commit

Permalink
docs(xla): update xla doc
Browse files Browse the repository at this point in the history
  • Loading branch information
chuchaoqun@megvii.com committed Feb 22, 2024
1 parent d2a2d0f commit e2ee8c1
Showing 1 changed file with 137 additions and 46 deletions.
183 changes: 137 additions & 46 deletions source/user-guide/model-development/jit/xla.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ mge_xlalib 安装命令如下:
# cuda 11.8
python3 -m pip install mge-xlalib==0.4.7+cuda11080.cudnn860 -f https://www.megengine.org.cn/whl/mge.html
xla 在编译优化时需要使用 nvptx 等工具进行运行时编译,所以我们需要在环境中安装相关依赖等,对于 cuda 11.8,nvidia 已经支持 pip 安装

.. code-block:: shell
pip install "nvidia-cuda-cupti-cu11>=11.8" "nvidia-cuda-nvcc-cu11>=11.8" "nvidia-cuda-runtime-cu11>=11.8"
对于 cuda 11.1 和 cuda 11.4,则需要手动自行安装 cuda,并把 cuda/bin 等目录加入 PATH 中。故而从性能和使用便利性上来说,如果想使用 mge-xla,更推荐使用 cuda 11.8。

XLA 编译器的使用方式与 MegEngine graph runtime 自带编译器类似, 需要用MegEngine提供的装饰器 (xla_trace)
对训练函数进行包装。 函数执行第一遍时会记录算子执行序列,以捕获静态图。 后续执行会把静态图用XLA编译, 并调用编译好的
XLA executable 加速训练过程。
Expand Down Expand Up @@ -85,36 +93,95 @@ XLA executable 加速训练过程。
print (xla_fused_softmax(inp)) # run in xla
如果我们想看到 mge 和 xla 优化的一些中间 IR 表示,可以通过设置环境变量 MGE_VERBOSE_XLA_IR 来打印相关结果。MGE_VERBOSE_XLA_IR 为 1 时,会打印 mge trace 出来的图 IR,MGE_VERBOSE_XLA_IR 为 2 时,会打印xla 的 hlo 图结构,在 MGE_VERBOSE_XLA_IR 为 3 时会打印 xla 编译优化后的图结构。如果我们 export MGE_VERBOSE_XLA_IR=1 后再执行上述代码,则可以看到:

.. code-block:: python
please_realize_func_name_system_1(
0%:<256x1000x1000,f32>
) {
1%:<256x1000x1000,f32> = io_mark_var(0%:<256x1000x1000,f32>)
2%:<256x1000x1,f32> = ReduceMAX(1%:<256x1000x1000,f32>)
3%:<256x1000x1000,f32> = SUB(1%:<256x1000x1000,f32>, 2%:<256x1000x1,f32>)
4%:<256x1000x1000,f32> = EXP(3%:<256x1000x1000,f32>)
5%:<256x1000x1,f32> = ReduceSUM(4%:<256x1000x1000,f32>)
6%:<256x1000x1000,f32> = TRUE_DIV(4%:<256x1000x1000,f32>, 5%:<256x1000x1,f32>)
7%:<256x1000x1000,f32> = io_mark_var(6%:<256x1000x1000,f32>)
return 1 7%:<256x1000x1000,f32>
}
当模型训练迭代(Iteration)完全静态的情况下, 您也可以使用 jit.xla_trace 装饰器将训练迭代全部交由XLA执行。
需要将 optimizer, module 作为train_func 参数传入,同时 train_func 中需包含包含模型前向、 反向
、 参数更新等代码,
代码示例如下:

.. code-block:: python
:emphasize-lines: 3-12, 20
from megengine.jit import xla_trace
@xla_trace(capture_as_const=True) #capture_as_const为True时, 所有不在train_func 参数列表里的外部Tensor会被当成常量捕获
def train_func(data, label, *, opt, net):
gm = GradManager()
gm.attach(net.parameters())
with gm:
logits = net(data)
loss = F.loss.cross_entropy(logits, label)
gm.backward(loss)
opt.step().clear_grad()
return loss
for epoch in range(total_epochs):
total_loss = 0
for step, (batch_data, batch_label) in enumerate(dataloader):
data = mge.tensor(batch_data)
label = mge.tensor(batch_label)
loss = train_func(data, label, opt=optimizer, net=model)
total_loss += loss.numpy().item()
print("epoch: {}, loss {}".format(epoch, total_loss/len(dataloader)))
:emphasize-lines: 44-51, 58
from functools import partial
import numpy as np
import megengine
import megengine.autodiff as autodiff
import megengine.functional as F
import megengine.module as M
from megengine import distributed as dist
from megengine.jit import partial_trace, xla_trace
from megengine.optimizer import AdamW
class ConvNet(M.Module):
def __init__(self):
super().__init__()
self.conv1 = M.Conv2d(3, 6, 5, bias=False)
self.bn1 = M.BatchNorm2d(6)
self.conv2 = M.Conv2d(6, 16, 5, bias=False)
self.bn2 = M.BatchNorm2d(16)
self.fc1 = M.Linear(16 * 5 * 5, 120)
self.fc2 = M.Linear(120, 84)
self.classifier = M.Linear(84, 10)
self.pool = M.AvgPool2d(2, 2)
def forward(self, x):
x = self.pool(self.bn1(self.conv1(x)))
x = self.pool(self.bn2(self.conv2(x)))
x = F.flatten(x, 1)
x = self.fc1(x)
x = self.fc2(x)
x = self.classifier(x)
return x
@dist.launcher(n_gpus=2, device_type="gpu")
def worker():
def runner():
model = ConvNet()
model.train()
dist.bcast_list_(model.tensors())
cblist = [dist.make_allreduce_cb("mean")]
gm = autodiff.GradManager().attach(model.parameters(), callbacks=cblist)
optimizer = AdamW(model.parameters(), lr=0.01)
@xla_trace(without_host=True, capture_as_const=True)
def func(model, optimizer, timage, tlabel):
with gm:
score = model(timage)
loss = F.nn.cross_entropy(score, tlabel)
gm.backward(loss)
optimizer.step().clear_grad()
return loss
image = np.random.randn(3, 8, 3, 32, 32)
label = np.random.randint(0, 10, (3, 8,))
for i in range(6):
timage = megengine.Tensor(image[i % 3])
tlabel = megengine.Tensor(label[i % 3])
loss = func(model, optimizer, timage, tlabel)
print(loss)
runner()
worker()
.. _partial_trace:

Expand All @@ -127,26 +194,50 @@ XLA executable 加速训练过程。
代码示例如下:

.. code-block:: python
:emphasize-lines: 3-5, 15
from megengine.jit import partial_trace
@partial_trace(backend="xla", capture_as_const=True)
def backbone(model, inp):
return model(inp)
for epoch in range(total_epochs):
total_loss = 0
gm = GradManager()
gm.attach(net.parameters())
for step, (batch_data, batch_label) in enumerate(dataloader):
data = mge.tensor(batch_data)
label = mge.tensor(batch_label)
with gm:
logits = backbone(net, data)
loss = F.loss.cross_entropy(logits, label)
gm.backward(loss)
opt.step().clear_grad()
total_loss += loss.numpy().item()
print("epoch: {}, loss {}".format(epoch, total_loss/len(dataloader)))
:emphasize-lines: 12-27
@dist.launcher(n_gpus=2, device_type="gpu")
def worker():
def runner():
model = ConvNet()
model.train()
dist.bcast_list_(model.tensors())
cblist = [dist.make_allreduce_cb("mean")]
gm = autodiff.GradManager().attach(model.parameters(), callbacks=cblist)
optimizer = AdamW(model.parameters(), lr=0.01)
model.forward = partial(
partial_trace(
func=type(model).forward,
backend="xla",
capture_as_const=True,
),
model,
)
optimizer._updates = partial(
partial_trace(
func=type(optimizer)._updates,
backend="xla",
capture_as_const=True,
),
optimizer,
)
image = np.random.randn(3, 8, 3, 32, 32)
label = np.random.randint(0, 10, (3, 8,))
for i in range(6):
timage = megengine.Tensor(image[i % 3])
tlabel = megengine.Tensor(label[i % 3])
with gm:
score = model(timage)
loss = F.nn.cross_entropy(score, tlabel)
gm.backward(loss)
optimizer.step().clear_grad()
print(loss)
runner()
worker()

0 comments on commit e2ee8c1

Please sign in to comment.