Skip to content

Commit

Permalink
Added final some stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
XKTZ committed Sep 15, 2024
1 parent 89bab7c commit 65d3508
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 35 deletions.
4 changes: 4 additions & 0 deletions src/rank_llm/rerank/listwise/listwise_rankllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def rerank_batch(
result.extend(batch_result)
bar.update(len(batch))

logger.info(
f"Average consumption per request: {consumption.consumption_reference_by_item / len(requests) : .2f}"
)

return result

def get_output_filename(
Expand Down
6 changes: 4 additions & 2 deletions src/rank_llm/rerank/listwise/rank_listwise_os_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
class RankListwiseOSLLM(ListwiseRankLLM):
def __init__(
self,
reorder_policy: ReorderPolicy,
model: str,
reorder_policy: ReorderPolicy = None,
name: str = "",
context_size: int = 4096,
window_size: int = 20,
Expand Down Expand Up @@ -270,7 +270,9 @@ def chunks(lst, n):
all_completed_prompts = []

with ThreadPoolExecutor() as executor:
for batch in tqdm(chunks(results, batch_size), desc="Processing batches"):
for batch in tqdm(
chunks(results, batch_size), desc="Processing batches", leave=False
):
completed_prompts = list(
executor.map(
lambda req: self.create_prompt(req[0], req[1]),
Expand Down
66 changes: 45 additions & 21 deletions src/rank_llm/rerank/listwise/reorder/reorder_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import copy
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Dict, List, Tuple, TypeVar, Union

import numpy as np

from rank_llm.data import Result

T = TypeVar("T")
Expand Down Expand Up @@ -42,11 +45,39 @@ def name() -> str:
pass

@staticmethod
def _shuffle_and_rescore(
results: List[Result], select_indexes: List[int]
) -> List[Result]:
# TODO: do nothing for now
return results
def _shuffle_indices(indices: List[int]) -> List[int]:
indices = list(indices)
random.shuffle(indices)
return indices

@staticmethod
def _shuffled(
func: Callable[[List[Tuple[Result, List[int]]]], List[List[int]]]
) -> Callable[[List[Tuple[Result, List[int]]]], List[List[int]]]:
def fun(batch: List[Tuple[Result, List[int]]]) -> List[List[int]]:
perms = []
perms_back = []
batch_feed = []
for res, ind in batch:
perm = np.random.permutation(len(ind)).tolist()
perm_back = [0 for _ in range(len(perm))]
perms.append(perm)

for i in range(len(perm)):
perm_back[perm[i]] = i

batch_feed.append((res, [ind[x] for x in perm]))
perms_back.append(perm_back)

result_raw = func(batch)

results = []
for result, perm_back in zip(result_raw, perms_back):
results.append([result[perm_back[x]] for x in range(len(result))])

return results

return fun

@staticmethod
def _reorder_by_rank(items: List[T], idxes: List[int], rank: List[int]) -> List[T]:
Expand All @@ -69,13 +100,10 @@ class SlidingWindowReorderPolicy(ReorderPolicy):
def __init__(
self,
step: int = 10,
shuffle_candidates: bool = False,
**kwargs,
):
self._step_size = step

self._shuffle_candidates = bool(shuffle_candidates)

def reorder(
self,
requests: List[Result],
Expand All @@ -89,20 +117,16 @@ def reorder(
) -> List[Result]:
window_size = model.window_size

rerank_results = [
Result(
query=copy.deepcopy(request.query),
candidates=copy.deepcopy(request.candidates),
ranking_exec_summary=[],
)
for request in requests
]

if self._shuffle_candidates:
self._shuffle_and_rescore(rerank_results, [*range(rank_start, rank_end)])

# order of requests
request_ranks = [[*range(len(request.candidates))] for request in requests]
if shuffle_candidates:
request_ranks = [
self._shuffle_indices(list(range(len(request.candidates))))
for request in requests
]
else:
request_ranks = [
list(range(len(request.candidates))) for request in requests
]

end_pos = rank_end
start_pos = rank_end - window_size
Expand Down
17 changes: 10 additions & 7 deletions src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import logging
import random
from dataclasses import dataclass
from typing import Callable, List, Tuple

Expand Down Expand Up @@ -62,11 +61,6 @@ def _remove_from_occ(self, lst: List[int], inds: List[int]):
st = set(inds)
return [x for x in lst if x not in st]

def _shuffle(self, lst: List[int]) -> List[int]:
l = [x for x in lst]
random.shuffle(l)
return l

def perform(self):
top_k = self._top_k
window_size = self._window_size
Expand Down Expand Up @@ -186,6 +180,7 @@ def reorder(
rank_start: int,
rank_end: int,
model: ModelFunction,
shuffle_candidates: bool = False,
**kwargs,
) -> list[Result]:
window_size = model.window_size
Expand All @@ -197,9 +192,17 @@ def reorder(
model.create_prompt(reqs), [ind for req, ind in reqs]
)

if shuffle_candidates:
indices = [
self._shuffle_indices(list(range(len(request.candidates))))
for request in requests
]
else:
indices = [list(range(rank_start, rank_end)) for _ in range(len(requests))]

request_ranks = multiple_sort(
requests,
[list(range(rank_start, rank_end)) for _ in range(len(requests))],
indices,
runner=runner,
top_k=self._top_k,
pivot=pivot,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ def __init__(self, indices: List[int], window_size: int, r: int):
indices, window_size=window_size, top_k=r
)

self.count_inference = 0

def _pop(self, x: int) -> List[TournamentSortNode]:
on: TournamentSortNode = self._idx_to_node[x]
lst = []
Expand All @@ -201,7 +199,6 @@ def perform(self, top_k: int):
padded = self._pad_size(resort_param)
request = ResortRequest(padded, [])
yield request
self.count_inference += 1
cleaned_result = self._unpad_perm(resort_param, padded, request.result)
nd.resort(cleaned_result)

Expand All @@ -219,7 +216,6 @@ def perform(self, top_k: int):
padded = self._pad_size(resort_param)
request = ResortRequest(padded, [])
yield request
self.count_inference += 1
assert len(request.result) > 0
cleaned_result = self._unpad_perm(
resort_param, padded, request.result
Expand Down Expand Up @@ -281,6 +277,7 @@ def reorder(
rank_start: int,
rank_end: int,
model: ModelFunction,
shuffle_candidates: bool = False,
**kwargs,
) -> list[Result]:
window_size = model.window_size
Expand All @@ -291,9 +288,17 @@ def reorder(
model.create_prompt(reqs), [ind for req, ind in reqs]
)

if shuffle_candidates:
indices = [
self._shuffle_indices(list(range(len(request.candidates))))
for request in requests
]
else:
indices = [list(range(rank_start, rank_end)) for _ in range(len(requests))]

request_ranks = multiple_sort(
requests,
[list(range(rank_start, rank_end)) for _ in range(len(requests))],
indices,
runner=runner,
window_size=window_size,
top_k=self._top_k,
Expand Down

0 comments on commit 65d3508

Please sign in to comment.