From 147b2e0cd5437db676aacc3bd99ac4582648875d Mon Sep 17 00:00:00 2001 From: Dustin Tran Date: Fri, 27 Oct 2023 22:02:30 -0700 Subject: [PATCH] Retry only upon RPCException and streamline retry logic. PiperOrigin-RevId: 577383233 --- edward2/maps.py | 106 +++++++++++++++++++++---------------------- edward2/maps_test.py | 18 ++++++-- 2 files changed, 67 insertions(+), 57 deletions(-) diff --git a/edward2/maps.py b/edward2/maps.py index efcb5f53..06559025 100644 --- a/edward2/maps.py +++ b/edward2/maps.py @@ -16,9 +16,10 @@ """A better map.""" import concurrent.futures -from typing import Callable, Literal, Optional, Sequence, TypeVar, overload +from typing import Callable, Literal, Sequence, TypeVar, overload from absl import logging +import grpc import tenacity T = TypeVar('T') @@ -31,10 +32,10 @@ def robust_map( fn: Callable[[T], U], inputs: Sequence[T], error_output: V = ..., - index_to_output: Optional[dict[int, U | V]] = ..., - log_percent: Optional[int | float] = ..., - max_retries: Optional[int] = ..., - max_workers: Optional[int] = ..., + index_to_output: dict[int, U | V] | None = ..., + log_percent: float = ..., + max_retries: int | None = ..., + max_workers: int | None = ..., raise_error: Literal[False] = ..., ) -> Sequence[U | V]: ... @@ -45,10 +46,10 @@ def robust_map( fn: Callable[[T], U], inputs: Sequence[T], error_output: V = ..., - index_to_output: Optional[dict[int, U | V]] = ..., - log_percent: Optional[int | float] = ..., - max_retries: Optional[int] = ..., - max_workers: Optional[int] = ..., + index_to_output: dict[int, U | V] | None = ..., + log_percent: float = ..., + max_retries: int | None = ..., + max_workers: int | None = ..., raise_error: Literal[True] = ..., ) -> Sequence[U]: ... @@ -61,15 +62,15 @@ def robust_map( fn: Callable[[T], U], inputs: Sequence[T], error_output: V = None, - index_to_output: Optional[dict[int, U | V]] = None, - log_percent: Optional[int | float] = 5, - max_retries: Optional[int] = None, - max_workers: Optional[int] = None, + index_to_output: dict[int, U | V] | None = None, + log_percent: float = 5, + max_retries: int | None = None, + max_workers: int | None = None, raise_error: bool = False, ) -> Sequence[U | V]: """Maps a function to inputs using a threadpool. - The map supports exception handling, retries with exponential backoff, and + The map supports RPC exception handling, retries with exponential backoff, and in-place updates in order to store intermediate progress. Args: @@ -95,53 +96,50 @@ def robust_map( if index_to_output is None: index_to_output = {} if max_retries is None: - # Apply exponential backoff with 3 retries. Retry infinitely in outer loop. - fn_retries = 3 + fn_with_backoff = tenacity.retry( + retry=tenacity.retry_if_exception_type(grpc.RpcError), + wait=tenacity.wait_random_exponential(min=1, max=30), + )(fn) else: - fn_retries = max_retries - fn_with_backoff = tenacity.retry( - wait=tenacity.wait_random_exponential(min=1, max=60), - stop=tenacity.stop_after_attempt(fn_retries), - )(fn) + fn_with_backoff = tenacity.retry( + retry=tenacity.retry_if_exception_type(grpc.RpcError), + wait=tenacity.wait_random_exponential(min=1, max=30), + stop=tenacity.stop_after_attempt(max_retries), + )(fn) num_inputs = len(inputs) log_steps = max(1, num_inputs * log_percent // 100) indices = [i for i in range(num_inputs) if i not in index_to_output.keys()] with concurrent.futures.ThreadPoolExecutor( max_workers=max_workers ) as executor: - while indices: - future_to_index = { - executor.submit(fn_with_backoff, inputs[i]): i for i in indices - } - indices = [] # Clear the list since the tasks have been submitted. - for future in concurrent.futures.as_completed(future_to_index): - index = future_to_index[future] - try: - output = future.result() - index_to_output[index] = output - except tenacity.RetryError as e: - if max_retries is not None and raise_error: - logging.exception('Item %s exceeded max retries.', index) - raise e - elif max_retries is not None: - logging.warning( - 'Item %s exceeded max retries. Output is set to %s. ' - 'Exception: %s.', - index, - error_output, - e, - ) - index_to_output[index] = error_output - else: - logging.info('Retrying item %s after exception: %s.', index, e) - indices.append(index) - processed_len = len(index_to_output) - if processed_len % log_steps == 0 or processed_len == num_inputs: - logging.info( - 'Completed %s/%s inputs, with %s left to retry.', - processed_len, - num_inputs, - len(indices), + future_to_index = { + executor.submit(fn_with_backoff, inputs[i]): i for i in indices + } + for future in concurrent.futures.as_completed(future_to_index): + index = future_to_index[future] + try: + output = future.result() + index_to_output[index] = output + except tenacity.RetryError as e: + if raise_error: + logging.exception('Item %s exceeded max retries.', index) + raise e + else: + logging.warning( + 'Item %s exceeded max retries. Output is set to %s. ' + 'Exception: %s.', + index, + error_output, + e, ) + index_to_output[index] = error_output + processed_len = len(index_to_output) + if processed_len % log_steps == 0 or processed_len == num_inputs: + logging.info( + 'Completed %s/%s inputs, with %s left to retry.', + processed_len, + num_inputs, + len(indices), + ) outputs = [index_to_output[i] for i in range(num_inputs)] return outputs diff --git a/edward2/maps_test.py b/edward2/maps_test.py index b8f9c903..849099e5 100644 --- a/edward2/maps_test.py +++ b/edward2/maps_test.py @@ -17,6 +17,7 @@ import edward2 as ed from edward2 import maps +import grpc import numpy as np import tenacity import tensorflow as tf @@ -41,7 +42,7 @@ def test_robust_map_library_import(self): def test_robust_map_error_output(self): def fn(x): if x == 1: - raise ValueError('Input value 1 is not supported.') + raise grpc.RpcError('Input value 1 takes too long to process.') else: return x + 1 @@ -69,7 +70,7 @@ def test_robust_map_index_to_output(self): def test_robust_map_max_retries(self): def fn(x): if x == 1: - raise ValueError('Input value 1 is not supported.') + raise grpc.RpcError('Input value 1 takes too long to process.') else: return x + 1 @@ -84,7 +85,7 @@ def fn(x): def test_robust_map_raise_error(self): def fn(x): if x == 1: - raise ValueError('Input value 1 is not supported.') + raise grpc.RpcError('Input value 1 is not supported.') else: return x + 1 @@ -97,6 +98,17 @@ def fn(x): raise_error=True, ) + def test_robust_map_non_rpc_error(self): + def fn(x): + if x == 1: + raise ValueError('Input value 1 is not supported.') + else: + return x + 1 + + x = [0, 1, 2] + with self.assertRaises(ValueError): + maps.robust_map(fn, x) + if __name__ == '__main__': tf.test.main()