Skip to content

Commit

Permalink
Merge pull request #7 from XKTZ/main
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
XKTZ authored Sep 12, 2024
2 parents 3795308 + edd1ff9 commit 89bab7c
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
We offer a suite of rerankers - pointwise models like monoT5 and listwise models with a focus on open source LLMs compatible with [FastChat](https://github.com/lm-sys/FastChat?tab=readme-ov-file#supported-models) (e.g., Vicuna, Zephyr, etc.) or [vLLM](https://https://github.com/vllm-project/vllm). We also support RankGPT variants, which are proprietary listwise rerankers. Some of the code in this repository is borrowed from [RankGPT](https://github.com/sunnweiwei/RankGPT), [PyGaggle](https://github.com/castorini/pygaggle), and [LiT5](https://github.com/castorini/LiT5)!

# Releases
current_version = 0.20.1
current_version = 0.20.2

## 📟 Instructions

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "rank-llm"
version = "0.20.1"
version = "0.20.2"
description = "A Package for running prompt decoders like RankVicuna"
readme = "README.md"
authors = [
Expand Down Expand Up @@ -35,7 +35,7 @@ vllm = [
Homepage = "https://github.com/castorini/rank_llm"

[tool.bumpver]
current_version = "0.20.1"
current_version = "0.20.2"
version_pattern = "MAJOR.MINOR.PATCH"
commit_message = "Bump version {old_version} -> {new_version}"
commit = true
Expand Down
3 changes: 0 additions & 3 deletions src/rank_llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from .retrieve_and_rerank import retrieve_and_rerank

__all__ = ["retrieve_and_rerank"]
2 changes: 1 addition & 1 deletion src/rank_llm/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch
from flask import Flask, jsonify, request

from rank_llm import retrieve_and_rerank
from rank_llm.rerank import PromptMode, get_azure_openai_args, get_openai_api_key
from rank_llm.rerank.listwise import RankListwiseOSLLM, SafeOpenai
from rank_llm.retrieve import RetrievalMethod, RetrievalMode
from rank_llm.retrieve_and_rerank import retrieve_and_rerank

""" API URL FORMAT
Expand Down
12 changes: 11 additions & 1 deletion src/rank_llm/rerank/api_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,20 @@

from dotenv import load_dotenv

# Common OpenAI API key paths
paths = [
"OPENAI_API_KEY",
"OPEN_AI_API_KEY",
]


def get_openai_api_key() -> str:
load_dotenv(dotenv_path=f".env.local")
return os.getenv("OPEN_AI_API_KEY")

for path in paths:
if os.getenv(path) is not None:
return os.getenv(path)
return None


def get_azure_openai_args() -> Dict[str, str]:
Expand Down
4 changes: 3 additions & 1 deletion src/rank_llm/rerank/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,9 @@ def extract_kwargs(
"""
extracted_kwargs = []
for key, default in keys_and_defaults:
value = kwargs.get(key, default)
value = kwargs.get(key, None)
if value is None:
value = default
if (
value is not None
and default is not None
Expand Down
3 changes: 2 additions & 1 deletion src/rank_llm/retrieve_and_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Dict, List, Union

from rank_llm.data import Query, Request
from rank_llm.evaluation.trec_eval import EvalFunction
from rank_llm.rerank import IdentityReranker, RankLLM, Reranker
from rank_llm.rerank.reranker import extract_kwargs
from rank_llm.retrieve import (
Expand Down Expand Up @@ -104,6 +103,8 @@ def retrieve_and_rerank(
and dataset not in ["dl22", "dl22-passage", "news"]
and TOPICS[dataset] not in ["dl22", "dl22-passage", "news"]
):
from rank_llm.evaluation.trec_eval import EvalFunction

print("Evaluating:")
EvalFunction.eval(["-c", "-m", "ndcg_cut.1", TOPICS[dataset], file_name])
EvalFunction.eval(["-c", "-m", "ndcg_cut.5", TOPICS[dataset], file_name])
Expand Down
2 changes: 1 addition & 1 deletion src/rank_llm/scripts/run_rank_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
parent = os.path.dirname(parent)
sys.path.append(parent)

from rank_llm import retrieve_and_rerank
from rank_llm.rerank import PromptMode
from rank_llm.retrieve import TOPICS, RetrievalMethod, RetrievalMode
from rank_llm.retrieve_and_rerank import retrieve_and_rerank


def main(args):
Expand Down
2 changes: 1 addition & 1 deletion test/retrieve/test_ServiceRetriever.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import unittest

from rank_llm import retrieve_and_rerank
from rank_llm.data import Candidate, Query, Request
from rank_llm.retrieve import RetrievalMethod, RetrievalMode, ServiceRetriever
from rank_llm.retrieve_and_rerank import retrieve_and_rerank


class TestServiceRetriever(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion test/test_retrieve_and_rerank.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from unittest.mock import MagicMock, patch

from rank_llm import retrieve_and_rerank
from rank_llm.retrieve_and_rerank import retrieve_and_rerank


# Anserini API must be hosted at 8081
Expand Down

0 comments on commit 89bab7c

Please sign in to comment.