diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml new file mode 100644 index 0000000..5ae5460 --- /dev/null +++ b/.github/workflows/testing.yml @@ -0,0 +1,21 @@ +name: testing +on: [push] +env: + APPLICATION_NAME : WORKFLOW +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Setup Python environment + uses: actions/setup-python@v4 + with: + python-version: '3.8' + - name: Install requirements + run : pip install --quiet --requirement testing_reqs.txt + - name: Lint code + run: | + flake8 --ignore=E117,E127,E128,E231,E401,E501,E722,E701,E704,F401,F523,F841 . --exclude src/cair/,src/mdl/ + # pylint --disable=C0301 --disable=C0326 *.py + # - name: Run unit tests + # run: python -m unittest --verbose --failfast diff --git a/output/toy.aol-ia/t5.small.local.docs.query.title/param.py b/output/toy.aol-ia/t5.small.local.docs.query.title/param.py index 9827d77..00a1c97 100644 --- a/output/toy.aol-ia/t5.small.local.docs.query.title/param.py +++ b/output/toy.aol-ia/t5.small.local.docs.query.title/param.py @@ -9,27 +9,27 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '-1' settings = { - 'cmd': ['agg', 'box'],# steps of pipeline, ['pair', 'finetune', 'predict', 'search', 'eval','agg', 'box'] + 'cmd': ['agg', 'box'], # steps of pipeline, ['pair', 'finetune', 'predict', 'search', 'eval','agg', 'box'] 'ncore': multiprocessing.cpu_count(), - 't5model': 'small.local',#'base.gc', 'small.local' - 'iter': 5, #number of finetuning iteration for t5 - 'nchanges': 5, #number of changes to a query - 'ranker': 'bm25', #'qld', 'bm25' - 'batch': None, #search per batch of queries for IR search using pyserini, if None, search per query - 'topk': 10, #number of retrieved documents for a query + 't5model': 'small.local', # 'base.gc', 'small.local' + 'iter': 5, # number of finetuning iteration for t5 + 'nchanges': 5, # number of changes to a query + 'ranker': 'bm25', # 'qld', 'bm25' + 'batch': None, # search per batch of queries for IR search using pyserini, if None, search per query + 'topk': 10, # number of retrieved documents for a query 'metric': 'map', # any valid trec_eval.9.0.4 metric like map, ndcg, recip_rank, ... - 'treclib': f'"./trec_eval.9.0.4/trec_eval{extension}"',#in non-windows, remove .exe, also for pytrec_eval, 'pytrec' + 'treclib': f'"./trec_eval.9.0.4/trec_eval{extension}"', # in non-windows, remove .exe, also for pytrec_eval, 'pytrec' 'msmarco.passage': { 'index_item': ['passage'], 'index': '../data/raw/msmarco.passage/lucene-index.msmarco-v1-passage.20220131.9ea315/', - 'pairing': [None, 'docs', 'query'], #[context={msmarco does not have userinfo}, input={query, doc, doc(s)}, output={query, doc, doc(s)}], s means concat of docs - 'lseq':{"inputs": 32, "targets": 256}, #query length and doc length for t5 model, + 'pairing': [None, 'docs', 'query'], # [context={msmarco does not have userinfo}, input={query, doc, doc(s)}, output={query, doc, doc(s)}], s means concat of docs + 'lseq':{"inputs": 32, "targets": 256}, # query length and doc length for t5 model, }, 'aol-ia': { - 'index_item': ['title'], # ['url'], ['title', 'url'], ['title', 'url', 'text'] - 'index': f'../data/raw/aol-ia/lucene-index/title/', - 'pairing': [None, 'docs', 'query'], #[context={2 scenarios, one with userID and one without userID). input={'userid','query','doc(s)'} output={'query','doc(s)'} - 'lseq':{"inputs": 32, "targets": 256}, #query length and doc length for t5 model, - 'filter': {'minql': 1, 'mindocl': 10}# [min query length, min doc length], after merge queries with relevant 'index_item', if |query| <= minql drop the row, if |'index_item'| < mindocl, drop row + 'index_item': ['title'], # ['url'], ['title', 'url'], ['title', 'url', 'text'] + 'index': '../data/raw/aol-ia/lucene-index/title/', + 'pairing': [None, 'docs', 'query'], # [context={2 scenarios, one with userID and one without userID). input={'userid','query','doc(s)'} output={'query','doc(s)'} + 'lseq':{"inputs": 32, "targets": 256}, # query length and doc length for t5 model, + 'filter': {'minql': 1, 'mindocl': 10} # [min query length, min doc length], after merge queries with relevant 'index_item', if |query| <= minql drop the row, if |'index_item'| < mindocl, drop row } -} \ No newline at end of file +} diff --git a/output/toy.msmarco.passage/t5.small.local.docs.query.passage/param.py b/output/toy.msmarco.passage/t5.small.local.docs.query.passage/param.py index bcf2d43..8967320 100644 --- a/output/toy.msmarco.passage/t5.small.local.docs.query.passage/param.py +++ b/output/toy.msmarco.passage/t5.small.local.docs.query.passage/param.py @@ -18,7 +18,7 @@ 'batch': 100, # search per batch of queries for IR search using pyserini, if None, search per query 'topk': 10, # number of retrieved documents for a query 'metric': 'map', # any valid trec_eval.9.0.4 metric like map, ndcg, recip_rank, ... - 'treclib': f'"./trec_eval.9.0.4/trec_eval{extension}"',#in non-windows, remove .exe, also for pytrec_eval, 'pytrec' + 'treclib': f'"./trec_eval.9.0.4/trec_eval{extension}"', # in non-windows, remove .exe, also for pytrec_eval, 'pytrec' 'msmarco.passage': { 'index_item': ['passage'], 'index': '../data/raw/msmarco.passage/lucene-index.msmarco-v1-passage.20220131.9ea315/', @@ -26,7 +26,7 @@ 'lseq': {"inputs": 32, "targets": 256}, # query length and doc length for t5 model, }, 'aol-ia': { - 'index_item': ['title'], # ['url'], ['title', 'url'], ['title', 'url', 'text'] + 'index_item': ['title'], # ['url'], ['title', 'url'], ['title', 'url', 'text'] 'index': '../data/raw/aol-ia/lucene-index/title/', 'pairing': [None, 'docs', 'query'], # [context={2 scenarios, one with userID and one without userID). input={'userid','query','doc(s)'} output={'query','doc(s)'} 'lseq': {"inputs": 32, "targets": 256}, # query length and doc length for t5 model, diff --git a/requirements.txt b/requirements.txt index a65e71a..83a4f84 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,5 @@ pandas pyserini ir_datasets filesplit +flake8 +pylint diff --git a/src/cmn/lucenex.py b/src/cmn/lucenex.py index f985cda..5eac79d 100644 --- a/src/cmn/lucenex.py +++ b/src/cmn/lucenex.py @@ -1,5 +1,6 @@ import sys, subprocess, os + def lucenex(corpus, output, ncore): """ common code to create index using the subprocess module @@ -13,4 +14,4 @@ def lucenex(corpus, output, ncore): '--index', output, '--generator', 'DefaultLuceneDocumentGenerator', '--threads', str(ncore), '--storePositions', '--storeDocvectors', '--storeRaw', '--optimize']) - print(f'Finished creating index.') + print('Finished creating index.') diff --git a/src/cmn/refiner.py b/src/cmn/refiner.py index 4429a79..fe90264 100644 --- a/src/cmn/refiner.py +++ b/src/cmn/refiner.py @@ -1,9 +1,10 @@ import pandas as pd -#creates a train_test_split using pandas +# creates a train_test_split using pandas datasets = ['diamond', 'platinum', 'gold'] -def train_test_split(input,train_split = 0.8): + +def train_test_split(input,train_split=0.8): for ds in datasets: refiner_ds = pd.read_csv(f'{input}/{ds}.tsv', sep='\t', encoding='utf-8', names=['qid', 'query', 'map', 'query_', 'map_']) train = refiner_ds.sample(frac=train_split, random_state=200) @@ -11,4 +12,3 @@ def train_test_split(input,train_split = 0.8): train.to_csv(f'{input}/{ds}.train.tsv', sep='\t', index=False, header=False, columns=['query', 'query_']) test.to_csv(f'{input}/{ds}.test.tsv', sep='\t', index=False, header=False, columns=['query', 'query_']) print(f'saving {ds} with {train_split * 100}% train split and {int(1 - train_split) * 100}% test split at {input} ') - diff --git a/src/dal/aol.py b/src/dal/aol.py index 273352c..0d9f1e7 100644 --- a/src/dal/aol.py +++ b/src/dal/aol.py @@ -2,11 +2,10 @@ from tqdm import tqdm from shutil import copyfile from ftfy import fix_text -tqdm.pandas() - from pyserini.search.lucene import LuceneSearcher - from dal.ds import Dataset +tqdm.pandas() + class Aol(Dataset): @@ -16,8 +15,8 @@ def __init__(self, settings, homedir, ncore): @classmethod def _build_index(cls, homedir, index_item, indexdir, ncore): - print(f"Creating index from scratch using ir-dataset ...") - #https://github.com/allenai/ir_datasets + print("Creating index from scratch using ir-dataset ...") + # https://github.com/allenai/ir_datasets os.environ['IR_DATASETS_HOME'] = '/'.join(homedir.split('/')[:-1]) if not os.path.isdir(os.environ['IR_DATASETS_HOME']): os.makedirs(os.environ['IR_DATASETS_HOME']) index_item_str = '.'.join(index_item) @@ -30,12 +29,12 @@ def _build_index(cls, homedir, index_item, indexdir, ncore): print('Getting queries and qrels ...') # the column order in the file is [qid, uid, did, uid]!!!! STUPID!! qrels = pd.DataFrame.from_records(aolia.qrels_iter(), columns=['qid', 'did', 'rel', 'uid'], nrows=1) # namedtuple - queries = pd.DataFrame.from_records(aolia.queries_iter(), columns=['qid', 'query'], nrows=1)# namedtuple + queries = pd.DataFrame.from_records(aolia.queries_iter(), columns=['qid', 'query'], nrows=1) # namedtuple print('Creating jsonl collections for indexing ...') print(f'Raw documents should be downloaded already at {homedir}/aol-ia/downloaded_docs/ as explained here: https://github.com/terrierteam/aolia-tools') - print(f'But it had bugs: https://github.com/allenai/ir_datasets/issues/222') - print(f'Sean MacAvaney provided us with the downloaded_docs.tar file. Thanks Sean!') + print('But it had bugs: https://github.com/allenai/ir_datasets/issues/222') + print('Sean MacAvaney provided us with the downloaded_docs.tar file. Thanks Sean!') Aol.create_jsonl(aolia, index_item, f'{homedir}/{cls.user_pairing}{index_item_str}') if len(os.listdir(f'{indexdir}/{cls.user_pairing}{index_item_str}')) == 0: @@ -87,7 +86,7 @@ def pair(cls, input, output, cat=True): # queries_qrels.drop_duplicates(subset=['qid', 'did','pid'], inplace=True) # two users with same click for same query if not cls.user_pairing: queries_qrels['uid'] = -1 queries_qrels['ctx'] = '' - queries_qrels.dropna(inplace=True) #empty doctxt, query, ... + queries_qrels.dropna(inplace=True) # empty doctxt, query, ... queries_qrels.drop(queries_qrels[queries_qrels['query'].str.strip().str.len() <= Dataset.settings['filter']['minql']].index, inplace=True) queries_qrels.drop(queries_qrels[queries_qrels[doccol].str.strip().str.len() < Dataset.settings["filter"]['mindocl']].index, inplace=True) # remove qrels whose docs are less than mindocl queries_qrels.drop_duplicates(subset=['qid', 'did'], inplace=True) @@ -100,7 +99,7 @@ def pair(cls, input, output, cat=True): if cls.user_pairing: qrels = pd.read_csv(f'{input}/{cls.user_pairing}qrels.train.tsv_', sep='\t', index_col=False, names=['qid', 'uid', 'did', 'rel']) batch_size = 1000000 # need to make this dynamic index_item_str = '.'.join(cls.settings['index_item']) - ## create dirs: + # create dirs: if not os.path.isdir(f'../output/aol-ia/{cls.user_pairing}t5.base.gc.docs.query.{index_item_str}/original_test_queries'): os.makedirs(f'../output/aol-ia/{cls.user_pairing}t5.base.gc.docs.query.{index_item_str}/original_test_queries') if not os.path.isdir(f'../output/aol-ia/{cls.user_pairing}t5.base.gc.docs.query.{index_item_str}/qrels'): os.makedirs(f'../output/aol-ia/{cls.user_pairing}t5.base.gc.docs.query.{index_item_str}/qrels') if len(queries_qrels) > batch_size: @@ -112,4 +111,3 @@ def pair(cls, input, output, cat=True): qrels_splits.to_csv(f'../output/aol-ia/{cls.user_pairing}t5.base.gc.docs.query.{index_item_str}/qrels/qrels.splits.{_}.tsv_', sep='\t', encoding='utf-8', index=False, header=False, columns=['qid', 'uid', 'did', 'rel']) return queries_qrels - diff --git a/src/dal/ds.py b/src/dal/ds.py index 377bc62..ad7b6f2 100644 --- a/src/dal/ds.py +++ b/src/dal/ds.py @@ -5,6 +5,7 @@ from pyserini.search.lucene import LuceneSearcher from pyserini.search.faiss import FaissSearcher, TctColBertQueryEncoder + class Dataset(object): searcher = None settings = None @@ -26,7 +27,7 @@ def _txt(cls, pid): # The``docid`` is overloaded: if it is of type ``str``, it is treated as an external collection ``docid``; # if it is of type ``int``, it is treated as an internal Lucene``docid``. # stupid!! try:return json.loads(cls.searcher.doc(str(pid)).raw())['contents'].lower() - except AttributeError: return '' #if Dataset.searcher.doc(str(pid)) is None + except AttributeError: return '' # if Dataset.searcher.doc(str(pid)) is None except Exception as e: raise e @classmethod @@ -87,7 +88,6 @@ def _docids(row): hits = cls.searcher.search(row.query, k=topk, remove_dups=True) for i, h in enumerate(hits): o.write(f'{qids[row.name]}\tQ0\t{h.docid:7}\t{i + 1:2}\t{h.score:.5f}\tPyserini\n') - queries.progress_apply(_docids, axis=1) @classmethod @@ -124,7 +124,6 @@ def aggregate(cls, original, changes, output, is_large_ds=False): agg_all.write(f'{row.qid}\t{change}\t{query}\t{metric_value}\n') if metric_value > 0 and metric_value >= row[f'original.{ranker}.{metric}']: agg_gold.write(f'{row.qid}\t{change}\t{query}\t{metric_value}\n') - @classmethod def box(cls, input, qrels, output, checks): ranker = input.columns[-1].split('.')[0] # e.g., bm25.success.10 => bm25 @@ -153,8 +152,3 @@ def box(cls, input, qrels, output, checks): print(f'{c} has {df.shape[0]} queries') qrels = df.merge(qrels, on='qid', how='inner') qrels.to_csv(f'{output}/{c}.qrels.tsv', sep='\t', encoding='utf-8', index=False, header=False, columns=['qid', 'did', 'pid', 'rel']) - - - - - diff --git a/src/dal/msmarco.py b/src/dal/msmarco.py index e6287f2..c6ac671 100644 --- a/src/dal/msmarco.py +++ b/src/dal/msmarco.py @@ -2,9 +2,9 @@ from os.path import isfile,join import pandas as pd from tqdm import tqdm +from dal.ds import Dataset tqdm.pandas() -from dal.ds import Dataset class MsMarcoPsg(Dataset): @@ -15,12 +15,11 @@ def pair(cls, input, output, cat=True): queries = pd.read_csv(f'{input}/queries.train.tsv', sep='\t', index_col=False, names=['qid', 'query'], converters={'query': str.lower}, header=None) qrels = pd.read_csv(f'{input}/qrels.train.tsv', sep='\t', index_col=False, names=['qid', 'did', 'pid', 'relevancy'], header=None) qrels.drop_duplicates(subset=['qid', 'pid'], inplace=True) # qrels have duplicates!! - qrels.to_csv(f'{input}/qrels.train.tsv_', index=False, sep='\t', header=False) #trec_eval.9.0.4 does not accept duplicate rows!! + qrels.to_csv(f'{input}/qrels.train.tsv_', index=False, sep='\t', header=False) # trec_eval.9.0.4 does not accept duplicate rows!! queries_qrels = pd.merge(queries, qrels, on='qid', how='inner', copy=False) doccol = 'docs' if cat else 'doc' - queries_qrels[doccol] = queries_qrels['pid'].progress_apply(cls._txt) #100%|██████████| 532761/532761 [00:32<00:00, 16448.77it/s] + queries_qrels[doccol] = queries_qrels['pid'].progress_apply(cls._txt) # 100%|██████████| 532761/532761 [00:32<00:00, 16448.77it/s] queries_qrels['ctx'] = '' if cat: queries_qrels = queries_qrels.groupby(['qid', 'query'], as_index=False, observed=True).agg({'did': list, 'pid': list, doccol: ' '.join}) queries_qrels.to_csv(output, sep='\t', encoding='utf-8', index=False) return queries_qrels - diff --git a/src/evl/metrics.py b/src/evl/metrics.py index 97f6569..6b9f405 100644 --- a/src/evl/metrics.py +++ b/src/evl/metrics.py @@ -10,8 +10,6 @@ t5_refinement = '../../output/aol-ia/t5.base.gc.docs.query.title.url/base.refinement' -##compute f1 score - def normalize_answer(s): """Lower text and remove punctuation, articles and extra whitespace.""" @@ -69,81 +67,80 @@ def f1_score(prediction, ground_truth): print(f'BLEU score {ds}: {bleu_mean}') print(f'f1 measure {ds}: {f1_mean}') -## msmarco base refinement -#rouge score for diamond: 0.42568057136051424 -#BLEU score diamond: 0.21029683850476916 -#f1 measure diamond: 0.416822260588614 - -#rouge score for platinum: 0.42969398167649014 -#BLEU score platinum: 0.2115697317809715 -#f1 measure platinum: 0.41915060043151015 +# msmarco base refinement +# rouge score for diamond: 0.42568057136051424 +# BLEU score diamond: 0.21029683850476916 +# f1 measure diamond: 0.416822260588614 -#rouge score for gold: 0.4356421447877607 -#BLEU score gold: 0.21678563445709004 -#f1 measure gold: 0.42532275952238546 +# rouge score for platinum: 0.42969398167649014 +# BLEU score platinum: 0.2115697317809715 +# f1 measure platinum: 0.41915060043151015 -## msmarco transfer refinement -#rouge score for diamond: 0.4186038646735186 -#BLEU score diamond: 0.2061915039658764 -#f1 measure diamond: 0.41743805966063974 +# rouge score for gold: 0.4356421447877607 +# BLEU score gold: 0.21678563445709004 +# f1 measure gold: 0.42532275952238546 -#rouge score for platinum: 0.4180435950210718 -#BLEU score platinum: 0.20547429381774826 -#f1 measure platinum: 0.41576693292575345 +# msmarco transfer refinement +# rouge score for diamond: 0.4186038646735186 +# BLEU score diamond: 0.2061915039658764 +# f1 measure diamond: 0.41743805966063974 -#rouge score for gold: 0.4256894215931806 -#BLEU score gold: 0.21156194047222282 -#f1 measure gold: 0.42353650076526406 +# rouge score for platinum: 0.4180435950210718 +# BLEU score platinum: 0.20547429381774826 +# f1 measure platinum: 0.41576693292575345 +# rouge score for gold: 0.4256894215931806 +# BLEU score gold: 0.21156194047222282 +# f1 measure gold: 0.42353650076526406 -## AOL title transfer -#rouge score for diamond: 0.24785768028983485 -#BLEU score diamond: 0.1057627554917162 -#f1 measure diamond: 0.23602467853502 +# AOL title transfer +# rouge score for diamond: 0.24785768028983485 +# BLEU score diamond: 0.1057627554917162 +# f1 measure diamond: 0.23602467853502 -#rouge score for platinum: 0.2148899068052302 -#BLEU score platinum: 0.08699215515241193 -#f1 measure platinum: 0.2011023023056243 +# rouge score for platinum: 0.2148899068052302 +# BLEU score platinum: 0.08699215515241193 +# f1 measure platinum: 0.2011023023056243 -## AOL title Base -#rouge score for diamond: 0.17408527383170763 -#BLEU score diamond: 0.07404099482985602 -#f1 measure diamond: 0.16509779706595218 +# AOL title Base +# rouge score for diamond: 0.17408527383170763 +# BLEU score diamond: 0.07404099482985602 +# f1 measure diamond: 0.16509779706595218 -#rouge score for platinum: 0.1555684917625179 -#BLEU score platinum: 0.06250855995300625 -#f1 measure platinum: 0.14425863392143193 +# rouge score for platinum: 0.1555684917625179 +# BLEU score platinum: 0.06250855995300625 +# f1 measure platinum: 0.14425863392143193 -#rouge score for gold: 0.17282032172417106 -#BLEU score gold: 0.07026203522067144 -#f1 measure gold: 0.16097593193386475 +# rouge score for gold: 0.17282032172417106 +# BLEU score gold: 0.07026203522067144 +# f1 measure gold: 0.16097593193386475 -## AOL title URL base -#rouge score for diamond: 0.25625260529488425 -#BLEU score diamond: 0.1079382534269401 -#f1 measure diamond: 0.24463295875478222 +# AOL title URL base +# rouge score for diamond: 0.25625260529488425 +# BLEU score diamond: 0.1079382534269401 +# f1 measure diamond: 0.24463295875478222 -#rouge score for platinum: 0.2213520954721022 -#BLEU score platinum: 0.08731537824992998 -#f1 measure platinum: 0.20731018675681637 +# rouge score for platinum: 0.2213520954721022 +# BLEU score platinum: 0.08731537824992998 +# f1 measure platinum: 0.20731018675681637 -#rouge score gold: 0.2493297981807377 -#BLEU score gold: 0.0997131509862043 -#f1 measure gold: 0.23478355839064657 +# rouge score gold: 0.2493297981807377 +# BLEU score gold: 0.0997131509862043 +# f1 measure gold: 0.23478355839064657 # AOL TITLE URL Transfer -#rouge score diamond: 0.25625260529488425 -#BLEU score diamond: 0.1079382534269401 -#f1 measure diamond: 0.24463295875478222 +# rouge score diamond: 0.25625260529488425 +# BLEU score diamond: 0.1079382534269401 +# f1 measure diamond: 0.24463295875478222 -#rouge score platinum: 0.2213520954721022 -#BLEU score platinum: 0.08731537824992998 -#f1 measure platinum: 0.20731018675681637 +# rouge score platinum: 0.2213520954721022 +# BLEU score platinum: 0.08731537824992998 +# f1 measure platinum: 0.20731018675681637 -#rouge score gold: 0.2493297981807377 -#BLEU score gold: 0.0997131509862043 -#f1 measure gold: 0.23478355839064657 \ No newline at end of file +# rouge score gold: 0.2493297981807377 +# BLEU score gold: 0.0997131509862043 +# f1 measure gold: 0.23478355839064657 diff --git a/src/evl/trecw.py b/src/evl/trecw.py index f885101..8bde3d8 100644 --- a/src/evl/trecw.py +++ b/src/evl/trecw.py @@ -1,8 +1,9 @@ import os + def evaluate(in_docids, out_metrics, qrels, metric, lib, mean=False, topk=10): - #qrels can have queries that are not in in_docids (superset) - #also prediction may have queries that are not known to qrels + # qrels can have queries that are not in in_docids (superset) + # also prediction may have queries that are not known to qrels # with open('pred', 'w') as f: # f.write(f'1\tQ0\t2\t1\t20.30781\tPyserini Batch\n') # f.write(f'1\tQ0\t3\t1\t5.30781\tPyserini Batch\n') @@ -10,11 +11,11 @@ def evaluate(in_docids, out_metrics, qrels, metric, lib, mean=False, topk=10): # with open('qrel', 'w') as f: # f.write(f'1\t0\t2\t1\n') # f.write(f'3\t0\t3\t1\n')#does not exist in prediction - #"./../trec_eval.9.0.4/trec_eval" -q -m ndcg qrel pred + # "./../trec_eval.9.0.4/trec_eval" -q -m ndcg qrel pred # ndcg 1 1.0000 # ndcg all 1.0000 # - #However, no duplicate [qid, docid] can be in qrels!! + # However, no duplicate [qid, docid] can be in qrels!! print(f'Evaluating retrieved docs for {in_docids} with {metric} ...') if 'trec_eval' in lib: @@ -31,4 +32,3 @@ def evaluate(in_docids, out_metrics, qrels, metric, lib, mean=False, topk=10): # 'map', # '../trec_eval.9.0.4/trec_eval', # mean=False) - diff --git a/src/main.py b/src/main.py index 5b11306..0f94593 100644 --- a/src/main.py +++ b/src/main.py @@ -7,6 +7,7 @@ import param + def run(data_list, domain_list, output, settings): # 'qrels.train.tsv' => ,["qid","did","pid","relevancy"] # 'queries.train.tsv' => ["qid","query"] @@ -30,12 +31,12 @@ def run(data_list, domain_list, output, settings): query_qrel_doc = None if 'pair' in settings['cmd']: - print(f'Pairing queries and relevant passages for training set ...') + print('Pairing queries and relevant passages for training set ...') cat = True if 'docs' in {in_type, out_type} else False query_qrel_doc = ds.pair(datapath, f'{prep_output}/{ds.user_pairing}queries.qrels.doc{"s" if cat else ""}.ctx.{index_item_str}.train.no_dups.tsv', cat=cat) # print(f'Pairing queries and relevant passages for test set ...') - #TODO: query_qrel_doc = pair(datapath, f'{prep_output}/queries.qrels.doc.ctx.{index_item_str}.test.tsv') - #query_qrel_doc = ds.pair(datapath, f'{prep_output}/queries.qrels.doc{"s" if cat else ""}.ctx.{index_item_str}.test.tsv', cat=cat) + # TODO: query_qrel_doc = pair(datapath, f'{prep_output}/queries.qrels.doc.ctx.{index_item_str}.test.tsv') + # query_qrel_doc = ds.pair(datapath, f'{prep_output}/queries.qrels.doc{"s" if cat else ""}.ctx.{index_item_str}.test.tsv', cat=cat) query_qrel_doc.to_csv(tsv_path['train'], sep='\t', encoding='utf-8', index=False, columns=[in_type, out_type], header=False) query_qrel_doc.to_csv(tsv_path['test'], sep='\t', encoding='utf-8', index=False, columns=[in_type, out_type], header=False) @@ -49,9 +50,9 @@ def run(data_list, domain_list, output, settings): print(f"Finetuning {t5_model} for {settings['iter']} iterations and storing the checkpoints at {t5_output} ...") mt5w.finetune( tsv_path=tsv_path, - pretrained_dir=f'./../output/t5-data/pretrained_models/{t5_model.split(".")[0]}', #"gs://t5-data/pretrained_models/{"small", "base", "large", "3B", "11B"} + pretrained_dir=f'./../output/t5-data/pretrained_models/{t5_model.split(".")[0]}', # "gs://t5-data/pretrained_models/{"small", "base", "large", "3B", "11B"} steps=settings['iter'], - output=t5_output, task_name=f"{domain.replace('-', '')}_cf",#:DD Task name must match regex: ^[\w\d\.\:_]+$ + output=t5_output, task_name=f"{domain.replace('-', '')}_cf", # :DD Task name must match regex: ^[\w\d\.\:_]+$ lseq=settings[domain]['lseq'], nexamples=None, in_type=in_type, out_type=out_type, gcloud=False) @@ -64,13 +65,13 @@ def run(data_list, domain_list, output, settings): output=t5_output, lseq=settings[domain]['lseq'], gcloud=False) - if 'search' in settings['cmd']: #'bm25 ranker' + if 'search' in settings['cmd']: # 'bm25 ranker' print(f"Searching documents for query changes using {settings['ranker']} ...") - #seems for some queries there is no qrels, so they are missed for t5 prediction. - #query_originals = pd.read_csv(f'{datapath}/queries.train.tsv', sep='\t', names=['qid', 'query'], dtype={'qid': str}) - #we use the file after panda.merge that create the training set so we make sure the mapping of qids + # seems for some queries there is no qrels, so they are missed for t5 prediction. + # query_originals = pd.read_csv(f'{datapath}/queries.train.tsv', sep='\t', names=['qid', 'query'], dtype={'qid': str}) + # we use the file after panda.merge that create the training set so we make sure the mapping of qids query_originals = pd.read_csv(f'{prep_output}/{ds.user_pairing}queries.qrels.doc{"s" if "docs" in {in_type, out_type} else ""}.ctx.{index_item_str}.train.tsv', sep='\t', usecols=['qid', 'query'], dtype={'qid': str}) - if settings['large_ds']: # we can run this logic if shape of query_originals is greater than split_size + if settings['large_ds']: # we can run this logic if shape of query_originals is greater than split_size import numpy as np import glob split_size = 1000000 # need to make this dynamic based on shape of query_originals. @@ -113,8 +114,6 @@ def run(data_list, domain_list, output, settings): if 'eval' in settings['cmd']: from evl import trecw - import glob - import itertools if settings['large_ds']: import glob import itertools @@ -127,10 +126,10 @@ def run(data_list, domain_list, output, settings): with mp.Pool(settings['ncore']) as p: p.starmap(partial(trecw.evaluate, metric=settings['metric'], lib=settings['treclib']), search_results) - #merge after results + # merge after results original_metrics_results_list = list() - #original merge + # original merge for i in [file for file in os.listdir(f'{t5_output}/original') if file.endswith(f'{settings["ranker"]}.{settings["metric"]}')]: print(f'appending query and metric for original, iteration {i} ') original_metrics_results_list.append(pd.read_csv(f'{t5_output}/original/{i}', sep='\t', names=['metric_name', 'qid', 'metric'], index_col=False, dtype={'qid': str})) @@ -143,7 +142,7 @@ def run(data_list, domain_list, output, settings): metrics_results_list = list() print(f'appending query and metric split files for pred.{i}') for change in [file for file in os.listdir(f'{t5_output}/{i}') if len(file.split('.')) == 2]: - #avoid loading queries if they already exists + # avoid loading queries if they already exists if not isfile(f'{t5_output}/pred.{i}'): metrics_query_list.append(pd.read_csv(f'{t5_output}/{i}/{change}', skip_blank_lines=False, names=['query'], sep='\r\r', index_col=False, @@ -176,6 +175,7 @@ def run(data_list, domain_list, output, settings): ds.aggregate(originals, changes, t5_output, settings["large_ds"]) if 'box' in settings['cmd']: + from evl import trecw box_path = join(t5_output, f'{settings["ranker"]}.{settings["metric"]}.boxes') if not os.path.isdir(box_path): os.makedirs(box_path) gold_df = pd.read_csv(f'{t5_output}/{settings["ranker"]}.{settings["metric"]}.agg.all.tsv', sep='\t', header=0, dtype={'qid': str}) @@ -185,7 +185,7 @@ def run(data_list, domain_list, output, settings): ds.box(gold_df, qrels, box_path, box_condition) for c in box_condition.keys(): print(f'Stamping boxes for {settings["ranker"]}.{settings["metric"]} before and after refinements ...') - from evl import trecw + if not os.path.isdir(join(box_path, 'stamps')): os.makedirs(join(box_path, 'stamps')) df = pd.read_csv(f'{box_path}/{c}.tsv', sep='\t', encoding='utf-8', index_col=False, header=None, names=['qid', 'query', 'metric', 'query_', 'metric_'], dtype={'qid': str}) @@ -211,7 +211,7 @@ def run(data_list, domain_list, output, settings): for change, metric_value in changes: all.append((row[change], row[f'{change}.{ranker}.{metric}'], change)) all = sorted(all, key=lambda x: x[1], reverse=True) # if row[f'original.{ranker}.{metric}'] == 0 and all[0][1] <= 0.1: #poor perf - if row[f'original.{ranker}.{metric}'] > all[0][1] and row[f'original.{ranker}.{metric}'] <= 1: # no prediction + if row[f'original.{ranker}.{metric}'] > all[0][1] and row[f'original.{ranker}.{metric}'] <= 1: # no prediction agg_poor_perf.write(f'{row.qid}\t{row.query}\t{row[f"original.{ranker}.{metric}"]}\t{all[0][0]}\t{all[0][1]}\n') original = pd.read_csv(f'{t5_output}/{ranker}.{metric}.agg.{condition}.tsv', sep='\t', encoding="utf-8", header=0, index_col=False, names=['qid', 'query', f'{ranker}.{metric}', 'query_', f'{ranker}.{metric}_']) @@ -228,7 +228,7 @@ def run(data_list, domain_list, output, settings): with mp.Pool(settings['ncore']) as p: p.starmap(partial(ds.search_df, qids=original['qid'].values.tolist(), ranker='tct_colbert', topk=100, batch=None, ncores=settings['ncore'], index=settings[f'{domain}']["dense_index"], encoder=settings[f'{domain}']['dense_encoder']), search_list) - p.starmap(partial(trecw.evaluate, qrels=f'{datapath}/qrels.train.tsv_', metric=settings['metric'], + p.starmap(partial(trecw.evaluate, qrels=f'{datapath}/qrels.train.tsv_', metric=settings['metric'], lib=settings['treclib']), search_results) # aggregate colbert results and compare with bm25 results @@ -252,7 +252,8 @@ def run(data_list, domain_list, output, settings): agg_df.to_csv(f'{t5_output}/colbert.comparison.{condition}.{metric}.tsv', sep="\t", index=None) if 'stats' in settings['cmd']: - from stats import stats + from stats import stats + def addargs(parser): dataset = parser.add_argument_group('dataset') @@ -266,6 +267,7 @@ def addargs(parser): # python -u main.py -data ../data/raw/toy.aol-ia -domain aol-ia # python -u main.py -data ../data/raw/toy.msmarco.passage ../data/raw/toy.aol-ia -domain msmarco.passage aol-ia + if __name__ == '__main__': freeze_support() mp.set_start_method('spawn') diff --git a/src/param.py b/src/param.py index 249a82f..12b033b 100644 --- a/src/param.py +++ b/src/param.py @@ -19,7 +19,7 @@ 'topk': 100, # number of retrieved documents for a query 'metric': 'recip_rank.10', # any valid trec_eval.9.0.4 metric like map, ndcg, recip_rank, ... 'large_ds': True, - 'treclib': f'"./trec_eval.9.0.4/trec_eval{extension}"', #in non-windows, remove .exe, also for pytrec_eval, 'pytrec' + 'treclib': f'"./trec_eval.9.0.4/trec_eval{extension}"', # in non-windows, remove .exe, also for pytrec_eval, 'pytrec' 'box': {'gold': 'refined_q_metric >= original_q_metric and refined_q_metric > 0', 'platinum': 'refined_q_metric > original_q_metric', 'diamond': 'refined_q_metric > original_q_metric and refined_q_metric == 1'}, @@ -33,9 +33,9 @@ }, 'aol-ia': { 'index_item': ['title'], # ['url'], ['title', 'url'], ['title', 'url', 'text'] - 'index': '../data/raw/aol-ia/lucene-index/title/', #change based on index_item - 'dense_index': '../data/raw/aol-ia/dense-index/tct_colbert.title/', #change based on index_item - 'dense_encoder':'../data/raw/aol-ia/dense-encoder/tct_colbert.title/', #change based on index_item + 'index': '../data/raw/aol-ia/lucene-index/title/', # change based on index_item + 'dense_index': '../data/raw/aol-ia/dense-index/tct_colbert.title/', # change based on index_item + 'dense_encoder':'../data/raw/aol-ia/dense-encoder/tct_colbert.title/', # change based on index_item 'pairing': [None, 'docs', 'query'], # [context={2 scenarios, one with userID and one without userID). input={'userid','query','doc(s)'} output={'query','doc(s)'} 'lseq': {"inputs": 32, "targets": 256}, # query length and doc length for t5 model, 'filter': {'minql': 1, 'mindocl': 10} # [min query length, min doc length], after merge queries with relevant 'index_item', if |query| <= minql drop the row, if |'index_item'| < mindocl, drop row diff --git a/src/qs.py b/src/qs.py index 646f772..d3812d0 100644 --- a/src/qs.py +++ b/src/qs.py @@ -1,18 +1,19 @@ import os, sys, time, random, string, json, numpy, glob, pandas as pd from collections import OrderedDict +from cair.main.recommender import run sys.path.extend(["./cair", "./cair/main"]) numpy.random.seed(7881) -from cair.main.recommender import run - ReQue = { 'input': '../output', 'output': './output' } + def generate_random_string(n=12): return ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(n)) + def tsv2json(df, output, topn=1): if not os.path.isdir(output): os.makedirs(output, exist_ok=True) @@ -41,7 +42,7 @@ def tsv2json(df, output, topn=1): qcol = 'query_' if (qcol not in df.columns) or pd.isna(row[qcol]): break - #check if the query string is a dict (for weighted expanders such as onfields) + # check if the query string is a dict (for weighted expanders such as onfields) try: row[qcol] = ' '.join(eval(row[qcol]).keys()) except: @@ -60,7 +61,7 @@ def tsv2json(df, output, topn=1): ('session_id', generate_random_string()), ('query', session_queries) ]) - print(str(row.qid) + ": " + qObj['text'] + '--' + str(i)+ '--> ' + q_Obj['text']); + print(str(row.qid) + ": " + qObj['text'] + '--' + str(i) + '--> ' + q_Obj['text']) fds.write(json.dumps(obj) + '\n') @@ -72,11 +73,12 @@ def tsv2json(df, output, topn=1): else: ftest.write(json.dumps(obj) + '\n') + def call_cair_run(data_dir, epochs): - dataset_name = 'msmarco'#it is hard code in the library. Do not touch! :)) + dataset_name = 'msmarco' # it is hard code in the library. Do not touch! :)) baseline_path = 'cair/' - cli_cmd = '' #'python ' + cli_cmd = '' # 'python ' cli_cmd += '{}main/recommender.py '.format(baseline_path) cli_cmd += '--dataset_name {} '.format(dataset_name) cli_cmd += '--data_dir {} '.format(data_dir) @@ -92,8 +94,8 @@ def call_cair_run(data_dir, epochs): cli_cmd += '--embed_dir {}data/fasttext/ '.format(baseline_path) cli_cmd += '--embedding_file crawl-300d-2M-subword.vec ' - #the models config are in QueStion\qs\cair\neuroir\hyparam.py - #only hredqs can be unidirectional! all other models are in bidirectional mode + # the models config are in QueStion\qs\cair\neuroir\hyparam.py + # only hredqs can be unidirectional! all other models are in bidirectional mode df = pd.DataFrame(columns=['model', 'epoch', 'rouge', 'bleu', 'bleu_list', 'exact_match', 'f1', 'elapsed_time']) for baseline in ['hredqs']: for epoch in epochs: @@ -105,13 +107,14 @@ def call_cair_run(data_dir, epochs): 'epoch': epoch, 'rouge': test_resutls['rouge'], 'bleu': test_resutls['bleu'], - 'bleu_list': ','.join([str(b) for b in test_resutls['bleu_list']]), + 'bleu_list': ','.join([str(b) for b in test_resutls['bleu_list']]), 'exact_match': test_resutls['em'], 'f1': test_resutls['f1'], 'elapsed_time': elapsed_time}, ignore_index=True) df.to_csv('{}/results.csv'.format(data_dir, baseline), index=False) + def aggregate(path): fs = glob.glob(path + "/**/results.csv", recursive=True) print(fs) @@ -130,10 +133,12 @@ def aggregate(path): df.to_csv(path + "agg_results.csv", index=False) -# # {CUDA_VISIBLE_DEVICES={zero-base gpu indexes, comma seprated reverse to the system}} python -u main.py {topn=[1,2,...]} {topics=[robust04, gov2, clueweb09b, clueweb12b13, all]} 2>&1 | tee log & -# # CUDA_VISIBLE_DEVICES=0,1 python -u main.py 1 robust04 2>&1 | tee robust04.topn1.log & -if __name__=='__main__': +# {CUDA_VISIBLE_DEVICES={zero-base gpu indexes, comma seprated reverse to the system}} python -u main.py {topn=[1,2,...]} {topics=[robust04, gov2, clueweb09b, clueweb12b13, all]} 2>&1 | tee log & +# CUDA_VISIBLE_DEVICES=0,1 python -u main.py 1 robust04 2>&1 | tee robust04.topn1.log & + + +if __name__ == '__main__': topn = int(sys.argv[1]) corpora = sys.argv[2:] if not corpora: @@ -153,7 +158,7 @@ def aggregate(path): tsv2json(df, f'{ReQue["input"]}/{corpus}/t5.base.gc.docs.query.title.url/boxes/qs-gold/', topn) data_dir = f'{ReQue["input"]}/{corpus}/t5.base.gc.docs.query.title.url/boxes/qs-gold' print('INFO: MAIN: Calling cair for {}'.format(data_dir)) - #call_cair_run(data_dir, epochs=[e for e in range(1, 10)] + [e * 10 for e in range(1, 21)]) + # call_cair_run(data_dir, epochs=[e for e in range(1, 10)] + [e * 10 for e in range(1, 21)]) call_cair_run(data_dir, epochs=[10]) - aggregate(ReQue['output'] + '/') \ No newline at end of file + aggregate(ReQue['output'] + '/') diff --git a/src/stats/get_stats.py b/src/stats/get_stats.py index 3ad6110..b2acf95 100644 --- a/src/stats/get_stats.py +++ b/src/stats/get_stats.py @@ -80,7 +80,6 @@ def count_total_queries(self): total_queries = [original_queries, refuned_queries] return total_queries - def combined_stats(self): i = 1 row_count = self.num_rows - 1 diff --git a/src/stats/stats.py b/src/stats/stats.py index d790b08..4dcff55 100644 --- a/src/stats/stats.py +++ b/src/stats/stats.py @@ -4,11 +4,13 @@ from stats.get_stats import get_stats datasets = ['diamond', 'platinum', 'gold'] + + def plot_stats(box_path): for ds in datasets: map_ds = pd.read_csv(f'{box_path}/{ds}.tsv', sep='\t', encoding='utf-8', names=['qid', 'i', 'i_map', 't', 't_map']) map_ds.sort_values(by='i_map', inplace=True) - stats = map_ds.groupby(np.arange(len(map_ds))//(len(map_ds)/10)).mean() + stats = map_ds.groupby(np.arange(len(map_ds)) // (len(map_ds) / 10)).mean() X = [x for x in range(1, 11)] original_mean = stats['i_map'] changes_mean = stats['t_map'] @@ -23,6 +25,7 @@ def plot_stats(box_path): plt.savefig(f'{box_path}/{ds}.jpg') plt.clf() + plot_stats('../output/toy.msmarco.passage/t5.small.local.docs.query.passage/bm25.map.boxes') file_path = '../output/toy.msmarco.passage/t5.small.local.docs.query.passage/bm25.recip_rank.10.agg.gold.tsv' @@ -52,5 +55,3 @@ def plot_stats(box_path): print(f'original_queries_score_stats: {stats["original_queries_score_stats"]}') print(f'delta_lengths_stats: {stats["delta_lengths_stats"]}') print(f'delta_scores_stats: {stats["delta_scores_stats"]}') - - diff --git a/testing_reqs.txt b/testing_reqs.txt new file mode 100644 index 0000000..1ae3ec6 --- /dev/null +++ b/testing_reqs.txt @@ -0,0 +1 @@ +flake8 \ No newline at end of file