Skip to content

Commit

Permalink
Example for hosting vLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
gongy committed Aug 4, 2023
1 parent f5ea162 commit de3ff1b
Show file tree
Hide file tree
Showing 2 changed files with 365 additions and 0 deletions.
215 changes: 215 additions & 0 deletions 06_gpu_and_ml/vllm-hosted/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
<html>
<head>
<script
defer
src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"
></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/marked/5.1.2/marked.min.js"></script>
<link href="https://unpkg.com/@tailwindcss/typography@0.4.1/dist/typography.min.css" rel="stylesheet">
<script src="https://cdn.tailwindcss.com"></script>

<!-- <script src="https://cdnjs.cloudflare.com/ajax/libs/dompurify/3.0.5/purify.min.js"></script>
item.markdownCompletion = DOMPurify.sanitize(
marked.parse(item.completion, {mangle: false, headerIds: false})
); -->

<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Modal vLLM Engine</title>
</head>
<body>
<section x-data="state()" class="max-w-2xl mx-auto pt-16 px-4">
<div class="text-xs font-semibold tracking-wide uppercase text-center text-white">
<a
href="https://modal.com/"
class="inline-flex gap-x-1 items-center bg-lime-400 py-0.5 px-3 rounded-full hover:text-lime-400 hover:ring hover:ring-lime-400 hover:bg-white focus:outline-neutral-400"
target="_blank"
>
powered by Modal
<svg
xmlns="http://www.w3.org/2000/svg"
class="w-4 h-4 animate-pulse -mr-1"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
class="lucide lucide-chevrons-right"
>
<path d="m6 17 5-5-5-5" />
<path d="m13 17 5-5-5-5" />
</svg>
</a>
</div>
<div class="text-4xl mt-4 mb-4 font-semibold tracking-tighter text-center">
LLaMA 2
</div>

<div class="flex flex-wrap justify-center items-center mt-8 mb-6">
<div
x-init="setInterval(() => refreshInfo(), 1000)"
class="inline-flex justify-center items-center gap-x-4 text-sm text-white px-3 py-1 bg-neutral-600 rounded-full"
>
<div x-show="!info.loaded" class="flex items-center gap-x-1">
<div class="animate-spin w-4 h-4">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path d="M21 12a9 9 0 1 1-6.219-8.56" />
</svg>
</div>
<span>loading stats</span>
</div>
<div x-show="info.loaded && info.backlog > 0">
<span x-text="info.backlog"></span>
inputs in queue
</div>
<div x-show="info.loaded && (info.num_active_runners > 0 || info.backlog === 0)">
<span x-text="info.num_total_runners"></span>
<span x-text="info.num_total_runners === 1 ? 'GPU online' : 'GPUs online'"></span>
</div>
<div x-show="info.loaded && info.num_active_runners > 0 && info.tps !== undefined">
<span x-show="info.tps !== undefined" x-text="info.tps.toFixed(2)"></span>
tokens/s
</div>
<div
class="flex items-center gap-x-1"
x-show="info.num_active_runners == 0 && info.backlog > 0"
>
<div class="animate-spin w-4 h-4">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path d="M21 12a9 9 0 1 1-6.219-8.56" />
</svg>
</div>
<span> GPU cold-starting </span>
</div>
</div>
</div>

<form class="relative">
<input
x-model="nextPrompt"
type="text"
placeholder="Ask something ..."
class="flex grow w-full h-10 pl-4 pr-12 py-2 text-md bg-white border rounded-3xl border-neutral-300 ring-offset-background placeholder:text-neutral-500 focus:border-neutral-300 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-neutral-400 disabled:cursor-not-allowed disabled:opacity-50"
@keydown.window.prevent.ctrl.k="$el.focus()"
@keydown.window.prevent.cmd.k="$el.focus()"
autofocus
/>
<div class="absolute top-0 right-0 flex items-center h-full pr-[0.3125rem]">
<button
@click.prevent="callApi()"
class="rounded-full bg-lime-400 p-2 focus:border-neutral-300 focus:outline-neutral-400"
>
<svg
xmlns="http://www.w3.org/2000/svg"
class="w-4 h-4"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
class="lucide lucide-plus"
>
<path d="M5 12h14" />
<path d="M12 5v14" />
</svg>
</button>
</div>
</form>

<div class="flex flex-col gap-y-4 my-8">
<template x-for="(item, index) in [...items].reverse()" :key="index">
<div class="w-full border px-4 py-2 rounded-3xl">
<div
x-data
class="text-sm mt-2 mb-4 whitespace-pre-line"
x-text="item.prompt"
:class="{'animate-pulse': item.loading}"
></div>
<div
x-show="item.completion.length === 0"
class="h-4 w-2 mt-2 mb-4 bg-neutral-500 animate-pulse"
></div>
<div
class="text-sm mt-2 mb-4 text-neutral-500 w-full prose max-w-none prose-neutral-100 leading-6"
x-show="item.completion.length > 0"
x-html="item.markdownCompletion"
>
</div>
</div>
</template>
</div>

