diff --git a/environment.yml b/environment.yml index f59e4d4..6c33716 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: repair +name: Repair channels: - conda-forge - pytorch diff --git a/output/toy.msmarco.passage/t5.small.local.docs.query.passage/refiner_param.py b/output/toy.msmarco.passage/t5.small.local.docs.query.passage/refiner_param.py new file mode 100644 index 0000000..0458df5 --- /dev/null +++ b/output/toy.msmarco.passage/t5.small.local.docs.query.passage/refiner_param.py @@ -0,0 +1,167 @@ +import random, os, numpy as np, platform +import torch + +random.seed(0) +torch.manual_seed(0) +torch.cuda.manual_seed_all(0) +np.random.seed(0) +extension = '.exe' if platform.system() == 'Windows' else "" +os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + +settings = { + 'query_refinement': True, + 'cmd': ['search', 'eval','agg', 'box'], # steps of pipeline, ['pair', 'finetune', 'predict', 'search', 'eval','agg', 'box','dense_retrieve', 'stats] + 'ncore': 1, + 't5model': 'small.local', # 'base.gc' on google cloud tpu, 'small.local' on local machine + 'iter': 5, # number of finetuning iteration for t5 + 'nchanges': 5, # number of changes to a query + 'ranker': 'bm25', # 'qld', 'bm25', 'tct_colbert' + 'batch': None, # search per batch of queries for IR search using pyserini, if None, search per query + 'topk': 100, # number of retrieved documents for a query + 'metric': 'map', # any valid trec_eval.9.0.4 metric like map, ndcg, recip_rank, ... + 'large_ds': False, + 'treclib': f'"./trec_eval.9.0.4/trec_eval{extension}"', # in non-windows, remove .exe, also for pytrec_eval, 'pytrec' + 'box': {'gold': 'refined_q_metric >= original_q_metric and refined_q_metric > 0', + 'platinum': 'refined_q_metric > original_q_metric', + 'diamond': 'refined_q_metric > original_q_metric and refined_q_metric == 1'} +} + +corpora = { + 'msmarco.passage': { + 'index_item': ['passage'], + 'index': '../data/raw/msmarco.passage/lucene-index.msmarco-v1-passage.20220131.9ea315/', + 'dense_encoder': 'castorini/tct_colbert-msmarco', + 'dense_index': 'msmarco-passage-tct_colbert-hnsw', + 'extcorpus': 'orcas', + 'pairing': [None, 'docs', 'query'], # [context={msmarco does not have userinfo}, input={query, doc, doc(s)}, output={query, doc, doc(s)}], s means concat of docs + 'lseq': {"inputs": 32, "targets": 256}, # query length and doc length for t5 model, + }, + 'aol-ia': { + 'index_item': ['title'], # ['url'], ['title', 'url'], ['title', 'url', 'text'] + 'index': '../data/raw/aol-ia/lucene-index/title/', # change based on index_item + 'dense_index': '../data/raw/aol-ia/dense-index/tct_colbert.title/', # change based on index_item + 'dense_encoder':'../data/raw/aol-ia/dense-encoder/tct_colbert.title/', # change based on index_item + 'pairing': [None, 'docs', 'query'], # [context={2 scenarios, one with userID and one without userID). input={'userid','query','doc(s)'} output={'query','doc(s)'} + 'extcorpus': 'msmarco.passage', + 'lseq': {"inputs": 32, "targets": 256}, # query length and doc length for t5 model, + 'filter': {'minql': 1, 'mindocl': 10} # [min query length, min doc length], after merge queries with relevant 'index_item', if |query| <= minql drop the row, if |'index_item'| < mindocl, drop row + }, + 'robust04': { + 'index': '../data/raw/robust04/lucene-index.robust04.pos+docvectors+rawdocs', + 'dense_index': '../data/raw/robust04/faiss_index_robust04', + 'encoded': '../data/raw/robust04/encoded_robust04', + 'size': 528155, + 'topics': '../data/raw/robust04/topics.robust04.txt', + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 2.25, # OnFields + 'w_a': 1, # OnFields + 'tokens': 148000000, + 'qrels': '../data/raw/robust04/qrels.robust04.txt', + 'extcorpus': 'gov2', # AdaptOnFields + 'pairing': [None, None, None], + 'index_item': [], + }, + 'gov2': { + 'index': '../data/raw/gov2/lucene-index.gov2.pos+docvectors+rawdocs', + 'size': 25000000, + 'topics': '../data/raw/gov2/topics.terabyte0{}.txt', # {} is a placeholder for subtopics in main.py -> run() + 'trec': ['4.701-750', '5.751-800', '6.801-850'], + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 4, # OnFields + 'w_a': 0.25, # OnFields + 'tokens': 17000000000, + 'qrels': '../data/raw/gov2/qrels.terabyte0{}.txt', # {} is a placeholder for subtopics in main.py -> run() + 'extcorpus': 'robust04', # AdaptOnFields + 'pairing': [None, None, None], + 'index_item': [], + }, + 'clueweb09b': { + 'index': '../data/raw/clueweb09b/lucene-index.cw09b.pos+docvectors+rawdocs', + 'size': 50000000, + 'topics': '../data/raw/clueweb09b/topics.web.{}.txt', # {} is a placeholder for subtopics in main.py -> run() + 'trec': ['1-50', '51-100', '101-150', '151-200'], + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 1, # OnFields + 'w_a': 0, # OnFields + 'tokens': 31000000000, + 'qrels': '../data/raw/clueweb09b/qrels.web.{}.txt', # {} is a placeholder for subtopics in main.py -> run() + 'extcorpus': 'gov2', # AdaptOnFields + 'pairing': [None, None, None], + 'index_item': [], + }, + 'clueweb12b13': { + 'index': '../data/raw/clueweb12b13/lucene-index.cw12b13.pos+docvectors+rawdocs', + 'size': 50000000, + 'topics': '../data/raw/clueweb12b13/topics.web.{}.txt', # {} is a placeholder for subtopics in main.py -> run() + 'trec': ['201-250', '251-300'], + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 4, # OnFields + 'w_a': 0, # OnFields + 'tokens': 31000000000, + 'qrels': '../data/raw/clueweb12b13/qrels.web.{}.txt', # {} is a placeholder for subtopics in main.py -> run() + 'extcorpus': 'gov2', # AdaptOnFields + 'pairing': [None, None, None], + 'index_item': [], + }, + 'antique': { + 'index': '../data/raw/antique/lucene-index-antique', + 'size': 403000, + 'topics': '../data/raw/antique/topics.antique.txt', + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 2.25, # OnFields # to be tuned + 'w_a': 1, # OnFields # to be tuned + 'tokens': 16000000, + 'qrels': '../data/raw/antique/qrels.antique.txt', + 'extcorpus': 'gov2', # AdaptOnFields + 'pairing': [None, None, None], + 'index_item': [], + }, + 'trec09mq': { + 'index': '/data/raw/clueweb09b/lucene-index.cw09b.pos+docvectors+rawdocs', + 'size': 50000000, + # 'topics': '../ds/trec2009mq/prep/09.mq.topics.20001-60000.prep.tsv', + 'topics': '../data/raw/trec09mq/09.mq.topics.20001-60000.prep', + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 2.25, # OnFields # to be tuned + 'w_a': 1, # OnFields # to be tuned + 'tokens': 16000000, + 'qrels': '../data/raw/trec09mq/prels.20001-60000.prep', + 'extcorpus': 'gov2', # AdaptOnFields + 'pairing': [None, None, None], + 'index_item': [], + }, + 'dbpedia': { + 'index': '../data/raw/dbpedia/lucene-index-dbpedia-collection', + 'size': 4632359, + 'topics': '../data/raw/dbpedia/topics.dbpedia.txt', + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 1, # OnFields # to be tuned + 'w_a': 1, # OnFields # to be tuned + 'tokens': 200000000, + 'qrels': '../data/raw/dbpedia/qrels.dbpedia.txt', + 'extcorpus': 'gov2', # AdaptOnFields + 'pairing': [None, None, None], + 'index_item': [], + }, + 'orcas': { + 'index': '../data/raw/orcas/lucene-index.msmarco-v1-doc.20220131.9ea315', + 'size': 50000000, + # 'topics': '../data/raw/trec2009mq/prep/09.mq.topics.20001-60000.prep.tsv', + 'topics': '../data/raw/orcas/preprocess/orcas-I-2M_topics.prep', + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 2.25, # OnFields # to be tuned + 'w_a': 1, # OnFields # to be tuned + 'tokens': 16000000, + 'qrels': '../data/raw/orcas/preprocess/orcas-doctrain-qrels.prep', + 'extcorpus': 'gov2', # AdaptOnFields + 'pairing': [None, None, None], + 'index_item': [], + }, +} + +# Only for sparse indexing +anserini = { + 'path': '../anserini/', + 'trec_eval': '../anserini/eval/trec_eval.9.0.4/trec_eval' +} + diff --git a/src/dal/aol.py b/src/dal/aol.py index 0d9f1e7..9c5d9aa 100644 --- a/src/dal/aol.py +++ b/src/dal/aol.py @@ -3,14 +3,15 @@ from shutil import copyfile from ftfy import fix_text from pyserini.search.lucene import LuceneSearcher -from dal.ds import Dataset +from ds import Dataset +from query import Query tqdm.pandas() class Aol(Dataset): - def __init__(self, settings, homedir, ncore): - try: super(Aol, self).__init__(settings=settings) + def __init__(self, settings, domain, homedir, ncore): + try: super(Aol, self).__init__(settings=settings, domain=domain) except: self._build_index(homedir, Dataset.settings['index_item'], Dataset.settings['index'], ncore) @classmethod @@ -22,7 +23,7 @@ def _build_index(cls, homedir, index_item, indexdir, ncore): index_item_str = '.'.join(index_item) if not os.path.isdir(f'{indexdir}/{cls.user_pairing}{index_item_str}'): os.makedirs(f'{indexdir}/{cls.user_pairing}{index_item_str}') import ir_datasets - from cmn.lucenex import lucenex + from src.cmn.lucenex import lucenex print(f"Setting up aol corpus using ir-datasets at {homedir}...") aolia = ir_datasets.load("aol-ia") @@ -48,6 +49,17 @@ def _build_index(cls, homedir, index_item, indexdir, ncore): # for d in os.listdir(homedir): # if not (d.find('aol-ia') > -1) and os.path.isdir(f'./../data/raw/{d}'): shutil.rmtree(f'./../data/raw/{d}') + def read_queries(cls, input, domain): + queries = pd.read_csv(f'{input}/queries.train.tsv', encoding='UTF-8', sep='\t', index_col=False, names=['qid', 'query'], converters={'query': str.lower}, header=None) + # the column order in the file is [qid, uid, did, uid]!!!! STUPID!! + qrels = pd.read_csv(f'{input}/qrels.train.tsv', encoding='UTF-8', sep='\t', index_col=False, names=['qid', 'uid', 'did', 'rel'], header=None) + qrels.to_csv(f'{input}/qrels.train.tsv_', index=False, sep='\t', header=False) + # docid is a hash of the URL. qid is the a hash of the *noramlised query* ==> two uid may have same qid then, same docid. + qrels.drop_duplicates(subset=['qid', 'did', 'uid'], inplace=True) + queries_qrels = pd.merge(queries, qrels, on='qid', how='inner', copy=False) + queries_qrels = queries_qrels.sort_values(by='qid') + cls.create_query_objects(queries_qrels, ['qid', 'uid', 'did', 'rel'], domain) + @classmethod def create_jsonl(cls, aolia, index_item, output): """ @@ -69,19 +81,21 @@ def create_jsonl(cls, aolia, index_item, output): output_jsonl_file.close() @classmethod - def pair(cls, input, output, cat=True): + def pair(cls, queries, output, cat=True): + # TODO: change the code in a way to use read_queries queries = pd.read_csv(f'{input}/queries.train.tsv', encoding='UTF-8', sep='\t', index_col=False, names=['qid', 'query'], converters={'query': str.lower}, header=None) # the column order in the file is [qid, uid, did, uid]!!!! STUPID!! qrels = pd.read_csv(f'{input}/qrels.train.tsv', encoding='UTF-8', sep='\t', index_col=False, names=['qid', 'uid', 'did', 'rel'], header=None) # docid is a hash of the URL. qid is the a hash of the *noramlised query* ==> two uid may have same qid then, same docid. qrels.drop_duplicates(subset=['qid', 'did', 'uid'], inplace=True) queries_qrels = pd.merge(queries, qrels, on='qid', how='inner', copy=False) + doccol = 'docs' if cat else 'doc' del queries # in the event of user oriented pairing, we simply concat qid and uid if cls.user_pairing: queries_qrels['qid'] = queries_qrels['qid'].astype(str) + "_" + queries_qrels['uid'].astype(str) queries_qrels = queries_qrels.astype('category') - queries_qrels[doccol] = queries_qrels['did'].progress_apply(cls._txt) + queries_qrels[doccol] = queries_qrels['did'].apply(cls._txt) # queries_qrels.drop_duplicates(subset=['qid', 'did','pid'], inplace=True) # two users with same click for same query if not cls.user_pairing: queries_qrels['uid'] = -1 @@ -111,3 +125,4 @@ def pair(cls, input, output, cat=True): qrels_splits.to_csv(f'../output/aol-ia/{cls.user_pairing}t5.base.gc.docs.query.{index_item_str}/qrels/qrels.splits.{_}.tsv_', sep='\t', encoding='utf-8', index=False, header=False, columns=['qid', 'uid', 'did', 'rel']) return queries_qrels + pass diff --git a/src/dal/ds.py b/src/dal/ds.py index ad7b6f2..6839cea 100644 --- a/src/dal/ds.py +++ b/src/dal/ds.py @@ -1,16 +1,17 @@ import json, pandas as pd from tqdm import tqdm from os.path import isfile,join - +from src.dal.query import Query from pyserini.search.lucene import LuceneSearcher from pyserini.search.faiss import FaissSearcher, TctColBertQueryEncoder class Dataset(object): + queries = [] searcher = None settings = None - def __init__(self, settings): + def __init__(self, settings, domain): Dataset.settings = settings # https://github.com/castorini/pyserini/blob/master/docs/prebuilt-indexes.md # searcher = LuceneSearcher.from_prebuilt_index('msmarco-v1-passage') @@ -19,19 +20,67 @@ def __init__(self, settings): Dataset.user_pairing = "user/" if "user" in settings["pairing"] else "" index_item_str = '.'.join(settings["index_item"]) if self.__class__.__name__ != 'MsMarcoPsg' else "" Dataset.searcher = LuceneSearcher(f'{Dataset.settings["index"]}{self.user_pairing}{index_item_str}') + Dataset.domain = domain if not Dataset.searcher: raise ValueError(f'Lucene searcher cannot find/build index at {Dataset.settings["index"]}!') + @classmethod + def read_queries(cls, input, domain): + is_tag_file = False + q, qid = '', '' + queries = pd.DataFrame(columns=['qid']) + with open(f'{input}/topics.{cls.domain}.txt', 'r', encoding='UTF-8') as Qfile: + for line in Qfile: + if '' in line and not is_tag_file: is_tag_file = True + if '' in line: + qid = int(line[line.index(':') + 1:]) + elif line[:7] == '': + q = line[8:].strip() + if not q: q = next(Qfile).strip() + elif '<topic' in line: + s = line.index('\"') + 1 + e = line.index('\"', s + 1) + qid = int(line[s:e]) + elif line[2:9] == '<query>': + q = line[9:-9] + elif len(line.split('\t')) >= 2 and not is_tag_file: + qid = line.split('\t')[0].rstrip() + q = line.split('\t')[1].rstrip() + if q != '' and qid != '': + new_line = {'qid': qid, 'query': q} + queries = pd.concat([queries, pd.DataFrame([new_line])], ignore_index=True) + q, qid = '', '' + infile = f'{input}/qrels.{cls.domain}.txt' + with open(infile, 'r') as file: separator = '\t' if '\t' in file.readline() else '\s' + qrels = pd.read_csv(infile, sep=separator, index_col=False, names=['qid', '0', 'did', 'relevancy'], header=None, engine='python') + qrels.drop_duplicates(subset=['qid', 'did'], inplace=True) # qrels have duplicates!! + qrels.to_csv(f'{input}/qrels.train.tsv_', index=False, sep='\t', header=False) + queries_qrels = pd.merge(queries, qrels, on='qid', how='left', copy=False) + queries_qrels = queries_qrels.sort_values(by='qid') + cls.create_query_objects(queries_qrels, ['qid', 'did', 'relevancy'], domain) + + @classmethod + def pair(cls, input, output, index_item, cat=True): pass + @classmethod def _txt(cls, pid): # The``docid`` is overloaded: if it is of type ``str``, it is treated as an external collection ``docid``; # if it is of type ``int``, it is treated as an internal Lucene``docid``. # stupid!! - try:return json.loads(cls.searcher.doc(str(pid)).raw())['contents'].lower() + try: return json.loads(cls.searcher.doc(str(pid)).raw())['contents'].lower() except AttributeError: return '' # if Dataset.searcher.doc(str(pid)) is None except Exception as e: raise e @classmethod - def pair(cls, input, output, index_item, cat=True): pass + def create_query_objects(cls, queries_qrels, qrel_col, domain): + qid = "" + query = None + for i, row in queries_qrels.iterrows(): + if qid != row['qid']: + if query: cls.queries.append(query) + qid = row['qid'] + query = Query(domain=domain, qid=qid, q=row['query']) + query.docs.update({col: str(row[col]) for col in qrel_col}) + if query: cls.queries.append(query) # gpu-based t5 generate the predictions in b'' format!!! @classmethod @@ -41,26 +90,24 @@ def clean(cls, tf_txt): return tf_txt.replace('b\'', '').replace('\'', '').replace('b\"', '').replace('\"', '') @classmethod - def search(cls, in_query, out_docids, qids, ranker='bm25', topk=100, batch=None, ncores=1, index=None): - print(f'Searching docs for {in_query} ...') + def search(cls, in_query:str, out_docids:str, qids:list, ranker='bm25', topk=100, batch=None, ncores=1, index=None): + ansi_reset = "\033[0m" + print(f'Searching docs for {hex_to_ansi("#3498DB")}{in_query} {ansi_reset}and writing results in {hex_to_ansi("#F1C40F")}{out_docids}{ansi_reset} ...') # https://github.com/google-research/text-to-text-transfer-transformer/issues/322 # with open(in_query, 'r', encoding='utf-8') as f: [to_docids(l) for l in f] - queries = pd.read_csv(in_query, names=['query'], sep='\r\r', skip_blank_lines=False, engine='python') # a query might be empty str (output of t5)!! + if (in_query.split('/')[-1]).split('.')[0] == 'refiner': queries = pd.read_csv(in_query, names=['query'], sep='\t', usecols=[1], skip_blank_lines=False, engine='python') + else: queries = pd.read_csv(in_query, names=['query'], sep='\r\r', skip_blank_lines=False, engine='python') # a query might be empty str (output of t5)!! assert len(queries) == len(qids) cls.search_df(queries, out_docids, qids, ranker=ranker, topk=topk, batch=batch, ncores=ncores, index=index) @classmethod - def search_df(cls, queries, out_docids, qids, ranker='bm25', topk=100, batch=None, ncores=1, index=None,encoder=None): - + def search_df(cls, queries, out_docids, qids, ranker='bm25', topk=100, batch=None, ncores=1, index=None, encoder=None): if not cls.searcher: if ranker == 'tct_colbert': cls.encoder = TctColBertQueryEncoder(encoder) - if 'msmarco.passage' in out_docids.split('/'): - cls.searcher = FaissSearcher.from_prebuilt_index(index, cls.encoder) - else: - cls.searcher = FaissSearcher(index, cls.encoder) - else: - cls.searcher = LuceneSearcher(index) + if 'msmarco.passage' in out_docids.split('/'): cls.searcher = FaissSearcher.from_prebuilt_index(index, cls.encoder) + else: cls.searcher = FaissSearcher(index, cls.encoder) + else: cls.searcher = LuceneSearcher(index) if ranker == 'bm25': cls.searcher.set_bm25(0.82, 0.68) if ranker == 'qld': cls.searcher.set_qld() @@ -86,9 +133,10 @@ def _docids(row): print(f'unique docids fetched less than {topk}') else: hits = cls.searcher.search(row.query, k=topk, remove_dups=True) - for i, h in enumerate(hits): o.write(f'{qids[row.name]}\tQ0\t{h.docid:7}\t{i + 1:2}\t{h.score:.5f}\tPyserini\n') + for i, h in enumerate(hits): o.write( + f'{qids[row.name]}\tQ0\t{h.docid:7}\t{i + 1:2}\t{h.score:.5f}\tPyserini\n') - queries.progress_apply(_docids, axis=1) + queries.apply(_docids, axis=1) @classmethod def aggregate(cls, original, changes, output, is_large_ds=False): @@ -96,12 +144,13 @@ def aggregate(cls, original, changes, output, is_large_ds=False): metric = '.'.join(changes[0][1].split('.')[3:]) # e.g., pred.0-1004000.bm25.success.10 => success.10 for change, metric_value in changes: - pred = pd.read_csv(join(output, change), sep='\r\r', skip_blank_lines=False, names=[change], converters={change: cls.clean}, engine='python', index_col=False, header=None) + if 'refiner.' in change: pred = pd.read_csv(join(output, change), sep='\t', usecols=[1], skip_blank_lines=False, names=[change], converters={change: cls.clean}, engine='python', index_col=False, header=None) + else: pred = pd.read_csv(join(output, change), sep='\r\r', skip_blank_lines=False, names=[change], converters={change: cls.clean}, engine='python', index_col=False, header=None) assert len(original['qid']) == len(pred[change]) if is_large_ds: pred_metric_values = pd.read_csv(join(output, metric_value), sep='\t', usecols=[1, 2], names=['qid', f'{change}.{ranker}.{metric}'], index_col=False, dtype={'qid': str}) else: - pred_metric_values = pd.read_csv(join(output, metric_value), sep='\t', usecols=[1, 2], names=['qid', f'{change}.{ranker}.{metric}'], index_col=False,skipfooter=1, dtype={'qid': str}) + pred_metric_values = pd.read_csv(join(output, metric_value), sep='\t', usecols=[1, 2], names=['qid', f'{change}.{ranker}.{metric}'], index_col=False,skipfooter=1, dtype={'qid': str}, engine='python') original[change] = pred # to know the actual change original = original.merge(pred_metric_values, how='left', on='qid') # to know the metric value of the change original[f'{change}.{ranker}.{metric}'].fillna(0, inplace=True) @@ -115,14 +164,14 @@ def aggregate(cls, original, changes, output, is_large_ds=False): agg_gold.write(f'qid\torder\tquery\t{ranker}.{metric}\n') agg_all.write(f'qid\torder\tquery\t{ranker}.{metric}\n') for index, row in tqdm(original.iterrows(), total=original.shape[0]): - agg_gold.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.{ranker}.{metric}"]}\n') - agg_all.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.{ranker}.{metric}"]}\n') + agg_gold.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.queries.{ranker}.{metric}"]}\n') + agg_all.write(f'{row.qid}\t-1\t{row.query}\t{row[f"original.queries.{ranker}.{metric}"]}\n') all = list() for change, metric_value in changes: all.append((row[change], row[f'{change}.{ranker}.{metric}'], change)) all = sorted(all, key=lambda x: x[1], reverse=True) for i, (query, metric_value, change) in enumerate(all): agg_all.write(f'{row.qid}\t{change}\t{query}\t{metric_value}\n') - if metric_value > 0 and metric_value >= row[f'original.{ranker}.{metric}']: agg_gold.write(f'{row.qid}\t{change}\t{query}\t{metric_value}\n') + if metric_value > 0 and metric_value >= row[f'original.queries.{ranker}.{metric}']: agg_gold.write(f'{row.qid}\t{change}\t{query}\t{metric_value}\n') @classmethod def box(cls, input, qrels, output, checks): @@ -149,6 +198,14 @@ def box(cls, input, qrels, output, checks): # df.drop_duplicates(subset=['qid'], inplace=True) del ds df.to_csv(f'{output}/{c}.tsv', sep='\t', encoding='utf-8', index=False, header=False) - print(f'{c} has {df.shape[0]} queries') - qrels = df.merge(qrels, on='qid', how='inner') - qrels.to_csv(f'{output}/{c}.qrels.tsv', sep='\t', encoding='utf-8', index=False, header=False, columns=['qid', 'did', 'pid', 'rel']) + print(f'{c} has {df.shape[0]} queries\n') + df = df.merge(qrels, on='qid', how='inner') + df.to_csv(f'{output}/{c}.qrels.tsv', sep='\t', encoding='utf-8', index=False, header=False, columns=list(qrels.columns)) + + +def hex_to_ansi(hex_color_code): + hex_color_code = hex_color_code.lstrip('#') + red = int(hex_color_code[0:2], 16) + green = int(hex_color_code[2:4], 16) + blue = int(hex_color_code[4:6], 16) + return f'\033[38;2;{red};{green};{blue}m' diff --git a/src/dal/msmarco.py b/src/dal/msmarco.py index c6ac671..d33610a 100644 --- a/src/dal/msmarco.py +++ b/src/dal/msmarco.py @@ -1,24 +1,33 @@ -import os -from os.path import isfile,join import pandas as pd from tqdm import tqdm -from dal.ds import Dataset +from src.dal.ds import Dataset tqdm.pandas() class MsMarcoPsg(Dataset): - def __init__(self, settings): super(MsMarcoPsg, self).__init__(settings=settings) + def __init__(self, settings, domain): super(MsMarcoPsg, self).__init__(settings=settings, domain=domain) @classmethod - def pair(cls, input, output, cat=True): + def read_queries(cls, input, domain): queries = pd.read_csv(f'{input}/queries.train.tsv', sep='\t', index_col=False, names=['qid', 'query'], converters={'query': str.lower}, header=None) qrels = pd.read_csv(f'{input}/qrels.train.tsv', sep='\t', index_col=False, names=['qid', 'did', 'pid', 'relevancy'], header=None) qrels.drop_duplicates(subset=['qid', 'pid'], inplace=True) # qrels have duplicates!! qrels.to_csv(f'{input}/qrels.train.tsv_', index=False, sep='\t', header=False) # trec_eval.9.0.4 does not accept duplicate rows!! queries_qrels = pd.merge(queries, qrels, on='qid', how='inner', copy=False) + queries_qrels = queries_qrels.sort_values(by='qid') + cls.create_query_objects(queries_qrels, ['qid', 'did', 'pid', 'relevancy'], domain) + + @classmethod + def pair(cls, output, cat=True): + # TODO: change the code in a way to use read_queries + queries = pd.read_csv(f'{input}/queries.train.tsv', sep='\t', index_col=False, names=['qid', 'query'],converters={'query': str.lower}, header=None) + qrels = pd.read_csv(f'{input}/qrels.train.tsv', sep='\t', index_col=False, names=['qid', 'did', 'pid', 'relevancy'], header=None) + qrels.drop_duplicates(subset=['qid', 'pid'], inplace=True) # qrels have duplicates!! + qrels.to_csv(f'{input}/qrels.train.tsv_', index=False, sep='\t', header=False) # trec_eval.9.0.4 does not accept duplicate rows!! + queries_qrels = pd.merge(queries, qrels, on='qid', how='inner', copy=False) doccol = 'docs' if cat else 'doc' - queries_qrels[doccol] = queries_qrels['pid'].progress_apply(cls._txt) # 100%|██████████| 532761/532761 [00:32<00:00, 16448.77it/s] + queries_qrels[doccol] = queries_qrels['pid'].apply(cls._txt) # 100%|██████████| 532761/532761 [00:32<00:00, 16448.77it/s] queries_qrels['ctx'] = '' if cat: queries_qrels = queries_qrels.groupby(['qid', 'query'], as_index=False, observed=True).agg({'did': list, 'pid': list, doccol: ' '.join}) queries_qrels.to_csv(output, sep='\t', encoding='utf-8', index=False) diff --git a/src/dal/query.py b/src/dal/query.py new file mode 100644 index 0000000..b5f8533 --- /dev/null +++ b/src/dal/query.py @@ -0,0 +1,44 @@ +from tqdm import tqdm +import pandas as pd + +class Query: + """ + Query Class + + Represents a query with associated attributes and features. + + Attributes: + qid (int): The query identifier. + q (str): The query text. + docs (list): A list of tuples containing document information. + Each tuple includes docid and relevancy, and additional information + related to documents can be added in between. + q_ (list): A list of tuples containing semantic similarity refined query, score and refiner's name. + user_id (str, optional): The user identifier associated with the query. + time (str, optional): The time of the query. + location (str, optional): The location associated with the query. + + Args: + qid (str): The query identifier. + q (str): The query text. + args (dict, optional): Additional features and attributes associated with the query, + including 'id' for user identifier, 'time' for time information, and 'location' + for location information. + + Example Usage: + # Creating a Query object + query = Query(qid='Q123', q='Sample query text', args={'id': 'U456', 'time': '2023-10-31'}) + + """ + def __init__(self, domain, qid, q, args=None): + self.domain = domain + self.qid = qid + self.q = q + self.docs = dict() + self.q_ = dict() + + # Features + if args: + if args['id']: self.user_id = args['id'] + if args['time']: self.time = args['time'] + if args['location']: self.time = args['location'] diff --git a/src/main.py b/src/main.py index 0f94593..62205e1 100644 --- a/src/main.py +++ b/src/main.py @@ -2,13 +2,15 @@ from functools import partial from multiprocessing import freeze_support from os import listdir -from os.path import isfile, join +from os.path import isfile, join, exists from shutil import copyfile import param +from src.refinement import refiner_factory as rf +from refinement.refiners.abstractqrefiner import AbstractQRefiner -def run(data_list, domain_list, output, settings): +def run(data_list, domain_list, output, corpora, settings): # 'qrels.train.tsv' => ,["qid","did","pid","relevancy"] # 'queries.train.tsv' => ["qid","query"] @@ -19,17 +21,45 @@ def run(data_list, domain_list, output, settings): if domain == 'msmarco.passage': from dal.msmarco import MsMarcoPsg - ds = MsMarcoPsg(param.settings[domain]) - if domain == 'aol-ia': + ds = MsMarcoPsg(corpora[domain], domain) + elif domain == 'aol-ia': from dal.aol import Aol - ds = Aol(param.settings[domain], datapath, param.settings['ncore']) - if domain == 'yandex' in domain_list: raise ValueError('Yandex is yet to be added ...') + ds = Aol(corpora[domain], domain, datapath, param.settings['ncore']) + elif domain == 'yandex' in domain_list: raise ValueError('Yandex is yet to be added ...') + else: + from dal.ds import Dataset + ds = Dataset(corpora[domain], domain) - index_item_str = '.'.join(settings[domain]['index_item']) - in_type, out_type = settings[domain]['pairing'][1], settings[domain]['pairing'][2] + ds.read_queries(datapath, domain) + + index_item_str = '.'.join(corpora[domain]['index_item']) + in_type, out_type = corpora[domain]['pairing'][1], corpora[domain]['pairing'][2] tsv_path = {'train': f'{prep_output}/{ds.user_pairing}{in_type}.{out_type}.{index_item_str}.train.tsv', 'test': f'{prep_output}/{ds.user_pairing}{in_type}.{out_type}.{index_item_str}.test.tsv'} + #TODO: change files naming + t5_model = settings['t5model'] # {"small", "base", "large", "3B", "11B"} cross {"local", "gc"} + t5_output = f'../output/{os.path.split(datapath)[-1]}/{ds.user_pairing}t5.{t5_model}.{in_type}.{out_type}.{index_item_str}' + if not os.path.isdir(t5_output): os.makedirs(t5_output) + copyfile('./param.py', f'{t5_output}/refiner_param.py') + query_qrel_doc = None + + # Query refinement - refining queries using the selected refiners + if settings['query_refinement']: + refiners = rf.get_nrf_refiner() + if rf: refiners += rf.get_rf_refiner(rankers=settings['ranker'], corpus=corpora[domain], output=t5_output, ext_corpus=corpora[corpora[domain]['extcorpus']]) + with mp.Pool(settings['ncore']) as p: + for refiner in refiners: + if refiner.get_model_name() == 'original.queries': refiner_outfile = f'{t5_output}/{refiner.get_model_name()}' + else: refiner_outfile = f'{t5_output}/refiner.{refiner.get_model_name()}' + if not exists(refiner_outfile): + print(f'Writing results from {refiner.get_model_name()} queries in {refiner_outfile}') + ds.queries = p.map(partial(refiner.preprocess_query), ds.queries) + refiner.write_queries(queries=ds.queries, outfile=refiner_outfile) + else: print(f'Results from {refiner.get_model_name()} queries in {refiner_outfile}') + + # Consider t5 as a refiner + # TODO: add paring with other expanders if 'pair' in settings['cmd']: print('Pairing queries and relevant passages for training set ...') cat = True if 'docs' in {in_type, out_type} else False @@ -40,10 +70,6 @@ def run(data_list, domain_list, output, settings): query_qrel_doc.to_csv(tsv_path['train'], sep='\t', encoding='utf-8', index=False, columns=[in_type, out_type], header=False) query_qrel_doc.to_csv(tsv_path['test'], sep='\t', encoding='utf-8', index=False, columns=[in_type, out_type], header=False) - t5_model = settings['t5model'] # {"small", "base", "large", "3B", "11B"} cross {"local", "gc"} - t5_output = f'../output/{os.path.split(datapath)[-1]}/{ds.user_pairing}t5.{t5_model}.{in_type}.{out_type}.{index_item_str}' - if not os.path.isdir(t5_output): os.makedirs(t5_output) - copyfile('./param.py', f'{t5_output}/param.py') if {'finetune', 'predict'} & set(settings['cmd']): from mdl import mt5w if 'finetune' in settings['cmd']: @@ -53,7 +79,7 @@ def run(data_list, domain_list, output, settings): pretrained_dir=f'./../output/t5-data/pretrained_models/{t5_model.split(".")[0]}', # "gs://t5-data/pretrained_models/{"small", "base", "large", "3B", "11B"} steps=settings['iter'], output=t5_output, task_name=f"{domain.replace('-', '')}_cf", # :DD Task name must match regex: ^[\w\d\.\:_]+$ - lseq=settings[domain]['lseq'], + lseq=corpora[domain]['lseq'], nexamples=None, in_type=in_type, out_type=out_type, gcloud=False) if 'predict' in settings['cmd']: @@ -63,54 +89,68 @@ def run(data_list, domain_list, output, settings): split='test', tsv_path=tsv_path, output=t5_output, - lseq=settings[domain]['lseq'], + lseq=corpora[domain]['lseq'], gcloud=False) + if 'search' in settings['cmd']: # 'bm25 ranker' print(f"Searching documents for query changes using {settings['ranker']} ...") # seems for some queries there is no qrels, so they are missed for t5 prediction. # query_originals = pd.read_csv(f'{datapath}/queries.train.tsv', sep='\t', names=['qid', 'query'], dtype={'qid': str}) - # we use the file after panda.merge that create the training set so we make sure the mapping of qids - query_originals = pd.read_csv(f'{prep_output}/{ds.user_pairing}queries.qrels.doc{"s" if "docs" in {in_type, out_type} else ""}.ctx.{index_item_str}.train.tsv', sep='\t', usecols=['qid', 'query'], dtype={'qid': str}) - if settings['large_ds']: # we can run this logic if shape of query_originals is greater than split_size - import numpy as np + # we use the file after panda.merge that create the training set, so we make sure the mapping of qids + # query_originals = pd.read_csv(f'{prep_output}/{ds.user_pairing}queries.qrels.doc{"s" if "docs" in {in_type, out_type} else ""}.ctx.{index_item_str}.train.tsv', sep='\t', usecols=['qid', 'query'], dtype={'qid': str}) + + + + # we can run this logic if shape of queries is greater than split_size + if settings['large_ds']: import glob - split_size = 1000000 # need to make this dynamic based on shape of query_originals. - for _, chunk in query_originals.groupby(np.arange(query_originals.shape[0]) // split_size): - file_changes = [(file, f'{file}.{settings["ranker"]}') for file in - glob.glob(f'{t5_output}/**/pred.{_}*') if f'{file}.{settings["ranker"]}' not in glob.glob(f'{t5_output}/**')] - chunk.drop_duplicates(subset=['qid'], inplace=True) # in the event there are duplicates in query qrels. - with mp.Pool(settings['ncore']) as p: - p.starmap(partial(ds.search, qids=chunk['qid'].values.tolist(), ranker=settings['ranker'], - topk=settings['topk'], batch=settings['batch'], ncores=settings['ncore'], - index=ds.searcher.index_dir), file_changes) - - print('for original queries') + original_dir = f'{t5_output}/original' + if not os.path.isdir(original_dir): os.makedirs(original_dir) + split_size = 1000000 # need to make this dynamic based on shape of queries. + for _, chunk in [ds.queries[i:i + split_size] for i in range(0, len(ds.queries), split_size)]: + # Generate original queries' files - keep records + original_file_i = f'{original_dir}/original.{_}.tsv' + pd.DataFrame({'query': [query.q for query in chunk]}).to_csv(original_file_i, sep='\t', index=False, header=False) + + file_changes = [(file, f'{file}.{settings["ranker"]}') for file in glob.glob(f'{t5_output}/**/pred.{_}*') if f'{file}.{settings["ranker"]}' not in glob.glob(f'{t5_output}/**')] + file_changes.extend([(f'{t5_output}/{f}', f'{t5_output}/{f}.{settings["ranker"]}') for f in os.listdir(t5_output) if os.path.isfile(os.path.join(t5_output, f)) and f.startswith('refiner.')]) + + with mp.Pool(settings['ncore']) as p: p.starmap(partial(ds.search, qids=[query.qid for query in chunk], ranker=settings['ranker'], topk=settings['topk'], batch=settings['batch'], ncores=settings['ncore'], index=ds.searcher.index_dir), file_changes) + + print('For original queries') file_changes = list() - for _, chunk in query_originals.groupby(np.arange(query_originals.shape[0]) // split_size): - chunk.drop_duplicates(subset=['qid'], inplace=True) - file_changes.append((f'{t5_output}/original/original.{_}.tsv', - f'{t5_output}/original/original.{_}.tsv.bm25', chunk['qid'].values.tolist())) - with mp.Pool(settings['ncore']) as p: - p.starmap(partial(ds.search, ranker=settings['ranker'], topk=settings['topk'], batch=settings['batch'], ncores=settings['ncore'], index=ds.searcher.index_dir), file_changes) - else: - query_changes = [(f'{t5_output}/{f}', f'{t5_output}/{f}.{settings["ranker"]}') for f in - listdir(t5_output) if - isfile(join(t5_output, f)) and f.startswith('pred.') and len(f.split('.')) == 2 and f'{f}.{settings["ranker"]}' not in listdir(t5_output)] + for _, chunk in [ds.queries[i:i + split_size] for i in range(0, len(ds.queries), split_size)]: file_changes.append((f'{t5_output}/original/original.{_}.tsv', f'{t5_output}/original/original.{_}.tsv.bm25', [query.qid for query in chunk])) + with mp.Pool(settings['ncore']) as p: p.starmap(partial(ds.search, ranker=settings['ranker'], topk=settings['topk'], batch=settings['batch'], ncores=settings['ncore'], index=ds.searcher.index_dir), file_changes) - # for (i, o) in query_changes: ds.search(i, o, query_originals['qid'].values.tolist(), settings['ranker'], topk=settings['topk'], batch=settings['batch']) + else: + # Here it considers generated queries from t5 or refiners and the original queries + query_changes = [ + (f'{t5_output}/{f}', f'{t5_output}/{f}.{settings["ranker"]}') + for f in listdir(t5_output) + if isfile(join(t5_output, f)) and ( + f.startswith('pred.') or f.startswith('refiner.') or f.startswith('original.') + ) and len(f.split('.')) == 2 and f'{f}.{settings["ranker"]}' not in listdir(t5_output) + ] + # query_changes = [] + # for f in listdir(t5_output): + # if isfile(join(t5_output, f)) and (f.startswith('pred.') or f.startswith('refiner.')) and len( + # f.split('.')) == 2 and f'{f}.{settings["ranker"]}' not in listdir(t5_output): + # query_changes.append((f'{t5_output}/{f}', f'{t5_output}/{f}.{settings["ranker"]}')) + # for (i, o) in query_changes: ds.search(i, o, query_originals['qid'].values.tolist(), settings['ranker'], topk=settings['topk'], batch=settings['batch']) # batch search: # for (i, o) in query_changes: ds.search(i, o, query_originals['qid'].values.tolist(), settings['ranker'], topk=settings['topk'], batch=settings['batch']) # seems the LuceneSearcher cannot be shared in multiple processes! See dal.ds.py - # parallel on each file ==> Problem: starmap does not understand inherited Dataset.searcher attribute! - with mp.Pool(settings['ncore']) as p: - p.starmap(partial(ds.search, qids=query_originals['qid'].values.tolist(), ranker=settings['ranker'], - topk=settings['topk'], batch=settings['batch'], ncores=settings['ncore'], - index=None), query_changes) + #TODO: parallel on each file ==> Problem: starmap does not understand inherited Dataset.searcher attribute! + user_pairing = "user/" if "user" in ds.settings["pairing"] else "" + index_item_str = '.'.join(settings["index_item"]) if ds.__class__.__name__ != 'MsMarcoPsg' else "" + with mp.Pool(settings['ncore']) as p: p.starmap(partial(ds.search, qids=[query.qid for query in ds.queries], ranker=settings['ranker'], topk=settings['topk'], batch=settings['batch'], ncores=settings['ncore'], index=f'{ds.settings["index"]}{user_pairing}{index_item_str}'), query_changes) + # we need to add the original queries as well - if not isfile(join(t5_output, f'original.{settings["ranker"]}')): - query_originals.to_csv(f'{t5_output}/original', columns=['query'], index=False, header=False) - ds.search_df(pd.DataFrame(query_originals['query']), f'{t5_output}/original.{settings["ranker"]}', query_originals['qid'].values.tolist(), settings['ranker'], topk=settings['topk'], batch=settings['batch'], ncores=settings['ncore']) + # original_path = f'{t5_output}/original.{settings["ranker"]}' + # if not isfile(original_path): + # pd.DataFrame({'query': [query.q for query in ds.queries]}).to_csv(original_path, sep='\t', index=False, header=False) + # ds.search_df(queries=pd.DataFrame([query.q for query in ds.queries]), out_docids=original_path, qids=[query.qid for query in ds.queries], ranker=settings['ranker'], topk=settings['topk'], batch=settings['batch'], ncores=settings['ncore']) if 'eval' in settings['cmd']: from evl import trecw @@ -158,19 +198,22 @@ def run(data_list, domain_list, output, settings): print('all files are merged and ready for aggregation') else: search_results = [(f'{t5_output}/{f}', f'{t5_output}/{f}.{settings["metric"]}') for f in listdir(t5_output) if f.endswith(settings["ranker"]) and f'{f}.{settings["ranker"]}.{settings["metric"]}' not in listdir(t5_output)] - if not isfile(f'{datapath}/{ds.user_pairing}qrels.train.tsv_'): - qrels = pd.read_csv(f'{datapath}/{ds.user_pairing}qrels.train.tsv', sep='\t', index_col=False, names=['qid', 'did', 'pid', 'relevancy'], header=None) - qrels.drop_duplicates(subset=['qid', 'pid'], inplace=True) # qrels have duplicates!! - qrels.to_csv(f'{datapath}/qrels.train.tsv_', index=False, sep='\t', header=False) # trec_eval.9.0.4 does not accept duplicate rows!! + # This snipet code is added in the read_quereis function! + # if not isfile(f'{datapath}/{ds.user_pairing}qrels.train.tsv_'): + # qrels = pd.read_csv(f'{datapath}/{ds.user_pairing}qrels.train.tsv', sep='\t', index_col=False, names=['qid', 'did', 'pid', 'relevancy'], header=None) + # qrels.drop_duplicates(subset=['qid', 'pid'], inplace=True) # qrels have duplicates!! + # qrels.to_csv(f'{datapath}/qrels.train.tsv_', index=False, sep='\t', header=False) # trec_eval.9.0.4 does not accept duplicate rows!! # for (i, o) in search_results: trecw.evaluate(i, o, qrels=f'{datapath}/qrels.train.tsv_', metric=settings['metric'], lib=settings['treclib']) - with mp.Pool(settings['ncore']) as p: - p.starmap(partial(trecw.evaluate, qrels=f'{datapath}/{ds.user_pairing}qrels.train.tsv_', metric=settings['metric'], lib=settings['treclib'], mean=not settings['large_ds']), search_results) + with mp.Pool(settings['ncore']) as p: p.starmap(partial(trecw.evaluate, qrels=f'{datapath}/{ds.user_pairing}qrels.train.tsv_', metric=settings['metric'], lib=settings['treclib'], mean=not settings['large_ds']), search_results) + if 'agg' in settings['cmd']: - originals = pd.read_csv(f'{prep_output}/queries.qrels.doc{"s" if "docs" in {in_type, out_type} else ""}.ctx.{index_item_str}.train.tsv', sep='\t', usecols=['qid', 'query'], dtype={'qid': str}) - original_metric_values = pd.read_csv(join(t5_output, f'original.{settings["ranker"]}.{settings["metric"]}'), sep='\t', usecols=[1, 2], names=['qid', f'original.{settings["ranker"]}.{settings["metric"]}'], index_col=False, dtype={'qid': str}) + # originals = pd.read_csv(f'{prep_output}/queries.qrels.doc{"s" if "docs" in {in_type, out_type} else ""}.ctx.{index_item_str}.train.tsv', sep='\t', usecols=['qid', 'query'], dtype={'qid': str}) + originals = pd.DataFrame({'qid': [str(query.qid) for query in ds.queries], 'query': [query.q for query in ds.queries]}) + original_metric_values = pd.read_csv(join(t5_output, f'original.queries.{settings["ranker"]}.{settings["metric"]}'), sep='\t', usecols=[1, 2], names=['qid', f'original.queries.{settings["ranker"]}.{settings["metric"]}'], index_col=False, dtype={'qid': str}) + originals = originals.merge(original_metric_values, how='left', on='qid') - originals[f'original.{settings["ranker"]}.{settings["metric"]}'].fillna(0, inplace=True) + originals[f'original.queries.{settings["ranker"]}.{settings["metric"]}'].fillna(0, inplace=True) changes = [('.'.join(f.split('.')[0:2]), f) for f in os.listdir(t5_output) if f.endswith(f'{settings["ranker"]}.{settings["metric"]}') and 'original' not in f] ds.aggregate(originals, changes, t5_output, settings["large_ds"]) @@ -179,21 +222,22 @@ def run(data_list, domain_list, output, settings): box_path = join(t5_output, f'{settings["ranker"]}.{settings["metric"]}.boxes') if not os.path.isdir(box_path): os.makedirs(box_path) gold_df = pd.read_csv(f'{t5_output}/{settings["ranker"]}.{settings["metric"]}.agg.all.tsv', sep='\t', header=0, dtype={'qid': str}) - qrels = pd.read_csv(f'{datapath}/qrels.train.tsv_', names=['qid', 'did', 'pid', 'rel'], sep='\t', dtype={'qid': str}) + qrels = pd.DataFrame([query.docs for query in ds.queries]) box_condition = settings['box'] ds.box(gold_df, qrels, box_path, box_condition) for c in box_condition.keys(): - print(f'Stamping boxes for {settings["ranker"]}.{settings["metric"]} before and after refinements ...') - + print(f'{c}: Stamping boxes for {settings["ranker"]}.{settings["metric"]} before and after refinements ...') if not os.path.isdir(join(box_path, 'stamps')): os.makedirs(join(box_path, 'stamps')) - df = pd.read_csv(f'{box_path}/{c}.tsv', sep='\t', encoding='utf-8', index_col=False, header=None, names=['qid', 'query', 'metric', 'query_', 'metric_'], dtype={'qid': str}) df.drop_duplicates(subset=['qid'], inplace=True) # See ds.boxing(): in case we store more than two changes with the same metric value - ds.search_df(df['query'].to_frame(), f'{box_path}/stamps/{c}.original.{settings["ranker"]}', df['qid'].values.tolist(), settings['ranker'], topk=settings['topk'], batch=settings['batch'], ncores=settings['ncore']) - trecw.evaluate(f'{box_path}/stamps/{c}.original.{settings["ranker"]}', f'{box_path}/stamps/{c}.original.{settings["ranker"]}.{settings["metric"]}', qrels=f'{datapath}/qrels.train.tsv_', metric=settings['metric'], lib=settings['treclib'], mean=True) - ds.search_df(df['query_'].to_frame().rename(columns={'query_': 'query'}), f'{box_path}/stamps/{c}.change.{settings["ranker"]}', df['qid'].values.tolist(), settings['ranker'], topk=settings['topk'], batch=settings['batch'], ncores=settings['ncore']) - trecw.evaluate(f'{box_path}/stamps/{c}.change.{settings["ranker"]}', f'{box_path}/stamps/{c}.change.{settings["ranker"]}.{settings["metric"]}', qrels=f'{datapath}/qrels.train.tsv_', metric=settings['metric'], lib=settings['treclib'], mean=True) + if df['query'].to_frame().empty: print(f'No queries for {c}') + else: + ds.search_df(df['query'].to_frame(), f'{box_path}/stamps/{c}.original.{settings["ranker"]}', df['qid'].values.tolist(), settings['ranker'], topk=settings['topk'], batch=settings['batch'], ncores=settings['ncore']) + trecw.evaluate(f'{box_path}/stamps/{c}.original.{settings["ranker"]}', f'{box_path}/stamps/{c}.original.{settings["ranker"]}.{settings["metric"]}', qrels=f'{datapath}/qrels.train.tsv_', metric=settings['metric'], lib=settings['treclib'], mean=True) + ds.search_df(df['query_'].to_frame().rename(columns={'query_': 'query'}), f'{box_path}/stamps/{c}.change.{settings["ranker"]}', df['qid'].values.tolist(), settings['ranker'], topk=settings['topk'], batch=settings['batch'], ncores=settings['ncore']) + trecw.evaluate(f'{box_path}/stamps/{c}.change.{settings["ranker"]}', f'{box_path}/stamps/{c}.change.{settings["ranker"]}.{settings["metric"]}', qrels=f'{datapath}/qrels.train.tsv_', metric=settings['metric'], lib=settings['treclib'], mean=True) + if 'dense_retrieve' in settings['cmd']: from evl import trecw from tqdm import tqdm @@ -226,7 +270,7 @@ def run(data_list, domain_list, output, settings): search_results = [(f'{t5_output}/original.{condition}.tct_colbert', f'{t5_output}/original.{condition}.tct_colbert.{metric}'), (f'{t5_output}/pred.{condition}.tct_colbert', f'{t5_output}/pred.{condition}.tct_colbert.{metric}')] with mp.Pool(settings['ncore']) as p: - p.starmap(partial(ds.search_df, qids=original['qid'].values.tolist(), ranker='tct_colbert', topk=100, batch=None, + p.starmap(partial(ds.search_list, qids=original['qid'].values.tolist(), ranker='tct_colbert', topk=100, batch=None, ncores=settings['ncore'], index=settings[f'{domain}']["dense_index"], encoder=settings[f'{domain}']['dense_encoder']), search_list) p.starmap(partial(trecw.evaluate, qrels=f'{datapath}/qrels.train.tsv_', metric=settings['metric'], lib=settings['treclib']), search_results) @@ -251,8 +295,7 @@ def run(data_list, domain_list, output, settings): agg_df['pred_sparse'] = original[f'{ranker}.{metric}_'] agg_df.to_csv(f'{t5_output}/colbert.comparison.{condition}.{metric}.tsv', sep="\t", index=None) - if 'stats' in settings['cmd']: - from stats import stats + if 'stats' in settings['cmd']: from stats import stats def addargs(parser): @@ -278,6 +321,7 @@ def addargs(parser): run(data_list=args.data_list, domain_list=args.domain_list, output=args.output, + corpora=param.corpora, settings=param.settings) # after finetuning and predict, we can benchmark on rankers and metrics diff --git a/src/mdl/mt5w.py b/src/mdl/mt5w.py index ff2575d..07e1198 100644 --- a/src/mdl/mt5w.py +++ b/src/mdl/mt5w.py @@ -1,6 +1,6 @@ import functools, os, sys, time import tensorflow.compat.v1 as tf -import tensorflow_datasets as tfds +# import tensorflow_datasets as tfds import t5 import t5.models @@ -95,7 +95,7 @@ def predict(iter, split, tsv_path, output, lseq, vocab_model_path='./../output/t # def predict(iter, split, tsv_path, pretrained_dir, steps, output, lseq, task_name, nexamples=None, in_type='query', out_type='doc', vocab_model_path='./../output/t5-data/vocabs/cc_en.32000/sentencepiece.model', gcloud=False): if gcloud: import gcloud - model_parallelism, train_batch_size, keep_checkpoint_max = {"small": (1, 256, 16), "base": (2, 128, 8), "large": (8, 64, 4), "3B": (8, 16, 1), "11B": (8, 16, 1)}[output.split('.')[-4]] + model_parallelism, train_batch_size, keep_checkpoint_max = {"small": (1, 256, 16), "base": (2, 128, 8), "large": (8, 64, 4), "3B": (8, 16, 1), "11B": (8, 16, 1)}[output.split('.')[-5]] model = t5.models.MtfModel( model_dir=output.replace('/', os.path.sep), tpu=gcloud.TPU_ADDRESS if gcloud else None, diff --git a/src/param.py b/src/param.py index 12b033b..c3108ff 100644 --- a/src/param.py +++ b/src/param.py @@ -9,25 +9,30 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '-1' settings = { - 'cmd': ['stats'], # steps of pipeline, ['pair', 'finetune', 'predict', 'search', 'eval','agg', 'box','dense_retrieve'] + 'query_refinement': True, + 'cmd': ['pair', 'finetune', 'predict', 'search', 'eval','agg', 'box'], # steps of pipeline, ['pair', 'finetune', 'predict', 'search', 'eval','agg', 'box','dense_retrieve', 'stats] 'ncore': 2, - 't5model': 'base.gc', # 'base.gc' on google cloud tpu, 'small.local' on local machine + 't5model': 'small.local', # 'base.gc' on google cloud tpu, 'small.local' on local machine 'iter': 5, # number of finetuning iteration for t5 'nchanges': 5, # number of changes to a query 'ranker': 'bm25', # 'qld', 'bm25', 'tct_colbert' 'batch': None, # search per batch of queries for IR search using pyserini, if None, search per query 'topk': 100, # number of retrieved documents for a query - 'metric': 'recip_rank.10', # any valid trec_eval.9.0.4 metric like map, ndcg, recip_rank, ... - 'large_ds': True, + 'metric': 'map', # any valid trec_eval.9.0.4 metric like map, ndcg, recip_rank, ... + 'large_ds': False, 'treclib': f'"./trec_eval.9.0.4/trec_eval{extension}"', # in non-windows, remove .exe, also for pytrec_eval, 'pytrec' 'box': {'gold': 'refined_q_metric >= original_q_metric and refined_q_metric > 0', 'platinum': 'refined_q_metric > original_q_metric', - 'diamond': 'refined_q_metric > original_q_metric and refined_q_metric == 1'}, + 'diamond': 'refined_q_metric > original_q_metric and refined_q_metric == 1'} +} + +corpora = { 'msmarco.passage': { 'index_item': ['passage'], 'index': '../data/raw/msmarco.passage/lucene-index.msmarco-v1-passage.20220131.9ea315/', 'dense_encoder': 'castorini/tct_colbert-msmarco', 'dense_index': 'msmarco-passage-tct_colbert-hnsw', + 'extcorpus': 'orcas', 'pairing': [None, 'docs', 'query'], # [context={msmarco does not have userinfo}, input={query, doc, doc(s)}, output={query, doc, doc(s)}], s means concat of docs 'lseq': {"inputs": 32, "targets": 256}, # query length and doc length for t5 model, }, @@ -39,5 +44,115 @@ 'pairing': [None, 'docs', 'query'], # [context={2 scenarios, one with userID and one without userID). input={'userid','query','doc(s)'} output={'query','doc(s)'} 'lseq': {"inputs": 32, "targets": 256}, # query length and doc length for t5 model, 'filter': {'minql': 1, 'mindocl': 10} # [min query length, min doc length], after merge queries with relevant 'index_item', if |query| <= minql drop the row, if |'index_item'| < mindocl, drop row - } + }, + 'robust04': { + 'index': '../data/raw/robust04/lucene-index.robust04.pos+docvectors+rawdocs', + 'dense_index': '../data/raw/robust04/faiss_index_robust04', + 'encoded': '../data/raw/robust04/encoded_robust04', + 'size': 528155, + 'topics': '../data/raw/robust04/topics.robust04.txt', + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 2.25, # OnFields + 'w_a': 1, # OnFields + 'tokens': 148000000, + 'qrels': '../data/raw/robust04/qrels.robust04.txt', + 'extcorpus': 'gov2', # AdaptOnFields + 'pairing': [None, None, None], + 'index_item': [], + }, + 'gov2': { + 'index': '../data/raw/gov2/lucene-index.gov2.pos+docvectors+rawdocs', + 'size': 25000000, + 'topics': '../data/raw/gov2/topics.terabyte0{}.txt', # {} is a placeholder for subtopics in main.py -> run() + 'trec': ['4.701-750', '5.751-800', '6.801-850'], + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 4, # OnFields + 'w_a': 0.25, # OnFields + 'tokens': 17000000000, + 'qrels': '../data/raw/gov2/qrels.terabyte0{}.txt', # {} is a placeholder for subtopics in main.py -> run() + 'extcorpus': 'robust04', # AdaptOnFields + 'pairing': [None, None, None], + 'index_item': [], + }, + 'clueweb09b': { + 'index': '../data/raw/clueweb09b/lucene-index.cw09b.pos+docvectors+rawdocs', + 'size': 50000000, + 'topics': '../data/raw/clueweb09b/topics.web.{}.txt', # {} is a placeholder for subtopics in main.py -> run() + 'trec': ['1-50', '51-100', '101-150', '151-200'], + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 1, # OnFields + 'w_a': 0, # OnFields + 'tokens': 31000000000, + 'qrels': '../data/raw/clueweb09b/qrels.web.{}.txt', # {} is a placeholder for subtopics in main.py -> run() + 'extcorpus': 'gov2', # AdaptOnFields + 'pairing': [None, None, None], + 'index_item': [], + }, + 'clueweb12b13': { + 'index': '../data/raw/clueweb12b13/lucene-index.cw12b13.pos+docvectors+rawdocs', + 'size': 50000000, + 'topics': '../data/raw/clueweb12b13/topics.web.{}.txt', # {} is a placeholder for subtopics in main.py -> run() + 'trec': ['201-250', '251-300'], + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 4, # OnFields + 'w_a': 0, # OnFields + 'tokens': 31000000000, + 'qrels': '../data/raw/clueweb12b13/qrels.web.{}.txt', # {} is a placeholder for subtopics in main.py -> run() + 'extcorpus': 'gov2', # AdaptOnFields + 'pairing': [None, None, None], + 'index_item': [], + }, + 'antique': { + 'index': '../data/raw/antique/lucene-index-antique', + 'size': 403000, + 'topics': '../data/raw/antique/topics.antique.txt', + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 2.25, # OnFields # to be tuned + 'w_a': 1, # OnFields # to be tuned + 'tokens': 16000000, + 'qrels': '../ds/antique/qrels.antique.txt', + 'extcorpus': 'gov2', # AdaptOnFields + }, + 'trec09mq': { + 'index': 'D:\clueweb09b\lucene-index.cw09b.pos+docvectors+rawdocs', + 'size': 50000000, + # 'topics': '../ds/trec2009mq/prep/09.mq.topics.20001-60000.prep.tsv', + 'topics': '../ds/trec09mq/09.mq.topics.20001-60000.prep', + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 2.25, # OnFields # to be tuned + 'w_a': 1, # OnFields # to be tuned + 'tokens': 16000000, + 'qrels': '../ds/trec09mq/prels.20001-60000.prep', + 'extcorpus': 'gov2', # AdaptOnFields + }, + 'dbpedia': { + 'index': '../ds/dbpedia/lucene-index-dbpedia-collection', + 'size': 4632359, + 'topics': '../ds/dbpedia/topics.dbpedia.txt', + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 1, # OnFields # to be tuned + 'w_a': 1, # OnFields # to be tuned + 'tokens': 200000000, + 'qrels': '../ds/dbpedia/qrels.dbpedia.txt', + 'extcorpus': 'gov2', # AdaptOnFields + }, + 'orcas': { + 'index': '../ds/orcas/lucene-index.msmarco-v1-doc.20220131.9ea315', + 'size': 50000000, + # 'topics': '../ds/trec2009mq/prep/09.mq.topics.20001-60000.prep.tsv', + 'topics': '../ds/orcas/preprocess/orcas-I-2M_topics.prep', + 'prels': '', # this will be generated after a retrieval {bm25, qld} + 'w_t': 2.25, # OnFields # to be tuned + 'w_a': 1, # OnFields # to be tuned + 'tokens': 16000000, + 'qrels': '../ds/orcas/preprocess/orcas-doctrain-qrels.prep', + 'extcorpus': 'gov2', # AdaptOnFields + }, } + +# Only for sparse indexing +anserini = { + 'path': '../anserini/', + 'trec_eval': '../anserini/eval/trec_eval.9.0.4/trec_eval' +} + diff --git a/src/refinement/refiner_factory.py b/src/refinement/refiner_factory.py new file mode 100644 index 0000000..7db2951 --- /dev/null +++ b/src/refinement/refiner_factory.py @@ -0,0 +1,55 @@ +from src.refinement.refiners.abstractqrefiner import AbstractQRefiner +from src.refinement.refiners.stem import Stem # Stem refiner is the wrapper for all stemmers as an refiner :) +from src.refinement import refiner_param +from src.refinement import utils + +#global analysis +def get_nrf_refiner(): + refiners_list = [AbstractQRefiner()] + if refiner_param.refiners['Thesaurus']: from src.refinement.refiners.thesaurus import Thesaurus; refiners_list.append(Thesaurus()) + if refiner_param.refiners['Thesaurus']: from src.refinement.refiners.thesaurus import Thesaurus; refiners_list.append(Thesaurus(replace=True)) + if refiner_param.refiners['Wordnet']: from src.refinement.refiners.wordnet import Wordnet; refiners_list.append(Wordnet()) + if refiner_param.refiners['Wordnet']: from src.refinement.refiners.wordnet import Wordnet; refiners_list.append(Wordnet(replace=True)) + if refiner_param.refiners['Word2Vec']: from src.refinement.refiners.word2vec import Word2Vec; refiners_list.append(Word2Vec('../pre/wiki-news-300d-1M.vec')) + if refiner_param.refiners['Word2Vec']: from src.refinement.refiners.word2vec import Word2Vec; refiners_list.append(Word2Vec('../pre/wiki-news-300d-1M.vec', replace=True)) + if refiner_param.refiners['Glove']: from src.refinement.refiners.glove import Glove; refiners_list.append(Glove('../pre/glove.6B.300d')) + if refiner_param.refiners['Glove']: from src.refinement.refiners.glove import Glove; refiners_list.append(Glove('../pre/glove.6B.300d', replace=True)) + if refiner_param.refiners['Anchor']: from src.refinement.refiners.anchor import Anchor; refiners_list.append(Anchor(anchorfile='../pre/anchor_text_en.ttl', vectorfile='../pre/wiki-anchor-text-en-ttl-300d.vec')) + if refiner_param.refiners['Anchor']: from src.refinement.refiners.anchor import Anchor; refiners_list.append(Anchor(anchorfile='../pre/anchor_text_en.ttl', vectorfile='../pre/wiki-anchor-text-en-ttl-300d.vec', replace=True)) + if refiner_param.refiners['Wiki']: from src.refinement.refiners.wiki import Wiki; refiners_list.append(Wiki('../pre/temp_model_Wiki')) + if refiner_param.refiners['Wiki']: from src.refinement.refiners.wiki import Wiki; refiners_list.append(Wiki('../pre/temp_model_Wiki', replace=True)) + if refiner_param.refiners['Tagmee']: from src.refinement.refiners.tagmee import Tagmee; refiners_list.append(Tagmee()) + if refiner_param.refiners['Tagmee']: from src.refinement.refiners.tagmee import Tagmee; refiners_list.append(Tagmee(replace=True)) + if refiner_param.refiners['SenseDisambiguation']: from src.refinement.refiners.sensedisambiguation import SenseDisambiguation; refiners_list.append(SenseDisambiguation()) + if refiner_param.refiners['SenseDisambiguation']: from src.refinement.refiners.sensedisambiguation import SenseDisambiguation; refiners_list.append(SenseDisambiguation(replace=True)) + if refiner_param.refiners['Conceptnet']: from src.refinement.refiners.conceptnet import Conceptnet; refiners_list.append(Conceptnet()) + if refiner_param.refiners['Conceptnet']: from src.refinement.refiners.conceptnet import Conceptnet; refiners_list.append(Conceptnet(replace=True)) + if refiner_param.refiners['KrovetzStemmer']: from stemmers.krovetz import KrovetzStemmer; refiners_list.append(Stem(KrovetzStemmer(jarfile='stemmers/kstem-3.4.jar'))) + if refiner_param.refiners['LovinsStemmer']: from stemmers.lovins import LovinsStemmer; refiners_list.append(Stem(LovinsStemmer())) + if refiner_param.refiners['PaiceHuskStemmer']: from stemmers.paicehusk import PaiceHuskStemmer; refiners_list.append(Stem(PaiceHuskStemmer())) + if refiner_param.refiners['PorterStemmer']: from stemmers.porter import PorterStemmer; refiners_list.append(Stem(PorterStemmer())) + if refiner_param.refiners['Porter2Stemmer']: from stemmers.porter2 import Porter2Stemmer; refiners_list.append(Stem(Porter2Stemmer())) + if refiner_param.refiners['SRemovalStemmer']: from stemmers.sstemmer import SRemovalStemmer; refiners_list.append(Stem(SRemovalStemmer())) + if refiner_param.refiners['Trunc4Stemmer']: from stemmers.trunc4 import Trunc4Stemmer; refiners_list.append(Stem(Trunc4Stemmer())) + if refiner_param.refiners['Trunc5Stemmer']: from stemmers.trunc5 import Trunc5Stemmer; refiners_list.append(Stem(Trunc5Stemmer())) + if refiner_param.refiners['BackTranslation']: from src.refinement.refiners.backtranslation import BackTranslation; refiners_list.extend([BackTranslation(each_lng) for index, each_lng in enumerate(refiner_param.backtranslation['tgt_lng'])]) + # since RF needs index and search output which depends on ir method and topics corpora, we cannot add this here. Instead, we run it individually + # RF assumes that there exist abstractqueryexpansion files + + return refiners_list + +#local analysis +def get_rf_refiner(rankers, corpus, output, ext_corpus=None): + refiners_list = [] + for ranker in rankers: + ranker_name = utils.get_ranker_name(ranker) + if refiner_param.refiners['RM3']: from src.refinement.refiners.rm3 import RM3; refiners_list.append(RM3(ranker=ranker_name, index=corpus['index'])) + if refiner_param.refiners['RelevanceFeedback']: from src.refinement.refiners.relevancefeedback import RelevanceFeedback; refiners_list.append(RelevanceFeedback(ranker=ranker_name, prels='{}.abstractqueryexpansion.{}.txt'.format(output, ranker_name), anserini=refiner_param.anserini['path'], index=corpus['index'])) + if refiner_param.refiners['Docluster']: from src.refinement.refiners.docluster import Docluster; refiners_list.append(Docluster(ranker=ranker_name, prels='{}.abstractqueryexpansion.{}.txt'.format(output, ranker_name), anserini=refiner_param.anserini['path'], index=corpus['index'])), + if refiner_param.refiners['Termluster']: from src.refinement.refiners.termluster import Termluster; refiners_list.append(Termluster(ranker=ranker_name, prels='{}.abstractqueryexpansion.{}.txt'.format(output, ranker_name), anserini=refiner_param.anserini['path'], index=corpus['index'])) + if refiner_param.refiners['Conceptluster']: from src.refinement.refiners.conceptluster import Conceptluster; refiners_list.append(Conceptluster(ranker=ranker_name, prels='{}.abstractqueryexpansion.{}.txt'.format(output, ranker_name), anserini=refiner_param.anserini['path'], index=corpus['index'])) + if refiner_param.refiners['BertQE']: from src.refinement.refiners.bertqe import BertQE; refiners_list.append(BertQE(ranker=ranker_name, prels='{}.abstractqueryexpansion.{}.txt'.format(output, ranker_name), index=corpus['index'], anserini=refiner_param.anserini['path'])) + if refiner_param.refiners['OnFields']: from src.refinement.refiners.onfields import OnFields; refiners_list.append(OnFields(ranker=ranker_name, prels='{}.abstractqueryexpansion.{}.txt'.format(output, ranker_name), anserini=refiner_param.anserini['path'], index=refiner_param.corpora[corpus]['index'], w_t=corpus['w_t'], w_a=corpus['w_a'], corpus_size=corpus['size'])) + if refiner_param.refiners['AdapOnFields']: from src.refinement.refiners.adaponfields import AdapOnFields; refiners_list.append(AdapOnFields(ranker=ranker_name, prels='{}.abstractqueryexpansion.{}.txt'.format(output, ranker_name), anserini=refiner_param.anserini['path'], index=corpus['index'], w_t=corpus['w_t'], w_a=corpus['w_a'], corpus_size=corpus['size'], collection_tokens=corpus['tokens'], ext_corpus=ext_corpus, ext_index=ext_corpus['index'], ext_collection_tokens=ext_corpus['tokens'], ext_w_t=ext_corpus['w_t'], ext_w_a=ext_corpus['w_a'], ext_corpus_size=ext_corpus['size'], adap=True)) + + return refiners_list diff --git a/src/refinement/refiner_param.py b/src/refinement/refiner_param.py new file mode 100644 index 0000000..63b6296 --- /dev/null +++ b/src/refinement/refiner_param.py @@ -0,0 +1,50 @@ +import sys, platform + +extension = '.exe' if platform.system() == 'Windows' else "" + +settings = { + 'transformer_model': 'johngiorgi/declutr-small', +} + +refiners = { + 'SenseDisambiguation': 0, + 'Thesaurus': 0, + 'Wordnet': 0, + 'Conceptnet': 0, + 'Tagmee': 0, + + 'Word2Vec': 0, + 'Glove': 0, + 'Anchor': 0, + 'Wiki': 0, + + 'KrovetzStemmer': 0, + 'LovinsStemmer': 0, + 'PaiceHuskStemmer': 0, + 'PorterStemmer': 0, + 'Porter2Stemmer': 0, + 'SRemovalStemmer': 0, + 'Trunc4Stemmer': 0, + 'Trunc5Stemmer': 0, + + 'RelevanceFeedback': 0, + 'Docluster': 0, + 'Termluster': 0, + 'Conceptluster': 0, + 'OnFields': 0, # make sure that the index for 'extcorpus' is available + 'AdapOnFields': 0, # make sure that the index for 'extcorpus' is available + 'BertQE': 0, + 'RM3': 0, + + 'BackTranslation': 1, + } + +# Backtranslation settings +backtranslation = { + 'src_lng': 'eng_Latn', + 'tgt_lng': ['fra_Latn'], # ['yue_Hant', 'kor_Hang', 'arb_Arab', 'pes_Arab', 'fra_Latn', 'deu_Latn', 'rus_Cyrl', 'zsm_Latn', 'tam_Taml', 'swh_Latn'] + 'max_length': 512, + 'device': 'cpu', + 'model_card': 'facebook/nllb-200-distilled-600M', + 'transformer_model': 'johngiorgi/declutr-small', +} diff --git a/src/refinement/refiners/__init__.py b/src/refinement/refiners/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/refinement/refiners/abstractqrefiner.py b/src/refinement/refiners/abstractqrefiner.py new file mode 100644 index 0000000..6d95db5 --- /dev/null +++ b/src/refinement/refiners/abstractqrefiner.py @@ -0,0 +1,74 @@ +import traceback +from src.dal.query import Query +from sentence_transformers import SentenceTransformer +from scipy.spatial.distance import cosine +import sys +sys.path.extend(['../src']) +from src.refinement import utils +from src.refinement.refiner_param import settings + + +class AbstractQRefiner: + def __init__(self, replace=False, topn=None): + self.transformer_model = SentenceTransformer(settings['transformer_model']) + self.replace = replace + self.topn = topn + + # All children expanders must call this in the returning line + def get_refined_query(self, q: Query, args=None): return q.q + + def get_refined_query_batch(self, queries, args=None): return queries, [1] * len(queries) + + def get_model_name(self): + # this is for backward compatibility for renaming this class + if self.__class__.__name__ == 'AbstractQRefiner': return 'original.queries'.lower() + return f"{self.__class__.__name__.lower()}{f'.topn{self.topn}' if self.topn else ''}{'.replace' if self.replace else ''}" + + def preprocess_query(self, query, clean=True): + ansi_reset = "\033[0m" + try: + q_ = self.get_refined_query(query) + q_ = utils.clean(q_) if clean else q_ + semsim = self.get_semsim(query, q_) + print(f'{utils.hex_to_ansi("#F1C40F")}Info: {utils.hex_to_ansi("#3498DB")}({self.get_model_name()}){ansi_reset} {query.qid}: {query.q} -> {utils.hex_to_ansi("#52BE80")}{q_}{ansi_reset}') + except Exception as e: + print(f'{utils.hex_to_ansi("#E74C3C")}WARNING: {utils.hex_to_ansi("#3498DB")}({self.get_model_name()}){ansi_reset} Refining query [{query.qid}:{query.q}] failed!') + print(traceback.format_exc()) + q_, semsim = query.q, 1 + + query.q_[self.get_model_name()] = (q_, semsim) + return query + + def preprocess_query_batch(self, queries, clean=True): + q_s, semsims = self.get_refined_query_batch(queries) + for q_, semsim, query in zip(q_s, semsims, queries): + if q_: + q_ = [utils.clean(q_) if clean else q_] + semsim = self.get_semsim(query, q_) + print(f'INFO: MAIN: {self.get_model_name()}: {query.qid}: {query.q} -> {q_}') + else: + print(f'WARNING: MAIN: {self.get_model_name()}: Refining query [{query.qid}:{query.q}] failed!') + print(traceback.format_exc()) + q_, semsim = query.q, 1 + + query.q_[self.get_model_name()] = (q_, semsim) + return queries + + ''' + Calculates the difference between the original and back-translated query + ''' + def get_semsim(self, q1, q2): + me, you = self.transformer_model.encode([q1, q2]) + return 1 - cosine(me, you) + + def write_queries(self, queries, outfile): + with open(outfile, 'w', encoding='utf-8') as file: + # file.write(f"qid\tq\tq_\tsemsim\n") + for query in queries: + file.write(f"{query.qid}\t{query.q}\t{query.q_[self.get_model_name()][0]}\t{query.q_[self.get_model_name()][1]}\n") + + +if __name__ == "__main__": + qe = AbstractQRefiner() + print(qe.get_model_name()) + print(qe.get_refined_query('International Crime Organization')) diff --git a/src/refinement/refiners/adaponfields.py b/src/refinement/refiners/adaponfields.py new file mode 100644 index 0000000..2abaf23 --- /dev/null +++ b/src/refinement/refiners/adaponfields.py @@ -0,0 +1,163 @@ +import sys +sys.path.extend(['../refinement']) + +import traceback, os, subprocess, nltk, string, math +from bs4 import BeautifulSoup +from nltk.tokenize import word_tokenize +from nltk.stem import PorterStemmer +from collections import Counter +from nltk.corpus import stopwords +from pyserini import analysis, index +import pyserini +from pyserini.search import SimpleSearcher +from pyserini import analysis, index + +from refiners.onfields import OnFields +import utils + +# @article{DBLP:journals/ipm/HeO07, +# author = {Ben He and +# Iadh Ounis}, +# title = {Combining fields for query expansion and adaptive query expansion}, +# journal = {Inf. Process. Manag.}, +# volume = {43}, +# number = {5}, +# pages = {1294--1307}, +# year = {2007}, +# url = {https://doi.org/10.1016/j.ipm.2006.11.002}, +# doi = {10.1016/j.ipm.2006.11.002}, +# timestamp = {Fri, 21 Feb 2020 13:11:30 +0100}, +# biburl = {https://dblp.org/rec/journals/ipm/HeO07.bib}, +# bibsource = {dblp computer science bibliography, https://dblp.org} +# } + +class AdapOnFields(OnFields): + + def __init__(self, ranker, prels, anserini, index, w_t, w_a,corpus_size, collection_tokens, + ext_index, ext_corpus, ext_collection_tokens, ext_w_t, ext_w_a, ext_corpus_size, + replace=False, topn=3, topw=10, adap=False): + OnFields.__init__(self, ranker, prels, anserini, index, w_t, w_a,corpus_size, topn=topn, replace=replace, topw=topw, adap=adap) + + self.collection_tokens = collection_tokens # number of tokens in the collection + + self.ext_index=ext_index + self.ext_corpus=ext_corpus + self.ext_collection_tokens=ext_collection_tokens # number of tokens in the external collection + self.ext_w_t=ext_w_t + self.ext_w_a=ext_w_a + self.ext_corpus_size=ext_corpus_size + + + def get_refined_query(self, q, args): + qid=args[0] + Preferred_expansion=self.avICTF(q) + if Preferred_expansion =="NoExpansionPreferred": + output_weighted_q_dic={} + for terms in q.split(): + output_weighted_q_dic[ps.stem(terms)]=2 + return super().get_refined_query(output_weighted_q_dic) + + elif Preferred_expansion =="InternalExpansionPreferred": + return super().get_refined_query(q, [qid]) + + elif Preferred_expansion =="ExternalExpansionPreferred": + self.adap = True + self.prels = None#when adap is True, no need for prels since it does the retrieval again! + self.index = self.ext_index + self.corpus = self.ext_corpus + self.w_t = self.ext_w_t + self.w_a = self.ext_w_a + self.corpus_size = self.ext_corpus_size + + return super().get_refined_query(q, [qid]) + + def get_model_name(self): + return super().get_model_name().replace('topn{}'.format(self.topn), + 'topn{}.ex{}.{}.{}'.format(self.topn,self.ext_corpus, self.ext_w_t, self.ext_w_a)) + + def write_expanded_queries(self, Qfilename, Q_filename,clean=False): + return super().write_expanded_queries(Qfilename, Q_filename, clean=False) + + def avICTF(self,query): + index_reader = index.IndexReader(self.ext_index) + ql=len(query.split()) + sub_result=1 + for term in query.split(): + try: + df, collection_freq = index_reader.get_term_counts(ps.stem(term.lower())) + except: + collection_freq=1 + df=1 + + if isinstance(collection_freq,int)==False: + collection_freq=1 + df=1 + + try: + sub_result= sub_result * (self.ext_collection_tokens / collection_freq) + except: + sub_result= sub_result * self.ext_collection_tokens + sub_result=math.log2(sub_result) + externalavICTF= (sub_result/ql) + index_reader = index.IndexReader(self.index) + sub_result=1 + for term in query.split(): + try: + df, collection_freq = index_reader.get_term_counts(ps.stem(term.lower())) + except: + collection_freq=1 + df=1 + if isinstance(collection_freq,int)==False: + df=1 + collection_freq=1 + try: + sub_result= sub_result * (self.ext_collection_tokens / collection_freq) + except: + sub_result= sub_result * self.ext_collection_tokens + sub_result=math.log2(sub_result) + internalavICTF = (sub_result/ql) + if internalavICTF < 10 and externalavICTF < 10: + return "NoExpansionPreferred" + elif internalavICTF >= externalavICTF: + return "InternalExpansionPreferred" + elif externalavICTF > internalavICTF: + return "ExternalExpansionPreferred" + + +if __name__ == "__main__": + number_of_tokens_in_collections={'robust04':148000000, + 'gov2' : 17000000000, + 'cw09' : 31000000000, + 'cw12' : 31000000000} + + tuned_weights={'robust04': {'w_t':2.25 , 'w_a':1 }, + 'gov2': {'w_t':4 , 'w_a':0.25 }, + 'cw09': {'w_t': 1, 'w_a': 0}, + 'cw12': {'w_t': 4, 'w_a': 0}} + + total_documents_number = { 'robust04':520000 , + 'gov2' : 25000000, + 'cw09' : 50000000 , + 'cw12': 50000000} + + qe = AdapOnFields(ranker='bm25', + corpus='robust04', + index='../anserini/lucene-index.robust04.pos+docvectors+rawdocs', + prels='./output/robust04/topics.robust04.abstractqueryexpansion.bm25.txt', + anserini='../anserini/', + w_t=tuned_weights['robust04']['w_t'], + w_a=tuned_weights['robust04']['w_a'], + corpus_size=total_documents_number['robust04'], + collection_tokens=number_of_tokens_in_collections['robust04'], + ext_corpus='gov2', + ext_index='../anserini/lucene-index.gov2.pos+docvectors+rawdocs', + ext_prels='./output/gov2/topics.terabyte04.701-750.abstractqueryexpansion.bm25.txt', + ext_collection_tokens = number_of_tokens_in_collections['gov2'], + ext_corpus_size=total_documents_number['gov2'], + ext_w_t= tuned_weights['gov2']['w_t'], + ext_w_a= tuned_weights['gov2']['w_a'], + ) + + print(qe.get_model_name()) + + print(qe.get_refined_query('most dangerous vehicle', [305])) diff --git a/src/refinement/refiners/anchor.py b/src/refinement/refiners/anchor.py new file mode 100644 index 0000000..548ab47 --- /dev/null +++ b/src/refinement/refiners/anchor.py @@ -0,0 +1,69 @@ +import gensim +from gensim.models.callbacks import CallbackAny2Vec +from gensim.models import KeyedVectors +# from rdflib import Graph + +from nltk.stem import PorterStemmer +ps = PorterStemmer() + +import sys, os +sys.path.extend(['../refinement']) + +from refiners.word2vec import Word2Vec +# The anchor texts dataset: +# https://wiki.dbpedia.org/downloads-2016-10 +# http://downloads.dbpedia.org/2016-10/core-i18n/en/anchor_text_en.ttl.bz2 + +class Anchor(Word2Vec): + def __init__(self, anchorfile, vectorfile, topn=3, replace=False): + Word2Vec.__init__(self, vectorfile, topn=topn, replace=replace) + Anchor.anchorfile = anchorfile + + def train(self): + + class AnchorIter: + def __init__(self, anchorfile): + self.anchorfile = anchorfile + def __iter__(self): + for i, line in enumerate(open(self.anchorfile, encoding='utf-8')): + if (i % 10000 == 0 and i > 0): + print('INFO: ANCHOR: {} anchors have been read ...'.format(i)) + s = line.find('> "') + e = line.find('"@en', s) + if s < 1: + continue + anchor_text = line[s + 3:e] + yield [ps.stem(w) for w in anchor_text.lower().split(' ')] + + class EpochLogger(CallbackAny2Vec): + def __init__(self, epoch_count): + self.epoch = 1 + self.epoch_count = epoch_count + def on_epoch_begin(self, model): + print("Epoch {}/{} ...".format(self.epoch, self.epoch_count)) + self.epoch += 1 + anchors = AnchorIter(Anchor.anchorfile) + anchors = [anchor for anchor in AnchorIter(Anchor.anchorfile)]#all in memory at once + model = gensim.models.Word2Vec(anchors, size=300, sg=1, window=2, iter=100, workers=40, min_count=0, callbacks=[EpochLogger(100)]) + model.wv.save(Anchor.vectorfile) + + def get_refined_query(self, q, args=None): + if not Word2Vec.word2vec: + if not os.path.exists(Anchor.vectorfile): + print('INFO: ANCHOR: Pretrained anchor vector file {} does not exist! Training has been started ...'.format(Anchor.vectorfile)) + self.train() + print('INFO: ANCHOR: Loading anchor vectors in {} ...'.format(Anchor.vectorfile)) + Word2Vec.word2vec = gensim.models.KeyedVectors.load(Anchor.vectorfile, mmap='r') + + return super().get_refined_query(q) + +if __name__ == "__main__": + qe = Anchor(anchorfile='../pre/anchor_text_en.ttl', vectorfile='../pre/wiki-anchor-text-en-ttl-300d-100iter.vec') + for i in range(5): + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani actor International Crime Organization')) + + qe.replace = True + for i in range(5): + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani actor International Crime Organization')) diff --git a/src/refinement/refiners/backtranslation.py b/src/refinement/refiners/backtranslation.py new file mode 100644 index 0000000..f905760 --- /dev/null +++ b/src/refinement/refiners/backtranslation.py @@ -0,0 +1,49 @@ +from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM +from src.refinement.refiners.abstractqrefiner import AbstractQRefiner +from src.refinement.refiner_param import backtranslation + + +class BackTranslation(AbstractQRefiner): + def __init__(self, tgt): + AbstractQRefiner.__init__(self) + + # Initialization + self.tgt = tgt + model = AutoModelForSeq2SeqLM.from_pretrained(backtranslation['model_card']) + tokenizer = AutoTokenizer.from_pretrained(backtranslation['model_card']) + + # Translation models + self.translator = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=backtranslation['src_lng'], tgt_lang=self.tgt, max_length=backtranslation['max_length'], device=backtranslation['device']) + self.back_translator = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=self.tgt, tgt_lang=backtranslation['src_lng'], max_length=backtranslation['max_length'], device=backtranslation['device']) + # Model use for calculating semsim + + ''' + Generates the backtranslated query then calculates the semantic similarity of the two queries + ''' + def get_refined_query(self, query, args=None): + translated_query = self.translator(query.q) + back_translated_query = self.back_translator(translated_query[0]['translation_text']) + return back_translated_query[0]['translation_text'] + # return super().get_expanded_query(q, [0]) + + def get_refined_query_batch(self, queries, args=None): + try: + translated_queries = self.translator([query.q for query in queries]) + back_translated_queries = self.back_translator([tq_['translation_text'] for tq_ in translated_queries]) + q_s = [q_['translation_text'] for q_ in back_translated_queries] + except: + q_s = [None] * len(queries) + return q_s + + ''' + Returns the name of the model ('backtranslation) with name of the target language + Example: 'backtranslation_fra_latn' + ''' + def get_model_name(self): + return super().get_model_name() + '_' + self.tgt.lower() + + +if __name__ == "__main__": + qe = BackTranslation() + print(qe.get_model_name()) + print(qe.get_refined_query('This is my pc')) diff --git a/src/refinement/refiners/bertqe.py b/src/refinement/refiners/bertqe.py new file mode 100644 index 0000000..3931542 --- /dev/null +++ b/src/refinement/refiners/bertqe.py @@ -0,0 +1,107 @@ +import sys +sys.path.extend(['../refinement']) +sys.path.extend(['../pygaggle']) + +import pyserini +from pyserini import index +#from pyserini.search import SimpleSearcher +import subprocess, string +import nltk +from bs4 import BeautifulSoup +from nltk.tokenize import word_tokenize +from collections import Counter +from nltk.corpus import stopwords +from pygaggle.rerank.base import Query, Text +from pygaggle.rerank.transformer import MonoT5 +from nltk.tokenize import word_tokenize + +from pygaggle.rerank.transformer import MonoBERT +from pygaggle.rerank.base import hits_to_texts + +from refiners.relevancefeedback import RelevanceFeedback +import utils + +reranker = MonoBERT() + +#@inproceedings{zheng-etal-2020-bert, +# title = "{BERT-QE}: {C}ontextualized {Q}uery {E}xpansion for {D}ocument {R}e-ranking", +# author = "Zheng, Zhi and Hui, Kai and He, Ben and Han, Xianpei and Sun, Le and Yates, Andrew", +# booktitle = "Findings of the Association for Computational Linguistics: EMNLP 2020", +# month = nov, +# year = "2020", +# address = "Online", +# publisher = "Association for Computational Linguistics", +# url = "https://www.aclweb.org/anthology/2020.findings-emnlp.424", +# pages = "4718--4728", +#} + +class BertQE(RelevanceFeedback): + def __init__(self, ranker, prels, anserini, index): + RelevanceFeedback.__init__(self, ranker, prels, anserini, index, topn=10) + self.index_reader = pyserini.index.IndexReader(self.index) + + + def get_refined_query(self, q, args): + q=q.translate(str.maketrans('', '', string.punctuation)) + qid=args[0] + topn_docs = self.get_topn_relevant_docids(qid) + print() + topn_text="" + for docid in topn_docs: + raw_doc=self.index_reader.doc_raw(docid).lower() + raw_doc= ''.join([i if ord(i) < 128 else ' ' for i in raw_doc]) + topn_text= topn_text+ ' ' + raw_doc + + chunk_dic_for_bert=[] + chunks=self.make_chunks(topn_text) + for i in range(len(chunks)): + chunk_dic_for_bert.append([i,chunks[i]]) + + chunk_scores=self.Bert_Score(q,chunk_dic_for_bert) + scores=list(chunk_scores.values()) + norm = [(float(i)-min(scores))/(max(scores)-min(scores)) for i in scores] + normalized_chunks={} + normalized_chunks[q]=1.5 + for i in range(5): + normalized_chunks[list(chunk_scores.keys())[i]]=norm[i] + return super().get_expanded_query(str(normalized_chunks)) + + def write_expanded_queries(self, Qfilename, Q_filename,clean=False): + return super().write_expanded_queries(Qfilename, Q_filename, clean=False) + + def make_chunks(self,raw_doc): + chunks=[] + terms=raw_doc.split() + for i in range(0, len(terms),5 ): + chunk='' + for j in range(i,i+5): + if j < (len(terms)-1): + chunk=chunk+' '+terms[j] + chunks.append(chunk) + return chunks + + def Bert_Score(self,q,doc_dic_for_bert): + chunk_scores={} + query = Query(q) + texts = [ Text(p[1], {'docid': p[0]}, 0) for p in doc_dic_for_bert] + reranked = reranker.rerank(query, texts) + reranked.sort(key=lambda x: x.score, reverse=True) + for i in range(0,10): + chunk_text=reranked[i].text + word_tokens = word_tokenize(chunk_text) + filtered_sentence = [w for w in word_tokens if not w in stop_words] + filtered_sentence = (" ").join(filtered_sentence).translate(str.maketrans('', '', string.punctuation)) + chunk_scores[filtered_sentence]=round(reranked[i].score,3) + #print(f'{i+1:2} {reranked[i].score:.5f} {reranked[i].text}') + return chunk_scores + +if __name__ == "__main__": + + qe = BertQE(ranker='bm25', + prels='./output/robust04/topics.robust04.abstractqueryexpansion.bm25.txt', + anserini='../anserini/', + index='../anserini/lucene-index.robust04.pos+docvectors+rawdocs') + print(qe.get_model_name()) + print(qe.get_expanded_query('International Organized Crime ', [305])) + + diff --git a/src/refinement/refiners/conceptluster.py b/src/refinement/refiners/conceptluster.py new file mode 100644 index 0000000..78f71f2 --- /dev/null +++ b/src/refinement/refiners/conceptluster.py @@ -0,0 +1,50 @@ +import tagme +tagme.GCUBE_TOKEN = "10df41c6-f741-45fc-88dd-9b24b2568a7b" + +import os,sys +sys.path.extend(['../refinement']) + +from refiners.termluster import Termluster +import utils +class Conceptluster(Termluster): + def __init__(self, ranker, prels, anserini, index, topn=5, topw=3): + Termluster.__init__(self, ranker, prels, anserini, index, topn=topn, topw=topw) + + def get_refined_query(self, q, args): + qid = args[0] + list_of_concept_lists = [] + docids = self.get_topn_relevant_docids(qid) + for docid in docids: + doc_text = self.get_document(docid) + concept_list = self.get_concepts(doc_text, score=0.1) + list_of_concept_lists.append(concept_list) + + G, cluster_dict = self.make_graph_document(list_of_concept_lists, min_edge=10) + expanded_query = self.expand_query_concept_cluster(q, G, cluster_dict, k_relevant_words=self.topw) + return super().get_expanded_query(expanded_query) + + def expand_query_concept_cluster(self, q, G, cluster_dict, k_relevant_words): + q += ' ' + ' '.join(self.get_concepts(q, 0.1)) + return super().refined_query_term_cluster(q, G, cluster_dict, k_relevant_words) + + def get_document(self, docid): + command = '\"{}target/appassembler/bin/IndexUtils\" -index \"{}\" -dumpRawDoc \"{}\"'.format(self.anserini, self.index, docid) + stream = os.popen(command) + return stream.read() + + def get_concepts(self, text, score): + concepts = tagme.annotate(text).get_annotations(score) + return list(set([c.entity_title for c in concepts if c.entity_title not in text])) + + +if __name__ == "__main__": + qe = Conceptluster(ranker='bm25', + prels='./output/robust04/topics.robust04.abstractqueryexpansion.bm25.txt', + anserini='../anserini/', + index='../ds/robust04/index-robust04-20191213') + + for i in range(5): + print(qe.get_model_name()) + print(qe.get_expanded_query('HosseinFani International Crime Organization', [301])) + print(qe.get_expanded_query('Agoraphobia', [698])) + print(qe.get_expanded_query('Unsolicited Faxes', [317])) diff --git a/src/refinement/refiners/conceptnet.py b/src/refinement/refiners/conceptnet.py new file mode 100644 index 0000000..f5a16b1 --- /dev/null +++ b/src/refinement/refiners/conceptnet.py @@ -0,0 +1,62 @@ +import requests + +import sys +sys.path.extend(['../refinement']) +from nltk.stem import PorterStemmer + +from refiners.abstractqrefiner import AbstractQRefiner +import utils + +class Conceptnet(AbstractQRefiner): + def __init__(self, replace=False, topn=3): + AbstractQRefiner.__init__(self, replace, topn) + + def get_refined_query(self, q, args=None): + upd_query = utils.get_tokenized_query(q) + res = [] + if not self.replace: + res = [w for w in upd_query] + ps = PorterStemmer() + for q in upd_query: + q_stem = ps.stem(q) + found_flag = False + try: + obj = requests.get('http://api.conceptnet.io/c/en/' + q).json() + except: + if self.replace: + res.append(q) + continue + if len(obj['edges']) < self.topn: + x = len(obj['edges']) + else: + x = self.topn + for i in range(x): + + try: + start_lan = obj['edges'][i]['start']['language'] + end_lan = obj['edges'][i]['end']['language'] + except: + continue + if obj['edges'][i]['start']['language'] != 'en' or obj['edges'][i]['end']['language'] != 'en': + continue + if obj['edges'][i]['start']['label'].lower() == q: + if obj['edges'][i]['end']['label'] not in res and q_stem != ps.stem(obj['edges'][i]['end']['label']): + found_flag = True + res.append(obj['edges'][i]['end']['label']) + elif obj['edges'][i]['end']['label'].lower() == q: + if obj['edges'][i]['start']['label'] not in res and q_stem != ps.stem(obj['edges'][i]['start']['label']): + found_flag = True + res.append(obj['edges'][i]['start']['label']) + if not found_flag and self.replace: + res.append(q) + return super().get_expanded_query(' '.join(res)) + + +if __name__ == "__main__": + qe = Conceptnet() + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani International Crime Organization')) + + qe = Conceptnet(replace=True) + print(qe.get_model_name()) + print(qe.get_refined_query('compost pile')) diff --git a/src/refinement/refiners/docluster.py b/src/refinement/refiners/docluster.py new file mode 100644 index 0000000..7cbfd12 --- /dev/null +++ b/src/refinement/refiners/docluster.py @@ -0,0 +1,127 @@ +import networkx as nx +from networkx.algorithms import community +import math + +import sys +sys.path.extend(['../refinement']) + +from refiners.relevancefeedback import RelevanceFeedback +class Docluster(RelevanceFeedback): + def __init__(self, ranker, prels, anserini, index, topn=10, topw=3): + RelevanceFeedback.__init__(self, ranker, prels, anserini, index, topn=topn) + self.topw = topw + + def get_model_name(self): + return super().get_model_name().replace('topn{}'.format(self.topn),'topn{}.{}'.format(self.topn, self.topw)) + + def getsim(self, tfidf1, tfidf2): + words_doc_id_1 = [] + values_doc_id_1 = [] + + words_doc_id_2 = [] + values_doc_id_2 = [] + + for x in tfidf1.split('\n'): + if not x: + continue + x_splited = x.split() + words_doc_id_1.append(x_splited[0]) + values_doc_id_1.append(int(x_splited[1])) + + for x in tfidf2.split('\n'): + if not x: + continue + x_splited = x.split() + words_doc_id_2.append(x_splited[0]) + values_doc_id_2.append(int(x_splited[1])) + + sum_docs_1_2 = 0 + i = 0 + for word in words_doc_id_1: + try: + index = words_doc_id_2.index(word) + except ValueError: + index = -1 + if index != -1: + sum_docs_1_2 = sum_docs_1_2 + values_doc_id_1[i] * values_doc_id_2[index] + i = i + 1 + + sum_doc_1 = 0 + for j in range(len(values_doc_id_1)): + sum_doc_1 = sum_doc_1 + (values_doc_id_1[j] * values_doc_id_1[j]) + + sum_doc_2 = 0 + for j in range(len(values_doc_id_2)): + sum_doc_2 = sum_doc_2 + (values_doc_id_2[j] * values_doc_id_2[j]) + + if sum_doc_1 == 0 or sum_doc_2 == 0: + return 0 + + result = sum_docs_1_2 / (math.sqrt(sum_doc_1) * math.sqrt(sum_doc_2)) + + return result + + def get_refined_query(self, q, args): + qid = args[0] + selected_words = [] + docids = self.get_topn_relevant_docids(qid) + tfidfs = [] + for docid in docids: + tfidfs.append(self.get_tfidf(docid)) + + G = nx.Graph() + for i in range(len(docids)): + G.add_node(docids[i]) + for j in range(i + 1, len(docids) - 1): + sim = self.getsim(tfidfs[i], tfidfs[j]) + if sim > 0.5: + G.add_weighted_edges_from([(docids[i], docids[j], sim)]) + comp = community.girvan_newman(G) + partitions = tuple(sorted(c) for c in next(comp)) + for partition in partitions: + if len(partition) > 1: + pairlist = [] + for p in partition: + pairlist.append(self.get_top_word(tfidf=tfidfs[docids.index(p)])) + + top_k = self.get_top_k(pairlist, self.topw) + for (word, value) in top_k: + selected_words.append(word) + + query_splited = q.lower().split() + for word in selected_words: + if word.lower() not in query_splited: + query_splited.append(word) + + return super().get_refined_query(' '.join(query_splited)) + + def get_top_k(self, pairlist, k): + output = [] + from_index = 0 + for j in range(min(len(pairlist), k)): + max_value = 0 + max_index = 0 + max_word = "" + for i in range(from_index, len(pairlist)): + (word, value) = pairlist[i] + if value > max_value: + max_value = value + max_word = word + max_index = i + output.append((max_word, max_value)) + temp = pairlist[from_index] + pairlist[from_index] = pairlist[max_index] + pairlist[max_index] = temp + from_index = from_index + 1 + return output + +if __name__ == "__main__": + qe = Docluster(ranker='bm25', + prels='./output/robust04/topics.robust04.abstractqueryexpansion.bm25.txt', + anserini='../anserini/', + index='../ds/robust04/index-robust04-20191213') + for i in range(5): + print(qe.get_model_name()) + # print(qe.get_expanded_query('HosseinFani International Crime Organization', [301])) + # print(qe.get_expanded_query('Agoraphobia', [698])) + print(qe.get_refined_query('Unsolicited Faxes', [317])) diff --git a/src/refinement/refiners/expanders.csv b/src/refinement/refiners/expanders.csv new file mode 100644 index 0000000..875f8bb --- /dev/null +++ b/src/refinement/refiners/expanders.csv @@ -0,0 +1,59 @@ +name, category +stem.krovetz,Stemming_Analysis +stem.lovins,Stemming_Analysis +stem.paicehusk,Stemming_Analysis +stem.porter,Stemming_Analysis +stem.porter2,Stemming_Analysis +stem.sstemmer,Stemming_Analysis +stem.trunc4,Stemming_Analysis +stem.trunc5,Stemming_Analysis +conceptnet.topn3,Semantic_Analysis +conceptnet.topn3.replace,Semantic_Analysis +glove.topn3,Semantic_Analysis +glove.topn3.replace,Semantic_Analysis +sensedisambiguation,Semantic_Analysis +sensedisambiguation.replace,Semantic_Analysis +thesaurus.topn3,Semantic_Analysis +thesaurus.topn3.replace,Semantic_Analysis +word2vec.topn3,Semantic_Analysis +word2vec.topn3.replace,Semantic_Analysis +wordnet.topn3,Semantic_Analysis +wordnet.topn3.replace,Semantic_Analysis +termluster.topn5.3.bm25,Term_Clustering +termluster.topn5.3.qld,Term_Clustering +conceptluster.topn5.3.bm25,Concept_Clustering +conceptluster.topn5.3.qld,Concept_Clustering +anchor.topn3,Anchor_Text +anchor.topn3.replace,Anchor_Text +wiki.topn3,Wikipedia +wiki.topn3.replace,Wikipedia +tagmee.topn3,Wikipedia +tagmee.topn3.replace,Wikipedia +relevancefeedback.topn10.bm25,Top_Documents +relevancefeedback.topn10.qld,Top_Documents +rm3.topn10.10.0.5.bm25,Top_Documents +rm3.topn10.10.0.5.qld,Top_Documents +adaponfields.topn3.exgov2.4.0.25.10.2.25.1.bm25,Top_Documents +adaponfields.topn3.exgov2.4.0.25.10.1.1.bm25,Top_Documents +adaponfields.topn3.exgov2.4.0.25.10.1.1.qld,Top_Documents +onfields.topn3.10.2.25.1.bm25,Top_Documents +onfields.topn3.10.1.1.bm25,Top_Documents +onfields.topn3.10.2.25.1.qld,Top_Documents +onfields.topn3.10.1.1.qld,Top_Documents +onfields.topn3.10.4.0.25.bm25,Top_Documents +onfields.topn3.10.1.0.bm25,Top_Documents +adaponfields.topn3.exgov2.4.0.25.10.2.25.1.qld,Top_Documents +adaponfields.topn3.exgov2.4.0.25.10.1.0.qld,Top_Documents +adaponfields.topn3.exrobust04.2.25.1.10.4.0.25.bm25,Top_Documents +adaponfields.topn3.exrobust04.2.25.1.10.4.0.25.qld,Top_Documents +adaponfields.topn3.exgov2.4.0.25.10.4.0.qld,Top_Documents +onfields.topn3.10.4.0.25.qld,Top_Documents +onfields.topn3.10.4.0.bm25,Top_Documents +adaponfields.topn3.exgov2.4.0.25.10.4.0.bm25,Top_Documents +onfields.topn3.10.4.0.qld,Top_Documents +adaponfields.topn3.exgov2.4.0.25.10.1.0.bm25,Top_Documents +onfields.topn3.10.1.0.qld,Top_Documents +bertqe.topn10.bm25,Top_Documents +bertqe.topn10.qld,Top_Documents +docluster.topn10.3.bm25,Document_Summaries +docluster.topn10.3.qld,Document_Summaries diff --git a/src/refinement/refiners/glove.py b/src/refinement/refiners/glove.py new file mode 100644 index 0000000..db0cc2c --- /dev/null +++ b/src/refinement/refiners/glove.py @@ -0,0 +1,70 @@ +import scipy +from nltk.stem import PorterStemmer +import numpy as np + +import sys +sys.path.extend(['../refinement']) + +from refiners.abstractqrefiner import AbstractQRefiner +import utils + +class Glove(AbstractQRefiner): + def __init__(self, vectorfile, replace=False, topn=3): + AbstractQRefiner.__init__(self, replace, topn) + Glove.vectorfile = vectorfile + Glove.glove = None + + + def get_refined_query(self, q, args=None): + if not Glove.glove: + print('INFO: Glove: Loading word vectors in {} ...'.format(Glove.vectorfile)) + Glove.glove = load_glove_model(Glove.vectorfile) + + upd_query = utils.get_tokenized_query(q) + synonyms = [] + res = [] + if not self.replace: + res = [w for w in upd_query] + ps = PorterStemmer() + for qw in upd_query: + found_flag = False + qw_stem = ps.stem(qw) + if qw.lower() in Glove.glove.keys(): + w = sorted(Glove.glove.keys(), key=lambda word: scipy.spatial.distance.euclidean(Glove.glove[word], Glove.glove[qw])) + w = w[:self.topn] + for u in w: + u_stem = ps.stem(u) + if u_stem != qw_stem: + found_flag = True + res.append(u) + if not found_flag and self.replace: + res.append(qw) + return super().get_refined_query(' '.join(res)) + + +def load_glove_model(glove_file): + with open(glove_file + ".txt", 'r', encoding='utf-8') as f: + model = {} + counter=0 + for line in f: + if counter>0: + split_line = line.split() + word = split_line[0] + embedding = np.array([float(val) for val in split_line[1:]]) + model[word] = embedding + counter+=1 + else: + counter+=1 + return model + +if __name__ == "__main__": + qe = Glove('../pre/glove.6B.300d') + for i in range(5): + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani International Crime Organization')) + qe.replace = True + for i in range(5): + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani International Crime Organization')) + + diff --git a/src/refinement/refiners/onfields.py b/src/refinement/refiners/onfields.py new file mode 100644 index 0000000..71a9578 --- /dev/null +++ b/src/refinement/refiners/onfields.py @@ -0,0 +1,271 @@ +import sys +sys.path.extend(['../refinement']) + +from pyserini.search import SimpleSearcher +import traceback, os, subprocess, nltk, string, math +from bs4 import BeautifulSoup +from nltk.tokenize import word_tokenize +from nltk.stem import PorterStemmer +from collections import Counter +from nltk.corpus import stopwords +from pyserini import analysis, index +import pyserini + +from refiners.relevancefeedback import RelevanceFeedback +import utils + +# @article{DBLP:journals/ipm/HeO07, +# author = {Ben He and +# Iadh Ounis}, +# title = {Combining fields for query expansion and adaptive query expansion}, +# journal = {Inf. Process. Manag.}, +# volume = {43}, +# number = {5}, +# pages = {1294--1307}, +# year = {2007}, +# url = {https://doi.org/10.1016/j.ipm.2006.11.002}, +# doi = {10.1016/j.ipm.2006.11.002}, +# timestamp = {Fri, 21 Feb 2020 13:11:30 +0100}, +# biburl = {https://dblp.org/rec/journals/ipm/HeO07.bib}, +# bibsource = {dblp computer science bibliography, https://dblp.org} +# } + +class OnFields(RelevanceFeedback): + def __init__(self, ranker, prels, anserini, index, w_t, w_a, corpus_size, replace=False, topn=3, topw=10,adap=False): + RelevanceFeedback.__init__(self, ranker, prels, anserini, index, topn=topn) + self.index_reader = pyserini.index.IndexReader(self.index) + self.topw=10 + self.adap=adap + self.w_t = w_t # weight for title field + self.w_a = w_a # weight for anchor field + self.corpus_size=corpus_size #total number of documents in the collection + + def get_refined_query(self, q, args): + q=q.translate(str.maketrans('', '', string.punctuation)) + qid=args[0] + if self.adap == False: + topn_docs = self.get_topn_relevant_docids(qid) + elif self.adap == True: + topn_docs=self.retrieve_and_get_topn_relevant_docids(ps.stem(q.lower())) + topn_title='' + topn_body='' + topn_anchor='' + for docid in topn_docs: + raw_doc=self.extract_raw_documents(docid) + raw_doc=raw_doc.lower() + raw_doc= ''.join([i if ord(i) < 128 else ' ' for i in raw_doc]) + + title='' + body=raw_doc + anchor='' + try: + title=self.extract_specific_field(raw_doc,'title') + except: + # 'title' field do not exist + pass + try: + body=self.extract_specific_field(raw_doc,'body') + if body =='': + body=raw_doc + except: + # 'body' field do not exist + pass + try: + anchor=self.extract_specific_field(raw_doc,'anchor') + except: + #'anchor' field do not exist + pass + + topn_title='{} {}'.format(topn_title ,title) + topn_anchor = '{} {}'.format(topn_anchor,anchor) + topn_body='{} {}'.format(topn_body,body) + topn_title=topn_title.translate(str.maketrans('', '', string.punctuation)) + topn_anchor=topn_anchor.translate(str.maketrans('', '', string.punctuation)) + topn_body=topn_body.translate(str.maketrans('', '', string.punctuation)) + all_topn_docs= '{} {} {}'.format(topn_body, topn_anchor ,topn_title) + tfx = self.term_weighting(topn_title,topn_anchor,topn_body) + tfx=dict(sorted(tfx.items(), key=lambda x: x[1])[::-1]) + w_t_dic={} + c=0 + + + for term in tfx.keys(): + if term.isalnum(): + + c=c+1 + collection_freq =1 + try: + df, collection_freq = self.index_reader.get_term_counts(term) + except: + # term do not exist in the collection + pass + if collection_freq==0 or collection_freq==None: + collection_freq=1 + + P_n = collection_freq / self.corpus_size + try: + term_weight= tfx[term] * math.log2( (1 + P_n ) / P_n) + math.log2( 1 + P_n) + w_t_dic[term]=term_weight + except: + + pass + + + sorted_term_weights=dict(sorted(w_t_dic.items(), key=lambda x: x[1])[::-1]) + counter=0 + top_n_informative_words={} + for keys,values in sorted_term_weights.items(): + counter=counter+1 + top_n_informative_words[keys]=values + if counter>self.topw: + break + + expanded_term_freq= {} + for keys,values in top_n_informative_words.items(): + expanded_term_freq[keys]=all_topn_docs.count(keys) + + for keys,values in top_n_informative_words.items(): + part_A = expanded_term_freq[keys] /max(expanded_term_freq.values()) + part_B = top_n_informative_words[keys] / max(top_n_informative_words.values()) + top_n_informative_words[keys]= round(part_A+part_B,3) + + for original_q_term in q.lower().split(): + top_n_informative_words[ps.stem(original_q_term)]=2 + + top_n_informative_words=dict(sorted(top_n_informative_words.items(), key=lambda x: x[1])[::-1]) + return super().get_refined_query(str(top_n_informative_words)) + + def get_model_name(self): + return super().get_model_name().replace('topn{}'.format(self.topn), + 'topn{}.{}.{}.{}'.format(self.topn, self.topw, self.w_t, self.w_a)) + + def write_expanded_queries(self, Qfilename, Q_filename,clean=False): + return super().write_expanded_queries(Qfilename, Q_filename, clean=False) + + def extract_raw_documents(self,docid): + index_address=self.index + anserini_address=self.anserini + cmd = '\"{}/target/appassembler/bin/IndexUtils\" -index \"{}\" -dumpRawDoc \"{}\"'.format(anserini_address,index_address,docid) + output = subprocess.check_output(cmd, shell=True) + return (output.decode('utf-8')) + + def extract_specific_field(self,raw_document,field): + soup = BeautifulSoup(raw_document) # txt is simply the a string with your XML file + title_out='' + body_out='' + anchor_out='' + if field=='title': + try: + title= soup.find('headline') + title_out=title.text + except: + if title_out=='': + try: + title = soup.find('title') + title_out=title.text + except: + # document had not 'title' + pass + + return title_out + + if field == 'body': + if '<html>' not in raw_document: + pageText = soup.findAll(text=True) + body_out= (' '.join(pageText)) + else: + bodies= soup.find_all('body') + for b in bodies: + try: + body_out='{} {}'.format(body_out,b.text.strip()) + except: + # no 'body' field in the document + pass + return body_out + + if field=='anchor': + for link in soup.findAll('a'): + if link.string != None: + anchor_out='{} {}'.format(anchor_out,link.string) + return anchor_out + + + def term_weighting(self,topn_title,topn_anchor,topn_body): + # w_t and w_a is tuned for all the copora ( should be tuned for future corpora as well) + + w_b = 1 + topn_title=topn_title.translate(str.maketrans('', '', string.punctuation)) + topn_body=topn_body.translate(str.maketrans('', '', string.punctuation)) + topn_anchor=topn_anchor.translate(str.maketrans('', '', string.punctuation)) + + title_tokens = word_tokenize(topn_title) + body_tokens= word_tokenize(topn_body) + anchor_tokens= word_tokenize(topn_anchor) + + filtered_words_title = [ps.stem(word) for word in title_tokens if word not in stop_words] + filtered_words_body = [ps.stem(word) for word in body_tokens if word not in stop_words] + filtered_words_anchor = [ps.stem(word) for word in anchor_tokens if word not in stop_words] + + term_freq_title = dict(Counter(filtered_words_title)) + term_freq_body = dict(Counter(filtered_words_body)) + term_freq_anchor = dict(Counter(filtered_words_anchor)) + + term_weights={} + for term in list(set(filtered_words_title)): + if term not in term_weights.keys(): + term_weights[term]=0 + term_weights[term]=term_weights[term] + term_freq_title[term] * self.w_t + + for term in list(set(filtered_words_body)): + if term_freq_body[term] : + if term not in term_weights.keys(): + term_weights[term]=0 + term_weights[term]=term_weights[term] + term_freq_body[term] * w_b + + for term in list(set(filtered_words_anchor)): + if term not in term_weights.keys(): + term_weights[term]=0 + term_weights[term]=term_weights[term] + term_freq_anchor[term] * self.w_a + + return term_weights + + def retrieve_and_get_topn_relevant_docids(self, q): + relevant_documents = [] + searcher = SimpleSearcher(self.index) + + if self.ranker =='bm25': + searcher.set_bm25() + elif self.ranker=='qld': + searcher.set_qld() + hits = searcher.search(q) + for i in range(0, self.topn): + relevant_documents.append(hits[i].docid) + return relevant_documents + + +if __name__ == "__main__": + tuned_weights={'robust04': {'w_t':2.25 , 'w_a':1 }, + 'gov2': {'w_t':4 , 'w_a':0.25 }, + 'cw09': {'w_t': 1, 'w_a': 0}, + 'cw12': {'w_t': 4, 'w_a': 0}} + + total_documents_number = { 'robust04':520000 , + 'gov2' : 25000000, + 'cw09' : 50000000 , + 'cw12': 50000000} + + + + qe = OnFields(ranker='bm25', + prels='./output/robust04/topics.robust04.abstractqueryexpansion.bm25.txt', + anserini='../anserini/', + index='./ds/robust04/lucene-index.robust04.pos+docvectors+rawdocs', + corpus='robust04', + w_t=tuned_weights['robust04']['w_t'], + w_a=tuned_weights['robust04']['w_a'], + corpus_size= total_documents_number['robust04']) + + print(qe.get_model_name()) + print(qe.get_refined_query('Most Dangerous Vehicles', [305])) + + diff --git a/src/refinement/refiners/relevancefeedback.py b/src/refinement/refiners/relevancefeedback.py new file mode 100644 index 0000000..a971ec0 --- /dev/null +++ b/src/refinement/refiners/relevancefeedback.py @@ -0,0 +1,86 @@ +import os +import sys +sys.path.extend(['../refinement']) + +from refiners.abstractqrefiner import AbstractQRefiner +class RelevanceFeedback(AbstractQRefiner): + def __init__(self, ranker, prels, anserini, index, topn=10): + AbstractQRefiner.__init__(self, replace=False, topn=topn) + self.prels = prels + self.f = None + self.anserini = anserini + self.index = index + self.ranker = ranker + + def get_model_name(self): + return super().get_model_name() + '.' + self.ranker + + def get_refined_query(self, q, args): + qid = args[0] + selected_words = [] + docids = self.get_topn_relevant_docids(qid) + for docid in docids: + tfidf = self.get_tfidf(docid) + top_word, _ = self.get_top_word(tfidf) + selected_words.append(top_word) + + query_splited = q.lower().split() + for word in selected_words: + if word.lower() not in query_splited: + query_splited.append(word) + + return super().get_refined_query(' '.join(query_splited)) + + def get_topn_relevant_docids(self, qid): + relevant_documents = [] + if not self.f: + self.f = open(self.prels, "r", encoding='utf-8') + self.f.seek(0) + i = 0 + for x in self.f: + x_splited = x.split() + try : + if (int(x_splited[0]) == qid or x_splited[0] == qid): + relevant_documents.append(x_splited[2]) + i = i+1 + if i >= self.topn: + break + except: + if ('dbpedia' in self.prels and x_splited[0] == qid): + relevant_documents.append(x_splited[2]) + i = i+1 + if i >= self.topn: + break + return super().get_refined_query(relevant_documents) + + def get_tfidf(self, docid): + #command = "target/appassembler/bin/IndexUtils -index lucene-index.robust04.pos+docvectors+rawdocs -dumpDocVector FBIS4-40260 -docVectorWeight TF_IDF " + cli_cmd = '\"{}target/appassembler/bin/IndexUtils\" -index \"{}\" -dumpDocVector \"{}\" -docVectorWeight TF_IDF'.format(self.anserini, self.index, docid) + stream = os.popen(cli_cmd) + return stream.read() + + def get_top_word(self, tfidf): + i = 0 + max = 0 + top_word = "" + for x in tfidf.split('\n'): + if not x: + continue + x_splited = x.split() + word = x_splited[0] + value = int(x_splited[1]) + if value > max: + top_word = word + max = value + + return top_word, max + + +if __name__ == "__main__": + qe = RelevanceFeedback(ranker='bm25', + prels='./output/robust04/topics.robust04.abstractqueryexpansion.bm25.txt', + anserini='../anserini/', + index='../ds/robust04/index-robust04-20191213') + for i in range(5): + print(qe.get_model_name()) + print(qe.get_refined_query('Agoraphobia', [698])) diff --git a/src/refinement/refiners/rm3.py b/src/refinement/refiners/rm3.py new file mode 100644 index 0000000..5626ab7 --- /dev/null +++ b/src/refinement/refiners/rm3.py @@ -0,0 +1,53 @@ +import os, sys, re, io +sys.path.extend(['../refinement']) + +from pyserini.search import SimpleSearcher +import utils #although it's not used here, this is required! +from refiners.relevancefeedback import RelevanceFeedback + +class RM3(RelevanceFeedback): + def __init__(self, ranker, index, topn=10, topw=10, original_q_w=0.5): + RelevanceFeedback.__init__(self, ranker=ranker, prels=None, anserini=None, index=index, topn=topn) + self.topw=topw + self.searcher = SimpleSearcher(index) + self.ranker=ranker + self.original_q_w=original_q_w + + + def get_refined_query(self, q, args=None): + + if self.ranker=='bm25': + self.searcher.set_bm25() + elif self.ranker=='qld': + self.searcher.set_qld() + + self.searcher.set_rm3(fb_terms=self.topw, fb_docs=self.topn, original_query_weight=self.original_q_w, rm3_output_query=True) + + f = io.BytesIO() + with utils.stdout_redirector_2_stream(f): + self.searcher.search(q) + print('RM3 Log: {0}"'.format(f.getvalue().decode('utf-8'))) + q_= self.parse_rm3_log(f.getvalue().decode('utf-8')) + + # with stdout_redirected(to='rm3.log'): + # self.searcher.search(q) + # rm3_log = open('rm3.log', 'r').read() + # q_ = self.parse_rm3_log(rm3_log) + # os.remove("rm3.log") + + return super().get_refined_query(q_) + + def get_model_name(self): + return super().get_model_name().replace('topn{}'.format(self.topn), 'topn{}.{}.{}'.format(self.topn, self.topw, self.original_q_w)) + + def parse_rm3_log(self,rm3_log): + new_q=rm3_log.split('Running new query:')[1] + new_q_clean=re.findall('\(([^)]+)', new_q) + new_q_clean=" ".join(new_q_clean) + return new_q_clean + + +if __name__ == "__main__": + qe = RM3(index='../ds/robust04/index-robust04-20191213/' ) + print(qe.get_model_name()) + print(qe.get_refined_query('International Crime Organization')) diff --git a/src/refinement/refiners/sensedisambiguation.py b/src/refinement/refiners/sensedisambiguation.py new file mode 100644 index 0000000..d2a5301 --- /dev/null +++ b/src/refinement/refiners/sensedisambiguation.py @@ -0,0 +1,36 @@ +from pywsd import disambiguate + +import sys +sys.path.extend(['../refinement']) + +from refiners.abstractqrefiner import AbstractQRefiner + +class SenseDisambiguation(AbstractQRefiner): + def __init__(self, replace=False): + AbstractQRefiner.__init__(self, replace) + + def get_refined_query(self, q, args=None): + res=[] + disamb = disambiguate(q) + for i,t in enumerate(disamb): + if t[1] is not None: + if not self.replace: + res.append(t[0]) + x=t[1].name().split('.')[0].split('_') + if t[0].lower() != (' '.join(x)).lower() or self.replace: + res.append(' '.join(x)) + else: + res.append(t[0]) + return super().get_refined_query(' '.join(res)) + + +if __name__ == "__main__": + qe = SenseDisambiguation() + print(qe.get_model_name()) + print(qe.get_refined_query('obama family tree')) + print(qe.get_refined_query('HosseinFani International Crime Organization')) + + qe = SenseDisambiguation(replace=True) + print(qe.get_model_name()) + print(qe.get_refined_query('maryland department of natural resources')) + print(qe.get_refined_query('HosseinFani International Crime Organization')) diff --git a/src/refinement/refiners/stem.py b/src/refinement/refiners/stem.py new file mode 100644 index 0000000..6f3ddbf --- /dev/null +++ b/src/refinement/refiners/stem.py @@ -0,0 +1,49 @@ +import sys +sys.path.extend(['../refinement']) + +from src.refinement.refiners.abstractqrefiner import AbstractQRefiner +from src.refinement.stemmers.abstractstemmer import AbstractStemmer +class Stem(AbstractQRefiner): + def __init__(self, stemmer:AbstractStemmer): + AbstractQRefiner.__init__(self, replace=False) + self.stemmer = stemmer + + def get_model_name(self): + return super().get_model_name() + '.' + self.stemmer.basename + + def get_refined_query(self, q, args=None): + return super().get_refined_query(self.stemmer.stem_query(q)) + +if __name__ == "__main__": + from src.refinement.stemmers.krovetz import KrovetzStemmer + qe = Stem(KrovetzStemmer(jarfile='stemmers/kstem-3.4.jar')) + print(qe.get_model_name()) + print(qe.get_refined_query('International Crime Organization')) + from refinement.stemmers.lovins import LovinsStemmer + qe = Stem(LovinsStemmer()) + print(qe.get_model_name()) + print(qe.get_refined_query('International Crime Organization')) + from refinement.stemmers.paicehusk import PaiceHuskStemmer + qe = Stem(PaiceHuskStemmer()) + print(qe.get_model_name()) + print(qe.get_refined_query('International Crime Organization')) + from refinement.stemmers.porter import PorterStemmer + qe = Stem(PorterStemmer()) + print(qe.get_model_name()) + print(qe.get_refined_query('International Crime Organization')) + from refinement.stemmers.porter2 import Porter2Stemmer + qe = Stem(Porter2Stemmer()) + print(qe.get_model_name()) + print(qe.get_refined_query('International Crime Organization')) + from refinement.stemmers.sstemmer import SRemovalStemmer + qe = Stem(SRemovalStemmer()) + print(qe.get_model_name()) + print(qe.get_refined_query('International Crime Organization')) + from refinement.stemmers.trunc4 import Trunc4Stemmer + qe = Stem(Trunc4Stemmer()) + print(qe.get_model_name()) + print(qe.get_refined_query('International Crime Organization')) + from refinement.stemmers.trunc5 import Trunc5Stemmer + qe = Stem(Trunc5Stemmer()) + print(qe.get_model_name()) + print(qe.get_refined_query('International Crime Organization')) diff --git a/src/refinement/refiners/tagmee.py b/src/refinement/refiners/tagmee.py new file mode 100644 index 0000000..3b577a3 --- /dev/null +++ b/src/refinement/refiners/tagmee.py @@ -0,0 +1,43 @@ +import tagme +tagme.GCUBE_TOKEN = "10df41c6-f741-45fc-88dd-9b24b2568a7b" + +import sys, os +sys.path.extend(['../refinement']) + +from refiners.abstractqrefiner import AbstractQRefiner +import utils + +class Tagmee(AbstractQRefiner): + def __init__(self, topn=3, replace=False): + AbstractQRefiner.__init__(self, replace, topn) + + def get_concepts(self, text, score): + concepts = tagme.annotate(text).get_annotations(score) + res = [] + for ann in concepts: + res.append(ann.entity_title) + return res + + def get_refined_query(self, q, args=None): + + query_concepts = self.get_concepts(q, 0.1) + upd_query = utils.get_tokenized_query(q) + res = [] + if not self.replace: + res = [w for w in upd_query] + for c in query_concepts: + res.append(c) + return super().get_refined_query(' '.join(res)) + + +if __name__ == "__main__": + qe = Tagmee() + + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani actor International Crime Organization')) + + + qe = Tagmee(replace=True) + print(qe.get_model_name()) + print(qe.get_refined_query('Magnetic Levitation-Maglev')) + diff --git a/src/refinement/refiners/termluster.py b/src/refinement/refiners/termluster.py new file mode 100644 index 0000000..31806bd --- /dev/null +++ b/src/refinement/refiners/termluster.py @@ -0,0 +1,111 @@ +import networkx as nx +from collections import defaultdict +from community import community_louvain +from nltk.stem import PorterStemmer + +import sys +sys.path.extend(['../refinement']) + +from refiners.relevancefeedback import RelevanceFeedback +import utils +class Termluster(RelevanceFeedback): + def __init__(self, ranker, prels, anserini, index, topn=5, topw=3): + RelevanceFeedback.__init__(self, ranker, prels, anserini, index, topn=topn) + self.topw = topw + + def get_model_name(self): + return super().get_model_name().replace('topn{}'.format(self.topn),'topn{}.{}'.format(self.topn, self.topw)) + + def get_refined_query(self, q, args): + qid = args[0] + list_of_word_lists = [] + docids = self.get_topn_relevant_docids(qid) + for docid in docids: + tfidf = self.get_tfidf(docid) + list_of_word_lists.append(self.get_list_of_words(tfidf, threshold=2)) + + G, cluster_dict = self.make_graph_document(list_of_word_lists, min_edge=4) + + # add three relevant words from each cluster for each query word + refined_query = self.refined_query_term_cluster(q, G, cluster_dict, k_relevant_words=self.topw) + + return super().get_refined_query(refined_query) + + def make_graph_document(self, list_s, min_edge): + G = nx.Graph() + counter = 1 + for s in list_s: + for i in range(len(s) - 1): + j = i + 1 + while j < len(s): + if (s[i], s[j]) in G.edges(): + G[s[i]][s[j]]['weight'] += 1 + elif (s[j], s[i]) in G.edges(): + G[s[j]][s[i]]['weight'] += 1 + else: + G.add_weighted_edges_from([(s[i], s[j], 1)]) + j += 1 + counter += 1 + G_TH = self.remove_nodes_from_graph(G, min_edge=min_edge) + clusters_dict = self.get_the_clusters(G_TH) + return G, clusters_dict + + def remove_nodes_from_graph(self, G, min_edge): + G = G.copy() + for n in G.copy().edges(data=True): + if n[2]['weight'] < min_edge: + G.remove_edge(n[0], n[1]) + G.remove_nodes_from(list(nx.isolates(G))) + return G + + def get_the_clusters(self, G): + clusters = community_louvain.best_partition(G) + clusters_dic = defaultdict(list) + for key, value in clusters.items(): + clusters_dic[value].append(key) + return clusters_dic + + def refined_query_term_cluster(self, q, G, cluster_dict, k_relevant_words): + upd_query = utils.get_tokenized_query(q) + porter = PorterStemmer() + res = [w for w in upd_query] + for qw in upd_query: + counter = 0 + for cluster in cluster_dict.values(): + if qw in cluster or porter.stem(qw) in cluster: + list_neighbors = [i for i in cluster if (i != qw and i != porter.stem(qw))] + counter += 1 + break + if counter == 0: + continue + weight_list = [] + for i in list_neighbors: + weight_list.append((i, G.edges[(qw, i)]['weight'] if (qw, i) in G.edges else (porter.stem(qw), i))) + final_res = sorted(weight_list, key=lambda x: x[1], reverse=True)[:k_relevant_words] + for u, v in final_res: + res.append(u) + return ' '.join(res) + + def get_list_of_words(self, tfidf, threshold): + list = [] + for x in tfidf.split('\n'): + if not x: + continue + x_splited = x.split() + w = x_splited[0] + value = int(x_splited[1]) + if not (w.isdigit()) and w not in stop_words and len(w) > 2 and value > threshold: + list.append(x_splited[0]) + + return list + +if __name__ == "__main__": + qe = Termluster(ranker='bm25', + prels='./output/robust04/topics.robust04.abstractqueryexpansion.bm25.txt', + anserini='../anserini/', + index='../ds/robust04/index-robust04-20191213') + for i in range(5): + print(qe.get_model_name()) + # print(qe.get_expanded_query('HosseinFani International Crime Organization', [301])) + # print(qe.get_expanded_query('Agoraphobia', [698])) + print(qe.get_refined_query('Unsolicited Faxes', [317])) diff --git a/src/refinement/refiners/thesaurus.py b/src/refinement/refiners/thesaurus.py new file mode 100644 index 0000000..5734322 --- /dev/null +++ b/src/refinement/refiners/thesaurus.py @@ -0,0 +1,79 @@ +from nltk.corpus import wordnet +import urllib, traceback +from urllib.request import urlopen +from bs4 import BeautifulSoup + +import sys +sys.path.extend(['../refinement']) + +from refiners.abstractqrefiner import AbstractQRefiner +import utils + +class Thesaurus(AbstractQRefiner): + def __init__(self, replace=False, topn=3): + AbstractQRefiner.__init__(self, replace, topn) + + def get_refined_query(self, q, args=None): + pos_dict = {'n': 'noun', 'v': 'verb', 'a': 'adjective', 's': 'satellite adj', 'r': 'adverb'} + upd_query = utils.get_tokenized_query(q) + q_ = [] + if not self.replace: + q_ = [w for w in upd_query] + for w in upd_query: + found_flag = False + if utils.valid(w): + pos = wordnet.synsets(w)[0].pos() if wordnet.synsets(w) else 'n' + syn = self.get_synonym(w, pos_dict[pos]) + if not syn and self.replace: + q_.append(w) + else: + q_.append(' '.join(syn)) + + return super().get_refined_query(' '.join(q_)) + + + def get_synonym(self, word, pos="noun"): + try: + if pos == "noun": + response = urlopen('http://www.thesaurus.com/browse/{}/noun'.format(word)) + # print(response) + elif pos == "verb": + response = urlopen('http://www.thesaurus.com/browse/{}/verb'.format(word)) + elif pos == "adjective": + response = urlopen('http://www.thesaurus.com/browse/{}/adjective'.format(word)) + else: + # raise PosTagError('invalid pos tag: {}, valid POS tags: {{noun,verb,adj}}'.format(pos)) + print('WARNING: Thesaurus: Invalid pos tag: {}'.format(pos)) + return [] + html = response.read().decode('utf-8') + soup = BeautifulSoup(html, 'lxml') + counter=0 + result = [] + if len(soup.findAll('ul', {'class': "css-1lc0dpe et6tpn80"})) > 0: + for s in str(soup.findAll('ul',{'class':"css-1lc0dpe et6tpn80"})[0]).split('href'): + if counter < self.topn: + counter+=1 + start_index=s.index('>') + end_index=s.index('<', start_index + 1) + result.append(s[start_index+1:end_index]) + return result + except urllib.error.HTTPError as err: + if err.code == 404: + return [] + except urllib.error.URLError: + print("No Internet Connection") + return [] + except: + print('WARNING: Thesaurus: Exception has been raised!') + print(traceback.format_exc()) + return [] + + +if __name__ == "__main__": + qe = Thesaurus() + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani International Crime Organization quickly')) + + qe = Thesaurus(replace=True) + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani International Crime Organization')) diff --git a/src/refinement/refiners/wiki.py b/src/refinement/refiners/wiki.py new file mode 100644 index 0000000..5108159 --- /dev/null +++ b/src/refinement/refiners/wiki.py @@ -0,0 +1,82 @@ +import gensim +import tagme +tagme.GCUBE_TOKEN = "10df41c6-f741-45fc-88dd-9b24b2568a7b" + +import sys, os +sys.path.extend(['../refinement']) + +# @inproceedings{DBLP:conf/coling/LiZTHIS16, +# author = {Yuezhang Li and +# Ronghuo Zheng and +# Tian Tian and +# Zhiting Hu and +# Rahul Iyer and +# Katia P. Sycara}, +# editor = {Nicoletta Calzolari and +# Yuji Matsumoto and +# Rashmi Prasad}, +# title = {Joint Embedding of Hierarchical Categories and Entities for Concept +# Categorization and Dataless Classification}, +# booktitle = {{COLING} 2016, 26th International Conference on Computational Linguistics, +# Proceedings of the Conference: Technical Papers, December 11-16, 2016, +# Osaka, Japan}, +# pages = {2678--2688}, +# publisher = {{ACL}}, +# year = {2016}, +# url = {https://www.aclweb.org/anthology/C16-1252/}, +# timestamp = {Mon, 16 Sep 2019 17:08:53 +0200}, +# biburl = {https://dblp.org/rec/conf/coling/LiZTHIS16.bib}, +# bibsource = {dblp computer science bibliography, https://dblp.org} +# } + +import utils +from refiners.word2vec import Word2Vec + +class Wiki(Word2Vec): + def __init__(self, vectorfile, topn=3, replace=False): + Word2Vec.__init__(self, vectorfile, topn=topn, replace=replace) + + def get_concepts(self, text, score): + concepts = tagme.annotate(text).get_annotations(score) + res = [] + for ann in concepts: + res.append(ann.entity_title) + return res + + def get_refined_query(self, q, args=None): + + if not Word2Vec.word2vec: + print('INFO: Word2Vec: Loading word vectors in {} ...'.format(Word2Vec.vectorfile)) + Word2Vec.word2vec = gensim.models.KeyedVectors.load(Word2Vec.vectorfile) + + query_concepts = self.get_concepts(q, 0.1) + upd_query = utils.get_tokenized_query(q) + res = [] + if not self.replace: + res = [w for w in upd_query] + for c in query_concepts: + c_lower_e = "e_" + c.replace(" ", "_").lower() + if c_lower_e in Word2Vec.word2vec.vocab: + w = Word2Vec.word2vec.most_similar(positive=[c_lower_e], topn=self.topn) + for u, v in w: + if u.startswith("e_"): + u = u.replace("e_", "") + elif u.startswith("c_"): + u = u.replace("c_", "") + res.append(u.replace("_", " ")) + + res.append(c) + return super().get_refined_query(' '.join(res)) + + +if __name__ == "__main__": + + qe = Wiki(vectorfile='../pre/temp_model_Wiki') + for i in range(5): + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani actor International Crime Organization')) + + qe.replace = True + for i in range(5): + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani actor International Crime Organization')) diff --git a/src/refinement/refiners/word2vec.py b/src/refinement/refiners/word2vec.py new file mode 100644 index 0000000..dffa864 --- /dev/null +++ b/src/refinement/refiners/word2vec.py @@ -0,0 +1,51 @@ +import gensim +from nltk.stem import PorterStemmer + +import sys +sys.path.extend(['../refinement']) + +from refiners.abstractqrefiner import AbstractQRefiner +import utils + +class Word2Vec(AbstractQRefiner): + def __init__(self, vectorfile, replace=False, topn=3): + AbstractQRefiner.__init__(self, replace, topn) + Word2Vec.vectorfile = vectorfile + Word2Vec.word2vec = None + + def get_refined_query(self, q, args=None): + if not Word2Vec.word2vec: + print('INFO: Word2Vec: Loading word vectors in {} ...'.format(Word2Vec.vectorfile)) + Word2Vec.word2vec = gensim.models.KeyedVectors.load_word2vec_format(Word2Vec.vectorfile) + + upd_query = utils.get_tokenized_query(q) + synonyms = [] + res = [] + if not self.replace: + res = [w for w in upd_query] + ps = PorterStemmer() + for qw in upd_query: + found_flag = False + qw_stem = ps.stem(qw) + if qw in Word2Vec.word2vec.key_to_index: #in gensim 4.0 vocab change to key_to_index + w = Word2Vec.word2vec.most_similar(positive=[qw], topn=self.topn) + for u,v in w: + u_stem=ps.stem(u) + if u_stem!=qw_stem: + found_flag = True + res.append(u) + if not found_flag and self.replace: + res.append(qw) + return super().get_refined_query(' '.join(res)) + + +if __name__ == "__main__": + qe = Word2Vec('../pre/wiki-news-300d-1M.vec') + for i in range(5): + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani International Crime Organization')) + + qe.replace = True + for i in range(5): + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani International Crime Organization')) diff --git a/src/refinement/refiners/wordnet.py b/src/refinement/refiners/wordnet.py new file mode 100644 index 0000000..73087e8 --- /dev/null +++ b/src/refinement/refiners/wordnet.py @@ -0,0 +1,48 @@ +from nltk.corpus import wordnet +from nltk.stem import PorterStemmer + +import sys +sys.path.extend(['../refinement']) + +from refiners.abstractqrefiner import AbstractQRefiner +import utils + +class Wordnet(AbstractQRefiner): + def __init__(self, replace=False, topn=3): + AbstractQRefiner.__init__(self, replace, topn) + + def get_refined_query(self, q, args=None): + upd_query = utils.get_tokenized_query(q) + ps = PorterStemmer() + synonyms =[] + res = [] + if not self.replace: + res=[w for w in upd_query] + for w in upd_query: + found_flag = False + w_stem=ps.stem(w) + for syn in wordnet.synsets(w): + for l in syn.lemmas(): + synonyms.append(l.name()) + synonyms=list(set(synonyms)) + synonyms=synonyms[:self.topn] + for s in synonyms: + s_stem=ps.stem(s) + if s_stem!=w_stem: + found_flag = True + res.append(s) + synonyms=[] + + if not found_flag and self.replace: + res.append(w) + return super().get_refined_query(' '.join(res)) + + +if __name__ == "__main__": + qe = Wordnet() + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani International Crime Organization')) + + qe = Wordnet(replace=True) + print(qe.get_model_name()) + print(qe.get_refined_query('HosseinFani International Crime Organization')) diff --git a/src/refinement/stemmers/__init__.py b/src/refinement/stemmers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/refinement/stemmers/abstractstemmer.py b/src/refinement/stemmers/abstractstemmer.py new file mode 100644 index 0000000..db71a76 --- /dev/null +++ b/src/refinement/stemmers/abstractstemmer.py @@ -0,0 +1,63 @@ +#!/usr/bin/python +import gzip +import codecs +from collections import defaultdict +from nltk.tokenize import WordPunctTokenizer +import re +import sys + +from src.refinement import utils +class AbstractStemmer(object): + def __init__(self): + super(AbstractStemmer, self).__init__() + self.tokenizer = WordPunctTokenizer() + self.vocab = set() + self.basename = 'nostemmer' + + def stem_query(self, q): + # isword = re.compile('[a-z0-9]+') + q = utils.clean(q) + curr_words = self.tokenizer.tokenize(q) + clean_words = [word.lower() for word in curr_words] + processed_words = self.process(clean_words) + self.vocab.update(processed_words) + return ' '.join(processed_words) + + def stem(self, files): + # We write files to a -[stemmer].txt file + filename_mod = files[0].split('.')[0] + wf = codecs.open('{1}-{0}.txt'.format(self.basename, filename_mod), 'w', encoding='utf-8') + isword = re.compile('[a-z0-9]+') + + # We can work with both gzip and non-gzip + for fname in files: + if fname.endswith('gz'): + f = gzip.open(fname, 'r') + else: + f = open(fname) + for no, line in enumerate(f): + if isinstance(line, bytes): + line = line.decode('utf-8') + # We drop empty lines + if len(line.strip()) == 0: + continue + + # Clean and process words + curr_words = self.tokenizer.tokenize(line) + clean_words = [word.lower() for word in curr_words] + processed_words = self.process(clean_words) + + # Keep track of vocab size + self.vocab.update(processed_words) + + # We output according to the one-doc-per-line format for Mallet + current_line = u' '.join(processed_words) + line_fmt = '{0}\n'.format(current_line) + wf.write(line_fmt) + f.close() + + print ('Resulting vocab size: {0}'.format(len(self.vocab))) + wf.close() + + def process(self, words): + raise NotImplementedError("No stemmer here!") diff --git a/src/refinement/stemmers/krovetz.py b/src/refinement/stemmers/krovetz.py new file mode 100644 index 0000000..168e885 --- /dev/null +++ b/src/refinement/stemmers/krovetz.py @@ -0,0 +1,27 @@ +#!/bin/python +from stemmers.abstractstemmer import AbstractStemmer +import subprocess +import sys + + +class KrovetzStemmer(AbstractStemmer): + + def __init__(self, jarfile): + super(KrovetzStemmer, self).__init__() + self.jarfile = jarfile + self.basename = 'krovetz' + + def process(self, words): + new_words = [] + while len(words) > 1000: + # new_words += subprocess.check_output(['java', '-jar', 'kstem-3.4.jar', '-w', ' '.join(words[:1000])]).split() + new_words += subprocess.check_output(['java', '-jar', self.jarfile, '-w', ' '.join(words[:1000])]).split() + words = words[1000:] + # new_words += subprocess.check_output('java -jar kstem-3.4.jar -w ' + ' '.join(words),shell=True,).split() + new_words += subprocess.check_output('java -jar ' + self.jarfile + ' -w ' + ' '.join(words), shell=True, ).split() + return [s.decode() for s in new_words] + + +if __name__ == '__main__': + stemmer = KrovetzStemmer() + stemmer.stem(sys.argv[1:]) diff --git a/src/refinement/stemmers/kstem-3.4.jar b/src/refinement/stemmers/kstem-3.4.jar new file mode 100644 index 0000000..d811f35 Binary files /dev/null and b/src/refinement/stemmers/kstem-3.4.jar differ diff --git a/src/refinement/stemmers/lovins.py b/src/refinement/stemmers/lovins.py new file mode 100644 index 0000000..84a9e03 --- /dev/null +++ b/src/refinement/stemmers/lovins.py @@ -0,0 +1,19 @@ +#!/bin/python +from stemmers.abstractstemmer import AbstractStemmer +import stemmers.lovinsstemmer as lovinsstemmer +import sys + + +class LovinsStemmer(AbstractStemmer): + + def __init__(self, ): + super(LovinsStemmer, self).__init__() + self.basename = 'lovins' + + def process(self, words): + return [lovinsstemmer.stem(word) for word in words] + + +if __name__ == '__main__': + stemmer = LovinsStemmer() + stemmer.stem(sys.argv[1:]) diff --git a/src/refinement/stemmers/lovinsstemmer.py b/src/refinement/stemmers/lovinsstemmer.py new file mode 100644 index 0000000..6c59bb3 --- /dev/null +++ b/src/refinement/stemmers/lovinsstemmer.py @@ -0,0 +1,542 @@ +"""This module implements the Lovins stemming algorithm. Use the ``stem()`` +function:: + + stemmed_word = stem(word) +""" + +from collections import defaultdict + + +# Conditions + +def A(base): + # A No restrictions on stem + return True + +def B(base): + # B Minimum stem length = 3 + return len(base) > 2 + +def C(base): + # C Minimum stem length = 4 + return len(base) > 3 + +def D(base): + # D Minimum stem length = 5 + return len(base) > 4 + +def E(base): + # E Do not remove ending after e + return base[-1] != "e" + +def F(base): + # F Minimum stem length = 3 and do not remove ending after e + return len(base) > 2 and base[-1] != "e" + +def G(base): + # G Minimum stem length = 3 and remove ending only after f + return len(base) > 2 and base[-1] == "f" + +def H(base): + # H Remove ending only after t or ll + c1, c2 = base[-2:] + return c2 == "t" or (c2 == "l" and c1 == "l") + +def I(base): + # I Do not remove ending after o or e + c = base[-1] + return c != "o" and c != "e" + +def J(base): + # J Do not remove ending after a or e + c = base[-1] + return c != "a" and c != "e" + +def K(base): + # K Minimum stem length = 3 and remove ending only after l, i or u*e + c = base[-1] + cc = "" if len(base) < 3 else base[-3] + return len(base) > 2 and (c == "l" or c == "i" or (c == "e" and cc == "u")) + +def L(base): + # L Do not remove ending after u, x or s, unless s follows o + c1, c2 = base[-2:] + return c2 != "u" and c2 != "x" and (c2 != "s" or c1 == "o") + +def M(base): + # M Do not remove ending after a, c, e or m + c = base[-1] + return c != "a" and c!= "c" and c != "e" and c != "m" + +def N(base): + # N Minimum stem length = 4 after s**, elsewhere = 3 + return len(base) > 3 or (len(base) == 3 and base[-1] != "s") + +def O(base): + # O Remove ending only after l or i + c = base[-1] + return c == "l" or c == "i" + +def P(base): + # P Do not remove ending after c + return base[-1] != "c" + +def Q(base): + # Q Minimum stem length = 3 and do not remove ending after l or n + c = base[-1] + return len(base) > 2 and (c != "l" and c != "n") + +def R(base): + # R Remove ending only after n or r + c = base[-1] + return c == "n" or c == "r" + +def S(base): + # S Remove ending only after dr or t, unless t follows t + l2 = base[-2] + return l2 == "rd" or (base[-1] == "t" and l2 != "tt") + +def T(base): + # T Remove ending only after s or t, unless t follows o + c1, c2 = base[-2:] + return c2 == "s" or (c2 == "t" and c1 != "o") + +def U(base): + # U Remove ending only after l, m, n or r + c = base[-1] + return c == "l" or c == "m" or c == "n" or c == "r" + +def V(base): + # V Remove ending only after c + return base[-1] == "c" + +def W(base): + # W Do not remove ending after s or u + c = base[-1] + return c != "s" and c != "u" + +def X(base): + # X Remove ending only after l, i or u*e + c = base[-1] + cc = "" if len(base) < 3 else base[-3] + return c == "l" or c == "i" or (c == "e" and cc == "u") + +def Y(base): + # Y Remove ending only after in + return base[-2:] == "in" + +def Z(base): + # Z Do not remove ending after f + return base[-1] != "f" + +def a(base): + # a Remove ending only after d, f, ph, th, l, er, or, es or t + c = base[-1] + l2 = base[-2:] + return (c == "d" or c == "f" or l2 == "ph" or l2 == "th" or c == "l" + or l2 == "er" or l2 == "or" or l2 == "es" or c == "t") + +def b(base): + # b Minimum stem length = 3 and do not remove ending after met or ryst + return len(base) > 2 and not (base.endswith("met") + or base.endswith("ryst")) + +def c(base): + # c Remove ending only after l + return base[-1] == "l" + +# Endings + +m = [None] * 12 + +m[11] = dict(( + ("alistically", B), + ("arizability", A), + ("izationally", B))) +m[10] = dict(( + ("antialness", A), + ("arisations", A), + ("arizations", A), + ("entialness", A))) +m[9] = dict(( + ("allically", C), + ("antaneous", A), + ("antiality", A), + ("arisation", A), + ("arization", A), + ("ationally", B), + ("ativeness", A), + ("eableness", E), + ("entations", A), + ("entiality", A), + ("entialize", A), + ("entiation", A), + ("ionalness", A), + ("istically", A), + ("itousness", A), + ("izability", A), + ("izational", A))) +m[8] = dict(( + ("ableness", A), + ("arizable", A), + ("entation", A), + ("entially", A), + ("eousness", A), + ("ibleness", A), + ("icalness", A), + ("ionalism", A), + ("ionality", A), + ("ionalize", A), + ("iousness", A), + ("izations", A), + ("lessness", A))) +m[7] = dict(( + ("ability", A), + ("aically", A), + ("alistic", B), + ("alities", A), + ("ariness", E), + ("aristic", A), + ("arizing", A), + ("ateness", A), + ("atingly", A), + ("ational", B), + ("atively", A), + ("ativism", A), + ("elihood", E), + ("encible", A), + ("entally", A), + ("entials", A), + ("entiate", A), + ("entness", A), + ("fulness", A), + ("ibility", A), + ("icalism", A), + ("icalist", A), + ("icality", A), + ("icalize", A), + ("ication", G), + ("icianry", A), + ("ination", A), + ("ingness", A), + ("ionally", A), + ("isation", A), + ("ishness", A), + ("istical", A), + ("iteness", A), + ("iveness", A), + ("ivistic", A), + ("ivities", A), + ("ization", F), + ("izement", A), + ("oidally", A), + ("ousness", A))) +m[6] = dict(( + ("aceous", A), + ("acious", B), + ("action", G), + ("alness", A), + ("ancial", A), + ("ancies", A), + ("ancing", B), + ("ariser", A), + ("arized", A), + ("arizer", A), + ("atable", A), + ("ations", B), + ("atives", A), + ("eature", Z), + ("efully", A), + ("encies", A), + ("encing", A), + ("ential", A), + ("enting", C), + ("entist", A), + ("eously", A), + ("ialist", A), + ("iality", A), + ("ialize", A), + ("ically", A), + ("icance", A), + ("icians", A), + ("icists", A), + ("ifully", A), + ("ionals", A), + ("ionate", D), + ("ioning", A), + ("ionist", A), + ("iously", A), + ("istics", A), + ("izable", E), + ("lessly", A), + ("nesses", A), + ("oidism", A))) +m[5] = dict(( + ("acies", A), + ("acity", A), + ("aging", B), + ("aical", A), + ("alist", A), + ("alism", B), + ("ality", A), + ("alize", A), + ("allic", b), + ("anced", B), + ("ances", B), + ("antic", C), + ("arial", A), + ("aries", A), + ("arily", A), + ("arity", B), + ("arize", A), + ("aroid", A), + ("ately", A), + ("ating", I), + ("ation", B), + ("ative", A), + ("ators", A), + ("atory", A), + ("ature", E), + ("early", Y), + ("ehood", A), + ("eless", A), + ("elily", A), + ("ement", A), + ("enced", A), + ("ences", A), + ("eness", E), + ("ening", E), + ("ental", A), + ("ented", C), + ("ently", A), + ("fully", A), + ("ially", A), + ("icant", A), + ("ician", A), + ("icide", A), + ("icism", A), + ("icist", A), + ("icity", A), + ("idine", I), + ("iedly", A), + ("ihood", A), + ("inate", A), + ("iness", A), + ("ingly", B), + ("inism", J), + ("inity", c), + ("ional", A), + ("ioned", A), + ("ished", A), + ("istic", A), + ("ities", A), + ("itous", A), + ("ively", A), + ("ivity", A), + ("izers", F), + ("izing", F), + ("oidal", A), + ("oides", A), + ("otide", A), + ("ously", A))) +m[4] = dict(( + ("able", A), + ("ably", A), + ("ages", B), + ("ally", B), + ("ance", B), + ("ancy", B), + ("ants", B), + ("aric", A), + ("arly", K), + ("ated", I), + ("ates", A), + ("atic", B), + ("ator", A), + ("ealy", Y), + ("edly", E), + ("eful", A), + ("eity", A), + ("ence", A), + ("ency", A), + ("ened", E), + ("enly", E), + ("eous", A), + ("hood", A), + ("ials", A), + ("ians", A), + ("ible", A), + ("ibly", A), + ("ical", A), + ("ides", L), + ("iers", A), + ("iful", A), + ("ines", M), + ("ings", N), + ("ions", B), + ("ious", A), + ("isms", B), + ("ists", A), + ("itic", H), + ("ized", F), + ("izer", F), + ("less", A), + ("lily", A), + ("ness", A), + ("ogen", A), + ("ward", A), + ("wise", A), + ("ying", B), + ("yish", A))) +m[3] = dict(( + ("acy", A), + ("age", B), + ("aic", A), + ("als", b), + ("ant", B), + ("ars", O), + ("ary", F), + ("ata", A), + ("ate", A), + ("eal", Y), + ("ear", Y), + ("ely", E), + ("ene", E), + ("ent", C), + ("ery", E), + ("ese", A), + ("ful", A), + ("ial", A), + ("ian", A), + ("ics", A), + ("ide", L), + ("ied", A), + ("ier", A), + ("ies", P), + ("ily", A), + ("ine", M), + ("ing", N), + ("ion", Q), + ("ish", C), + ("ism", B), + ("ist", A), + ("ite", a), + ("ity", A), + ("ium", A), + ("ive", A), + ("ize", F), + ("oid", A), + ("one", R), + ("ous", A))) +m[2] = dict(( + ("ae", A), + ("al", b), + ("ar", X), + ("as", B), + ("ed", E), + ("en", F), + ("es", E), + ("ia", A), + ("ic", A), + ("is", A), + ("ly", B), + ("on", S), + ("or", T), + ("um", U), + ("us", V), + ("yl", R), + ("s'", A), + ("'s", A))) +m[1] = dict(( + ("a", A), + ("e", A), + ("i", A), + ("o", A), + ("s", W), + ("y", B))) + + +def remove_ending(word): + length = len(word) + el = 11 + while el > 0: + if length - el > 1: + ending = word[length-el:] + cond = m[el].get(ending) + if cond: + base = word[:length-el] + if cond(base): + return base + el -= 1 + return word + + +_endings = (("iev", "ief"), + ("uct", "uc"), + ("iev", "ief"), + ("uct", "uc"), + ("umpt", "um"), + ("rpt", "rb"), + ("urs", "ur"), + ("istr", "ister"), + ("metr", "meter"), + ("olv", "olut"), + ("ul", "l", "aoi"), + ("bex", "bic"), + ("dex", "dic"), + ("pex", "pic"), + ("tex", "tic"), + ("ax", "ac"), + ("ex", "ec"), + ("ix", "ic"), + ("lux", "luc"), + ("uad", "uas"), + ("vad", "vas"), + ("cid", "cis"), + ("lid", "lis"), + ("erid", "eris"), + ("pand", "pans"), + ("end", "ens", "s"), + ("ond", "ons"), + ("lud", "lus"), + ("rud", "rus"), + ("her", "hes", "pt"), + ("mit", "mis"), + ("ent", "ens", "m"), + ("ert", "ers"), + ("et", "es", "n"), + ("yt", "ys"), + ("yz", "ys")) + +# Hash the ending rules by the last letter of the target ending +_endingrules = defaultdict(list) +for rule in _endings: + _endingrules[rule[0][-1]].append(rule) + +_doubles = frozenset(("dd", "gg", "ll", "mm", "nn", "pp", "rr", "ss", "tt")) + + +def fix_ending(word): + if word[-2:] in _doubles: + word = word[:-1] + + for endingrule in _endingrules[word[-1]]: + target, newend = endingrule[:2] + if word.endswith(target): + if len(endingrule) > 2 and len(word) > len(target): + exceptafter = endingrule[2] + c = word[0-(len(target)+1)] + if c in exceptafter: return word + + return word[:0-len(target)] + newend + + return word + + +def stem(word): + """Returns the stemmed version of the argument string. + """ + return fix_ending(remove_ending(word)) + + + diff --git a/src/refinement/stemmers/nostemmer.py b/src/refinement/stemmers/nostemmer.py new file mode 100644 index 0000000..bba960a --- /dev/null +++ b/src/refinement/stemmers/nostemmer.py @@ -0,0 +1,18 @@ +#!/bin/python +from abstractstemmer import AbstractStemmer +import sys + + +class NoStemmer(AbstractStemmer): + + def __init__(self, ): + super(NoStemmer, self).__init__() + self.basename = 'nostemmer' + + def process(self, words): + return words + + +if __name__ == '__main__': + stemmer = NoStemmer() + stemmer.stem(sys.argv[1:]) diff --git a/src/refinement/stemmers/paicehusk.py b/src/refinement/stemmers/paicehusk.py new file mode 100644 index 0000000..6046093 --- /dev/null +++ b/src/refinement/stemmers/paicehusk.py @@ -0,0 +1,19 @@ +#!/bin/python +from stemmers.abstractstemmer import AbstractStemmer +import stemmers.paicehuskstemmer as paicehuskstemmer +import sys + + +class PaiceHuskStemmer(AbstractStemmer): + + def __init__(self, ): + super(PaiceHuskStemmer, self).__init__() + self.basename = 'paicehusk' + + def process(self, words): + return [paicehuskstemmer.stem(word) for word in words] + + +if __name__ == '__main__': + stemmer = PaiceHuskStemmer() + stemmer.stem(sys.argv[1:]) diff --git a/src/refinement/stemmers/paicehuskstemmer.py b/src/refinement/stemmers/paicehuskstemmer.py new file mode 100644 index 0000000..6a1b866 --- /dev/null +++ b/src/refinement/stemmers/paicehuskstemmer.py @@ -0,0 +1,254 @@ +"""This module contains an object that implements the Paice-Husk stemming +algorithm. + +If you just want to use the standard Paice-Husk stemming rules, use the +module's ``stem()`` function:: + + stemmed_word = stem(word) + +If you want to use a custom rule set, read the rules into a string where the +rules are separated by newlines, and instantiate the object with the string, +then use the object's stem method to stem words:: + + stemmer = PaiceHuskStemmer(my_rules_string) + stemmed_word = stemmer.stem(word) +""" + +import re +from collections import defaultdict + + +class PaiceHuskStemmer(object): + """Implements the Paice-Husk stemming algorithm. + """ + + rule_expr = re.compile(r""" + ^(?P<ending>\w+) + (?P<intact>[*]?) + (?P<num>\d+) + (?P<append>\w*) + (?P<cont>[.>]) + """, re.UNICODE | re.VERBOSE) + + stem_expr = re.compile("^\w+", re.UNICODE) + + def __init__(self, ruletable): + """ + :param ruletable: a string containing the rule data, separated + by newlines. + """ + self.rules = defaultdict(list) + self.read_rules(ruletable) + + def read_rules(self, ruletable): + rule_expr = self.rule_expr + rules = self.rules + + for line in ruletable.split("\n"): + line = line.strip() + if not line: + continue + + match = rule_expr.match(line) + if match: + ending = match.group("ending")[::-1] + lastchar = ending[-1] + intact = match.group("intact") == "*" + num = int(match.group("num")) + append = match.group("append") + cont = match.group("cont") == ">" + + rules[lastchar].append((ending, intact, num, append, cont)) + else: + raise Exception("Bad rule: %r" % line) + + def first_vowel(self, word): + # Find the first position of each regular vowel (if any) + has_vowel_list = [p for p in [word.find(v) for v in "aeiou"] if p > -1] + # We add some logic to make sure y can be the only vowel + if len(has_vowel_list) > 0: + vp = min(has_vowel_list) + else: + vp = -1 + yp = word.find("y") + if yp > 0 and (yp < vp | vp == -1): + return yp + return vp + + def strip_prefix(self, word): + for prefix in ("kilo", "micro", "milli", "intra", "ultra", "mega", + "nano", "pico", "pseudo"): + if word.startswith(prefix) and len(word) > len(prefix): + return word[len(prefix):] + return word + + def stem(self, word): + """Returns a stemmed version of the argument string. + """ + if len(word) == 0: + return word + + rules = self.rules + match = self.stem_expr.match(word) + if not match: return word + stem = self.strip_prefix(match.group(0)) + + is_intact = True + continuing = True + while continuing: + rulelist = rules.get(stem[-1]) + if not rulelist: break + pfv = self.first_vowel(stem) + continuing = False + for ending, intact, num, append, cont in rulelist: + if stem.endswith(ending): + if intact and not is_intact: continue + newlen = len(stem) - num + len(append) + if ((pfv == 0 and newlen < 2) + or (pfv > 0 and newlen < 3) + or (pfv >= newlen) + or (pfv < 0)): + # If word starts with vowel, minimum stem length is 2. + # If word starts with consonant, minimum stem length is + # 3 and there must be a vowel in the stem somewhere + continue + + is_intact = False + if num > 0: + stem = stem[:0-num] + append + + continuing = cont + break + + return stem + +# The default rules for the Paice-Husk stemming algorithm + +defaultrules = """ +ai*2. { -ia > - if intact } +a*1. { -a > - if intact } +bb1. { -bb > -b } +city3s. { -ytic > -ys } +ci2> { -ic > - } +cn1t> { -nc > -nt } +dd1. { -dd > -d } +dei3y> { -ied > -y } +deec2ss. { -ceed > -cess } +dee1. { -eed > -ee } +de2> { -ed > - } +dooh4> { -hood > - } +e1> { -e > - } +feil1v. { -lief > -liev } +fi2> { -if > - } +gni3> { -ing > - } +gai3y. { -iag > -y } +ga2> { -ag > - } +gg1. { -gg > -g } +ht*2. { -th > - if intact } +hsiug5ct. { -guish > -ct } +hsi3> { -ish > - } +i*1. { -i > - if intact } +i1y> { -i > -y } +ji1d. { -ij > -id -- see nois4j> & vis3j> } +juf1s. { -fuj > -fus } +ju1d. { -uj > -ud } +jo1d. { -oj > -od } +jeh1r. { -hej > -her } +jrev1t. { -verj > -vert } +jsim2t. { -misj > -mit } +jn1d. { -nj > -nd } +j1s. { -j > -s } +lbaifi6. { -ifiabl > - } +lbai4y. { -iabl > -y } +lba3> { -abl > - } +lbi3. { -ibl > - } +lib2l> { -bil > -bl } +lc1. { -cl > c } +lufi4y. { -iful > -y } +luf3> { -ful > - } +lu2. { -ul > - } +lai3> { -ial > - } +lau3> { -ual > - } +la2> { -al > - } +ll1. { -ll > -l } +mui3. { -ium > - } +mu*2. { -um > - if intact } +msi3> { -ism > - } +mm1. { -mm > -m } +nois4j> { -sion > -j } +noix4ct. { -xion > -ct } +noi3> { -ion > - } +nai3> { -ian > - } +na2> { -an > - } +nee0. { protect -een } +ne2> { -en > - } +nn1. { -nn > -n } +pihs4> { -ship > - } +pp1. { -pp > -p } +re2> { -er > - } +rae0. { protect -ear } +ra2. { -ar > - } +ro2> { -or > - } +ru2> { -ur > - } +rr1. { -rr > -r } +rt1> { -tr > -t } +rei3y> { -ier > -y } +sei3y> { -ies > -y } +sis2. { -sis > -s } +si2> { -is > - } +ssen4> { -ness > - } +ss0. { protect -ss } +suo3> { -ous > - } +su*2. { -us > - if intact } +s*1> { -s > - if intact } +s0. { -s > -s } +tacilp4y. { -plicat > -ply } +ta2> { -at > - } +tnem4> { -ment > - } +tne3> { -ent > - } +tna3> { -ant > - } +tpir2b. { -ript > -rib } +tpro2b. { -orpt > -orb } +tcud1. { -duct > -duc } +tpmus2. { -sumpt > -sum } +tpec2iv. { -cept > -ceiv } +tulo2v. { -olut > -olv } +tsis0. { protect -sist } +tsi3> { -ist > - } +tt1. { -tt > -t } +uqi3. { -iqu > - } +ugo1. { -ogu > -og } +vis3j> { -siv > -j } +vie0. { protect -eiv } +vi2> { -iv > - } +ylb1> { -bly > -bl } +yli3y> { -ily > -y } +ylp0. { protect -ply } +yl2> { -ly > - } +ygo1. { -ogy > -og } +yhp1. { -phy > -ph } +ymo1. { -omy > -om } +ypo1. { -opy > -op } +yti3> { -ity > - } +yte3> { -ety > - } +ytl2. { -lty > -l } +yrtsi5. { -istry > - } +yra3> { -ary > - } +yro3> { -ory > - } +yfi3. { -ify > - } +ycn2t> { -ncy > -nt } +yca3> { -acy > - } +zi2> { -iz > - } +zy1s. { -yz > -ys } +""" + +# Make the standard rules available as a module-level function + +stem = PaiceHuskStemmer(defaultrules).stem + + + + + + + diff --git a/src/refinement/stemmers/porter.py b/src/refinement/stemmers/porter.py new file mode 100644 index 0000000..19506ea --- /dev/null +++ b/src/refinement/stemmers/porter.py @@ -0,0 +1,19 @@ +#!/bin/python +from stemmers.abstractstemmer import AbstractStemmer +from stemmers import porterstemmer +import sys + + +class PorterStemmer(AbstractStemmer): + + def __init__(self, ): + super(PorterStemmer, self).__init__() + self.basename = 'porter' + + def process(self, words): + return [porterstemmer.stem(word) for word in words] + + +if __name__ == '__main__': + stemmer = PorterStemmer() + # stemmer.stem(sys.argv[1:]) diff --git a/src/refinement/stemmers/porter2.py b/src/refinement/stemmers/porter2.py new file mode 100644 index 0000000..83a32cb --- /dev/null +++ b/src/refinement/stemmers/porter2.py @@ -0,0 +1,19 @@ +#!/bin/python +from stemmers.abstractstemmer import AbstractStemmer +import stemmers.porter2stemmer as porter2stemmer +import sys + + +class Porter2Stemmer(AbstractStemmer): + + def __init__(self, ): + super(Porter2Stemmer, self).__init__() + self.basename = 'porter2' + + def process(self, words): + return [porter2stemmer.stem(word) for word in words] + + +if __name__ == '__main__': + stemmer = Porter2Stemmer() + stemmer.stem(sys.argv[1:]) diff --git a/src/refinement/stemmers/porter2stemmer.py b/src/refinement/stemmers/porter2stemmer.py new file mode 100644 index 0000000..ebb2380 --- /dev/null +++ b/src/refinement/stemmers/porter2stemmer.py @@ -0,0 +1,288 @@ +"""An implementation of the Porter2 stemming algorithm. +See http://snowball.tartarus.org/algorithms/english/stemmer.html + +Adapted from pyporter2 by Michael Dirolf. + +This algorithm is more correct but (at least in this implementation) +several times slower than the original porter algorithm as implemented +in stemming.porter. +""" + +import re + +r_exp = re.compile(r"[^aeiouy]*[aeiouy]+[^aeiouy](\w*)") +ewss_exp1 = re.compile(r"^[aeiouy][^aeiouy]$") +ewss_exp2 = re.compile(r".*[^aeiouy][aeiouy][^aeiouywxY]$") +ccy_exp = re.compile(r"([aeiouy])y") +s1a_exp = re.compile(r"[aeiouy].") +s1b_exp = re.compile(r"[aeiouy]") + +def get_r1(word): + # exceptional forms + if word.startswith('gener') or word.startswith('arsen'): + return 5 + if word.startswith('commun'): + return 6 + + # normal form + match = r_exp.match(word) + if match: + return match.start(1) + return len(word) + +def get_r2(word): + match = r_exp.match(word, get_r1(word)) + if match: + return match.start(1) + return len(word) + +def ends_with_short_syllable(word): + if len(word) == 2: + if ewss_exp1.match(word): + return True + if ewss_exp2.match(word): + return True + return False + +def is_short_word(word): + if ends_with_short_syllable(word): + if get_r1(word) == len(word): + return True + return False + +def remove_initial_apostrophe(word): + if word.startswith("'"): + return word[1:] + return word + +def capitalize_consonant_ys(word): + if word.startswith('y'): + word = 'Y' + word[1:] + return ccy_exp.sub('\g<1>Y', word) + +def step_0(word): + if word.endswith("'s'"): + return word[:-3] + if word.endswith("'s"): + return word[:-2] + if word.endswith("'"): + return word[:-1] + return word + +def step_1a(word): + if word.endswith('sses'): + return word[:-4] + 'ss' + if word.endswith('ied') or word.endswith('ies'): + if len(word) > 4: + return word[:-3] + 'i' + else: + return word[:-3] + 'ie' + if word.endswith('us') or word.endswith('ss'): + return word + if word.endswith('s'): + preceding = word[:-1] + if s1a_exp.search(preceding): + return preceding + return word + return word + +doubles = ('bb', 'dd', 'ff', 'gg', 'mm', 'nn', 'pp', 'rr', 'tt') +def ends_with_double(word): + for double in doubles: + if word.endswith(double): + return True + return False +def step_1b_helper(word): + if word.endswith('at') or word.endswith('bl') or word.endswith('iz'): + return word + 'e' + if ends_with_double(word): + return word[:-1] + if is_short_word(word): + return word + 'e' + return word +s1b_suffixes = ('ed', 'edly', 'ing', 'ingly') + +def step_1b(word, r1): + if word.endswith('eedly'): + if len(word) - 5 >= r1: + return word[:-3] + return word + if word.endswith('eed'): + if len(word) - 3 >= r1: + return word[:-1] + return word + + for suffix in s1b_suffixes: + if word.endswith(suffix): + preceding = word[:-len(suffix)] + if s1b_exp.search(preceding): + return step_1b_helper(preceding) + return word + + return word + +def step_1c(word): + if word.endswith('y') or word.endswith('Y') and len(word) > 1: + if word[-2] not in 'aeiouy': + if len(word) > 2: + return word[:-1] + 'i' + return word + +def step_2_helper(word, r1, end, repl, prev): + if word.endswith(end): + if len(word) - len(end) >= r1: + if prev == []: + return word[:-len(end)] + repl + for p in prev: + if word[:-len(end)].endswith(p): + return word[:-len(end)] + repl + return word + return None +s2_triples = (('ization', 'ize', []), + ('ational', 'ate', []), + ('fulness', 'ful', []), + ('ousness', 'ous', []), + ('iveness', 'ive', []), + ('tional', 'tion', []), + ('biliti', 'ble', []), + ('lessli', 'less', []), + ('entli', 'ent', []), + ('ation', 'ate', []), + ('alism', 'al', []), + ('aliti', 'al', []), + ('ousli', 'ous', []), + ('iviti', 'ive', []), + ('fulli', 'ful', []), + ('enci', 'ence', []), + ('anci', 'ance', []), + ('abli', 'able', []), + ('izer', 'ize', []), + ('ator', 'ate', []), + ('alli', 'al', []), + ('bli', 'ble', []), + ('ogi', 'og', ['l']), + ('li', '', ['c', 'd', 'e', 'g', 'h', 'k', 'm', 'n', 'r', 't'])) + +def step_2(word, r1): + for trip in s2_triples: + attempt = step_2_helper(word, r1, trip[0], trip[1], trip[2]) + if attempt: + return attempt + return word + +def step_3_helper(word, r1, r2, end, repl, r2_necessary): + if word.endswith(end): + if len(word) - len(end) >= r1: + if not r2_necessary: + return word[:-len(end)] + repl + else: + if len(word) - len(end) >= r2: + return word[:-len(end)] + repl + return word + return None +s3_triples = (('ational', 'ate', False), + ('tional', 'tion', False), + ('alize', 'al', False), + ('icate', 'ic', False), + ('iciti', 'ic', False), + ('ative', '', True), + ('ical', 'ic', False), + ('ness', '', False), + ('ful', '', False)) +def step_3(word, r1, r2): + for trip in s3_triples: + attempt = step_3_helper(word, r1, r2, trip[0], trip[1], trip[2]) + if attempt: + return attempt + return word + +s4_delete_list = ('al', 'ance', 'ence', 'er', 'ic', 'able', 'ible', 'ant', 'ement', + 'ment', 'ent', 'ism', 'ate', 'iti', 'ous', 'ive', 'ize') + +def step_4(word, r2): + for end in s4_delete_list: + if word.endswith(end): + if len(word) - len(end) >= r2: + return word[:-len(end)] + return word + + if word.endswith('sion') or word.endswith('tion'): + if len(word) - 3 >= r2: + return word[:-3] + + return word + +def step_5(word, r1, r2): + if word.endswith('l'): + if len(word) - 1 >= r2 and word[-2] == 'l': + return word[:-1] + return word + + if word.endswith('e'): + if len(word) - 1 >= r2: + return word[:-1] + if len(word) - 1 >= r1 and not ends_with_short_syllable(word[:-1]): + return word[:-1] + + return word + +def normalize_ys(word): + return word.replace('Y', 'y') + +exceptional_forms = {'skis': 'ski', + 'skies': 'sky', + 'dying': 'die', + 'lying': 'lie', + 'tying': 'tie', + 'idly': 'idl', + 'gently': 'gentl', + 'ugly': 'ugli', + 'early': 'earli', + 'only': 'onli', + 'singly': 'singl', + 'sky': 'sky', + 'news': 'news', + 'howe': 'howe', + 'atlas': 'atlas', + 'cosmos': 'cosmos', + 'bias': 'bias', + 'andes': 'andes'} + +exceptional_early_exit_post_1a = frozenset(['inning', 'outing', 'canning', 'herring', + 'earring', 'proceed', 'exceed', 'succeed']) + + +def stem(word): + if len(word) <= 2: + return word + word = remove_initial_apostrophe(word) + + # handle some exceptional forms + if word in exceptional_forms: + return exceptional_forms[word] + + word = capitalize_consonant_ys(word) + r1 = get_r1(word) + r2 = get_r2(word) + word = step_0(word) + word = step_1a(word) + + # handle some more exceptional forms + if word in exceptional_early_exit_post_1a: + return word + + word = step_1b(word, r1) + word = step_1c(word) + word = step_2(word, r1) + word = step_3(word, r1, r2) + word = step_4(word, r2) + word = step_5(word, r1, r2) + word = normalize_ys(word) + + return word + +if __name__ == "__main__": + assert stem("bill's") == "bill" + assert stem("y's") == "y" + + diff --git a/src/refinement/stemmers/porterstemmer.py b/src/refinement/stemmers/porterstemmer.py new file mode 100644 index 0000000..29ff5a6 --- /dev/null +++ b/src/refinement/stemmers/porterstemmer.py @@ -0,0 +1,188 @@ +""" +Reimplementation of the +`Porter stemming algorithm <http://tartarus.org/~martin/PorterStemmer/>`_ +in Python. + +In my quick tests, this implementation about 3.5 times faster than the +seriously weird Python linked from the official page. +""" + +import re + +# Suffix replacement lists + +_step2list = { + "ational": "ate", + "tional": "tion", + "enci": "ence", + "anci": "ance", + "izer": "ize", + "bli": "ble", + "alli": "al", + "entli": "ent", + "eli": "e", + "ousli": "ous", + "ization": "ize", + "ation": "ate", + "ator": "ate", + "alism": "al", + "iveness": "ive", + "fulness": "ful", + "ousness": "ous", + "aliti": "al", + "iviti": "ive", + "biliti": "ble", + "logi": "log", + } + +_step3list = { + "icate": "ic", + "ative": "", + "alize": "al", + "iciti": "ic", + "ical": "ic", + "ful": "", + "ness": "", + } + + +_cons = "[^aeiou]" +_vowel = "[aeiouy]" +_cons_seq = "[^aeiouy]+" +_vowel_seq = "[aeiou]+" + +# m > 0 +_mgr0 = re.compile("^(" + _cons_seq + ")?" + _vowel_seq + _cons_seq) +# m == 0 +_meq1 = re.compile("^(" + _cons_seq + ")?" + _vowel_seq + _cons_seq + "(" + _vowel_seq + ")?$") +# m > 1 +_mgr1 = re.compile("^(" + _cons_seq + ")?" + _vowel_seq + _cons_seq + _vowel_seq + _cons_seq) +# vowel in stem +_s_v = re.compile("^(" + _cons_seq + ")?" + _vowel) +# ??? +_c_v = re.compile("^" + _cons_seq + _vowel + "[^aeiouwxy]$") + +# Patterns used in the rules + +_ed_ing = re.compile("^(.*)(ed|ing)$") +_at_bl_iz = re.compile("(at|bl|iz)$") +_step1b = re.compile("([^aeiouylsz])\\1$") +_step2 = re.compile("^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$") +_step3 = re.compile("^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$") +_step4_1 = re.compile("^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$") +_step4_2 = re.compile("^(.+?)(s|t)(ion)$") +_step5 = re.compile("^(.+?)e$") + +# Stemming function + +def stem(w): + """Uses the Porter stemming algorithm to remove suffixes from English + words. + + >>> stem("fundamentally") + "fundament" + """ + + if len(w) < 3: return w + + first_is_y = w[0] == "y" + if first_is_y: + w = "Y" + w[1:] + + # Step 1a + if w.endswith("s"): + if w.endswith("sses"): + w = w[:-2] + elif w.endswith("ies"): + w = w[:-2] + elif w[-2] != "s": + w = w[:-1] + + # Step 1b + + if w.endswith("eed"): + s = w[:-3] + if _mgr0.match(s): + w = w[:-1] + else: + m = _ed_ing.match(w) + if m: + stem = m.group(1) + if _s_v.match(stem): + w = stem + if _at_bl_iz.match(w): + w += "e" + elif _step1b.match(w): + w = w[:-1] + elif _c_v.match(w): + w += "e" + + # Step 1c + + if w.endswith("y"): + stem = w[:-1] + if _s_v.match(stem): + w = stem + "i" + + # Step 2 + + m = _step2.match(w) + if m: + stem = m.group(1) + suffix = m.group(2) + if _mgr0.match(stem): + w = stem + _step2list[suffix] + + # Step 3 + + m = _step3.match(w) + if m: + stem = m.group(1) + suffix = m.group(2) + if _mgr0.match(stem): + w = stem + _step3list[suffix] + + # Step 4 + + m = _step4_1.match(w) + if m: + stem = m.group(1) + if _mgr1.match(stem): + w = stem + else: + m = _step4_2.match(w) + if m: + stem = m.group(1) + m.group(2) + if _mgr1.match(stem): + w = stem + + # Step 5 + + m = _step5.match(w) + if m: + stem = m.group(1) + if _mgr1.match(stem) or (_meq1.match(stem) and not _c_v.match(stem)): + w = stem + + if w.endswith("ll") and _mgr1.match(w): + w = w[:-1] + + if first_is_y: + w = "y" + w[1:] + + return w + +if __name__ == '__main__': + print (stem("fundamentally")) + + + + + + + + + + + + diff --git a/src/refinement/stemmers/sstemmer.py b/src/refinement/stemmers/sstemmer.py new file mode 100644 index 0000000..ed2a4e7 --- /dev/null +++ b/src/refinement/stemmers/sstemmer.py @@ -0,0 +1,28 @@ +#!/bin/python +from stemmers.abstractstemmer import AbstractStemmer +import sys + + +class SRemovalStemmer(AbstractStemmer): + + def __init__(self, ): + super(SRemovalStemmer, self).__init__() + self.basename = 'sstemmer' + + def process(self, words): + return [self.stem_word(word) for word in words] + + def stem_word(self, word): + if len(word) > 5 and word[-3:] == 'ies' and word[-4] not in 'ae': + return word[:-3] + 'y' + elif len(word) > 4 and word[-2:] == 'es' and word[-3] not in 'aeo': + return word[:-1] + elif len(word) > 3 and word[-1] == 's' and word[-2] not in 'us': + return word[:-1] + else: + return word + + +if __name__ == '__main__': + stemmer = SRemovalStemmer() + stemmer.stem(sys.argv[1:]) diff --git a/src/refinement/stemmers/trunc4.py b/src/refinement/stemmers/trunc4.py new file mode 100644 index 0000000..969a171 --- /dev/null +++ b/src/refinement/stemmers/trunc4.py @@ -0,0 +1,24 @@ +#!/bin/python +from stemmers.abstractstemmer import AbstractStemmer +import sys + + +class Trunc4Stemmer(AbstractStemmer): + + def __init__(self, ): + super(Trunc4Stemmer, self).__init__() + self.basename = 'trunc4' + + def process(self, words): + return [self.stem_word(word) for word in words] + + def stem_word(self, word): + if len(word) > 4: + return word[:4] + else: + return word + + +if __name__ == '__main__': + stemmer = Trunc4Stemmer() + stemmer.stem(sys.argv[1:]) diff --git a/src/refinement/stemmers/trunc5.py b/src/refinement/stemmers/trunc5.py new file mode 100644 index 0000000..e1c0bd9 --- /dev/null +++ b/src/refinement/stemmers/trunc5.py @@ -0,0 +1,24 @@ +#!/bin/python +from stemmers.abstractstemmer import AbstractStemmer +import sys + + +class Trunc5Stemmer(AbstractStemmer): + + def __init__(self, ): + super(Trunc5Stemmer, self).__init__() + self.basename = 'trunc5' + + def process(self, words): + return [self.stem_word(word) for word in words] + + def stem_word(self, word): + if len(word) > 5: + return word[:5] + else: + return word + + +if __name__ == '__main__': + stemmer = Trunc5Stemmer() + stemmer.stem(sys.argv[1:]) diff --git a/src/refinement/utils.py b/src/refinement/utils.py new file mode 100644 index 0000000..534f043 --- /dev/null +++ b/src/refinement/utils.py @@ -0,0 +1,141 @@ +from nltk.corpus import stopwords +from nltk.tokenize import word_tokenize +from nltk.stem import PorterStemmer +import re +from contextlib import contextmanager +import os, sys, threading, io, tempfile + +stop_words = set(stopwords.words('english')) +l = ['.', ',', '!', '?', ';', 'a', 'an', '(', ')', "'", '_', '-', '<', '>', 'if', '/', '[', ']', ' '] +stop_words.update(l) + +ps = PorterStemmer() + +def get_tokenized_query(q): + word_tokens = word_tokenize(q) + q_ = [w.lower() for w in word_tokens if w.lower() not in stop_words] + return q_ + +def valid(word): + """ + Check if input is null or contains only spaces or numbers or special characters + """ + temp = re.sub(r'[^A-Za-z ]', ' ', word) + temp = re.sub(r"\s+", " ", temp) + temp = temp.strip() + if temp != "": + return True + return False + +def clean(str): + result = [ch if ch.isalpha() else ' ' for ch in str] + return ''.join(result) + +def insert_row(df, idx, row): + import pandas as pd + df1 = df[0:idx] + df2 = df[idx:] + df1.loc[idx] = row + df = pd.concat([df1, df2]) + df.index = [*range(df.shape[0])] + return df + +def get_raw_query(topicreader,Q_filename): + q_file=open(Q_filename,'r').readlines() + raw_queries={} + if topicreader=='Trec': + for line in q_file: + if '<title>' in line : + raw_queries[qid]=line.split('<title>')[1].rstrip().lower() + elif '<num>' in line: + qid=line.split(':')[1].rstrip() + + elif topicreader=='Webxml': + for line in q_file: + if '<query>' in line: + raw_queries[qid]=line.split('<query>')[1].rstrip().lower().split('</query>')[0] + elif '<topic number' in line: + qid=line.split('<topic number="')[1].split('"')[0] + elif topicreader=='TsvInt' or topicreader=='TsvString': + for line in q_file: + qid=line.split('\t')[0] + raw_queries[qid]=line.split('\t')[1].rstrip().lower() + return raw_queries + +def get_ranker_name(ranker): + return ranker.replace('-', '').replace(' ', '.') + +# Thanks to the following links, we can capture outputs from external c/java libs +# - https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/ +# - https://stackoverflow.com/questions/5081657/how-do-i-prevent-a-c-shared-library-to-print-on-stdout-in-python/17954769#17954769 + +# libc = ctypes.CDLL(None) +# c_stdout = ctypes.c_void_p.in_dll(libc, 'stdout') +@contextmanager +def stdout_redirector_2_stream(stream): + # The original fd stdout points to. Usually 1 on POSIX systems. + original_stdout_fd = sys.stdout.fileno() + + def _redirect_stdout(to_fd): + """Redirect stdout to the given file descriptor.""" + # Flush the C-level buffer stdout + # libc.fflush(c_stdout) + # Flush and close sys.stdout - also closes the file descriptor (fd) + sys.stdout.close() + # Make original_stdout_fd point to the same file as to_fd + os.dup2(to_fd, original_stdout_fd) + # Create a new sys.stdout that points to the redirected fd + sys.stdout = io.TextIOWrapper(os.fdopen(original_stdout_fd, 'wb')) + + # Save a copy of the original stdout fd in saved_stdout_fd + saved_stdout_fd = os.dup(original_stdout_fd) + try: + # Create a temporary file and redirect stdout to it + tfile = tempfile.TemporaryFile(mode='w+b') + _redirect_stdout(tfile.fileno()) + # Yield to caller, then redirect stdout back to the saved fd + yield + _redirect_stdout(saved_stdout_fd) + # Copy contents of temporary file to the given stream + tfile.flush() + tfile.seek(0, io.SEEK_SET) + stream.write(tfile.read()) + finally: + tfile.close() + os.close(saved_stdout_fd) + +@contextmanager +def stdout_redirected_2_file(to=os.devnull): + ''' + import os + with stdout_redirected(to=filename): + print("from Python") + os.system("echo non-Python applications are also supported") + ''' + fd = sys.stdout.fileno() + + ##### assert that Python and C stdio write using the same file descriptor + ####assert libc.fileno(ctypes.c_void_p.in_dll(libc, "stdout")) == fd == 1 + + def _redirect_stdout(to): + sys.stdout.close() # + implicit flush() + os.dup2(to.fileno(), fd) # fd writes to 'to' file + sys.stdout = os.fdopen(fd, 'w') # Python writes to fd + + with os.fdopen(os.dup(fd), 'w') as old_stdout: + with open(to, 'w') as file: + _redirect_stdout(to=file) + try: + yield # allow code to be run with the redirected stdout + finally: + _redirect_stdout(to=old_stdout) # restore stdout. + # buffering and flags such as + # CLOEXEC may be different + + +def hex_to_ansi(hex_color_code): + hex_color_code = hex_color_code.lstrip('#') + red = int(hex_color_code[0:2], 16) + green = int(hex_color_code[2:4], 16) + blue = int(hex_color_code[4:6], 16) + return f'\033[38;2;{red};{green};{blue}m' diff --git a/src/stats/stats.py b/src/stats/stats.py index 4dcff55..9e3bb26 100644 --- a/src/stats/stats.py +++ b/src/stats/stats.py @@ -10,7 +10,7 @@ def plot_stats(box_path): for ds in datasets: map_ds = pd.read_csv(f'{box_path}/{ds}.tsv', sep='\t', encoding='utf-8', names=['qid', 'i', 'i_map', 't', 't_map']) map_ds.sort_values(by='i_map', inplace=True) - stats = map_ds.groupby(np.arange(len(map_ds)) // (len(map_ds) / 10)).mean() + stats = map_ds.groupby(np.arange(len(map_ds)) // (len(map_ds) / 10)).mean(numeric_only=True) X = [x for x in range(1, 11)] original_mean = stats['i_map'] changes_mean = stats['t_map']