You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
I am trying to deploy a bert model I finetuned for text classification. The model is stored in s3. When I try to deploy it, contents of my custom code/inference.py file in the model.tar.gz are getting modified. Hence, I see a model load error in the logs.
To reproduce
Below is my custom inference.py
import os
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import json
import logging
# Initialize logging
logging.basicConfig(level=logging.INFO)
def model_fn(model_dir):
"""Load model and tokenizer from the specified directory."""
logging.info(f"Loading model and tokenizer from {model_dir}")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
logging.info("Model and tokenizer loaded successfully")
return model, tokenizer
def input_fn(request_body, request_content_type):
"""Parse input data from request body."""
if request_content_type == 'application/json':
input_data = json.loads(request_body)
return input_data['text']
else:
raise ValueError(f"Unsupported content type: {request_content_type}")
def predict_fn(input_data, model_tokenizer):
"""Make a prediction based on the input data and model."""
model, tokenizer = model_tokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
try:
# Tokenize input
inputs = tokenizer(input_data, return_tensors="pt").to(device)
# Perform prediction
with torch.no_grad():
output = model(**inputs)
return output
except Exception as e:
logging.error(f"Error during prediction: {str(e)}")
raise ValueError(f"Error during prediction: {str(e)}")
def output_fn(prediction, content_type):
"""Convert the model prediction output to a JSON serializable format."""
if content_type == 'application/json':
if hasattr(prediction, 'logits'):
return json.dumps(prediction.logits[0].cpu().numpy().tolist())
else:
raise ValueError("Prediction output does not contain logits")
else:
raise ValueError(f"Unsupported content type: {content_type}")
It is being replaced by the following:
import os
import torch
from safetensors.torch import load_file
from transformers import BertTokenizer
def model_fn(model_dir):
# Load model
model_path = os.path.join(model_dir, 'model.safetensors')
model = load_file(model_path)
model.eval()
# Load tokenizer
tokenizer_dir = os.path.join(model_dir, 'tokenizer')
tokenizer = BertTokenizer.from_pretrained(tokenizer_dir)
return model, tokenizer
def input_fn(request_body, request_content_type):
# Assuming JSON input
import json
input_data = json.loads(request_body)
return input_data['text']
def predict_fn(input_data, model_tokenizer):
model, tokenizer = model_tokenizer
# Tokenize input
inputs = tokenizer(input_data, return_tensors="pt")
# Perform prediction
with torch.no_grad():
output = model(**inputs)
return output
def output_fn(prediction, content_type):
# Assuming the model returns a tensor
return prediction[0].numpy().tolist()
The text was updated successfully, but these errors were encountered:
Describe the bug
I am trying to deploy a bert model I finetuned for text classification. The model is stored in s3. When I try to deploy it, contents of my custom code/inference.py file in the model.tar.gz are getting modified. Hence, I see a model load error in the logs.
To reproduce
Below is my custom inference.py
It is being replaced by the following:
The text was updated successfully, but these errors were encountered: