Skip to content

Commit

Permalink
Merge pull request #3292 from lllyasviel/develop
Browse files Browse the repository at this point in the history
Release v2.5.0
  • Loading branch information
mashb1t authored Jul 17, 2024
2 parents 5a71495 + 97a8475 commit f97adaf
Show file tree
Hide file tree
Showing 39 changed files with 2,754 additions and 927 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ __pycache__
*.partial
*.onnx
sorted_styles.json
hash_cache.txt
/input
/cache
/language/default.json
Expand Down
9 changes: 6 additions & 3 deletions args_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@
args_parser.parser.add_argument("--disable-preset-download", action='store_true',
help="Disables downloading models for presets", default=False)

args_parser.parser.add_argument("--enable-describe-uov-image", action='store_true',
help="Disables automatic description of uov images when prompt is empty", default=False)
args_parser.parser.add_argument("--enable-auto-describe-image", action='store_true',
help="Enables automatic description of uov and enhance image when prompt is empty", default=False)

args_parser.parser.add_argument("--always-download-new-model", action='store_true',
help="Always download newer models ", default=False)
help="Always download newer models", default=False)

args_parser.parser.add_argument("--rebuild-hash-cache", help="Generates missing model and LoRA hashes.",
type=int, nargs="?", metavar="CPU_NUM_THREADS", const=-1)

args_parser.parser.set_defaults(
disable_cuda_malloc=True,
Expand Down
2 changes: 1 addition & 1 deletion css/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ div:has(> #positive_prompt) {
}

.advanced_check_row {
width: 250px !important;
width: 330px !important;
}

.min_check {
Expand Down
24 changes: 24 additions & 0 deletions experiments_mask_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# https://github.com/sail-sg/EditAnything/blob/main/sam2groundingdino_edit.py

import numpy as np
from PIL import Image

from extras.inpaint_mask import SAMOptions, generate_mask_from_image

original_image = Image.open('cat.webp')
image = np.array(original_image, dtype=np.uint8)

sam_options = SAMOptions(
dino_prompt='eye',
dino_box_threshold=0.3,
dino_text_threshold=0.25,
dino_erode_or_dilate=0,
dino_debug=False,
max_detections=2,
model_type='vit_b'
)

mask_image, _, _, _ = generate_mask_from_image(image, sam_options=sam_options)

merged_masks_img = Image.fromarray(mask_image)
merged_masks_img.show()
43 changes: 43 additions & 0 deletions extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
batch_size = 1
modelname = "groundingdino"
backbone = "swin_T_224_1k"
position_embedding = "sine"
pe_temperatureH = 20
pe_temperatureW = 20
return_interm_indices = [1, 2, 3]
backbone_freeze_keywords = None
enc_layers = 6
dec_layers = 6
pre_norm = False
dim_feedforward = 2048
hidden_dim = 256
dropout = 0.0
nheads = 8
num_queries = 900
query_dim = 4
num_patterns = 0
num_feature_levels = 4
enc_n_points = 4
dec_n_points = 4
two_stage_type = "standard"
two_stage_bbox_embed_share = False
two_stage_class_embed_share = False
transformer_activation = "relu"
dec_pred_bbox_embed_share = True
dn_box_noise_scale = 1.0
dn_label_noise_ratio = 0.5
dn_label_coef = 1.0
dn_bbox_coef = 1.0
embed_init_tgt = True
dn_labelbook_size = 2000
max_text_len = 256
text_encoder_type = "bert-base-uncased"
use_text_enhancer = True
use_fusion_layer = True
use_checkpoint = True
use_transformer_ckpt = True
use_text_cross_attention = True
text_dropout = 0.0
fusion_dropout = 0.0
fusion_droppath = 0.1
sub_sentence_present = True
100 changes: 100 additions & 0 deletions extras/GroundingDINO/util/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Tuple, List

import ldm_patched.modules.model_management as model_management
from ldm_patched.modules.model_patcher import ModelPatcher
from modules.config import path_inpaint
from modules.model_loader import load_file_from_url

import numpy as np
import supervision as sv
import torch
from groundingdino.util.inference import Model
from groundingdino.util.inference import load_model, preprocess_caption, get_phrases_from_posmap


class GroundingDinoModel(Model):
def __init__(self):
self.config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py'
self.model = None
self.load_device = torch.device('cpu')
self.offload_device = torch.device('cpu')

@torch.no_grad()
@torch.inference_mode()
def predict_with_caption(
self,
image: np.ndarray,
caption: str,
box_threshold: float = 0.35,
text_threshold: float = 0.25
) -> Tuple[sv.Detections, torch.Tensor, torch.Tensor, List[str]]:
if self.model is None:
filename = load_file_from_url(
url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
file_name='groundingdino_swint_ogc.pth',
model_dir=path_inpaint)
model = load_model(model_config_path=self.config_file, model_checkpoint_path=filename)

self.load_device = model_management.text_encoder_device()
self.offload_device = model_management.text_encoder_offload_device()

model.to(self.offload_device)

self.model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device)

model_management.load_model_gpu(self.model)

processed_image = GroundingDinoModel.preprocess_image(image_bgr=image).to(self.load_device)
boxes, logits, phrases = predict(
model=self.model,
image=processed_image,
caption=caption,
box_threshold=box_threshold,
text_threshold=text_threshold,
device=self.load_device)
source_h, source_w, _ = image.shape
detections = GroundingDinoModel.post_process_result(
source_h=source_h,
source_w=source_w,
boxes=boxes,
logits=logits)
return detections, boxes, logits, phrases


def predict(
model,
image: torch.Tensor,
caption: str,
box_threshold: float,
text_threshold: float,
device: str = "cuda"
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]:
caption = preprocess_caption(caption=caption)

# override to use model wrapped by patcher
model = model.model.to(device)
image = image.to(device)

with torch.no_grad():
outputs = model(image[None], captions=[caption])

prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256)
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4)

mask = prediction_logits.max(dim=1)[0] > box_threshold
logits = prediction_logits[mask] # logits.shape = (n, 256)
boxes = prediction_boxes[mask] # boxes.shape = (n, 4)

tokenizer = model.tokenizer
tokenized = tokenizer(caption)

phrases = [
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '')
for logit
in logits
]

return boxes, logits.max(dim=1)[0], phrases


default_groundingdino = GroundingDinoModel().predict_with_caption
2 changes: 1 addition & 1 deletion extras/censor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def censor(self, images: list | np.ndarray) -> list | np.ndarray:
model_management.load_model_gpu(self.safety_checker_model)

single = False
if not isinstance(images, list) or isinstance(images, np.ndarray):
if not isinstance(images, (list, np.ndarray)):
images = [images]
single = True

Expand Down
130 changes: 130 additions & 0 deletions extras/inpaint_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import sys

import modules.config
import numpy as np
import torch
from extras.GroundingDINO.util.inference import default_groundingdino
from extras.sam.predictor import SamPredictor
from rembg import remove, new_session
from segment_anything import sam_model_registry
from segment_anything.utils.amg import remove_small_regions


class SAMOptions:
def __init__(self,
# GroundingDINO
dino_prompt: str = '',
dino_box_threshold=0.3,
dino_text_threshold=0.25,
dino_erode_or_dilate=0,
dino_debug=False,

# SAM
max_detections=2,
model_type='vit_b'
):
self.dino_prompt = dino_prompt
self.dino_box_threshold = dino_box_threshold
self.dino_text_threshold = dino_text_threshold
self.dino_erode_or_dilate = dino_erode_or_dilate
self.dino_debug = dino_debug
self.max_detections = max_detections
self.model_type = model_type


def optimize_masks(masks: torch.Tensor) -> torch.Tensor:
"""
removes small disconnected regions and holes
"""
fine_masks = []
for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w]
fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0])
masks = np.stack(fine_masks, axis=0)[:, np.newaxis]
return torch.from_numpy(masks)


def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=None,
sam_options: SAMOptions | None = SAMOptions) -> tuple[np.ndarray | None, int | None, int | None, int | None]:
dino_detection_count = 0
sam_detection_count = 0
sam_detection_on_mask_count = 0

if image is None:
return None, dino_detection_count, sam_detection_count, sam_detection_on_mask_count

if extras is None:
extras = {}

if 'image' in image:
image = image['image']

if mask_model != 'sam' or sam_options is None:
result = remove(
image,
session=new_session(mask_model, **extras),
only_mask=True,
**extras
)

return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count

detections, boxes, logits, phrases = default_groundingdino(
image=image,
caption=sam_options.dino_prompt,
box_threshold=sam_options.dino_box_threshold,
text_threshold=sam_options.dino_text_threshold
)

H, W = image.shape[0], image.shape[1]
boxes = boxes * torch.Tensor([W, H, W, H])
boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2
boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2]

sam_checkpoint = modules.config.download_sam_model(sam_options.model_type)
sam = sam_model_registry[sam_options.model_type](checkpoint=sam_checkpoint)

sam_predictor = SamPredictor(sam)
final_mask_tensor = torch.zeros((image.shape[0], image.shape[1]))
dino_detection_count = boxes.size(0)

if dino_detection_count > 0:
sam_predictor.set_image(image)

if sam_options.dino_erode_or_dilate != 0:
for index in range(boxes.size(0)):
assert boxes.size(1) == 4
boxes[index][0] -= sam_options.dino_erode_or_dilate
boxes[index][1] -= sam_options.dino_erode_or_dilate
boxes[index][2] += sam_options.dino_erode_or_dilate
boxes[index][3] += sam_options.dino_erode_or_dilate

if sam_options.dino_debug:
from PIL import ImageDraw, Image
debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black")
draw = ImageDraw.Draw(debug_dino_image)
for box in boxes.numpy():
draw.rectangle(box.tolist(), fill="white")
return np.array(debug_dino_image), dino_detection_count, sam_detection_count, sam_detection_on_mask_count

transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2])
masks, _, _ = sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)

masks = optimize_masks(masks)
sam_detection_count = len(masks)
if sam_options.max_detections == 0:
sam_options.max_detections = sys.maxsize
sam_objects = min(len(logits), sam_options.max_detections)
for obj_ind in range(sam_objects):
mask_tensor = masks[obj_ind][0]
final_mask_tensor += mask_tensor
sam_detection_on_mask_count += 1

final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy()
mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255
mask_image = np.array(mask_image, dtype=np.uint8)
return mask_image, dino_detection_count, sam_detection_count, sam_detection_on_mask_count
Loading

0 comments on commit f97adaf

Please sign in to comment.