diff --git a/docker-compose.yaml b/docker-compose.yaml index 94044916b..926b8515c 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -43,7 +43,7 @@ services: - redis_data:/data super__postgres: - image: "docker.io/library/postgres:latest" + image: "docker.io/library/postgres:15" environment: - POSTGRES_USER=superagi - POSTGRES_PASSWORD=password diff --git a/gui/pages/Content/Models/AddModel.js b/gui/pages/Content/Models/AddModel.js index e596cb80c..0ef3d5497 100644 --- a/gui/pages/Content/Models/AddModel.js +++ b/gui/pages/Content/Models/AddModel.js @@ -1,14 +1,14 @@ import React, {useEffect, useState} from "react"; import ModelForm from "./ModelForm"; -export default function AddModel({internalId, getModels, sendModelData}){ +export default function AddModel({internalId, getModels, sendModelData, env}){ return(
- +
diff --git a/gui/pages/Content/Models/ModelForm.js b/gui/pages/Content/Models/ModelForm.js index d8b248c56..45794bf18 100644 --- a/gui/pages/Content/Models/ModelForm.js +++ b/gui/pages/Content/Models/ModelForm.js @@ -1,12 +1,12 @@ import React, {useEffect, useRef, useState} from "react"; import {removeTab, openNewTab, createInternalId, getUserClick} from "@/utils/utils"; import Image from "next/image"; -import {fetchApiKey, storeModel, verifyEndPoint} from "@/pages/api/DashboardService"; +import {fetchApiKey, storeModel, testModel, verifyEndPoint} from "@/pages/api/DashboardService"; import {BeatLoader, ClipLoader} from "react-spinners"; import {ToastContainer, toast} from 'react-toastify'; -export default function ModelForm({internalId, getModels, sendModelData}){ - const models = ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm']; +export default function ModelForm({internalId, getModels, sendModelData, env}){ + const models = env === 'DEV' ? ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm', 'Local LLM'] : ['OpenAI', 'Replicate', 'Hugging Face', 'Google Palm']; const [selectedModel, setSelectedModel] = useState('Select a Model'); const [modelName, setModelName] = useState(''); const [modelDescription, setModelDescription] = useState(''); @@ -14,9 +14,12 @@ export default function ModelForm({internalId, getModels, sendModelData}){ const [modelEndpoint, setModelEndpoint] = useState(''); const [modelDropdown, setModelDropdown] = useState(false); const [modelVersion, setModelVersion] = useState(''); + const [modelContextLength, setContextLength] = useState(4096); const [tokenError, setTokenError] = useState(false); const [lockAddition, setLockAddition] = useState(true); const [isLoading, setIsLoading] = useState(false) + const [modelStatus, setModelStatus] = useState(null); + const [createClickable, setCreateClickable] = useState(true); const modelRef = useRef(null); useEffect(() => { @@ -79,13 +82,31 @@ export default function ModelForm({internalId, getModels, sendModelData}){ }) } + const handleModelStatus = async () => { + try { + setCreateClickable(false); + const response = await testModel(); + if(response.status === 200) { + setModelStatus(true); + setCreateClickable(true); + } else { + setModelStatus(false); + setCreateClickable(true); + } + } catch(error) { + console.log("Error Message:: " + error); + setModelStatus(false); + setCreateClickable(true); + } + } + const handleModelSuccess = (model) => { model.contentType = 'Model' sendModelData(model) } const storeModelDetails = (modelProviderId) => { - storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion).then((response) =>{ + storeModel(modelName,modelDescription, modelEndpoint, modelProviderId, modelTokenLimit, "Custom", modelVersion, modelContextLength).then((response) =>{ setIsLoading(false) let data = response.data if (data.error) { @@ -153,18 +174,42 @@ export default function ModelForm({internalId, getModels, sendModelData}){ onChange={(event) => setModelVersion(event.target.value)}/>
} + {(selectedModel === 'Local LLM') &&
+ Model Context Length + setContextLength(event.target.value)}/> +
} +
Token Limit setModelTokenLimit(parseInt(event.target.value, 10))}/>
-
- - + {modelStatus===false &&
+ error-icon +
+ Test model failed +
+
} + + {modelStatus===true &&
+ +
+ Test model successful +
+
} + +
+ {selectedModel==='Local LLM' && } +
+ + +
diff --git a/gui/pages/Dashboard/Content.js b/gui/pages/Dashboard/Content.js index 0611a7be0..5ad7740b6 100644 --- a/gui/pages/Dashboard/Content.js +++ b/gui/pages/Dashboard/Content.js @@ -470,7 +470,7 @@ export default function Content({env, selectedView, selectedProjectId, organisat organisationId={organisationId} sendKnowledgeData={addTab} sendAgentData={addTab} selectedProjectId={selectedProjectId} editAgentId={tab.id} fetchAgents={getAgentList} toolkits={toolkits} template={null} edit={true} agents={agents}/>} - {tab.contentType === 'Add_Model' && } + {tab.contentType === 'Add_Model' && } {tab.contentType === 'Model' && }
}
diff --git a/gui/pages/_app.css b/gui/pages/_app.css index 88e8d2dd3..973582b49 100644 --- a/gui/pages/_app.css +++ b/gui/pages/_app.css @@ -231,18 +231,6 @@ input[type="range"]::-moz-range-track { z-index: 10; } -.dropdown_container_models { - flex-direction: column; - align-items: flex-start; - border-radius: 8px; - background: #2E293F; - box-shadow: -2px 2px 24px rgba(0, 0, 0, 0.4); - position: absolute; - width: fit-content; - height: fit-content; - padding: 8px; -} - .dropdown_container { width: 150px; height: auto; @@ -783,7 +771,6 @@ p { .mt_74{margin-top: 74px;} .mt_80{margin-top: 80px;} .mt_90{margin-top: 90px;} -.mt_130{margin-top: 130px;} .mb_1{margin-bottom: 1px;} .mb_2{margin-bottom: 2px;} @@ -991,22 +978,6 @@ p { line-height: normal; } -.text_20 { - color: #FFF; - font-size: 20px; - font-style: normal; - font-weight: 400; - line-height: normal; -} - -.text_20 { - color: #FFF; - font-size: 20px; - font-style: normal; - font-weight: 400; - line-height: normal; -} - .text_20_bold{ color: #FFF; font-size: 20px; @@ -1107,7 +1078,6 @@ p { .w_73{width: 73%} .w_97{width: 97%} .w_100{width: 100%} -.w_99vw{width: 99vw} .w_inherit{width: inherit} .w_fit_content{width:fit-content} .w_inherit{width: inherit} @@ -1125,11 +1095,11 @@ p { .h_80vh{height: 80vh} .h_calc92{height: calc(100vh - 92px)} .h_calc_add40{height: calc(80vh + 40px)} -.h_calc_sub_60{height: calc(92.5vh - 60px)} .mxh_78vh{max-height: 78vh} .flex_dir_col{flex-direction: column} +.flex_none{flex: none} .justify_center{justify-content: center} .justify_end{justify-content: flex-end} @@ -1138,8 +1108,6 @@ p { .display_flex{display: inline-flex} .display_flex_container{display: flex} -.display_none{display: none} -.display_block{display: block} .align_center{align-items: center} .align_start{align-items: flex-start} @@ -1178,8 +1146,6 @@ p { .bt_white{border-top: 1px solid rgba(255, 255, 255, 0.08);} -.bt_white{border-top: 1px solid rgba(255, 255, 255, 0.08);} - .color_white{color:#FFFFFF} .color_gray{color:#888888} @@ -1188,7 +1154,7 @@ p { .lh_18{line-height: 18px;} .lh_24{line-height: 24px;} -.padding_0{padding: 0} +.padding_0{padding: 0;} .padding_5{padding: 5px;} .padding_6{padding: 6px;} .padding_8{padding: 8px;} @@ -1505,7 +1471,6 @@ tr{ .bg_none{background: none;} .bg_primary{background: #2E293F;} .bg_secondary{background: #272335;} -.bg_none{background: none} .container { height: 100%; @@ -1871,6 +1836,13 @@ tr{ padding: 12px; } +.success_box{ + border-radius: 8px; + padding: 12px; + border-left: 4px solid rgba(255, 255, 255, 0.60); + background: rgba(255, 255, 255, 0.08); +} + .horizontal_line { margin: 16px 0 16px -16px; border: 1px solid #ffffff20; @@ -1922,26 +1894,4 @@ tr{ .tooltip-class { background-color: green; border-radius: 6px; -} - -.text_dropdown { - color: #FFFFFF; - font-family: Plus Jakarta Sans, sans-serif; - font-style: normal; - font-weight: 500; - line-height: normal; -} - -.text_dropdown_18 { - font-size: 18px; -} - -.vertical_divider { - background: transparent; - /*border-color: rgba(255, 255, 255, 0.08);*/ - border: 1.2px solid rgba(255, 255, 255, 0.08);; - height: 20px; - width: 0; -} - - +} \ No newline at end of file diff --git a/gui/pages/api/DashboardService.js b/gui/pages/api/DashboardService.js index 38e13c698..2e5f93869 100644 --- a/gui/pages/api/DashboardService.js +++ b/gui/pages/api/DashboardService.js @@ -358,8 +358,12 @@ export const verifyEndPoint = (model_api_key, end_point, model_provider) => { }); } -export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version) => { - return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version}); +export const storeModel = (model_name, description, end_point, model_provider_id, token_limit, type, version, context_length) => { + return api.post(`/models_controller/store_model`,{model_name, description, end_point, model_provider_id, token_limit, type, version, context_length}); +} + +export const testModel = () => { + return api.get(`/models_controller/test_local_llm`); } export const fetchModels = () => { @@ -389,7 +393,6 @@ export const getToolLogs = (toolName) => { export const publishTemplateToMarketplace = (agentData) => { return api.post(`/agent_templates/publish_template`, agentData); }; - export const getKnowledgeMetrics = (knowledgeName) => { return api.get(`analytics/knowledge/${knowledgeName}/usage`) } diff --git a/superagi/controllers/models_controller.py b/superagi/controllers/models_controller.py index 2521f9abd..40f202461 100644 --- a/superagi/controllers/models_controller.py +++ b/superagi/controllers/models_controller.py @@ -2,6 +2,7 @@ from superagi.helper.auth import check_auth, get_user_organisation from superagi.helper.models_helper import ModelsHelper from superagi.apm.call_log_helper import CallLogHelper +from superagi.lib.logger import logger from superagi.models.models import Models from superagi.models.models_config import ModelsConfig from superagi.config.config import get_config @@ -9,6 +10,7 @@ from fastapi_sqlalchemy import db import logging from pydantic import BaseModel +from superagi.helper.llm_loader import LLMLoader router = APIRouter() @@ -26,6 +28,7 @@ class StoreModelRequest(BaseModel): token_limit: int type: str version: str + context_length: int class ModelName (BaseModel): model: str @@ -69,7 +72,9 @@ async def verify_end_point(model_api_key: str = None, end_point: str = None, mod @router.post("/store_model", status_code=200) async def store_model(request: StoreModelRequest, organisation=Depends(get_user_organisation)): try: - return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version) + #context_length = 4096 + logger.info(request) + return Models.store_model_details(db.session, organisation.id, request.model_name, request.description, request.end_point, request.model_provider_id, request.token_limit, request.type, request.version, request.context_length) except Exception as e: logging.error(f"Error storing the Model Details: {str(e)}") raise HTTPException(status_code=500, detail="Internal Server Error") @@ -164,4 +169,32 @@ def get_models_details(page: int = 0): marketplace_models = Models.fetch_marketplace_list(page) marketplace_models_with_install = Models.get_model_install_details(db.session, marketplace_models, organisation_id, ModelsTypes.MARKETPLACE.value) - return marketplace_models_with_install \ No newline at end of file + return marketplace_models_with_install + +@router.get("/test_local_llm", status_code=200) +def test_local_llm(): + try: + llm_loader = LLMLoader(context_length=4096) + llm_model = llm_loader.model + llm_grammar = llm_loader.grammar + if llm_model is None: + logger.error("Model not found.") + raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.") + if llm_grammar is None: + logger.error("Grammar not found.") + raise HTTPException(status_code=404, detail="Grammar not found.") + + messages = [ + {"role":"system", + "content":"You are an AI assistant. Give response in a proper JSON format"}, + {"role":"user", + "content":"Hi!"} + ] + response = llm_model.create_chat_completion(messages=messages, grammar=llm_grammar) + content = response["choices"][0]["message"]["content"] + logger.info(content) + return "Model loaded successfully." + + except Exception as e: + logger.info("Error: ",e) + raise HTTPException(status_code=404, detail="Error while loading the model. Please check your model path and try again.") \ No newline at end of file diff --git a/superagi/helper/llm_loader.py b/superagi/helper/llm_loader.py index 8c2b19e45..8d78337da 100644 --- a/superagi/helper/llm_loader.py +++ b/superagi/helper/llm_loader.py @@ -35,4 +35,4 @@ def grammar(self): "superagi/llms/grammar/json.gbnf") except Exception as e: logger.error(e) - return self._grammar + return self._grammar \ No newline at end of file diff --git a/superagi/jobs/agent_executor.py b/superagi/jobs/agent_executor.py index e47486a05..45e91ee4d 100644 --- a/superagi/jobs/agent_executor.py +++ b/superagi/jobs/agent_executor.py @@ -13,7 +13,6 @@ from superagi.lib.logger import logger from superagi.llms.google_palm import GooglePalm from superagi.llms.hugging_face import HuggingFace -from superagi.llms.replicate import Replicate from superagi.llms.llm_model_factory import get_model from superagi.llms.replicate import Replicate from superagi.models.agent import Agent @@ -28,8 +27,6 @@ from superagi.worker import execute_agent from superagi.agent.types.agent_workflow_step_action_types import AgentWorkflowStepAction from superagi.agent.types.agent_execution_status import AgentExecutionStatus -from superagi.vector_store.redis import Redis -from superagi.config.config import get_config # from superagi.helper.tool_helper import get_tool_config_by_key @@ -139,6 +136,8 @@ def get_embedding(cls, model_source, model_api_key): return HuggingFace(api_key=model_api_key) if "Replicate" in model_source: return Replicate(api_key=model_api_key) + if "Custom" in model_source: + return LocalLLM() return None def _check_for_max_iterations(self, session, organisation_id, agent_config, agent_execution_id): @@ -184,4 +183,4 @@ def execute_waiting_workflows(self): AgentWaitStepHandler(session=session, agent_id=agent_execution.agent_id, agent_execution_id=agent_execution.id).handle_next_step() execute_agent.delay(agent_execution.id, datetime.now()) - session.close() + session.close() \ No newline at end of file diff --git a/superagi/llms/llm_model_factory.py b/superagi/llms/llm_model_factory.py index af6cfedf6..345c4f8c7 100644 --- a/superagi/llms/llm_model_factory.py +++ b/superagi/llms/llm_model_factory.py @@ -34,6 +34,9 @@ def get_model(organisation_id, api_key, model="gpt-3.5-turbo", **kwargs): elif provider_name == 'Hugging Face': print("Provider is Hugging Face") return HuggingFace(model=model_instance.model_name, end_point=model_instance.end_point, api_key=api_key, **kwargs) + elif provider_name == 'Local LLM': + print("Provider is Local LLM") + return LocalLLM(model=model_instance.model_name, context_length=model_instance.context_length) else: print('Unknown provider.') @@ -46,5 +49,7 @@ def build_model_with_api_key(provider_name, api_key): return GooglePalm(api_key=api_key) elif provider_name.lower() == 'hugging face': return HuggingFace(api_key=api_key) + elif provider_name.lower() == 'local llm': + return LocalLLM(api_key=api_key) else: print('Unknown provider.') \ No newline at end of file diff --git a/superagi/llms/local_llm.py b/superagi/llms/local_llm.py index 608afa289..a146d7daa 100644 --- a/superagi/llms/local_llm.py +++ b/superagi/llms/local_llm.py @@ -89,4 +89,4 @@ def get_models(self): return self.model def verify_access_key(self, api_key): - return True + return True \ No newline at end of file diff --git a/superagi/models/models.py b/superagi/models/models.py index 5a58b74d6..0474f5fa2 100644 --- a/superagi/models/models.py +++ b/superagi/models/models.py @@ -1,3 +1,4 @@ +import yaml from sqlalchemy import Column, Integer, String, and_ from sqlalchemy.sql import func from typing import List, Dict, Union @@ -5,6 +6,7 @@ from superagi.controllers.types.models_types import ModelsTypes from superagi.helper.encyption_helper import decrypt_data import requests, logging +from superagi.lib.logger import logger marketplace_url = "https://app.superagi.com/api" # marketplace_url = "http://localhost:8001" @@ -39,6 +41,7 @@ class Models(DBBaseModel): version = Column(String, nullable=False) org_id = Column(Integer, nullable=False) model_features = Column(String, nullable=False) + context_length = Column(Integer, nullable=True) def __repr__(self): """ @@ -103,7 +106,7 @@ def fetch_model_tokens(cls, session, organisation_id) -> Dict[str, int]: return {"error": "Unexpected Error Occured"} @classmethod - def store_model_details(cls, session, organisation_id, model_name, description, end_point, model_provider_id, token_limit, type, version): + def store_model_details(cls, session, organisation_id, model_name, description, end_point, model_provider_id, token_limit, type, version, context_length): from superagi.models.models_config import ModelsConfig if not model_name: return {"error": "Model Name is empty or undefined"} @@ -129,9 +132,12 @@ def store_model_details(cls, session, organisation_id, model_name, description, return model # Return error message if model not found # Check the 'provider' from ModelsConfig table - if not end_point and model["provider"] not in ['OpenAI', 'Google Palm', 'Replicate']: + if not end_point and model["provider"] not in ['OpenAI', 'Google Palm', 'Replicate','Local LLM']: return {"error": "End Point is empty or undefined"} + if context_length is None: + context_length = 0 + try: model = Models( model_name=model_name, @@ -142,7 +148,8 @@ def store_model_details(cls, session, organisation_id, model_name, description, type=type, version=version, org_id=organisation_id, - model_features='' + model_features='', + context_length=context_length ) session.add(model) session.commit() @@ -228,4 +235,4 @@ def fetch_model_details(cls, session, organisation_id, model_id: int) -> Dict[st except Exception as e: logging.error(f"Unexpected Error Occured: {e}") - return {"error": "Unexpected Error Occured"} + return {"error": "Unexpected Error Occured"} \ No newline at end of file diff --git a/superagi/models/models_config.py b/superagi/models/models_config.py index 998e8170c..7d577d331 100644 --- a/superagi/models/models_config.py +++ b/superagi/models/models_config.py @@ -1,4 +1,5 @@ from sqlalchemy import Column, Integer, String, and_, distinct +from superagi.lib.logger import logger from superagi.models.base_model import DBBaseModel from superagi.models.organisation import Organisation from superagi.models.project import Project @@ -69,6 +70,9 @@ def fetch_value_by_agent_id(cls, session, agent_id: int, model: str): if not config: return None + if config.provider == 'Local LLM': + return {"provider": config.provider, "api_key": config.api_key} if config else None + return {"provider": config.provider, "api_key": decrypt_data(config.api_key)} if config else None @classmethod @@ -123,8 +127,13 @@ def fetch_api_key(cls, session, organisation_id, model_provider): api_key_data = session.query(ModelsConfig.id, ModelsConfig.provider, ModelsConfig.api_key).filter( and_(ModelsConfig.org_id == organisation_id, ModelsConfig.provider == model_provider)).first() + logger.info(api_key_data) if api_key_data is None: return [] + elif api_key_data.provider == 'Local LLM': + api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider, + 'api_key': api_key_data.api_key}] + return api_key else: api_key = [{'id': api_key_data.id, 'provider': api_key_data.provider, 'api_key': decrypt_data(api_key_data.api_key)}] diff --git a/superagi/types/model_source_types.py b/superagi/types/model_source_types.py index f811a60c6..6e9de18ad 100644 --- a/superagi/types/model_source_types.py +++ b/superagi/types/model_source_types.py @@ -6,6 +6,7 @@ class ModelSourceType(Enum): OpenAI = 'OpenAi' Replicate = 'Replicate' HuggingFace = 'Hugging Face' + LocalLLM = 'Local LLM' @classmethod def get_model_source_type(cls, name): diff --git a/tests/unit_tests/controllers/test_models_controller.py b/tests/unit_tests/controllers/test_models_controller.py index 489cff636..790229789 100644 --- a/tests/unit_tests/controllers/test_models_controller.py +++ b/tests/unit_tests/controllers/test_models_controller.py @@ -2,6 +2,11 @@ import pytest from fastapi.testclient import TestClient from main import app +from llama_cpp import Llama +from llama_cpp import LlamaGrammar +import llama_cpp + +from superagi.helper.llm_loader import LLMLoader client = TestClient(app) @@ -50,7 +55,8 @@ def test_store_model_success(mock_get_db): "model_provider_id": 1, "token_limit": 10, "type": "mock_type", - "version": "mock_version" + "version": "mock_version", + "context_length":4096 } with patch('superagi.helper.auth.get_user_organisation') as mock_get_user_org, \ patch('superagi.helper.auth.db') as mock_auth_db: @@ -100,3 +106,13 @@ def test_get_marketplace_models_list_success(mock_get_db): patch('superagi.helper.auth.db') as mock_auth_db: response = client.get("/models_controller/marketplace/list/0") assert response.status_code == 200 + +def test_get_local_llm(): + with(patch.object(LLMLoader, 'model', new_callable=MagicMock)) as mock_model: + with(patch.object(LLMLoader, 'grammar', new_callable=MagicMock)) as mock_grammar: + + mock_model.create_chat_completion.return_value = {"choices": [{"message": {"content": "Hello!"}}]} + + response = client.get("/models_controller/test_local_llm") + + assert response.status_code == 200 \ No newline at end of file diff --git a/tests/unit_tests/models/test_models.py b/tests/unit_tests/models/test_models.py index 3bdc43075..d4880538c 100644 --- a/tests/unit_tests/models/test_models.py +++ b/tests/unit_tests/models/test_models.py @@ -133,6 +133,7 @@ def test_store_model_details_when_model_exists(mock_session): token_limit=500, type="type", version="v1.0", + context_length=4096 ) # Assert @@ -161,6 +162,7 @@ def test_store_model_details_when_model_not_exists(mock_session, monkeypatch): token_limit=500, type="type", version="v1.0", + context_length=4096 ) # Assert @@ -187,6 +189,7 @@ def test_store_model_details_when_unexpected_error_occurs(mock_session, monkeypa token_limit=500, type="type", version="v1.0", + context_length=4096 ) # Assert @@ -229,6 +232,4 @@ def test_fetch_model_details(mock_models_config, mock_session): "token_limit": 100, "type": "type1", "model_provider": "example_provider" - } - - + } \ No newline at end of file