Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Configurable Default Engines via Environment Variables #265

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Install the [PyPI](https://pypi.org/project/stability-sdk/) package via:
- `pyenv/bin/activate` to use the venv.
- Set the `STABILITY_HOST` environment variable. This is by default set to the production endpoint `grpc.stability.ai:443`.
- Set the `STABILITY_KEY` environment variable.
- Optional, set the `DEFAULT_ENGINE` environment variable. This is by default set to `stable-diffusion-xl-1024-v1-0`.
- Optional, set the `DEFAULT_UPSCALE_ENGINE` environment variable. This is by default set to `esrgan-v1-x2plus`.

Then to invoke:

Expand Down
22 changes: 13 additions & 9 deletions src/stability_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def __init__(
self,
host: str = "grpc.stability.ai:443",
key: str = "",
engine: str = "stable-diffusion-xl-1024-v1-0",
upscale_engine: str = "esrgan-v1-x2plus",
engine: str = None,
upscale_engine: str = None,
verbose: bool = False,
wait_for_ready: bool = True,
):
Expand All @@ -105,16 +105,20 @@ def __init__(

:param host: Host to connect to.
:param key: Key to use for authentication.
:param engine: Engine to use.
:param upscale_engine: Upscale engine to use.
:param engine: Engine to use. Defaults to the value from the environment
variable DEFAULT_ENGINE, or "stable-diffusion-xl-1024-v1-0" if the
variable is not set.
:param upscale_engine: Upscale engine to use. Defaults to the value from the
environment variable UPSCALE_ENGINE, or "esrgan-v1-x2plus" if the variable
is not set.
:param verbose: Whether to print debug messages.
:param wait_for_ready: Whether to wait for the server to be ready, or
to fail immediately.
"""
self.verbose = verbose
self.engine = engine
self.upscale_engine = upscale_engine

self.engine = engine or os.getenv("DEFAULT_ENGINE", "stable-diffusion-xl-1024-v1-0")
self.upscale_engine = upscale_engine or os.getenv("DEFAULT_UPSCALE_ENGINE", "esrgan-v1-x2plus")
self.grpc_args = {"wait_for_ready": wait_for_ready}
if verbose:
logger.info(f"Opening channel to {host}")
Expand Down Expand Up @@ -500,7 +504,7 @@ def process_cli(
"-e",
type=str,
help="engine to use for upscale",
default="esrgan-v1-x2plus",
default=None,
)
parser_upscale.add_argument(
"prompt", nargs="*"
Expand Down Expand Up @@ -573,7 +577,7 @@ def process_cli(
"-e",
type=str,
help="engine to use for inference",
default="stable-diffusion-xl-1024-v1-0",
default=None,
)
parser_generate.add_argument(
"--init_image",
Expand Down