-
Notifications
You must be signed in to change notification settings - Fork 8
/
infer.py
59 lines (48 loc) · 2.14 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from huggingface_hub.repocard import RepoCard
from diffusers import DiffusionPipeline
import torch
import os
model = "stabilityai/stable-diffusion-xl-base-1.0"
subject="cat"
templates = [
"A picture of a {} in the jungle",
"A picture of a {} in the snow",
"A picture of a {} on the beach",
"A picture of a {} on a cobblestone street",
"A picture of a {} on top of pink fabric",
"A picture of a {} on top of a wooden floor",
"A picture of a {} with a city in the background",
"A picture of a {} with a mountain in the background",
"A picture of a {} with a blue house in the background",
"A picture of a {} on top of a purple rug in a forest",
"A picture of a {} with a wheat field in the background",
"A picture of a {} with a tree and autumn leaves in the background",
"A picture of a {} with the Eiffel Tower in the background",
"A picture of a {} floating on top of water",
"A picture of a {} floating in an ocean of milk",
"A picture of a {} on top of green grass with sunflowers around it",
"A picture of a {} on top of a mirror",
"A picture of a {} on top of the sidewalk in a crowded street",
"A picture of a {} on top of a dirt road",
"A picture of a {} on top of a white rug",
"A picture of a red {}",
"A picture of a purple {}",
"A picture of a shiny {}",
"A picture of a wet {}",
"A picture of a cube shaped {}",
]
prompts = [template.format(subject) for template in templates]
for method in ["lora", "lora-dash"]:
if method == "lora":
path = "./lora-trained-xl-{}".format(subject.replace(" ", "_"))
elif method == "lora-dash":
path = "./lora-trained-xl-dash-{}".format(subject.replace(" ", "_"))
pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.load_lora_weights(path)
for prompt in prompts:
print(prompt)
image = pipe(prompt, num_inference_steps=50).images[0]
output_p_dir="output_images/{}/{}".format(subject, prompt.replace(" ", "_"))
os.makedirs(output_p_dir, exist_ok=True)
image.save("{}/{}.jpg".format(output_p_dir, method))