We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hi, trying to use torch2trt on a simple example. But i receive a strange sefault on nn.Linear. Here are details.
torch2trt
Libs:
tensorrt-cu12==10.0.1 tensorrt-cu12-bindings==10.0.1 tensorrt-cu12-libs==10.0.1 torch==2.0.0 torch2trt==0.5.0
Code:
class Example(nn.Module): def __init__( self, c_in: int, d_in: int = 2, out_channels: int = 9, n_layer: int = 3, n_head: int = 4, n_embd: int = 256, n_tokens: int = 32, bias: bool = False, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster ): super().__init__() self.n_embd = n_embd self.n_layer = n_layer self.pe = nn.Embedding(n_tokens, n_embd) self.encoder_c = nn.Linear(c_in, n_embd) def forward(self, features_c, d_emb, K_caches, V_caches): features_c = features_c.unsqueeze(0) # 1, X, 1, c_in d_emb = d_emb.unsqueeze(0) # 1, X, H T = K_caches.shape[3] + 1 print(features_c.device, d_emb.device, K_caches.device, V_caches.device, self.encoder_c.weight.device) return features_c # return here works fine features_emb = self.encoder_c(features_c) return features_emb # return here fails with segfault ... n_x = 180 n_tokens = 32 c_in = 339 d_in = 4 model = Example(c_in=c_in, d_in=d_in) model = model.to("cuda:0").eval() model_trt = torch2trt(mlp_last, [x_last.to("cuda:0"), d_emb.to("cuda:0"), k_cache.to("cuda:0"), v_cache.to("cuda:0")], log_level=trt.Logger.VERBOSE)
The text was updated successfully, but these errors were encountered:
Hi @brazhenko ,
Thanks for reaching out!
Does this issue occur if you call torch2trt(..., use_onnx=True)?
torch2trt(..., use_onnx=True)
John
Sorry, something went wrong.
No branches or pull requests
Hi, trying to use
torch2trt
on a simple example. But i receive a strange sefault on nn.Linear. Here are details.Libs:
Code:
The text was updated successfully, but these errors were encountered: