From bd7b6bf1fd20855f228c3bef9b3ce2686c6fbf13 Mon Sep 17 00:00:00 2001 From: giefferre Date: Thu, 6 Jun 2024 10:56:23 +0200 Subject: [PATCH] feat: implement CustomEndpoint --- runpod/__init__.py | 2 +- runpod/endpoint/__init__.py | 2 +- runpod/endpoint/runner.py | 71 +++++++++++++++++++++++++++++++------ 3 files changed, 63 insertions(+), 12 deletions(-) diff --git a/runpod/__init__.py b/runpod/__init__.py index 5d5d135b..2b481b20 100644 --- a/runpod/__init__.py +++ b/runpod/__init__.py @@ -26,7 +26,7 @@ get_credentials, set_credentials, ) -from .endpoint import AsyncioEndpoint, AsyncioJob, Endpoint +from .endpoint import AsyncioEndpoint, AsyncioJob, Endpoint, CustomEndpoint from .serverless.modules.rp_logger import RunPodLogger from .version import __version__ diff --git a/runpod/endpoint/__init__.py b/runpod/endpoint/__init__.py index c931d5f1..e30870b9 100644 --- a/runpod/endpoint/__init__.py +++ b/runpod/endpoint/__init__.py @@ -1,5 +1,5 @@ ''' Allows endpoints to be imported as a module. ''' -from .runner import Endpoint, Job +from .runner import Endpoint, Job, CustomEndpoint from .asyncio.asyncio_runner import Endpoint as AsyncioEndpoint from .asyncio.asyncio_runner import Job as AsyncioJob diff --git a/runpod/endpoint/runner.py b/runpod/endpoint/runner.py index 34baafff..a07dcd4e 100644 --- a/runpod/endpoint/runner.py +++ b/runpod/endpoint/runner.py @@ -17,7 +17,7 @@ class RunPodClient: """A client for running endpoint calls.""" - def __init__(self): + def __init__(self, custom_endpoint_url_base: Optional[str] = None): """ Initialize a RunPodClient instance. @@ -39,6 +39,8 @@ def __init__(self): } self.endpoint_url_base = endpoint_url_base + if custom_endpoint_url_base: + self.endpoint_url_base = custom_endpoint_url_base def _request(self, method: str, endpoint: str, data: Optional[dict] = None, timeout: int = 10): @@ -101,8 +103,8 @@ def __init__(self, endpoint_id: str, job_id: str, client: RunPodClient): def _fetch_job(self, source: str = "status") -> Dict[str, Any]: """ Returns the raw json of the status, raises an exception if invalid """ - status_url = f"{self.endpoint_id}/{source}/{self.job_id}" - job_state = self.rp_client.get(endpoint=status_url) + status_url = self._get_endpoint_url_for_job_method(source) + job_state = self._get_job_state(status_url) if is_completed(job_state["status"]): self.job_status = job_state["status"] @@ -154,8 +156,23 @@ def cancel(self, timeout: int = 3) -> Any: Args: timeout: The number of seconds to wait for the server to respond before giving up. """ - return self.rp_client.post(f"{self.endpoint_id}/cancel/{self.job_id}", - data=None, timeout=timeout) + return self.rp_client.post( + self._get_endpoint_url_for_job_method("cancel"), + data=None, + timeout=timeout, + ) + + def _get_endpoint_url_for_job_method(self, method: str) -> str: + """ Returns the endpoint URL for the given method. """ + if not self.endpoint_id: + return f"{method}/{self.job_id}" + return f"{self.endpoint_id}/{method}/{self.job_id}" + + def _get_job_state(self, status_url: str) -> Dict[str, Any]: + """ Returns the state of the job. """ + if not self.endpoint_id: # this is due to a bug in the local server api + return self.rp_client.post(endpoint=status_url, data=None, timeout=60) + return self.rp_client.get(endpoint=status_url) # ---------------------------------------------------------------------------- # @@ -180,7 +197,7 @@ def __init__(self, endpoint_id: str): self.endpoint_id = endpoint_id self.rp_client = RunPodClient() - def run(self, request_input: Dict[str, Any]) -> Job: + def run(self, request_input: Dict[str, Any], timeout: int = 10) -> Job: """ Run the endpoint with the given input. @@ -193,7 +210,11 @@ def run(self, request_input: Dict[str, Any]) -> Job: if not request_input.get("input"): request_input = {"input": request_input} - job_request = self.rp_client.post(f"{self.endpoint_id}/run", request_input) + job_request = self.rp_client.post( + self._get_endpoint_url_for_method("run"), + request_input, + timeout=timeout, + ) return Job(self.endpoint_id, job_request["id"], self.rp_client) def run_sync(self, request_input: Dict[str, Any], timeout: int = 86400) -> Dict[str, Any]: @@ -207,7 +228,10 @@ def run_sync(self, request_input: Dict[str, Any], timeout: int = 86400) -> Dict[ request_input = {"input": request_input} job_request = self.rp_client.post( - f"{self.endpoint_id}/runsync", request_input, timeout=timeout) + self._get_endpoint_url_for_method("runsync"), + request_input, + timeout=timeout, + ) if job_request["status"] in FINAL_STATES: return job_request.get("output", None) @@ -221,7 +245,8 @@ def health(self, timeout: int = 3) -> Dict[str, Any]: Args: timeout: The number of seconds to wait for the server to respond before giving up. """ - return self.rp_client.get(f"{self.endpoint_id}/health", timeout=timeout) + return self.rp_client.get( + self._get_endpoint_url_for_method("health"), timeout=timeout) def purge_queue(self, timeout: int = 3) -> Dict[str, Any]: """ @@ -230,4 +255,30 @@ def purge_queue(self, timeout: int = 3) -> Dict[str, Any]: Args: timeout: The number of seconds to wait for the server to respond before giving up. """ - return self.rp_client.post(f"{self.endpoint_id}/purge-queue", data=None, timeout=timeout) + return self.rp_client.post( + self._get_endpoint_url_for_method("purge-queue"), + data=None, + timeout=timeout, + ) + + def _get_endpoint_url_for_method(self, method: str) -> str: + """ Returns the endpoint URL for the given method. """ + if method not in ["run", "runsync", "health", "purge-queue"]: + raise ValueError(f"Method '{method}' is not supported.") + + if not self.endpoint_id: + return f"{method}" + return f"{self.endpoint_id}/{method}" + + +class CustomEndpoint(Endpoint): + def __init__(self, custom_endpoint_url_base: str): + """ + Initialize an Endpoint instance with the given endpoint base URL. + Intended for usage with regular Pods or test servers. + + Args: + custom_endpoint_url_base: The custom endpoint URL base. + """ + self.endpoint_id = None + self.rp_client = RunPodClient(custom_endpoint_url_base)