-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_models_response.py
112 lines (86 loc) · 3.18 KB
/
test_models_response.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import json
import logging
import os
import boto3
import click
logging.basicConfig(format="%(levelname)s %(message)s", level=logging.INFO)
LOGGER = logging.getLogger()
LOGGER.setLevel("INFO")
MODEL_ID = os.getenv("MODEL_ID")
def converse_with_model(model_id, text, type):
try:
client = boto3.client("bedrock-runtime")
messages = []
messages.append({"role": type, "content": [{"text": text}]})
response = client.converse(
modelId=model_id,
messages=messages,
)
LOGGER.info(response["output"]["message"]["content"][0]["text"])
except Exception as e:
LOGGER.error("Error: " + str(e))
return "Error: " + str(e)
def ask_model(text):
try:
bedrock = boto3.client("bedrock-runtime")
# Invoke the agent with a prompt
prompt = text
# f"Write a summary of the text provided: {text}"
body = json.dumps(
{
"inputText": prompt,
"textGenerationConfig": {
"maxTokenCount": 4096,
"stopSequences": [],
"temperature": 0,
"topP": 1,
},
}
)
modelId = MODEL_ID # "amazon.titan-text-express-v1"
accept = "application/json"
contentType = "application/json"
response = bedrock.invoke_model(
body=body, modelId=modelId, accept=accept, contentType=contentType
)
response_body = json.loads(response.get("body").read())
finish_reason = response_body.get("error")
if finish_reason is not None:
raise Exception(f"Text generation error. Error is {finish_reason}")
LOGGER.info(
"Successfully generated text with Amazon &titan-text-express; model %s",
modelId,
)
for result in response_body["results"]:
LOGGER.info(f"Token count: {result['tokenCount']}")
LOGGER.info(f"Output text: {result['outputText']}")
LOGGER.info(f"Completion reason: {result['completionReason']}")
except Exception as e:
LOGGER.error("Error summarizing text: " + str(e))
return "Error summarizing text: " + str(e)
@click.command()
@click.option("--prompt", default="", help="Prompt for the model")
def handler(prompt):
"""
Triggers a step function execution.
Args:
event (dict): The event that triggered the function.
context (dict): The context of the function execution.
Returns:
dict: A dictionary containing the status code and message.
"""
# Log the event argument for debugging and for use in local development.
response = {"text": prompt}
ask_model(response["text"])
return {"statusCode": 200, "message": "Success"}
@click.command()
@click.option("--prompt", default="What is your name?", help="Prompt for the model")
@click.option("--type", default="user", help="user or assistant")
@click.option(
"--model", default="amazon.titan-text-express-v1", help="AWS Bedrock Model ID"
)
def speak(model, prompt, type):
converse_with_model(model_id=model, text=prompt, type=type)
if __name__ == "__main__":
# handler()
speak()