Skip to content

Commit

Permalink
Retry only upon RPCException and streamline retry logic.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 577383233
  • Loading branch information
dustinvtran authored and edward-bot committed Oct 28, 2023
1 parent 0fbed1f commit 147b2e0
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 57 deletions.
106 changes: 52 additions & 54 deletions edward2/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
"""A better map."""

import concurrent.futures
from typing import Callable, Literal, Optional, Sequence, TypeVar, overload
from typing import Callable, Literal, Sequence, TypeVar, overload

from absl import logging
import grpc
import tenacity

T = TypeVar('T')
Expand All @@ -31,10 +32,10 @@ def robust_map(
fn: Callable[[T], U],
inputs: Sequence[T],
error_output: V = ...,
index_to_output: Optional[dict[int, U | V]] = ...,
log_percent: Optional[int | float] = ...,
max_retries: Optional[int] = ...,
max_workers: Optional[int] = ...,
index_to_output: dict[int, U | V] | None = ...,
log_percent: float = ...,
max_retries: int | None = ...,
max_workers: int | None = ...,
raise_error: Literal[False] = ...,
) -> Sequence[U | V]:
...
Expand All @@ -45,10 +46,10 @@ def robust_map(
fn: Callable[[T], U],
inputs: Sequence[T],
error_output: V = ...,
index_to_output: Optional[dict[int, U | V]] = ...,
log_percent: Optional[int | float] = ...,
max_retries: Optional[int] = ...,
max_workers: Optional[int] = ...,
index_to_output: dict[int, U | V] | None = ...,
log_percent: float = ...,
max_retries: int | None = ...,
max_workers: int | None = ...,
raise_error: Literal[True] = ...,
) -> Sequence[U]:
...
Expand All @@ -61,15 +62,15 @@ def robust_map(
fn: Callable[[T], U],
inputs: Sequence[T],
error_output: V = None,
index_to_output: Optional[dict[int, U | V]] = None,
log_percent: Optional[int | float] = 5,
max_retries: Optional[int] = None,
max_workers: Optional[int] = None,
index_to_output: dict[int, U | V] | None = None,
log_percent: float = 5,
max_retries: int | None = None,
max_workers: int | None = None,
raise_error: bool = False,
) -> Sequence[U | V]:
"""Maps a function to inputs using a threadpool.
The map supports exception handling, retries with exponential backoff, and
The map supports RPC exception handling, retries with exponential backoff, and
in-place updates in order to store intermediate progress.
Args:
Expand All @@ -95,53 +96,50 @@ def robust_map(
if index_to_output is None:
index_to_output = {}
if max_retries is None:
# Apply exponential backoff with 3 retries. Retry infinitely in outer loop.
fn_retries = 3
fn_with_backoff = tenacity.retry(
retry=tenacity.retry_if_exception_type(grpc.RpcError),
wait=tenacity.wait_random_exponential(min=1, max=30),
)(fn)
else:
fn_retries = max_retries
fn_with_backoff = tenacity.retry(
wait=tenacity.wait_random_exponential(min=1, max=60),
stop=tenacity.stop_after_attempt(fn_retries),
)(fn)
fn_with_backoff = tenacity.retry(
retry=tenacity.retry_if_exception_type(grpc.RpcError),
wait=tenacity.wait_random_exponential(min=1, max=30),
stop=tenacity.stop_after_attempt(max_retries),
)(fn)
num_inputs = len(inputs)
log_steps = max(1, num_inputs * log_percent // 100)
indices = [i for i in range(num_inputs) if i not in index_to_output.keys()]
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
) as executor:
while indices:
future_to_index = {
executor.submit(fn_with_backoff, inputs[i]): i for i in indices
}
indices = [] # Clear the list since the tasks have been submitted.
for future in concurrent.futures.as_completed(future_to_index):
index = future_to_index[future]
try:
output = future.result()
index_to_output[index] = output
except tenacity.RetryError as e:
if max_retries is not None and raise_error:
logging.exception('Item %s exceeded max retries.', index)
raise e
elif max_retries is not None:
logging.warning(
'Item %s exceeded max retries. Output is set to %s. '
'Exception: %s.',
index,
error_output,
e,
)
index_to_output[index] = error_output
else:
logging.info('Retrying item %s after exception: %s.', index, e)
indices.append(index)
processed_len = len(index_to_output)
if processed_len % log_steps == 0 or processed_len == num_inputs:
logging.info(
'Completed %s/%s inputs, with %s left to retry.',
processed_len,
num_inputs,
len(indices),
future_to_index = {
executor.submit(fn_with_backoff, inputs[i]): i for i in indices
}
for future in concurrent.futures.as_completed(future_to_index):
index = future_to_index[future]
try:
output = future.result()
index_to_output[index] = output
except tenacity.RetryError as e:
if raise_error:
logging.exception('Item %s exceeded max retries.', index)
raise e
else:
logging.warning(
'Item %s exceeded max retries. Output is set to %s. '
'Exception: %s.',
index,
error_output,
e,
)
index_to_output[index] = error_output
processed_len = len(index_to_output)
if processed_len % log_steps == 0 or processed_len == num_inputs:
logging.info(
'Completed %s/%s inputs, with %s left to retry.',
processed_len,
num_inputs,
len(indices),
)
outputs = [index_to_output[i] for i in range(num_inputs)]
return outputs
18 changes: 15 additions & 3 deletions edward2/maps_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import edward2 as ed
from edward2 import maps
import grpc
import numpy as np
import tenacity
import tensorflow as tf
Expand All @@ -41,7 +42,7 @@ def test_robust_map_library_import(self):
def test_robust_map_error_output(self):
def fn(x):
if x == 1:
raise ValueError('Input value 1 is not supported.')
raise grpc.RpcError('Input value 1 takes too long to process.')
else:
return x + 1

Expand Down Expand Up @@ -69,7 +70,7 @@ def test_robust_map_index_to_output(self):
def test_robust_map_max_retries(self):
def fn(x):
if x == 1:
raise ValueError('Input value 1 is not supported.')
raise grpc.RpcError('Input value 1 takes too long to process.')
else:
return x + 1

Expand All @@ -84,7 +85,7 @@ def fn(x):
def test_robust_map_raise_error(self):
def fn(x):
if x == 1:
raise ValueError('Input value 1 is not supported.')
raise grpc.RpcError('Input value 1 is not supported.')
else:
return x + 1

Expand All @@ -97,6 +98,17 @@ def fn(x):
raise_error=True,
)

def test_robust_map_non_rpc_error(self):
def fn(x):
if x == 1:
raise ValueError('Input value 1 is not supported.')
else:
return x + 1

x = [0, 1, 2]
with self.assertRaises(ValueError):
maps.robust_map(fn, x)


if __name__ == '__main__':
tf.test.main()

0 comments on commit 147b2e0

Please sign in to comment.