-
Notifications
You must be signed in to change notification settings - Fork 176
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
<html> | ||
<head> | ||
<script | ||
defer | ||
src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js" | ||
></script> | ||
<script src="https://cdnjs.cloudflare.com/ajax/libs/marked/5.1.2/marked.min.js"></script> | ||
<link href="https://unpkg.com/@tailwindcss/typography@0.4.1/dist/typography.min.css" rel="stylesheet"> | ||
<script src="https://cdn.tailwindcss.com"></script> | ||
|
||
<!-- <script src="https://cdnjs.cloudflare.com/ajax/libs/dompurify/3.0.5/purify.min.js"></script> | ||
item.markdownCompletion = DOMPurify.sanitize( | ||
marked.parse(item.completion, {mangle: false, headerIds: false}) | ||
); --> | ||
|
||
<meta name="viewport" content="width=device-width, initial-scale=1" /> | ||
<title>Modal vLLM Engine</title> | ||
</head> | ||
<body> | ||
<section x-data="state()" class="max-w-2xl mx-auto pt-16 px-4"> | ||
<div class="text-xs font-semibold tracking-wide uppercase text-center text-white"> | ||
<a | ||
href="https://modal.com/" | ||
class="inline-flex gap-x-1 items-center bg-lime-400 py-0.5 px-3 rounded-full hover:text-lime-400 hover:ring hover:ring-lime-400 hover:bg-white focus:outline-neutral-400" | ||
target="_blank" | ||
> | ||
powered by Modal | ||
<svg | ||
xmlns="http://www.w3.org/2000/svg" | ||
class="w-4 h-4 animate-pulse -mr-1" | ||
viewBox="0 0 24 24" | ||
fill="none" | ||
stroke="currentColor" | ||
stroke-width="2" | ||
stroke-linecap="round" | ||
stroke-linejoin="round" | ||
class="lucide lucide-chevrons-right" | ||
> | ||
<path d="m6 17 5-5-5-5" /> | ||
<path d="m13 17 5-5-5-5" /> | ||
</svg> | ||
</a> | ||
</div> | ||
<div class="text-4xl mt-4 mb-4 font-semibold tracking-tighter text-center"> | ||
LLaMA 2 | ||
</div> | ||
|
||
<div class="flex flex-wrap justify-center items-center mt-8 mb-6"> | ||
<div | ||
x-init="setInterval(() => refreshInfo(), 1000)" | ||
class="inline-flex justify-center items-center gap-x-4 text-sm text-white px-3 py-1 bg-neutral-600 rounded-full" | ||
> | ||
<div x-show="!info.loaded" class="flex items-center gap-x-1"> | ||
<div class="animate-spin w-4 h-4"> | ||
<svg | ||
xmlns="http://www.w3.org/2000/svg" | ||
viewBox="0 0 24 24" | ||
fill="none" | ||
stroke="currentColor" | ||
stroke-width="2" | ||
stroke-linecap="round" | ||
stroke-linejoin="round" | ||
> | ||
<path d="M21 12a9 9 0 1 1-6.219-8.56" /> | ||
</svg> | ||
</div> | ||
<span>loading stats</span> | ||
</div> | ||
<div x-show="info.loaded && info.backlog > 0"> | ||
<span x-text="info.backlog"></span> | ||
inputs in queue | ||
</div> | ||
<div x-show="info.loaded && (info.num_active_runners > 0 || info.backlog === 0)"> | ||
<span x-text="info.num_total_runners"></span> | ||
<span x-text="info.num_total_runners === 1 ? 'GPU online' : 'GPUs online'"></span> | ||
</div> | ||
<div x-show="info.loaded && info.num_active_runners > 0 && info.tps !== undefined"> | ||
<span x-show="info.tps !== undefined" x-text="info.tps.toFixed(2)"></span> | ||
tokens/s | ||
</div> | ||
<div | ||
class="flex items-center gap-x-1" | ||
x-show="info.num_active_runners == 0 && info.backlog > 0" | ||
> | ||
<div class="animate-spin w-4 h-4"> | ||
<svg | ||
xmlns="http://www.w3.org/2000/svg" | ||
viewBox="0 0 24 24" | ||
fill="none" | ||
stroke="currentColor" | ||
stroke-width="2" | ||
stroke-linecap="round" | ||
stroke-linejoin="round" | ||
> | ||
<path d="M21 12a9 9 0 1 1-6.219-8.56" /> | ||
</svg> | ||
</div> | ||
<span> GPU cold-starting </span> | ||
</div> | ||
</div> | ||
</div> | ||
|
||
<form class="relative"> | ||
<input | ||
x-model="nextPrompt" | ||
type="text" | ||
placeholder="Ask something ..." | ||
class="flex grow w-full h-10 pl-4 pr-12 py-2 text-md bg-white border rounded-3xl border-neutral-300 ring-offset-background placeholder:text-neutral-500 focus:border-neutral-300 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-neutral-400 disabled:cursor-not-allowed disabled:opacity-50" | ||
@keydown.window.prevent.ctrl.k="$el.focus()" | ||
@keydown.window.prevent.cmd.k="$el.focus()" | ||
autofocus | ||
/> | ||
<div class="absolute top-0 right-0 flex items-center h-full pr-[0.3125rem]"> | ||
<button | ||
@click.prevent="callApi()" | ||
class="rounded-full bg-lime-400 p-2 focus:border-neutral-300 focus:outline-neutral-400" | ||
> | ||
<svg | ||
xmlns="http://www.w3.org/2000/svg" | ||
class="w-4 h-4" | ||
viewBox="0 0 24 24" | ||
fill="none" | ||
stroke="currentColor" | ||
stroke-width="2" | ||
stroke-linecap="round" | ||
stroke-linejoin="round" | ||
class="lucide lucide-plus" | ||
> | ||
<path d="M5 12h14" /> | ||
<path d="M12 5v14" /> | ||
</svg> | ||
</button> | ||
</div> | ||
</form> | ||
|
||
<div class="flex flex-col gap-y-4 my-8"> | ||
<template x-for="(item, index) in [...items].reverse()" :key="index"> | ||
<div class="w-full border px-4 py-2 rounded-3xl"> | ||
<div | ||
x-data | ||
class="text-sm mt-2 mb-4 whitespace-pre-line" | ||
x-text="item.prompt" | ||
:class="{'animate-pulse': item.loading}" | ||
></div> | ||
<div | ||
x-show="item.completion.length === 0" | ||
class="h-4 w-2 mt-2 mb-4 bg-neutral-500 animate-pulse" | ||
></div> | ||
<div | ||
class="text-sm mt-2 mb-4 text-neutral-500 w-full prose max-w-none prose-neutral-100 leading-6" | ||
x-show="item.completion.length > 0" | ||
x-html="item.markdownCompletion" | ||
> | ||
</div> | ||
</div> | ||
</template> | ||
</div> | ||
|
||
<script> | ||
function state() { | ||
return { | ||
nextPrompt: "", | ||
items: [], | ||
info: { backlog: 0, num_total_runners: 0, num_active_runners: 0 }, | ||
callApi() { | ||
console.log(this.nextPrompt); | ||
if (!this.nextPrompt) return; | ||
|
||
let item = { | ||
id: Math.random(), | ||
prompt: this.nextPrompt, | ||
completion: "", | ||
loading: true, | ||
markdownCompletion: "", | ||
}; | ||
this.nextPrompt = ""; | ||
this.items.push(item); | ||
const eventSource = new EventSource( | ||
`/completion/${encodeURIComponent(item.prompt)}`, | ||
); | ||
|
||
console.log("Created event source ..."); | ||
|
||
eventSource.onmessage = (event) => { | ||
item.completion += JSON.parse(event.data).text; | ||
item.markdownCompletion = marked.parse(item.completion, {mangle: false, headerIds: false}); | ||
// Hacky way to notify element to update | ||
this.items = this.items.map((i) => | ||
i.id === item.id ? { ...item } : i, | ||
); | ||
}; | ||
|
||
eventSource.onerror = (event) => { | ||
eventSource.close(); | ||
item.loading = false; | ||
this.items = this.items.map((i) => | ||
i.id === item.id ? { ...item } : i, | ||
); | ||
console.log(item.completion); | ||
}; | ||
}, | ||
refreshInfo() { | ||
fetch("/stats") | ||
.then((response) => response.json()) | ||
.then((data) => { | ||
this.info = { ...data, loaded: true }; | ||
}) | ||
.catch((error) => console.log(error)); | ||
}, | ||
}; | ||
} | ||
</script> | ||
</section> | ||
</body> | ||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import time | ||
import os | ||
import json | ||
|
||
from modal import Stub, Mount, Image, Secret, Dict, asgi_app, web_endpoint, method, gpu | ||
|
||
from pathlib import Path | ||
|
||
MODEL_DIR = "/model" | ||
def download_model_to_folder(): | ||
from huggingface_hub import snapshot_download | ||
|
||
snapshot_download("meta-llama/Llama-2-13b-chat-hf", local_dir=MODEL_DIR, local_dir_use_symlinks=False, token=os.environ["HUGGINGFACE_TOKEN"]) | ||
|
||
vllm_image = ( | ||
Image.from_dockerhub("nvcr.io/nvidia/pytorch:22.12-py3") | ||
.pip_install( | ||
"torch==2.0.1", index_url="https://download.pytorch.org/whl/cu118" | ||
) | ||
# Pinned to 07/21/2023 | ||
.pip_install( | ||
"vllm @ git+https://github.com/vllm-project/vllm.git@d7a1c6d614756b3072df3e8b52c0998035fb453f" | ||
) | ||
.run_function(download_model_to_folder, secret=Secret.from_name("huggingface")) | ||
) | ||
|
||
stub = Stub("llama-demo") | ||
stub.dict = Dict.new() | ||
|
||
# vLLM class | ||
@stub.cls(gpu=gpu.A100(), image=vllm_image, allow_concurrent_inputs=60, concurrency_limit=1, container_idle_timeout=600) | ||
class Engine: | ||
def __enter__(self): | ||
from vllm.engine.arg_utils import AsyncEngineArgs | ||
from vllm.engine.async_llm_engine import AsyncLLMEngine | ||
|
||
# tokens generated since last report | ||
self.last_report, self.generated_tokens = time.time(), 0 | ||
|
||
engine_args = AsyncEngineArgs( | ||
model=MODEL_DIR, | ||
# Only uses 90% of GPU memory by default | ||
gpu_memory_utilization=0.95 | ||
) | ||
|
||
self.engine = AsyncLLMEngine.from_engine_args(engine_args) | ||
self.template = """<s>[INST] <<SYS>> | ||
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. | ||
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. | ||
<</SYS>> | ||
{} [/INST] """ | ||
|
||
def generated(self, n: int): | ||
# Log that n tokens have been generated | ||
t = time.time() | ||
self.generated_tokens += n | ||
# Save to dict every second | ||
if t - self.last_report > 1.0: | ||
stub.app.dict.update( | ||
tps=self.generated_tokens / (t - self.last_report), | ||
t=self.last_report | ||
) | ||
self.last_report, self.generated_tokens = t, 0 | ||
|
||
@method() | ||
async def completion(self, question: str): | ||
if not question: | ||
return | ||
|
||
from vllm.sampling_params import SamplingParams | ||
from vllm.utils import random_uuid | ||
|
||
sampling_params = SamplingParams( | ||
presence_penalty=0.8, | ||
temperature=0.2, | ||
top_p=0.95, | ||
top_k=50, | ||
max_tokens=1024, | ||
) | ||
request_id = random_uuid() | ||
results_generator = self.engine.generate(self.template.format(question), sampling_params, request_id) | ||
|
||
t0 = time.time() | ||
index, tokens = 0, 0 | ||
async for request_output in results_generator: | ||
if '\ufffd' == request_output.outputs[0].text[-1]: | ||
continue | ||
yield request_output.outputs[0].text[index:] | ||
index = len(request_output.outputs[0].text) | ||
|
||
# Token accounting | ||
new_tokens = len(request_output.outputs[0].token_ids) | ||
self.generated(new_tokens - tokens) | ||
tokens = new_tokens | ||
|
||
throughput = tokens / (time.time() - t0) | ||
print(f"Request completed: {throughput:.4f} tokens/s") | ||
print(request_output.outputs[0].text) | ||
|
||
|
||
# Front-end functionality | ||
frontend_path = Path(__file__).parent / "vllm-hosted" | ||
|
||
@stub.function( | ||
mounts=[Mount.from_local_dir(frontend_path, remote_path="/assets")], | ||
keep_warm=3, | ||
concurrency_limit=6, | ||
allow_concurrent_inputs=24, | ||
timeout=600, | ||
) | ||
@asgi_app() | ||
def app(): | ||
import fastapi | ||
import fastapi.staticfiles | ||
from fastapi.responses import StreamingResponse | ||
from fastapi.middleware.cors import CORSMiddleware | ||
|
||
web_app = fastapi.FastAPI() | ||
|
||
@web_app.get("/stats") | ||
async def stats(): | ||
stats = Engine().completion.get_current_stats() | ||
try: | ||
tps, t = stub.app.dict.get("tps"), stub.app.dict.get("t") | ||
except KeyError: | ||
tps, t = 0, 0 | ||
return { | ||
"backlog": stats.backlog, | ||
"num_active_runners": stats.num_active_runners, | ||
"num_total_runners": stats.num_total_runners, | ||
"tps": tps if t > time.time() - 4.0 else 0, | ||
} | ||
|
||
@web_app.get("/completion/{question}") | ||
async def get(question: str): | ||
from urllib.parse import unquote | ||
|
||
print("Web server received request for", unquote(question)) | ||
|
||
# FastAPI will run this in a separate thread | ||
def generate(): | ||
for chunk in Engine().completion.call(unquote(question)): | ||
yield f'data: {json.dumps(dict(text=chunk), ensure_ascii=False)}\n\n' | ||
|
||
return StreamingResponse(generate(), media_type="text/event-stream") | ||
|
||
web_app.mount("/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True)) | ||
return web_app |