Skip to content

Commit

Permalink
Merge pull request #46 from DelaramRajaei/main
Browse files Browse the repository at this point in the history
Merging ReQue and RePair Projects
  • Loading branch information
DelaramRajaei authored Nov 17, 2023
2 parents f084fd9 + 56fb084 commit 3789932
Show file tree
Hide file tree
Showing 51 changed files with 4,083 additions and 115 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: repair
name: Repair
channels:
- conda-forge
- pytorch
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import random, os, numpy as np, platform
import torch

random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
extension = '.exe' if platform.system() == 'Windows' else ""
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

settings = {
'query_refinement': True,
'cmd': ['search', 'eval','agg', 'box'], # steps of pipeline, ['pair', 'finetune', 'predict', 'search', 'eval','agg', 'box','dense_retrieve', 'stats]
'ncore': 1,
't5model': 'small.local', # 'base.gc' on google cloud tpu, 'small.local' on local machine
'iter': 5, # number of finetuning iteration for t5
'nchanges': 5, # number of changes to a query
'ranker': 'bm25', # 'qld', 'bm25', 'tct_colbert'
'batch': None, # search per batch of queries for IR search using pyserini, if None, search per query
'topk': 100, # number of retrieved documents for a query
'metric': 'map', # any valid trec_eval.9.0.4 metric like map, ndcg, recip_rank, ...
'large_ds': False,
'treclib': f'"./trec_eval.9.0.4/trec_eval{extension}"', # in non-windows, remove .exe, also for pytrec_eval, 'pytrec'
'box': {'gold': 'refined_q_metric >= original_q_metric and refined_q_metric > 0',
'platinum': 'refined_q_metric > original_q_metric',
'diamond': 'refined_q_metric > original_q_metric and refined_q_metric == 1'}
}

