Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Fixed inspecting schema issue when working with InjectedToolArg annotations #28435

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions libs/core/langchain_core/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from langchain_core.tools.base import (
InjectedToolArg as InjectedToolArg,
)
from langchain_core.tools.base import (
InjectedToolArgSchema as InjectedToolArgSchema,
)
from langchain_core.tools.base import SchemaAnnotationError as SchemaAnnotationError
from langchain_core.tools.base import (
ToolException as ToolException,
Expand Down
46 changes: 38 additions & 8 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
PydanticDeprecationWarning,
SkipValidation,
ValidationError,
WithJsonSchema,
model_validator,
validate_arguments,
)
Expand Down Expand Up @@ -103,7 +104,7 @@ def _get_filtered_args(
for i, (k, param) in enumerate(valid_keys.items())
if k not in filter_args
and (i > 0 or param.name not in ("self", "cls"))
and (include_injected or not _is_injected_arg_type(param.annotation))
and (include_injected or not _check_injected_arg_type(param.annotation))
}


Expand Down Expand Up @@ -263,21 +264,26 @@ def create_schema_from_function(

inferred_model = validated.model # type: ignore

# extract the function parameters
existing_params: list[str] = list(sig.parameters.keys())

# Create filtered arguments list
if filter_args:
filter_args_ = filter_args
filter_args_ = list(filter_args)

else:
# Handle classmethods and instance methods
existing_params: list[str] = list(sig.parameters.keys())
if existing_params and existing_params[0] in ("self", "cls") and in_class:
filter_args_ = [existing_params[0]] + list(FILTERED_ARGS)
else:
filter_args_ = list(FILTERED_ARGS)

for existing_param in existing_params:
if not include_injected and _is_injected_arg_type(
sig.parameters[existing_param].annotation
):
filter_args_.append(existing_param)
# add arguments with InjectedToolArg annotation to filter_args_
for existing_param in existing_params:
if not include_injected and _check_injected_arg_type(
sig.parameters[existing_param].annotation
):
filter_args_.append(existing_param)

description, arg_descriptions = _infer_arg_descriptions(
func,
Expand Down Expand Up @@ -954,6 +960,17 @@ class InjectedToolArg:
"""Annotation for a Tool arg that is **not** meant to be generated by a model."""


# Add Custom json schema for injected tool arguments
InjectedToolArgSchema = Annotated[
InjectedToolArg,
WithJsonSchema(
{
"type": "Injected-Tool-Argument",
}
),
]


def _is_injected_arg_type(type_: type) -> bool:
return any(
isinstance(arg, InjectedToolArg)
Expand All @@ -962,6 +979,19 @@ def _is_injected_arg_type(type_: type) -> bool:
)


# Identify if a type contains an InjectedToolArg annotation
# Used to filter out injected arguments from the schema
def _check_injected_arg_type(type_: type) -> bool:
if type_ is InjectedToolArg:
return True
for arg in get_args(type_):
if arg is InjectedToolArg or (
isinstance(arg, type) and issubclass(arg, InjectedToolArg)
):
return True
return False


def get_all_basemodel_annotations(
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
) -> dict[str, type]:
Expand Down
2 changes: 2 additions & 0 deletions libs/core/langchain_core/tools/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,12 +173,14 @@ def add(a: int, b: int) -> int:
name = name or source_function.__name__
if args_schema is None and infer_schema:
# schema name is appended within function
# Use include_injected to exclude injected args in the schema
args_schema = create_schema_from_function(
name,
source_function,
parse_docstring=parse_docstring,
error_on_invalid_docstring=error_on_invalid_docstring,
filter_args=_filter_schema_args(source_function),
# include_injected=False,
)
description_ = description
if description is None and not parse_docstring:
Expand Down
Loading