diff --git a/extension/training/examples/XOR/export_model.py b/extension/training/examples/XOR/export_model.py index c2cff7d428..a245361e18 100644 --- a/extension/training/examples/XOR/export_model.py +++ b/extension/training/examples/XOR/export_model.py @@ -24,7 +24,7 @@ def _export_model(): # Captures the forward graph. The graph will look similar to the model definition now. # Will move to export_for_training soon which is the api planned to be supported in the long term. - ep = export(net, (x, torch.ones(1, dtype=torch.int64))) + ep = export(net, (x, torch.ones(1, dtype=torch.int64)), strict=True) # Captures the backward graph. The exported_program now contains the joint forward and backward graph. ep = _export_forward_backward(ep) # Lower the graph to edge dialect.