corpora = {
'msmarco.passage': {
'index_item': ['passage'],
'index': '../data/raw/msmarco.passage/lucene-index.msmarco-v1-passage.20220131.9ea315/',
'dense_encoder': 'castorini/tct_colbert-msmarco',
'dense_index': 'msmarco-passage-tct_colbert-hnsw',
'extcorpus': 'orcas',
'pairing': [None, 'docs', 'query'], # [context={msmarco does not have userinfo}, input={query, doc, doc(s)}, output={query, doc, doc(s)}], s means concat of docs
'lseq': {"inputs": 32, "targets": 256}, # query length and doc length for t5 model,
},
'aol-ia': {
'index_item': ['title'], # ['url'], ['title', 'url'], ['title', 'url', 'text']
'index': '../data/raw/aol-ia/lucene-index/title/', # change based on index_item
'dense_index': '../data/raw/aol-ia/dense-index/tct_colbert.title/', # change based on index_item
'dense_encoder':'../data/raw/aol-ia/dense-encoder/tct_colbert.title/', # change based on index_item
'pairing': [None, 'docs', 'query'], # [context={2 scenarios, one with userID and one without userID). input={'userid','query','doc(s)'} output={'query','doc(s)'}
'extcorpus': 'msmarco.passage',
'lseq': {"inputs": 32, "targets": 256}, # query length and doc length for t5 model,
'filter': {'minql': 1, 'mindocl': 10} # [min query length, min doc length], after merge queries with relevant 'index_item', if |query| <= minql drop the row, if |'index_item'| < mindocl, drop row
},
'robust04': {
'index': '../data/raw/robust04/lucene-index.robust04.pos+docvectors+rawdocs',
'dense_index': '../data/raw/robust04/faiss_index_robust04',
'encoded': '../data/raw/robust04/encoded_robust04',
'size': 528155,
'topics': '../data/raw/robust04/topics.robust04.txt',
'prels': '', # this will be generated after a retrieval {bm25, qld}
'w_t': 2.25, # OnFields
'w_a': 1, # OnFields
'tokens': 148000000,
'qrels': '../data/raw/robust04/qrels.robust04.txt',
'extcorpus': 'gov2', # AdaptOnFields
'pairing': [None, None, None],
'index_item': [],
},
'gov2': {
'index': '../data/raw/gov2/lucene-index.gov2.pos+docvectors+rawdocs',
'size': 25000000,
'topics': '../data/raw/gov2/topics.terabyte0{}.txt', # {} is a placeholder for subtopics in main.py -> run()
'trec': ['4.701-750', '5.751-800', '6.801-850'],
'prels': '', # this will be generated after a retrieval {bm25, qld}
'w_t': 4, # OnFields
'w_a': 0.25, # OnFields
'tokens': 17000000000,
'qrels': '../data/raw/gov2/qrels.terabyte0{}.txt', # {} is a placeholder for subtopics in main.py -> run()
'extcorpus': 'robust04', # AdaptOnFields
'pairing': [None, None, None],
'index_item': [],
},
'clueweb09b': {
'index': '../data/raw/clueweb09b/lucene-index.cw09b.pos+docvectors+rawdocs',
'size': 50000000,
'topics': '../data/raw/clueweb09b/topics.web.{}.txt', # {} is a placeholder for subtopics in main.py -> run()
'trec': ['1-50', '51-100', '101-150', '151-200'],
'prels': '', # this will be generated after a retrieval {bm25, qld}
'w_t': 1, # OnFields
'w_a': 0, # OnFields
'tokens': 31000000000,
'qrels': '../data/raw/clueweb09b/qrels.web.{}.txt', # {} is a placeholder for subtopics in main.py -> run()
'extcorpus': 'gov2', # AdaptOnFields
'pairing': [None, None, None],
'index_item': [],
},
'clueweb12b13': {
'index': '../data/raw/clueweb12b13/lucene-index.cw12b13.pos+docvectors+rawdocs',
'size': 50000000,
'topics': '../data/raw/clueweb12b13/topics.web.{}.txt', # {} is a placeholder for subtopics in main.py -> run()
'trec': ['201-250', '251-300'],
'prels': '', # this will be generated after a retrieval {bm25, qld}
'w_t': 4, # OnFields
'w_a': 0, # OnFields
'tokens': 31000000000,
'qrels': '../data/raw/clueweb12b13/qrels.web.{}.txt', # {} is a placeholder for subtopics in main.py -> run()
'extcorpus': 'gov2', # AdaptOnFields
'pairing': [None, None, None],
'index_item': [],
},
'antique': {
'index': '../data/raw/antique/lucene-index-antique',
'size': 403000,
'topics': '../data/raw/antique/topics.antique.txt',
'prels': '', # this will be generated after a retrieval {bm25, qld}
'w_t': 2.25, # OnFields # to be tuned
'w_a': 1, # OnFields # to be tuned
'tokens': 16000000,
'qrels': '../data/raw/antique/qrels.antique.txt',
'extcorpus': 'gov2', # AdaptOnFields
'pairing': [None, None, None],
'index_item': [],
},
'trec09mq': {
'index': '/data/raw/clueweb09b/lucene-index.cw09b.pos+docvectors+rawdocs',
'size': 50000000,
# 'topics': '../ds/trec2009mq/prep/09.mq.topics.20001-60000.prep.tsv',
'topics': '../data/raw/trec09mq/09.mq.topics.20001-60000.prep',
'prels': '', # this will be generated after a retrieval {bm25, qld}
'w_t': 2.25, # OnFields # to be tuned
'w_a': 1, # OnFields # to be tuned
'tokens': 16000000,
'qrels': '../data/raw/trec09mq/prels.20001-60000.prep',
'extcorpus': 'gov2', # AdaptOnFields
'pairing': [None, None, None],
'index_item': [],
},
'dbpedia': {
'index': '../data/raw/dbpedia/lucene-index-dbpedia-collection',
'size': 4632359,
'topics': '../data/raw/dbpedia/topics.dbpedia.txt',
'prels': '', # this will be generated after a retrieval {bm25, qld}
'w_t': 1, # OnFields # to be tuned
'w_a': 1, # OnFields # to be tuned
'tokens': 200000000,
'qrels': '../data/raw/dbpedia/qrels.dbpedia.txt',
'extcorpus': 'gov2', # AdaptOnFields
'pairing': [None, None, None],
'index_item': [],
},
'orcas': {
'index': '../data/raw/orcas/lucene-index.msmarco-v1-doc.20220131.9ea315',
'size': 50000000,
# 'topics': '../data/raw/trec2009mq/prep/09.mq.topics.20001-60000.prep.tsv',
'topics': '../data/raw/orcas/preprocess/orcas-I-2M_topics.prep',
'prels': '', # this will be generated after a retrieval {bm25, qld}
'w_t': 2.25, # OnFields # to be tuned
'w_a': 1, # OnFields # to be tuned
'tokens': 16000000,
'qrels': '../data/raw/orcas/preprocess/orcas-doctrain-qrels.prep',
'extcorpus': 'gov2', # AdaptOnFields
'pairing': [None, None, None],
'index_item': [],
},
}

# Only for sparse indexing
anserini = {
'path': '../anserini/',
'trec_eval': '../anserini/eval/trec_eval.9.0.4/trec_eval'
}

27 changes: 21 additions & 6 deletions src/dal/aol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from shutil import copyfile
from ftfy import fix_text
from pyserini.search.lucene import LuceneSearcher
from dal.ds import Dataset
from ds import Dataset
from query import Query
tqdm.pandas()