<script>
function state() {
return {
nextPrompt: "",
items: [],
info: { backlog: 0, num_total_runners: 0, num_active_runners: 0 },
callApi() {
console.log(this.nextPrompt);
if (!this.nextPrompt) return;

let item = {
id: Math.random(),
prompt: this.nextPrompt,
completion: "",
loading: true,
markdownCompletion: "",
};
this.nextPrompt = "";
this.items.push(item);
const eventSource = new EventSource(
`/completion/${encodeURIComponent(item.prompt)}`,
);

console.log("Created event source ...");

eventSource.onmessage = (event) => {
item.completion += JSON.parse(event.data).text;
item.markdownCompletion = marked.parse(item.completion, {mangle: false, headerIds: false});
// Hacky way to notify element to update
this.items = this.items.map((i) =>
i.id === item.id ? { ...item } : i,
);
};

eventSource.onerror = (event) => {
eventSource.close();
item.loading = false;
this.items = this.items.map((i) =>
i.id === item.id ? { ...item } : i,
);
console.log(item.completion);
};
},
refreshInfo() {
fetch("/stats")
.then((response) => response.json())
.then((data) => {
this.info = { ...data, loaded: true };
})
.catch((error) => console.log(error));
},
};
}
</script>
</section>
</body>
</html>
150 changes: 150 additions & 0 deletions 06_gpu_and_ml/vllm_hosted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import time
import os
import json

from modal import Stub, Mount, Image, Secret, Dict, asgi_app, web_endpoint, method, gpu

from pathlib import Path

MODEL_DIR = "/model"
def download_model_to_folder():
from huggingface_hub import snapshot_download

snapshot_download("meta-llama/Llama-2-13b-chat-hf", local_dir=MODEL_DIR, local_dir_use_symlinks=False, token=os.environ["HUGGINGFACE_TOKEN"])

vllm_image = (
Image.from_dockerhub("nvcr.io/nvidia/pytorch:22.12-py3")
.pip_install(
"torch==2.0.1", index_url="https://download.pytorch.org/whl/cu118"
)
# Pinned to 07/21/2023
.pip_install(
"vllm @ git+https://github.com/vllm-project/vllm.git@d7a1c6d614756b3072df3e8b52c0998035fb453f"
)
.run_function(download_model_to_folder, secret=Secret.from_name("huggingface"))
)

stub = Stub("llama-demo")
stub.dict = Dict.new()

# vLLM class
@stub.cls(gpu=gpu.A100(), image=vllm_image, allow_concurrent_inputs=60, concurrency_limit=1, container_idle_timeout=600)
class Engine:
def __enter__(self):
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine

# tokens generated since last report
self.last_report, self.generated_tokens = time.time(), 0

engine_args = AsyncEngineArgs(
model=MODEL_DIR,
# Only uses 90% of GPU memory by default
gpu_memory_utilization=0.95
)

self.engine = AsyncLLMEngine.from_engine_args(engine_args)
self.template = """<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>
{} [/INST] """

def generated(self, n: int):
# Log that n tokens have been generated
t = time.time()
self.generated_tokens += n
# Save to dict every second
if t - self.last_report > 1.0:
stub.app.dict.update(
tps=self.generated_tokens / (t - self.last_report),
t=self.last_report
)
self.last_report, self.generated_tokens = t, 0

@method()
async def completion(self, question: str):
if not question:
return

from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

sampling_params = SamplingParams(
presence_penalty=0.8,
temperature=0.2,
top_p=0.95,
top_k=50,
max_tokens=1024,
)
request_id = random_uuid()
results_generator = self.engine.generate(self.template.format(question), sampling_params, request_id)

t0 = time.time()
index, tokens = 0, 0
async for request_output in results_generator:
if '\ufffd' == request_output.outputs[0].text[-1]:
continue
yield request_output.outputs[0].text[index:]
index = len(request_output.outputs[0].text)

# Token accounting
new_tokens = len(request_output.outputs[0].token_ids)
self.generated(new_tokens - tokens)
tokens = new_tokens

throughput = tokens / (time.time() - t0)
print(f"Request completed: {throughput:.4f} tokens/s")
print(request_output.outputs[0].text)


# Front-end functionality
frontend_path = Path(__file__).parent / "vllm-hosted"

@stub.function(
mounts=[Mount.from_local_dir(frontend_path, remote_path="/assets")],
keep_warm=3,
concurrency_limit=6,
allow_concurrent_inputs=24,
timeout=600,
)
@asgi_app()
def app():
import fastapi
import fastapi.staticfiles
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware

web_app = fastapi.FastAPI()

@web_app.get("/stats")
async def stats():
stats = Engine().completion.get_current_stats()
try:
tps, t = stub.app.dict.get("tps"), stub.app.dict.get("t")
except KeyError:
tps, t = 0, 0
return {
"backlog": stats.backlog,
"num_active_runners": stats.num_active_runners,
"num_total_runners": stats.num_total_runners,
"tps": tps if t > time.time() - 4.0 else 0,
}

@web_app.get("/completion/{question}")
async def get(question: str):
from urllib.parse import unquote

print("Web server received request for", unquote(question))

# FastAPI will run this in a separate thread
def generate():
for chunk in Engine().completion.call(unquote(question)):
yield f'data: {json.dumps(dict(text=chunk), ensure_ascii=False)}\n\n'

return StreamingResponse(generate(), media_type="text/event-stream")

web_app.mount("/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True))
return web_app

0 comments on commit de3ff1b

Please sign in to comment.