Skip to content

Commit

Permalink
Add elapsed time and ETA to robust_map.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 578223292
  • Loading branch information
dustinvtran authored and edward-bot committed Oct 31, 2023
1 parent edf5e4c commit f4a32d9
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions edward2/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"""A better map."""

import concurrent.futures
import datetime
import time
from typing import Callable, Literal, Sequence, TypeVar, overload

from absl import logging
Expand Down Expand Up @@ -104,15 +106,18 @@ def robust_map(
wait=tenacity.wait_random_exponential(min=1, max=30),
stop=tenacity.stop_after_attempt(max_retries),
)(fn)
num_existing = len(index_to_output)
num_inputs = len(inputs)
log_steps = max(1, num_inputs * log_percent // 100)
logging.info('Found %s/%s existing examples.', num_existing, num_inputs)
indices = [i for i in range(num_inputs) if i not in index_to_output.keys()]
log_steps = max(1, num_inputs * log_percent // 100)
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
future_to_index = {
executor.submit(fn_with_backoff, inputs[i]): i for i in indices
}
start = time.time()
for future in concurrent.futures.as_completed(future_to_index):
index = future_to_index[future]
try:
Expand All @@ -131,13 +136,22 @@ def robust_map(
e,
)
index_to_output[index] = error_output
processed_len = len(index_to_output)
if processed_len % log_steps == 0 or processed_len == num_inputs:
num_so_far = len(index_to_output)
if num_so_far % log_steps == 0 or num_so_far == num_inputs:
end = time.time()
elapsed = datetime.timedelta(seconds=end - start)
num_completed = num_so_far - num_existing
avg_per_example = elapsed / num_completed
num_remaining = num_inputs - num_so_far
eta = avg_per_example * num_remaining
logging.info(
'Completed %s/%s inputs, with %s left to retry.',
processed_len,
'Completed %s/%s inputs. Elapsed time (started with %s inputs): %s.'
' ETA: %s.',
num_so_far,
num_inputs,
len(indices),
num_existing,
elapsed,
eta,
)
outputs = [index_to_output[i] for i in range(num_inputs)]
return outputs

0 comments on commit f4a32d9

Please sign in to comment.