Skip to content

Commit

Permalink
Merge pull request #367 from waleedqk/include_stop_sequence_param
Browse files Browse the repository at this point in the history
Add param include_stop_sequence
  • Loading branch information
gkumbhat authored Jul 16, 2024
2 parents 7d54770 + 5443644 commit 8f71845
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 1 deletion.
4 changes: 4 additions & 0 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def run(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand All @@ -242,6 +243,7 @@ def run(
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
include_stop_sequence=include_stop_sequence,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
Expand Down Expand Up @@ -281,6 +283,7 @@ def run_stream_out(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing against the model running in TGIS
Expand Down Expand Up @@ -308,6 +311,7 @@ def run_stream_out(
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
include_stop_sequence=include_stop_sequence,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
Expand Down
4 changes: 4 additions & 0 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def run(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand All @@ -258,6 +259,7 @@ def run(
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
include_stop_sequence=include_stop_sequence,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
Expand Down Expand Up @@ -297,6 +299,7 @@ def run_stream_out(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing for text generation module.
Expand All @@ -316,6 +319,7 @@ def run_stream_out(
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
include_stop_sequence=include_stop_sequence,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
Expand Down
15 changes: 15 additions & 0 deletions caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@
token_ranks: bool
Whether or not to include rank of each returned token.
Applicable only if generated_tokens == true and/or input_tokens == true
include_stop_sequence: bool
Whether or not to include stop sequence.
If not specified, default behavior depends on server setting.
""".format(
GENERATE_FUNCTION_ARGS
)
Expand Down Expand Up @@ -109,6 +112,7 @@ def validate_inf_params(
generated_tokens,
token_logprobs,
token_ranks,
include_stop_sequence,
eos_token,
max_new_tokens,
min_new_tokens,
Expand Down Expand Up @@ -139,6 +143,9 @@ def validate_inf_params(
error.type_check("<NLP65883539E>", bool, generated_tokens=generated_tokens)
error.type_check("<NLP65883540E>", bool, token_logprobs=token_logprobs)
error.type_check("<NLP65883541E>", bool, token_ranks=token_ranks)
error.type_check(
"<NLP65883542E>", bool, include_stop_sequence=include_stop_sequence
)
error.type_check("<NLP85452188E>", str, allow_none=True, eos_token=eos_token)
error.type_check(
"<NLP03860681E>",
Expand Down Expand Up @@ -243,6 +250,7 @@ def get_params(
generated_tokens,
token_logprobs,
token_ranks,
include_stop_sequence,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
Expand Down Expand Up @@ -290,6 +298,7 @@ def get_params(
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
time_limit_millis=int(max_time * 1000) if max_time else None,
include_stop_sequence=include_stop_sequence,
)

if exponential_decay_length_penalty:
Expand Down Expand Up @@ -358,6 +367,7 @@ def unary_generate(
generated_tokens,
token_logprobs,
token_ranks,
include_stop_sequence,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
Expand Down Expand Up @@ -399,6 +409,7 @@ def unary_generate(
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
include_stop_sequence=include_stop_sequence,
eos_token=self.eos_token,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
Expand All @@ -423,6 +434,7 @@ def unary_generate(
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
include_stop_sequence=include_stop_sequence,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
Expand Down Expand Up @@ -506,6 +518,7 @@ def stream_generate(
generated_tokens,
token_logprobs,
token_ranks,
include_stop_sequence,
max_new_tokens,
min_new_tokens,
truncate_input_tokens,
Expand Down Expand Up @@ -547,6 +560,7 @@ def stream_generate(
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
include_stop_sequence=include_stop_sequence,
eos_token=self.eos_token,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
Expand All @@ -569,6 +583,7 @@ def stream_generate(
generated_tokens=generated_tokens,
token_logprobs=token_logprobs,
token_ranks=token_ranks,
include_stop_sequence=include_stop_sequence,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
truncate_input_tokens=truncate_input_tokens,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers=[
]
dependencies = [
"caikit[runtime-grpc,runtime-http]>=0.26.34,<0.27.0",
"caikit-tgis-backend>=0.1.34,<0.2.0",
"caikit-tgis-backend>=0.1.36,<0.2.0",
# TODO: loosen dependencies
"grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking
"grpcio-reflection>=1.62.2",
Expand Down
1 change: 1 addition & 0 deletions tests/toolkit/text_generation/test_tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def test_TGISGenerationClient_rpc_errors(status_code, method):
max_time=None,
exponential_decay_length_penalty=None,
stop_sequences=["asdf"],
include_stop_sequence=True,
)
if method.endswith("_generate")
else dict()
Expand Down

0 comments on commit 8f71845

Please sign in to comment.