Skip to content

Commit

Permalink
feat: Add Xpuls AI Integration + Demo Environment + Prompt Debugging (
Browse files Browse the repository at this point in the history
…#11)

* - Add Xpuls.ai Integration
- Add Demo Environment

* cleanup

* Make xpulsai tracing optional

* Make xpulsai tracing optional

* Bump version to 0.1.0
  • Loading branch information
SHARANTANGEDA authored Oct 6, 2023
1 parent 38d3b1c commit e5e9eca
Show file tree
Hide file tree
Showing 15 changed files with 345 additions and 42 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.idea
venv
dist

*.env
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,19 @@ pip install xpuls-mlmonitor
## Usage Example
```python
from xpuls.mlmonitor.langchain.instrument import LangchainTelemetry
import os

# Enable this for advance tracking with our xpuls-ml platform
os.environ["XPULSAI_TRACING_ENABLED"] = "true"

# Add default labels that will be added to all captured metrics
default_labels = {"service": "ml-project-service", "k8s_cluster": "app0", "namespace": "dev", "agent_name": "fallback_value"}

# Enable the auto-telemetry
LangchainTelemetry(default_labels=default_labels).auto_instrument()
LangchainTelemetry(
default_labels=default_labels,
xpuls_host_url="http://app.xpuls.ai" # Optional param, required when XPULSAI_TRACING is enabled
).auto_instrument()

## [Optional] Override labels for scope of decorator [Useful if you have multiple scopes where you need to override the default label values]
@TelemetryOverrideLabels(agent_name="chat_agent_alpha")
Expand Down
Empty file added __init__.py
Empty file.
4 changes: 4 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from demo.openai_langchain import run_openai_agent

res = run_openai_agent()
print(str(res))
Empty file added demo/__init__.py
Empty file.
65 changes: 65 additions & 0 deletions demo/openai_langchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import logging
import os

import openai
from langchain.agents import initialize_agent, AgentType
from langchain.chat_models import AzureChatOpenAI
from langchain.memory import ConversationBufferMemory

from xpuls.mlmonitor.langchain.decorators.map_xpuls_project import MapXpulsProject
from xpuls.mlmonitor.langchain.decorators.telemetry_override_labels import TelemetryOverrideLabels
from xpuls.mlmonitor.langchain.instrument import LangchainTelemetry

logger = logging.getLogger(__name__)

openai.api_key = os.getenv("OPENAI_API_KEY")
openai.api_type = "azure"
openai.api_base = os.getenv("OPENAI_URL")
os.environ["OPENAI_API_BASE"] = os.getenv("OPENAI_URL")
os.environ["OPENAI_API_VERSION"] = "2023-03-15-preview"
openai.api_version = "2023-03-15-preview"

# Set this to enable Advanced prompt tracing with server
# os.environ["XPULSAI_TRACING_ENABLED"] = "false"
os.environ["XPULSAI_TRACING_ENABLED"] = "true"

default_labels = {"system": "openai-ln-test", "agent_name": "fallback_value"}

LangchainTelemetry(
default_labels=default_labels,
xpuls_host_url="http://localhost:8000"
).auto_instrument()

memory = ConversationBufferMemory(memory_key="chat_history")
chat_model = AzureChatOpenAI(
deployment_name="gpt35turbo",
model_name="gpt-35-turbo",
temperature=0
)


@TelemetryOverrideLabels(agent_name="chat_agent_alpha")
@MapXpulsProject(project_id="default") # Get Project ID from console
def run_openai_agent():
agent = initialize_agent(llm=chat_model,
verbose=True,
tools=[],
agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION,
memory=memory,
# handle_parsing_errors="Check your output and make sure it conforms!",
return_intermediate_steps=False,
agent_executor_kwargs={"extra_prompt_messages": "test"})

try:
res = agent.run("You are to behave as a think tank to answer the asked question in most creative way,"
" ensure to NOT be abusive or racist, you should validate your response w.r.t to validity "
"in practical world before giving final answer" +
f"\nQuestion: How does nature work?, is balance of life true? \n")
except ValueError as e:
res = str(e)
if not res.startswith("Could not parse LLM output: `"):
raise e
logger.error(f" Got ValueError: {e}")
res = res.removeprefix("Could not parse LLM output: `").removesuffix("`")

return res
2 changes: 2 additions & 0 deletions demo_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
openai
langchain
5 changes: 2 additions & 3 deletions requirements/base_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
opentelemetry-instrumentation-requests
opentelemetry-api
opentelemetry-sdk
prometheus-client
pydantic
requests
urllib3
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def read_requirements(file_name):

setup(
name='xpuls-mlmonitor',
version='0.0.8',
version='0.1.0',
author='Sai Sharan Tangeda',
author_email='saisarantangeda@gmail.com',
description='Automated telemetry and monitoring for ML & LLM Frameworks',
Expand Down
21 changes: 21 additions & 0 deletions xpuls/mlmonitor/langchain/decorators/map_xpuls_project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import contextvars
from typing import Optional, Any, Dict


class MapXpulsProject:
_context: contextvars.ContextVar[Optional[Dict[str, Any]]] = contextvars.ContextVar('telemetry_extra_labels_vars',
default=None)

def __init__(self, project_id: Optional[str] = None, project_slug: Optional[str] = None):
if project_id is None and project_slug is None:
raise ValueError("Both `project_id` and `project_slug` cannot be null")
self.project_id = project_id
self.project_slug = project_slug

def __call__(self, func):
def wrapped_func(*args, **kwargs):
self._context.set({'project_id': self.project_id, 'project_slug': self.project_slug})

return func(*args, **kwargs)

return wrapped_func
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(self, **labels):
def __call__(self, func):
def wrapped_func(*args, **kwargs):
self._context.set(self.labels)

return func(*args, **kwargs)

return wrapped_func
12 changes: 2 additions & 10 deletions xpuls/mlmonitor/langchain/handlers/callback_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,23 @@

import pydantic
from langchain.callbacks.base import AsyncCallbackHandler

from langchain.schema.output import LLMResult
from langchain.schema.messages import BaseMessage

from langchain.schema.agent import AgentAction, AgentFinish

from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider

from xpuls.mlmonitor.langchain.profiling.prometheus import LangchainChainMetrics, LangchainPrometheusMetrics, \
LangchainChatModelMetrics, LangchainOpenAITokens, LangchainToolMetrics
from xpuls.mlmonitor.utils.common import get_safe_dict_value

from . import constants as c
# Set the tracer provider and a console exporter
trace.set_tracer_provider(TracerProvider())
trace.get_tracer_provider()

tracer = trace.get_tracer(__name__)


class CallbackHandler(AsyncCallbackHandler):
log = logging.getLogger()

def __init__(self, ln_metrics: LangchainPrometheusMetrics, chain_run_id: str, override_labels: Dict[str, str]) -> None:
def __init__(self, ln_metrics: LangchainPrometheusMetrics, chain_run_id: str,
override_labels: Dict[str, str]) -> None:
self.llm_start_time = None
self.llm_end_time = None

Expand Down
17 changes: 12 additions & 5 deletions xpuls/mlmonitor/langchain/instrument.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
from typing import Dict, Any
from langsmith import Client


from xpuls.mlmonitor.langchain.patches import patch_chain
from xpuls.mlmonitor.langchain.profiling.prometheus import LangchainPrometheusMetrics
from xpuls.mlmonitor.langchain.xpuls_client import XpulsAILangChainClient


class LangchainTelemetry:
def __init__(self, default_labels: Dict[str, Any],
enable_prometheus=True,
enable_otel_tracing=True,
enable_otel_logging=False):
xpuls_host_url: str = "http://localhost:8000",
enable_prometheus: bool = True,
enable_otel_tracing: bool = True,
enable_otel_logging: bool = False):
self.ln_metrics = LangchainPrometheusMetrics(default_labels)

self.xpuls_client = XpulsAILangChainClient(
api_url=xpuls_host_url
)

self.default_labels = default_labels
self.enable_prometheus = enable_prometheus
self.enable_otel_tracing = enable_otel_tracing
self.enable_otel_logging = enable_otel_logging

def auto_instrument(self):
patch_chain(self.ln_metrics)
patch_chain(self.ln_metrics, self.xpuls_client)
print("** ProfileML -> Langchain auto-instrumentation completed successfully **")

62 changes: 40 additions & 22 deletions xpuls/mlmonitor/langchain/patches/patch_chain.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,77 @@
import os
import uuid
from typing import Dict, Any

from langchain.callbacks import LangChainTracer
from langchain.chains.base import Chain
from langsmith import Client

from xpuls.mlmonitor.langchain.decorators.telemetry_override_labels import TelemetryOverrideLabels
from xpuls.mlmonitor.langchain.decorators.map_xpuls_project import MapXpulsProject
from xpuls.mlmonitor.langchain.handlers.callback_handlers import CallbackHandler

from xpuls.mlmonitor.langchain.profiling.prometheus import LangchainPrometheusMetrics
from xpuls.mlmonitor.langchain.xpuls_client import XpulsAILangChainClient


def patch_chain(ln_metrics: LangchainPrometheusMetrics):
def patch_chain(ln_metrics: LangchainPrometheusMetrics, xpuls_client: XpulsAILangChainClient):
# Store the original run method
original_run = Chain.run
original_arun = Chain.arun

def patched_run(self, *args, **kwargs):
def _apply_patch(kwargs):
try:
override_labels = TelemetryOverrideLabels._context.get()
if override_labels is None:
override_labels = {}
except Exception as e:
override_labels = {}
print(f"Error getting labels. Exception: {e}")

try:
project_details = MapXpulsProject._context.get()
if project_details is None:
project_details = {}
except Exception as e:
project_details = {'project_id': 'default'}

updated_labels = dict(ln_metrics.get_default_labels(), **override_labels)
chain_run_id = str(uuid.uuid4())
ln_tracer = LangChainTracer(
project_name=project_details['project_id'] if project_details['project_id'] is not None else
project_details['project_slug'],
client=xpuls_client,
)

callback_handler = CallbackHandler(ln_metrics, chain_run_id, override_labels)

with ln_metrics.agent_run_histogram.labels(**dict(ln_metrics.get_default_labels(), **override_labels)).time():
if 'callbacks' in kwargs:
kwargs['callbacks'].append(callback_handler)
else:
kwargs['callbacks'] = [callback_handler]

# Call the original run method
return original_run(self, *args, **kwargs)
if os.getenv("XPULSAI_TRACING_ENABLED", "false") == "true":
kwargs['callbacks'].append(ln_tracer)
metadata = {'xpuls': {'labels': updated_labels, 'run_id': chain_run_id,
'project_id': project_details['project_id'] if project_details['project_id'] is not None else project_details['project_slug']}}
if 'metadata' in kwargs:
kwargs['metadata'] = dict(kwargs['metadata'], **metadata)
else:
kwargs['metadata'] = metadata
return kwargs, ln_tracer, updated_labels

def patched_arun(self, *args, **kwargs):
try:
override_labels = TelemetryOverrideLabels._context.get()
if override_labels is None:
override_labels = {}
except Exception as e:
override_labels = {}
print(f"Error getting labels. Exception: {e}")
def patched_run(self, *args, **kwargs):

updated_kwargs, ln_tracer, updated_labels = _apply_patch(kwargs)

chain_run_id = str(uuid.uuid4())
callback_handler = CallbackHandler(ln_metrics, chain_run_id, override_labels)
with ln_metrics.agent_run_histogram.labels(**dict(ln_metrics.get_default_labels(), **override_labels)).time():
if 'callbacks' in kwargs:
kwargs['callbacks'].append(callback_handler)
else:
kwargs['callbacks'] = [callback_handler]
# Call the original run method
return original_run(self, *args, **updated_kwargs)

def patched_arun(self, *args, **kwargs):
updated_kwargs, ln_tracer, updated_labels = _apply_patch(kwargs)

# Call the original run method
return original_arun(self, *args, **kwargs)
# Call the original run method
return original_arun(self, *args, **updated_kwargs)

# Patch the Chain class's run method with the new one
Chain.run = patched_run
Expand Down
Loading

0 comments on commit e5e9eca

Please sign in to comment.