Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/TheBotiverse/Botiverse into…
Browse files Browse the repository at this point in the history
… main
  • Loading branch information
YousefAtefB committed Jul 11, 2023
2 parents afb0828 + e4fcd93 commit 1abeea5
Show file tree
Hide file tree
Showing 42 changed files with 5,463 additions and 175 deletions.
5 changes: 5 additions & 0 deletions botiverse/Theorizer/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Datasets/*
**.txt
**.json
models/pretrained-model/*
**.bin
Empty file.
1 change: 1 addition & 0 deletions botiverse/Theorizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from botiverse.Theorizer import generate
27 changes: 27 additions & 0 deletions botiverse/Theorizer/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
'''
Common constants
'''
import os

THEORIZER_PATH= os.path.abspath(os.path.dirname(__file__))
BOTIVERSE_PATH=None
SQUAD_TRAIN_PATH=None

QUESTION_TYPES = [
"Who", "Where", "When", "Why", "Which", "What", "How", "Yes-No", "Other"]
INFO_QUESTION_TYPES = [
"Who", "Where", "When", "Why", "Which", "What", "How"]
YES_NO_QUESTION_TYPES = [
"Am", "Is", "Was", "Were", "Are",
"Does", "Do", "Did",
"Have", "Had", "Has",
"Could", "Can",
"Shall", "Should",
"Will", "Would",
"May", "Might"]
Q_TYPE2ID_DICT = {
"What": 0, "Who": 1, "How": 2,
"Where": 3, "When": 4, "Why": 5,
"Which": 6, "Boolean": 7, "Other": 8}

STOP_WORDS = []
94 changes: 94 additions & 0 deletions botiverse/Theorizer/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
from botiverse.Theorizer.model.finetuned_model import MyGPT2LMHeadModel
from botiverse.Theorizer.model.dataloader import SPECIAL_TOKENS_DICT
from botiverse.Theorizer.squad.sample_data import select_with_default_sampel_probs
from transformers import GPT2Tokenizer


def __prepare(context):
sampled_infos = select_with_default_sampel_probs(context)
instances = []
for info in sampled_infos["selected_infos"]:
for style in info["styles"]:
for clue in info["clues"]:
instances.append(
{
"paragraph": sampled_infos["context"],
"clue": clue.clue_text,
"answer": info["answer"]["answer_text"],
"style": style,
}
)
return instances


def generate(context, max_length=50):
# Load the fine-tuned model and tokenizer
model_path_or_name = "botiverse/Theorizer/model/pretrained-model"

model = MyGPT2LMHeadModel.from_pretrained(
model_path_or_name, ignore_mismatched_sizes=True
)
tokenizer = GPT2Tokenizer.from_pretrained(
model_path_or_name, **SPECIAL_TOKENS_DICT
)

instances = __prepare(context)
# print(instances)
qa_dict = {}
qa_dict["context"] = context
qa_dict["qa"] = set()
for inst in instances:
paragraph = inst["paragraph"]
clue = inst["clue"]
answer = inst["answer"]
style = inst["style"]
input_sequence = (
"<sos> "
+ paragraph
+ " <clue> "
+ clue
+ " <answer> "
+ answer
+ " <style> "
+ style
+ " <question> "
+ style
)

# Tokenize the input sequence
input_ids = tokenizer.encode(input_sequence, return_tensors="pt")
with torch.no_grad():
# Generate the question
generated = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
num_beams=5,
no_repeat_ngram_size=2,
early_stopping=True,
temperature=0.7,
)
question_start_index = (
input_ids[0].tolist().index(
tokenizer.convert_tokens_to_ids("<question>"))
)

# Slice the generated tensor to exclude the input sequence
generated_question = generated[0, question_start_index + 1:]

# Decode the generated question
question = tokenizer.decode(
generated_question, skip_special_tokens=True)

qa_dict["qa"].add((question, answer))
qa_dict["qa"] = list(qa_dict["qa"])
return qa_dict


if __name__ == "__main__":
context = "Bob is eating a delicious cake in Vancouver."
qa_dict = generate(context)
import json
print(json.dumps(qa_dict,indent=4))
Empty file.
Loading

0 comments on commit 1abeea5

Please sign in to comment.