Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

273 consistent styling #295

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
Expand All @@ -10,28 +10,30 @@ repos:
- id: detect-private-key

- repo: https://github.com/psf/black
rev: 24.4.2
rev: 24.10.0
hooks:
- id: black
args: ['--line-length=88']
args: ['--config=pyproject.toml']
exclude: ^docs/|.*\.(json|yaml|md|txt)$

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.2
rev: v0.8.2
hooks:
# Run the linter.
- id: ruff
args: ['--fix']
args: ['--fix', '--config=pyproject.toml']
exclude: ^docs/|.*\.(json|yaml|md|txt)$

# Add local hooks to run custom commands
# stage files after ruff
- repo: local
hooks:
- id: run-make-format
name: Run Make Format
entry: make format
language: system
stages: [commit]
pass_filenames: false

# - repo: https://github.com/pycqa/flake8
# rev: 4.0.1
# hooks:
Expand Down
6 changes: 4 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ setup:
# Format code using Black and Ruff
.PHONY: format
format:
$(PYTHON) black $(SRC_DIR)
git ls-files | xargs pre-commit run black --files
$(PYTHON) black $(SRC_DIR) --config pyproject.toml
$(PYTHON) ruff check --fix $(SRC_DIR)
# remove git ls-files | xargs pre-commit run black --files, causes a circular dependency

# Run lint checks using Ruff
.PHONY: lint
lint:
$(PYTHON) black --check $(SRC_DIR) --config pyproject.toml
$(PYTHON) ruff check $(SRC_DIR)

# Run all pre-commit hooks on all files
Expand Down
1 change: 0 additions & 1 deletion adalflow/adalflow/components/agent/react.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def _execute_action(self, action_step: StepOutput) -> Optional[StepOutput]:
"""Parse the action string to a function call and execute it. Update the action_step with the result."""
action = action_step.action
try:

fun: Function = self.tool_manager.parse_func_expr(action)
result: FunctionOutput = self.tool_manager.execute_func(fun)
# TODO: optimize the action_step
Expand Down
1 change: 0 additions & 1 deletion adalflow/adalflow/components/model_client/cohere_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
if (
model_type == ModelType.RERANKER
): # query -> # scores for top_k documents, index for the top_k documents, return as tuple

