Skip to content

Commit

Permalink
fix use 1 file _PROXY_track_cost_callback (#7304)
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaan-jaff authored Dec 19, 2024
1 parent cf9312a commit 4b2958b
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 115 deletions.
130 changes: 130 additions & 0 deletions litellm/proxy/hooks/proxy_track_cost_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import asyncio
import traceback
from typing import Optional, Union

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs,
)
from litellm.proxy.auth.auth_checks import log_db_metrics
from litellm.types.utils import StandardLoggingPayload
from litellm.utils import get_end_user_id_for_cost_tracking


@log_db_metrics
async def _PROXY_track_cost_callback(
kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion
start_time=None,
end_time=None, # start/end time for completion
):
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
update_cache,
update_database,
)

verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
try:
verbose_proxy_logger.debug(
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
litellm_params = kwargs.get("litellm_params", {}) or {}
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
user_id = metadata.get("user_api_key_user_id", None)
team_id = metadata.get("user_api_key_team_id", None)
org_id = metadata.get("user_api_key_org_id", None)
key_alias = metadata.get("user_api_key_alias", None)
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
sl_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
response_cost = (
sl_object.get("response_cost", None)
if sl_object is not None
else kwargs.get("response_cost", None)
)

if response_cost is not None:
user_api_key = metadata.get("user_api_key", None)
if kwargs.get("cache_hit", False) is True:
response_cost = 0.0
verbose_proxy_logger.info(
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
)

verbose_proxy_logger.debug(
f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
)
if user_api_key is not None or user_id is not None or team_id is not None:
## UPDATE DATABASE
await update_database(
token=user_api_key,
response_cost=response_cost,
user_id=user_id,
end_user_id=end_user_id,
team_id=team_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
org_id=org_id,
)

# update cache
asyncio.create_task(
update_cache(
token=user_api_key,
user_id=user_id,
end_user_id=end_user_id,
response_cost=response_cost,
team_id=team_id,
parent_otel_span=parent_otel_span,
)
)

await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
token=user_api_key,
key_alias=key_alias,
end_user_id=end_user_id,
response_cost=response_cost,
max_budget=end_user_max_budget,
)
else:
raise Exception(
"User API key and team id and user id missing from custom callback."
)
else:
if kwargs["stream"] is not True or (
kwargs["stream"] is True and "complete_streaming_response" in kwargs
):
if sl_object is not None:
cost_tracking_failure_debug_info: Union[dict, str] = (
sl_object["response_cost_failure_debug_info"] # type: ignore
or "response_cost_failure_debug_info is None in standard_logging_object"
)
else:
cost_tracking_failure_debug_info = (
"standard_logging_object not found"
)
model = kwargs.get("model")
raise Exception(
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
)
except Exception as e:
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
model = kwargs.get("model", "")
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
asyncio.create_task(
proxy_logging_obj.failed_tracking_alert(
error_message=error_msg,
failing_model=model,
)
)
verbose_proxy_logger.exception("Error in tracking cost callback - %s", str(e))
4 changes: 0 additions & 4 deletions litellm/proxy/proxy_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,4 @@ model_list:
litellm_params:
model: openai/o1-preview
api_key: os.environ/OPENAI_API_KEY
- model_name: openai/*
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY

112 changes: 1 addition & 111 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def generate_feedback_box():
_OPTIONAL_PromptInjectionDetection,
)
from litellm.proxy.hooks.proxy_failure_handler import _PROXY_failure_handler
from litellm.proxy.hooks.proxy_track_cost_callback import _PROXY_track_cost_callback
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
from litellm.proxy.management_endpoints.customer_endpoints import (
router as customer_router,
Expand Down Expand Up @@ -692,117 +693,6 @@ def cost_tracking():
litellm._async_success_callback.append(_PROXY_track_cost_callback) # type: ignore


@log_db_metrics
async def _PROXY_track_cost_callback(
kwargs, # kwargs to completion
completion_response: litellm.ModelResponse, # response from completion
start_time=None,
end_time=None, # start/end time for completion
):
verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
global prisma_client
try:
verbose_proxy_logger.debug(
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
litellm_params = kwargs.get("litellm_params", {}) or {}
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
user_id = metadata.get("user_api_key_user_id", None)
team_id = metadata.get("user_api_key_team_id", None)
org_id = metadata.get("user_api_key_org_id", None)
key_alias = metadata.get("user_api_key_alias", None)
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
sl_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
response_cost = (
sl_object.get("response_cost", None)
if sl_object is not None
else kwargs.get("response_cost", None)
)

if response_cost is not None:
user_api_key = metadata.get("user_api_key", None)
if kwargs.get("cache_hit", False) is True:
response_cost = 0.0
verbose_proxy_logger.info(
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
)

verbose_proxy_logger.debug(
f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
)
if user_api_key is not None or user_id is not None or team_id is not None:
## UPDATE DATABASE
await update_database(
token=user_api_key,
response_cost=response_cost,
user_id=user_id,
end_user_id=end_user_id,
team_id=team_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
org_id=org_id,
)

# update cache
asyncio.create_task(
update_cache(
token=user_api_key,
user_id=user_id,
end_user_id=end_user_id,
response_cost=response_cost,
team_id=team_id,
parent_otel_span=parent_otel_span,
)
)

await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
token=user_api_key,
key_alias=key_alias,
end_user_id=end_user_id,
response_cost=response_cost,
max_budget=end_user_max_budget,
)
else:
raise Exception(
"User API key and team id and user id missing from custom callback."
)
else:
if kwargs["stream"] is not True or (
kwargs["stream"] is True and "complete_streaming_response" in kwargs
):
if sl_object is not None:
cost_tracking_failure_debug_info: Union[dict, str] = (
sl_object["response_cost_failure_debug_info"] # type: ignore
or "response_cost_failure_debug_info is None in standard_logging_object"
)
else:
cost_tracking_failure_debug_info = (
"standard_logging_object not found"
)
model = kwargs.get("model")
raise Exception(
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
)
except Exception as e:
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
model = kwargs.get("model", "")
metadata = kwargs.get("litellm_params", {}).get("metadata", {})
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
asyncio.create_task(
proxy_logging_obj.failed_tracking_alert(
error_message=error_msg,
failing_model=model,
)
)
verbose_proxy_logger.debug(error_msg)


def error_tracking():
global prisma_client
if prisma_client is not None:
Expand Down

0 comments on commit 4b2958b

Please sign in to comment.