Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement CustomEndpoint #321

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion runpod/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
get_credentials,
set_credentials,
)
from .endpoint import AsyncioEndpoint, AsyncioJob, Endpoint
from .endpoint import AsyncioEndpoint, AsyncioJob, Endpoint, CustomEndpoint
from .serverless.modules.rp_logger import RunPodLogger
from .version import __version__

Expand Down
2 changes: 1 addition & 1 deletion runpod/endpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
''' Allows endpoints to be imported as a module. '''

from .runner import Endpoint, Job
from .runner import Endpoint, Job, CustomEndpoint
from .asyncio.asyncio_runner import Endpoint as AsyncioEndpoint
from .asyncio.asyncio_runner import Job as AsyncioJob
71 changes: 61 additions & 10 deletions runpod/endpoint/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class RunPodClient:
"""A client for running endpoint calls."""

def __init__(self):
def __init__(self, custom_endpoint_url_base: Optional[str] = None):
"""
Initialize a RunPodClient instance.

Expand All @@ -39,6 +39,8 @@ def __init__(self):
}

self.endpoint_url_base = endpoint_url_base
if custom_endpoint_url_base:
self.endpoint_url_base = custom_endpoint_url_base

def _request(self,
method: str, endpoint: str, data: Optional[dict] = None, timeout: int = 10):
Expand Down Expand Up @@ -101,8 +103,8 @@ def __init__(self, endpoint_id: str, job_id: str, client: RunPodClient):

def _fetch_job(self, source: str = "status") -> Dict[str, Any]:
""" Returns the raw json of the status, raises an exception if invalid """
status_url = f"{self.endpoint_id}/{source}/{self.job_id}"
job_state = self.rp_client.get(endpoint=status_url)
status_url = self._get_endpoint_url_for_job_method(source)
job_state = self._get_job_state(status_url)

if is_completed(job_state["status"]):
self.job_status = job_state["status"]
Expand Down Expand Up @@ -154,8 +156,23 @@ def cancel(self, timeout: int = 3) -> Any:
Args:
timeout: The number of seconds to wait for the server to respond before giving up.
"""
return self.rp_client.post(f"{self.endpoint_id}/cancel/{self.job_id}",
data=None, timeout=timeout)
return self.rp_client.post(
self._get_endpoint_url_for_job_method("cancel"),
data=None,
timeout=timeout,
)

def _get_endpoint_url_for_job_method(self, method: str) -> str:
""" Returns the endpoint URL for the given method. """
if not self.endpoint_id:
return f"{method}/{self.job_id}"
return f"{self.endpoint_id}/{method}/{self.job_id}"

def _get_job_state(self, status_url: str) -> Dict[str, Any]:
""" Returns the state of the job. """
if not self.endpoint_id: # this is due to a bug in the local server api
return self.rp_client.post(endpoint=status_url, data=None, timeout=60)
return self.rp_client.get(endpoint=status_url)


# ---------------------------------------------------------------------------- #
Expand All @@ -180,7 +197,7 @@ def __init__(self, endpoint_id: str):
self.endpoint_id = endpoint_id
self.rp_client = RunPodClient()

def run(self, request_input: Dict[str, Any]) -> Job:
def run(self, request_input: Dict[str, Any], timeout: int = 10) -> Job:
"""
Run the endpoint with the given input.

Expand All @@ -193,7 +210,11 @@ def run(self, request_input: Dict[str, Any]) -> Job:
if not request_input.get("input"):
request_input = {"input": request_input}

job_request = self.rp_client.post(f"{self.endpoint_id}/run", request_input)
job_request = self.rp_client.post(
self._get_endpoint_url_for_method("run"),
request_input,
timeout=timeout,
)
return Job(self.endpoint_id, job_request["id"], self.rp_client)

def run_sync(self, request_input: Dict[str, Any], timeout: int = 86400) -> Dict[str, Any]:
Expand All @@ -207,7 +228,10 @@ def run_sync(self, request_input: Dict[str, Any], timeout: int = 86400) -> Dict[
request_input = {"input": request_input}

job_request = self.rp_client.post(
f"{self.endpoint_id}/runsync", request_input, timeout=timeout)
self._get_endpoint_url_for_method("runsync"),
request_input,
timeout=timeout,
)

if job_request["status"] in FINAL_STATES:
return job_request.get("output", None)
Expand All @@ -221,7 +245,8 @@ def health(self, timeout: int = 3) -> Dict[str, Any]:
Args:
timeout: The number of seconds to wait for the server to respond before giving up.
"""
return self.rp_client.get(f"{self.endpoint_id}/health", timeout=timeout)
return self.rp_client.get(
self._get_endpoint_url_for_method("health"), timeout=timeout)

def purge_queue(self, timeout: int = 3) -> Dict[str, Any]:
"""
Expand All @@ -230,4 +255,30 @@ def purge_queue(self, timeout: int = 3) -> Dict[str, Any]:
Args:
timeout: The number of seconds to wait for the server to respond before giving up.
"""
return self.rp_client.post(f"{self.endpoint_id}/purge-queue", data=None, timeout=timeout)
return self.rp_client.post(
self._get_endpoint_url_for_method("purge-queue"),
data=None,
timeout=timeout,
)

def _get_endpoint_url_for_method(self, method: str) -> str:
""" Returns the endpoint URL for the given method. """
if method not in ["run", "runsync", "health", "purge-queue"]:
raise ValueError(f"Method '{method}' is not supported.")

if not self.endpoint_id:
return f"{method}"
return f"{self.endpoint_id}/{method}"


class CustomEndpoint(Endpoint):
def __init__(self, custom_endpoint_url_base: str):
"""
Initialize an Endpoint instance with the given endpoint base URL.
Intended for usage with regular Pods or test servers.

Args:
custom_endpoint_url_base: The custom endpoint URL base.
"""
self.endpoint_id = None
self.rp_client = RunPodClient(custom_endpoint_url_base)