Skip to content

Commit

Permalink
remove resize code from prev. trained model - Did not work with all m…
Browse files Browse the repository at this point in the history
…odel types
  • Loading branch information
nreimers committed Apr 21, 2021
1 parent 7f62f4b commit abdfbf0
Showing 1 changed file with 1 addition and 14 deletions.
15 changes: 1 addition & 14 deletions sentence_transformers/cross_encoder/CrossEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,11 @@ def __init__(self, model_name:str, num_labels:int = None, max_length:int = None,
if num_labels is None and not classifier_trained:
num_labels = 1

resize_num_labels = None
if num_labels is not None:
if hasattr(self.config, 'num_labels') and self.config.num_labels is not None and self.config.num_labels != num_labels:
#Resize classifier head
resize_num_labels = num_labels
else:
self.config.num_labels = num_labels
self.config.num_labels = num_labels

self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_args)

if resize_num_labels is not None:
print("Warning: Loaded model was trained for {} labels. Resize CrossEncoder to {} labels. You must re-train this model to get meaningful predictions".format(self.config.num_labels, resize_num_labels))
self.config.num_labels = resize_num_labels
self.model.config.num_labels = resize_num_labels
self.model.classifier = torch.nn.Linear(self.config.hidden_size, resize_num_labels)

self.max_length = max_length

if device is None:
Expand Down Expand Up @@ -214,7 +202,6 @@ def fit(self,
logits = activation_fct(model_predictions.logits)
if self.config.num_labels == 1:
logits = logits.view(-1)

loss_value = loss_fct(logits, labels)
loss_value.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)
Expand Down

0 comments on commit abdfbf0

Please sign in to comment.