response = self.sync_client.rerank(**api_kwargs)
top_k_scores = [result.relevance_score for result in response.results]
top_k_indices = [result.index for result in response.results]
Expand Down
1 change: 0 additions & 1 deletion adalflow/adalflow/components/model_client/google_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def convert_inputs_to_api_kwargs(
raise TypeError("input must be a sequence of text")
final_model_kwargs["input"] = input
elif model_type == ModelType.LLM:

final_model_kwargs["prompt"] = input
else:
raise ValueError(f"model_type {model_type} is not supported")
Expand Down
4 changes: 2 additions & 2 deletions adalflow/adalflow/components/model_client/ollama_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import warnings
from adalflow.core.types import ModelType, GeneratorOutput

from adalflow.utils.lazy_import import safe_import, OptionalPackages

ollama = safe_import(OptionalPackages.OLLAMA.value[0], OptionalPackages.OLLAMA.value[1])
# need to pick either safe or regular import
# ollama = safe_import(OptionalPackages.OLLAMA.value[0], OptionalPackages.OLLAMA.value[1])
import ollama
from ollama import RequestError, ResponseError, GenerateResponse

Expand Down
13 changes: 3 additions & 10 deletions adalflow/adalflow/components/model_client/transformers_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
from adalflow.core.functional import get_top_k_indices_scores

# optional import
from adalflow.utils.lazy_import import safe_import, OptionalPackages


transformers = safe_import(
OptionalPackages.TRANSFORMERS.value[0], OptionalPackages.TRANSFORMERS.value[1]
)
torch = safe_import(OptionalPackages.TORCH.value[0], OptionalPackages.TORCH.value[1])
# need to pick either safe or regular import
# transformers = safe_import(OptionalPackages.TRANSFORMERS.value[0], OptionalPackages.TRANSFORMERS.value[1])
# torch = safe_import(OptionalPackages.TORCH.value[0], OptionalPackages.TORCH.value[1])

import torch

Expand Down Expand Up @@ -201,7 +198,6 @@ def infer_bge_reranker_base(
input = [(query, doc) for doc in documents]

with torch.no_grad():

inputs = self.tokenizer(
input,
padding=True,
Expand Down Expand Up @@ -358,7 +354,6 @@ def init_model(self, model_name: str):
raise ValueError(f"Model {model_name} is not supported")

def _parse_chat_completion_from_pipeline(self, completion: Any) -> str:

text = completion[0]["generated_text"]

pattern = r"(?<=\|assistant\|>).*"
Expand Down Expand Up @@ -407,7 +402,6 @@ def _infer_from_pipeline(
)

if model == "HuggingFaceH4/zephyr-7b-beta":

prompt = model_to_use.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
Expand Down Expand Up @@ -728,7 +722,6 @@ def convert_inputs_to_api_kwargs(
}

class CustomizeLLM:

def __init__(self) -> None:
pass

Expand Down
1 change: 1 addition & 0 deletions adalflow/adalflow/components/model_client/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"Helpers for model client for integrating models and parsing the output."

from adalflow.core.types import EmbedderOutput, Embedding, Usage


Expand Down
2 changes: 0 additions & 2 deletions adalflow/adalflow/components/output_parsers/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def __init__(
exclude_fields: ExcludeType = None,
return_data_class: bool = False,
):

super().__init__()
if not is_dataclass(data_class):
raise TypeError(f"Provided class is not a dataclass: {data_class}")
Expand Down Expand Up @@ -349,7 +348,6 @@ def format_instructions(self) -> str:
return "The output should be a boolean value. True or False."

def call(self, input: str) -> bool:

input = input.strip()
output = None
# evaluate the expression to get the boolean value
Expand Down
2 changes: 1 addition & 1 deletion adalflow/adalflow/components/retriever/bm25_retriever.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""BM25 retriever implementation. """
"""BM25 retriever implementation."""

from typing import List, Dict, Optional, Callable, Any, Sequence
import numpy as np
Expand Down
2 changes: 0 additions & 2 deletions adalflow/adalflow/core/base_data_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,8 @@ class MyOutputs(DataClass):
__output_fields__: List[str] = []

def __post_init__(self):

for f in fields(self):
if "desc" not in f.metadata and "description" not in f.metadata:

logger.debug(
f"Class { self.__class__.__name__} Field {f.name} is missing 'desc' in metadata"
)
Expand Down
1 change: 0 additions & 1 deletion adalflow/adalflow/core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,6 @@ def remove_from(*dicts_or_sets):
)
self.register_parameter(name, value)
else: # set component

components = self.__dict__.get("_components")
if isinstance(value, Component):
if components is None:
Expand Down
1 change: 0 additions & 1 deletion adalflow/adalflow/core/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(
model_kwargs: Dict[str, Any] = {},
output_processors: Optional[Component] = None,
) -> None:

super().__init__(model_kwargs=model_kwargs)
if not isinstance(model_kwargs, Dict):
raise TypeError(
Expand Down
1 change: 0 additions & 1 deletion adalflow/adalflow/core/func_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,6 @@ def _extra_repr(self) -> str:


if __name__ == "__main__":

import asyncio
import time

Expand Down
3 changes: 0 additions & 3 deletions adalflow/adalflow/core/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ class TrecDataList:
if is_dataclass(cls) or is_potential_dataclass(
cls
): # Optional[Address] will be false, and true for each check

log.debug(
f"{is_dataclass(cls)} of {cls}, {is_potential_dataclass(cls)} of {cls}"
)
Expand Down Expand Up @@ -922,7 +921,6 @@ def get_top_k_indices_scores(


def generate_readable_key_for_function(fn: Callable) -> str:

module_name = fn.__module__
function_name = fn.__name__
return f"{module_name}.{function_name}"
Expand Down Expand Up @@ -1236,7 +1234,6 @@ def parse_json_str_to_obj(json_str: str) -> Union[Dict[str, Any], List[Any]]:
except json.JSONDecodeError:
# 3rd attemp using yaml
try:

# NOTE: parsing again with pyyaml
# pyyaml is less strict, and allows for trailing commas
# right now we rely on this since guidance program generates
Expand Down
4 changes: 0 additions & 4 deletions adalflow/adalflow/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def get_cache_path(self) -> str:
def _get_default_mapping(
output: "GeneratorOutput" = None,
) -> Tuple[Dict[str, Callable], List[str]]:

if (
output.data
and isinstance(output.data, DataClass)
Expand Down Expand Up @@ -546,7 +545,6 @@ def backward(
backward_engine: Optional["Generator"] = None,
id: Optional[str] = None, # the id of the input
) -> Parameter:

log.info(f"Generator: Backward: {response}")

children_params = response.predecessors
Expand Down Expand Up @@ -678,7 +676,6 @@ def _backward_through_one_predecessor(
data=manual_response, raw_response=manual_response
)
else:

gradient_output: GeneratorOutput = backward_engine(
prompt_kwargs=backward_engine_prompt_kwargs
)
Expand Down Expand Up @@ -881,7 +878,6 @@ def failure_message_to_backward_engine(


class BackwardEngine(Generator): # it is a generator with defaule template

__doc__ = """The backward engine is a Generator with a default template for the backward pass.

If you want to customize the template, you can create your own backward engine"""
Expand Down
2 changes: 0 additions & 2 deletions adalflow/adalflow/core/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def _convert_prompt_kwargs_to_str(prompt_kwargs: Dict) -> Dict[str, str]:
prompt_kwargs_str: Dict[str, str] = {}

for key, p in prompt_kwargs.items():

if isinstance(p, Parameter):

prompt_kwargs_str[key] = p.data
else:
prompt_kwargs_str[key] = p
Expand Down
1 change: 0 additions & 1 deletion adalflow/adalflow/core/tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def execute_func_expr(self, expr: FunctionExpression) -> FunctionOutput:
r"""Execute the function expression. Support both sync and async functions."""
func: Function = self.parse_func_expr(expr)
try:

return self.execute_func(func)
except Exception as e:
# NOTE: if the function expression is not a function call, try to execute it as a function expression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def process_batch(self, documents: List[Document]):
def __call__(self, documents: List[Document]):
batch_size = self.batch_size
for i in range(0, len(documents), batch_size):

List = documents[i : i + batch_size]
print(i, len(List))
self.process_batch(List)
Expand Down
2 changes: 0 additions & 2 deletions adalflow/adalflow/datasets/big_bench_hard.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(
*args,
**kwargs,
):

if split not in ["train", "val", "test"]:
raise ValueError("Split must be one of 'train', 'val', 'test'")

Expand All @@ -65,7 +64,6 @@ def __init__(
) # dont use a tuple, use a dict {"x": ..., "y": ...}

def _check_or_download_dataset(self, data_path: str = None, split: str = "train"):

if data_path is None:
raise ValueError("data_path must be specified")
json_path = os.path.join(data_path, f"{self.task_name}.json")
Expand Down
2 changes: 0 additions & 2 deletions adalflow/adalflow/datasets/trec.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def calculate_class_weights(labels: torch.Tensor) -> torch.Tensor:


def sample_subset_dataset(dataset, num_samples: int, sample_weights):

# Create a WeightedRandomSampler to get 400 samples
sampler = WeightedRandomSampler(
weights=sample_weights, num_samples=num_samples, replacement=False
Expand Down Expand Up @@ -171,7 +170,6 @@ def __init__(
)

def _check_or_download_dataset(self, data_path: str = None, split: str = "train"):

if data_path is None:
raise ValueError("data_path must be specified")
split_csv_path = os.path.join(data_path, f"{split}.csv")
Expand Down
1 change: 0 additions & 1 deletion adalflow/adalflow/eval/g_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


class GEvalMetric(Enum):

RELEVANCE = "Relevance" # range [1, 5]
FLUENCY = "Fluency" # range [1, 3]
CONSISTENCY = "Consistency" # range [1, 5]
Expand Down
1 change: 0 additions & 1 deletion adalflow/adalflow/optim/_llm_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from adalflow.core.base_data_class import DataClass

if TYPE_CHECKING:

from adalflow.core.model_client import ModelClient


Expand Down
3 changes: 0 additions & 3 deletions adalflow/adalflow/optim/few_shot/bootstrap_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def add_scores(self, ids: List[str], scores: List[float], is_teacher: bool = Tru
)

for score in scores:

if not isinstance(score, float):
raise ValueError(
f"score must be a float, got {type(score)}, score: {score}"
Expand Down Expand Up @@ -198,7 +197,6 @@ def samples_to_str(
sample_strs = []
for sample in samples:
try:

# process the input fields
if augmented:
exclude_fields = ["id", "score"]
Expand Down Expand Up @@ -239,7 +237,6 @@ def propose(self):

demo_str = ""
if len(sampled_augmented_demos) > 0:

demo_str = self.samples_to_str(
samples=sampled_augmented_demos,
augmented=True,
Expand Down
1 change: 0 additions & 1 deletion adalflow/adalflow/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def zero_grad(self):


class DemoOptimizer(Optimizer):

__doc__ = r"""Base class for all demo optimizers.

Demo optimizer are few-shot optimization, where it will sample raw examples from train dataset or bootstrap examples from the model's output.
Expand Down
3 changes: 0 additions & 3 deletions adalflow/adalflow/optim/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,6 @@ def build_graph(node: "Parameter"):
def backward(
self,
): # engine should be the llm or customized backwards function to pass feedback

# topological sort of all the predecessors of the current parameter in the graph
log.debug(f"Backward pass for {self.data}, backward function: {self.grad_fn}")
topo: List[Parameter] = []
Expand Down Expand Up @@ -577,7 +576,6 @@ def wrap_and_escape(text, width=40):
log.info(f"Node: {n.name}, {n.to_dict()}")
# track gradients
for g in n.gradients:

log.info(f"Gradient: {g.name}, {g.to_dict()}")
log.info(f"Gradient prompt: {g.gradient_prompt}")
for n1, n2 in edges:
Expand Down Expand Up @@ -685,7 +683,6 @@ def __repr__(self):


def _check_and_reduce_gradients(variable: Parameter) -> Set[Parameter]:

if variable.get_gradient_and_context_text() == "":
log.debug(f"No gradients detected for {variable.data}")
return variable.gradients
Expand Down
Loading
Loading