From 5f00747a9042b68c1591ca11efd0942acbecfac9 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 01:10:18 +0100 Subject: [PATCH] `server`: test request cancellation (WIP) --- examples/server/tests/features/cancel.feature | 43 +++++++++++++++++++ examples/server/tests/features/steps/steps.py | 28 +++++++++++- 2 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 examples/server/tests/features/cancel.feature diff --git a/examples/server/tests/features/cancel.feature b/examples/server/tests/features/cancel.feature new file mode 100644 index 0000000000000..54ded24c67c19 --- /dev/null +++ b/examples/server/tests/features/cancel.feature @@ -0,0 +1,43 @@ +@llama.cpp +@server +Feature: Cancellation of llama.cpp server requests + + Background: Server startup + Given a server listening on localhost:8080 + And 500 milliseconds delay in sampler for testing + And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models + And a model file test-model.gguf + And a model alias tinyllama-2 + And BOS token is 1 + And 42 as server seed + # KV Cache corresponds to the total amount of tokens + # that can be stored across all independent sequences: #4130 + # see --ctx-size and #5568 + And 256 KV cache size + And 32 as batch size + And 1 slots + And 64 server max tokens to predict + Then the server is starting + Then the server is healthy + + # Scenario: Health + # Then the server is ready + # And all slots are idle + + @wip + Scenario Outline: Cancelling completion request frees up slot + Given a prompt: + """ + Once upon + """ + And 256 max tokens to predict + And 256 server max tokens to predict + And streaming is + And a completion request cancelled after 100 milliseconds + # And wait for 50 milliseconds + Then all slots are idle + + Examples: Prompts + | enable_streaming | + | disabled | + | enabled | diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 31bfb0b2b152a..5bc4b06316351 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -291,6 +291,25 @@ async def step_request_completion(context, api_error: Literal['raised'] | str): api_error_code = int(api_error) assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}" +@step('wait for {millis:d} milliseconds') +@async_run_until_complete +async def step_request_completion(context, millis: int): + await asyncio.sleep(millis / 1000.0) + +@step('a completion request cancelled after {disconnect_after_millis:d} milliseconds') +@async_run_until_complete +async def step_request_completion(context, disconnect_after_millis: int): + seeds = await completions_seed(context, num_seeds=1) + await request_completion(context.prompts.pop(), + seeds[0] if seeds is not None else seeds, + context.base_url, + debug=context.debug, + n_predict=context.n_predict, + cache_prompt=context.cache_prompt, + id_slot=context.id_slot, + disconnect_after_millis=disconnect_after_millis, + user_api_key=context.user_api_key, + temperature=context.temperature) @step('{predicted_n:d} tokens are predicted matching {re_content}') def step_n_tokens_predicted_with_content(context, predicted_n, re_content): @@ -982,9 +1001,10 @@ async def request_completion(prompt, id_slot=None, expect_api_error=None, user_api_key=None, + disconnect_after_millis=None, temperature=None) -> int | dict[str, Any]: if debug: - print(f"Sending completion request: {prompt}") + print(f"Sending completion request: {prompt} with n_predict={n_predict}") origin = "my.super.domain" headers = { 'Origin': origin @@ -1008,6 +1028,10 @@ async def request_completion(prompt, "n_probs": 2, }, headers=headers) as response: + if disconnect_after_millis is not None: + await asyncio.sleep(disconnect_after_millis / 1000) + return 0 + if expect_api_error is None or not expect_api_error: assert response.status == 200 assert response.headers['Access-Control-Allow-Origin'] == origin @@ -1352,7 +1376,7 @@ async def request_slots_status(context, expected_slots): def assert_slots_status(slots, expected_slots): - assert len(slots) == len(expected_slots) + assert len(slots) == len(expected_slots), f'invalid number of slots: {len(slots)} (actual) != {len(expected_slots)} (expected)' for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)): for key in expected: assert expected[key] == slot[key], (f"invalid slot {slot_id}"