class Aol(Dataset):

def __init__(self, settings, homedir, ncore):
try: super(Aol, self).__init__(settings=settings)
def __init__(self, settings, domain, homedir, ncore):
try: super(Aol, self).__init__(settings=settings, domain=domain)
except: self._build_index(homedir, Dataset.settings['index_item'], Dataset.settings['index'], ncore)

@classmethod
Expand All @@ -22,7 +23,7 @@ def _build_index(cls, homedir, index_item, indexdir, ncore):
index_item_str = '.'.join(index_item)
if not os.path.isdir(f'{indexdir}/{cls.user_pairing}{index_item_str}'): os.makedirs(f'{indexdir}/{cls.user_pairing}{index_item_str}')
import ir_datasets
from cmn.lucenex import lucenex
from src.cmn.lucenex import lucenex

print(f"Setting up aol corpus using ir-datasets at {homedir}...")
aolia = ir_datasets.load("aol-ia")
Expand All @@ -48,6 +49,17 @@ def _build_index(cls, homedir, index_item, indexdir, ncore):
# for d in os.listdir(homedir):
# if not (d.find('aol-ia') > -1) and os.path.isdir(f'./../data/raw/{d}'): shutil.rmtree(f'./../data/raw/{d}')

def read_queries(cls, input, domain):
queries = pd.read_csv(f'{input}/queries.train.tsv', encoding='UTF-8', sep='\t', index_col=False, names=['qid', 'query'], converters={'query': str.lower}, header=None)
# the column order in the file is [qid, uid, did, uid]!!!! STUPID!!
qrels = pd.read_csv(f'{input}/qrels.train.tsv', encoding='UTF-8', sep='\t', index_col=False, names=['qid', 'uid', 'did', 'rel'], header=None)
qrels.to_csv(f'{input}/qrels.train.tsv_', index=False, sep='\t', header=False)
# docid is a hash of the URL. qid is the a hash of the *noramlised query* ==> two uid may have same qid then, same docid.
qrels.drop_duplicates(subset=['qid', 'did', 'uid'], inplace=True)
queries_qrels = pd.merge(queries, qrels, on='qid', how='inner', copy=False)
queries_qrels = queries_qrels.sort_values(by='qid')
cls.create_query_objects(queries_qrels, ['qid', 'uid', 'did', 'rel'], domain)

@classmethod
def create_jsonl(cls, aolia, index_item, output):
"""
Expand All @@ -69,19 +81,21 @@ def create_jsonl(cls, aolia, index_item, output):
output_jsonl_file.close()

@classmethod
def pair(cls, input, output, cat=True):
def pair(cls, queries, output, cat=True):
# TODO: change the code in a way to use read_queries
queries = pd.read_csv(f'{input}/queries.train.tsv', encoding='UTF-8', sep='\t', index_col=False, names=['qid', 'query'], converters={'query': str.lower}, header=None)
# the column order in the file is [qid, uid, did, uid]!!!! STUPID!!
qrels = pd.read_csv(f'{input}/qrels.train.tsv', encoding='UTF-8', sep='\t', index_col=False, names=['qid', 'uid', 'did', 'rel'], header=None)
# docid is a hash of the URL. qid is the a hash of the *noramlised query* ==> two uid may have same qid then, same docid.
qrels.drop_duplicates(subset=['qid', 'did', 'uid'], inplace=True)
queries_qrels = pd.merge(queries, qrels, on='qid', how='inner', copy=False)

doccol = 'docs' if cat else 'doc'
del queries
# in the event of user oriented pairing, we simply concat qid and uid
if cls.user_pairing: queries_qrels['qid'] = queries_qrels['qid'].astype(str) + "_" + queries_qrels['uid'].astype(str)
queries_qrels = queries_qrels.astype('category')
queries_qrels[doccol] = queries_qrels['did'].progress_apply(cls._txt)
queries_qrels[doccol] = queries_qrels['did'].apply(cls._txt)

# queries_qrels.drop_duplicates(subset=['qid', 'did','pid'], inplace=True) # two users with same click for same query
if not cls.user_pairing: queries_qrels['uid'] = -1
Expand Down Expand Up @@ -111,3 +125,4 @@ def pair(cls, input, output, cat=True):
qrels_splits.to_csv(f'../output/aol-ia/{cls.user_pairing}t5.base.gc.docs.query.{index_item_str}/qrels/qrels.splits.{_}.tsv_', sep='\t',
encoding='utf-8', index=False, header=False, columns=['qid', 'uid', 'did', 'rel'])
return queries_qrels
pass
Loading

0 comments on commit 3789932

Please sign in to comment.