Skip to content

Commit

Permalink
Merge pull request #419 from kardSIM/v3_Hf_support
Browse files Browse the repository at this point in the history
Add Hugging API
  • Loading branch information
feder-cr authored Sep 23, 2024
2 parents 4dff44a + f32f837 commit ec44aeb
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 21 deletions.
3 changes: 3 additions & 0 deletions data_folder/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,7 @@ job_applicants_threshold:

llm_model_type: openai
llm_model: gpt-4o-mini

#llm_model_type: huggingface
#llm_model: 'tiiuae/falcon-7b-instruct'
# llm_api_url: https://api.pawan.krd/cosmosrp/v1 this field is optional
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ jsonschema==4.23.0
jsonschema-specifications==2023.12.1
langchain==0.2.11
langchain-anthropic
langchain-huggingface
langchain-community==0.2.10
langchain-core===0.2.36
langchain-google-genai==1.0.10
Expand Down
75 changes: 54 additions & 21 deletions src/llm/llm_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ def invoke(self, prompt: str) -> BaseMessage:
response = self.model.invoke(prompt)
return response

class HuggingFaceModel(AIModel):
def __init__(self, api_key: str, llm_model: str):
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
self.model = HuggingFaceEndpoint(repo_id=llm_model, huggingfacehub_api_token=api_key,
temperature=0.4)
self.chatmodel=ChatHuggingFace(llm=self.model)

def invoke(self, prompt: str) -> BaseMessage:
response = self.chatmodel.invoke(prompt)
logger.debug("Invoking Model from Hugging Face API")
print(response,type(response))
return response

class AIAdapter:
def __init__(self, config: dict, api_key: str):
Expand All @@ -111,6 +123,8 @@ def _create_model(self, config: dict, api_key: str) -> AIModel:
return OllamaModel(llm_model, llm_api_url)
elif llm_model_type == "gemini":
return GeminiModel(api_key, llm_model)
elif llm_model_type == "huggingface":
return HuggingFaceModel(api_key, llm_model)
else:
raise ValueError(f"Unsupported model type: {llm_model_type}")

Expand Down Expand Up @@ -286,27 +300,46 @@ def parse_llmresult(self, llmresult: AIMessage) -> Dict[str, Dict]:
logger.debug(f"Parsing LLM result: {llmresult}")

try:
content = llmresult.content
response_metadata = llmresult.response_metadata
id_ = llmresult.id
usage_metadata = llmresult.usage_metadata

parsed_result = {
"content": content,
"response_metadata": {
"model_name": response_metadata.get("model_name", ""),
"system_fingerprint": response_metadata.get("system_fingerprint", ""),
"finish_reason": response_metadata.get("finish_reason", ""),
"logprobs": response_metadata.get("logprobs", None),
},
"id": id_,
"usage_metadata": {
"input_tokens": usage_metadata.get("input_tokens", 0),
"output_tokens": usage_metadata.get("output_tokens", 0),
"total_tokens": usage_metadata.get("total_tokens", 0),
},
}

if hasattr(llmresult, 'usage_metadata '):
content = llmresult.content
response_metadata = llmresult.response_metadata
id_ = llmresult.id
usage_metadata = llmresult.usage_metadata

parsed_result = {
"content": content,
"response_metadata": {
"model_name": response_metadata.get("model_name", ""),
"system_fingerprint": response_metadata.get("system_fingerprint", ""),
"finish_reason": response_metadata.get("finish_reason", ""),
"logprobs": response_metadata.get("logprobs", None),
},
"id": id_,
"usage_metadata": {
"input_tokens": usage_metadata.get("input_tokens", 0),
"output_tokens": usage_metadata.get("output_tokens", 0),
"total_tokens": usage_metadata.get("total_tokens", 0),
},
}
else :
content = llmresult.content
response_metadata = llmresult.response_metadata
id_ = llmresult.id
token_usage = response_metadata['token_usage']

parsed_result = {
"content": content,
"response_metadata": {
"model_name": response_metadata.get("model", ""),
"finish_reason": response_metadata.get("finish_reason", ""),
},
"id": id_,
"usage_metadata": {
"input_tokens": token_usage.prompt_tokens,
"output_tokens": token_usage.completion_tokens,
"total_tokens": token_usage.total_tokens,
},
}
logger.debug(f"Parsed LLM result successfully: {parsed_result}")
return parsed_result

Expand Down

0 comments on commit ec44aeb

Please sign in to comment.