Skip to content

Commit

Permalink
Merge pull request #64 from marclp-es/add_imagegen_example
Browse files Browse the repository at this point in the history
  • Loading branch information
tjbck authored Jun 5, 2024
2 parents ac14bf6 + dbb6800 commit 53eee3d
Showing 1 changed file with 86 additions and 0 deletions.
86 changes: 86 additions & 0 deletions examples/pipelines/providers/openai_dalle_manifold_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""A manifold to integrate OpenAI's ImageGen models into Open-WebUI"""

from typing import List, Union, Generator, Iterator

from pydantic import BaseModel

from openai import OpenAI

class Pipeline:
"""OpenAI ImageGen pipeline"""

class Valves(BaseModel):
"""Options to change from the WebUI"""

OPENAI_API_BASE_URL: str = "https://api.openai.com/v1"
OPENAI_API_KEY: str = ""
IMAGE_SIZE: str = "1024x1024"
NUM_IMAGES: int = 1

def __init__(self):
self.type = "manifold"
self.name = "ImageGen: "

self.valves = self.Valves()
self.client = OpenAI(
base_url=self.valves.OPENAI_API_BASE_URL,
api_key=self.valves.OPENAI_API_KEY,
)

self.pipelines = self.get_openai_assistants()

async def on_startup(self) -> None:
"""This function is called when the server is started."""
print(f"on_startup:{__name__}")

async def on_shutdown(self):
"""This function is called when the server is stopped."""
print(f"on_shutdown:{__name__}")

async def on_valves_updated(self):
"""This function is called when the valves are updated."""
print(f"on_valves_updated:{__name__}")
self.client = OpenAI(
base_url=self.valves.OPENAI_API_BASE_URL,
api_key=self.valves.OPENAI_API_KEY,
)
self.pipelines = self.get_openai_assistants()

def get_openai_assistants(self) -> List[dict]:
"""Get the available ImageGen models from OpenAI
Returns:
List[dict]: The list of ImageGen models
"""

if self.valves.OPENAI_API_KEY:
models = self.client.models.list()
return [
{
"id": model.id,
"name": model.id,
}
for model in models
if "dall-e" in model.id
]

return []

def pipe(
self, user_message: str, model_id: str, messages: List[dict], body: dict
) -> Union[str, Generator, Iterator]:
print(f"pipe:{__name__}")

response = self.client.images.generate(
model=model_id,
prompt=user_message,
size=self.valves.IMAGE_SIZE,
n=self.valves.NUM_IMAGES,
)

message = ""
for image in response.data:
if image.url:
message += "![image](" + image.url + ")\n"

yield message

0 comments on commit 53eee3d

Please sign in to comment.