-
Notifications
You must be signed in to change notification settings - Fork 526
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #371 from Winter523/master
add inference/run_classifier_mt_infer.py
- Loading branch information
Showing
1 changed file
with
179 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
""" | ||
This script provides an example to wrap UER-py for multi-task classification inference. | ||
""" | ||
import sys | ||
import os | ||
import torch | ||
import argparse | ||
import collections | ||
import torch.nn as nn | ||
|
||
uer_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | ||
sys.path.append(uer_dir) | ||
|
||
from uer.embeddings import * | ||
from uer.encoders import * | ||
from uer.utils.constants import * | ||
from uer.utils import * | ||
from uer.utils.config import load_hyperparam | ||
from uer.utils.seed import set_seed | ||
from uer.utils.misc import pooling | ||
from uer.model_loader import * | ||
from uer.opts import infer_opts, tokenizer_opts, log_opts | ||
|
||
|
||
class MultitaskClassifier(nn.Module): | ||
def __init__(self, args): | ||
super(MultitaskClassifier, self).__init__() | ||
self.embedding = Embedding(args) | ||
for embedding_name in args.embedding: | ||
tmp_emb = str2embedding[embedding_name](args, len(args.tokenizer.vocab)) | ||
self.embedding.update(tmp_emb, embedding_name) | ||
self.encoder = str2encoder[args.encoder](args) | ||
self.pooling_type = args.pooling | ||
self.output_layers_1 = nn.ModuleList([nn.Linear(args.hidden_size, args.hidden_size) for _ in args.labels_num_list]) | ||
self.output_layers_2 = nn.ModuleList([nn.Linear(args.hidden_size, labels_num) for labels_num in args.labels_num_list]) | ||
|
||
def forward(self, src, tgt, seg, soft_tgt=None): | ||
""" | ||
Args: | ||
src: [batch_size x seq_length] | ||
tgt: [batch_size] | ||
seg: [batch_size x seq_length] | ||
""" | ||
# Embedding. | ||
emb = self.embedding(src, seg) | ||
# Encoder. | ||
memory_bank = self.encoder(emb, seg) | ||
# Target. | ||
memory_bank = pooling(memory_bank, seg, self.pooling_type) | ||
logits = [] | ||
for i in range(len(self.output_layers_1)): | ||
output_i = torch.tanh(self.output_layers_1[i](memory_bank)) | ||
logits_i = self.output_layers_2[i](output_i) | ||
logits.append(logits_i) | ||
|
||
return None, logits | ||
|
||
|
||
def read_dataset(args, path): | ||
dataset, columns = [], {} | ||
with open(path, mode="r", encoding="utf-8") as f: | ||
for line_id, line in enumerate(f): | ||
if line_id == 0: | ||
line = line.rstrip("\r\n").split("\t") | ||
for i, column_name in enumerate(line): | ||
columns[column_name] = i | ||
continue | ||
line = line.rstrip("\r\n").split("\t") | ||
if "text_b" not in columns: # Sentence classification. | ||
text_a = line[columns["text_a"]] | ||
src = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN]) | ||
seg = [1] * len(src) | ||
else: # Sentence pair classification. | ||
text_a, text_b = line[columns["text_a"]], line[columns["text_b"]] | ||
src_a = args.tokenizer.convert_tokens_to_ids([CLS_TOKEN] + args.tokenizer.tokenize(text_a) + [SEP_TOKEN]) | ||
src_b = args.tokenizer.convert_tokens_to_ids(args.tokenizer.tokenize(text_b) + [SEP_TOKEN]) | ||
src = src_a + src_b | ||
seg = [1] * len(src_a) + [2] * len(src_b) | ||
|
||
if len(src) > args.seq_length: | ||
src = src[: args.seq_length] | ||
seg = seg[: args.seq_length] | ||
PAD_ID = args.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0] | ||
while len(src) < args.seq_length: | ||
src.append(PAD_ID) | ||
seg.append(0) | ||
dataset.append((src, seg)) | ||
|
||
return dataset | ||
|
||
|
||
def batch_loader(batch_size, src, seg): | ||
instances_num = src.size()[0] | ||
for i in range(instances_num // batch_size): | ||
src_batch = src[i * batch_size : (i + 1) * batch_size, :] | ||
seg_batch = seg[i * batch_size : (i + 1) * batch_size, :] | ||
yield src_batch, seg_batch | ||
if instances_num > instances_num // batch_size * batch_size: | ||
src_batch = src[instances_num // batch_size * batch_size :, :] | ||
seg_batch = seg[instances_num // batch_size * batch_size :, :] | ||
yield src_batch, seg_batch | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
|
||
infer_opts(parser) | ||
|
||
tokenizer_opts(parser) | ||
|
||
parser.add_argument("--output_logits", action="store_true", help="Write logits to output file.") | ||
parser.add_argument("--output_prob", action="store_true", help="Write probabilities to output file.") | ||
parser.add_argument("--labels_num_list", default=[], nargs='+', type=int, help="Dataset labels num list.") | ||
log_opts(parser) | ||
|
||
args = parser.parse_args() | ||
|
||
# Load the hyperparameters from the config file. | ||
args = load_hyperparam(args) | ||
|
||
# Build tokenizer. | ||
args.tokenizer = str2tokenizer[args.tokenizer](args) | ||
|
||
# Build classification model and load parameters. | ||
args.soft_targets, args.soft_alpha = False, False | ||
model = MultitaskClassifier(args) | ||
model = load_model(model, args.load_model_path) | ||
|
||
# For simplicity, we use DataParallel wrapper to use multiple GPUs. | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
model = model.to(device) | ||
if torch.cuda.device_count() > 1: | ||
print("{0} GPUs are available. Let's use them.".format(torch.cuda.device_count())) | ||
model = torch.nn.DataParallel(model) | ||
|
||
dataset = read_dataset(args, args.test_path) | ||
|
||
src = torch.LongTensor([sample[0] for sample in dataset]) | ||
seg = torch.LongTensor([sample[1] for sample in dataset]) | ||
|
||
batch_size = args.batch_size | ||
instances_num = src.size()[0] | ||
|
||
print("The number of prediction instances: {0}".format(instances_num)) | ||
|
||
model.eval() | ||
|
||
with open(args.prediction_path, mode="w", encoding="utf-8") as f: | ||
f.write("label") | ||
if args.output_logits: | ||
f.write("\t" + "logits") | ||
if args.output_prob: | ||
f.write("\t" + "prob") | ||
f.write("\n") | ||
for i, (src_batch, seg_batch) in enumerate(batch_loader(batch_size, src, seg)): | ||
src_batch = src_batch.to(device) | ||
seg_batch = seg_batch.to(device) | ||
with torch.no_grad(): | ||
_, logits = model(src_batch, None, seg_batch) | ||
|
||
pred = [torch.argmax(logits_i, dim=-1) for logits_i in logits] | ||
prob = [nn.Softmax(dim=-1)(logits_i) for logits_i in logits] | ||
|
||
logits = [x.cpu().numpy().tolist() for x in logits] | ||
pred = [x.cpu().numpy().tolist() for x in pred] | ||
prob = [x.cpu().numpy().tolist() for x in prob] | ||
|
||
for j in range(len(pred[0])): | ||
f.write("|".join([str(v[j]) for v in pred])) | ||
if args.output_logits: | ||
f.write("\t" + "|".join([" ".join(["{0:.4f}".format(w) for w in v[j]]) for v in logits])) | ||
if args.output_prob: | ||
f.write("\t" + "|".join([" ".join(["{0:.4f}".format(w) for w in v[j]]) for v in prob])) | ||
f.write("\n") | ||
f.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |