-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3292 from lllyasviel/develop
Release v2.5.0
- Loading branch information
Showing
39 changed files
with
2,754 additions
and
927 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ __pycache__ | |
*.partial | ||
*.onnx | ||
sorted_styles.json | ||
hash_cache.txt | ||
/input | ||
/cache | ||
/language/default.json | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# https://github.com/sail-sg/EditAnything/blob/main/sam2groundingdino_edit.py | ||
|
||
import numpy as np | ||
from PIL import Image | ||
|
||
from extras.inpaint_mask import SAMOptions, generate_mask_from_image | ||
|
||
original_image = Image.open('cat.webp') | ||
image = np.array(original_image, dtype=np.uint8) | ||
|
||
sam_options = SAMOptions( | ||
dino_prompt='eye', | ||
dino_box_threshold=0.3, | ||
dino_text_threshold=0.25, | ||
dino_erode_or_dilate=0, | ||
dino_debug=False, | ||
max_detections=2, | ||
model_type='vit_b' | ||
) | ||
|
||
mask_image, _, _, _ = generate_mask_from_image(image, sam_options=sam_options) | ||
|
||
merged_masks_img = Image.fromarray(mask_image) | ||
merged_masks_img.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
batch_size = 1 | ||
modelname = "groundingdino" | ||
backbone = "swin_T_224_1k" | ||
position_embedding = "sine" | ||
pe_temperatureH = 20 | ||
pe_temperatureW = 20 | ||
return_interm_indices = [1, 2, 3] | ||
backbone_freeze_keywords = None | ||
enc_layers = 6 | ||
dec_layers = 6 | ||
pre_norm = False | ||
dim_feedforward = 2048 | ||
hidden_dim = 256 | ||
dropout = 0.0 | ||
nheads = 8 | ||
num_queries = 900 | ||
query_dim = 4 | ||
num_patterns = 0 | ||
num_feature_levels = 4 | ||
enc_n_points = 4 | ||
dec_n_points = 4 | ||
two_stage_type = "standard" | ||
two_stage_bbox_embed_share = False | ||
two_stage_class_embed_share = False | ||
transformer_activation = "relu" | ||
dec_pred_bbox_embed_share = True | ||
dn_box_noise_scale = 1.0 | ||
dn_label_noise_ratio = 0.5 | ||
dn_label_coef = 1.0 | ||
dn_bbox_coef = 1.0 | ||
embed_init_tgt = True | ||
dn_labelbook_size = 2000 | ||
max_text_len = 256 | ||
text_encoder_type = "bert-base-uncased" | ||
use_text_enhancer = True | ||
use_fusion_layer = True | ||
use_checkpoint = True | ||
use_transformer_ckpt = True | ||
use_text_cross_attention = True | ||
text_dropout = 0.0 | ||
fusion_dropout = 0.0 | ||
fusion_droppath = 0.1 | ||
sub_sentence_present = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from typing import Tuple, List | ||
|
||
import ldm_patched.modules.model_management as model_management | ||
from ldm_patched.modules.model_patcher import ModelPatcher | ||
from modules.config import path_inpaint | ||
from modules.model_loader import load_file_from_url | ||
|
||
import numpy as np | ||
import supervision as sv | ||
import torch | ||
from groundingdino.util.inference import Model | ||
from groundingdino.util.inference import load_model, preprocess_caption, get_phrases_from_posmap | ||
|
||
|
||
class GroundingDinoModel(Model): | ||
def __init__(self): | ||
self.config_file = 'extras/GroundingDINO/config/GroundingDINO_SwinT_OGC.py' | ||
self.model = None | ||
self.load_device = torch.device('cpu') | ||
self.offload_device = torch.device('cpu') | ||
|
||
@torch.no_grad() | ||
@torch.inference_mode() | ||
def predict_with_caption( | ||
self, | ||
image: np.ndarray, | ||
caption: str, | ||
box_threshold: float = 0.35, | ||
text_threshold: float = 0.25 | ||
) -> Tuple[sv.Detections, torch.Tensor, torch.Tensor, List[str]]: | ||
if self.model is None: | ||
filename = load_file_from_url( | ||
url="https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth", | ||
file_name='groundingdino_swint_ogc.pth', | ||
model_dir=path_inpaint) | ||
model = load_model(model_config_path=self.config_file, model_checkpoint_path=filename) | ||
|
||
self.load_device = model_management.text_encoder_device() | ||
self.offload_device = model_management.text_encoder_offload_device() | ||
|
||
model.to(self.offload_device) | ||
|
||
self.model = ModelPatcher(model, load_device=self.load_device, offload_device=self.offload_device) | ||
|
||
model_management.load_model_gpu(self.model) | ||
|
||
processed_image = GroundingDinoModel.preprocess_image(image_bgr=image).to(self.load_device) | ||
boxes, logits, phrases = predict( | ||
model=self.model, | ||
image=processed_image, | ||
caption=caption, | ||
box_threshold=box_threshold, | ||
text_threshold=text_threshold, | ||
device=self.load_device) | ||
source_h, source_w, _ = image.shape | ||
detections = GroundingDinoModel.post_process_result( | ||
source_h=source_h, | ||
source_w=source_w, | ||
boxes=boxes, | ||
logits=logits) | ||
return detections, boxes, logits, phrases | ||
|
||
|
||
def predict( | ||
model, | ||
image: torch.Tensor, | ||
caption: str, | ||
box_threshold: float, | ||
text_threshold: float, | ||
device: str = "cuda" | ||
) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: | ||
caption = preprocess_caption(caption=caption) | ||
|
||
# override to use model wrapped by patcher | ||
model = model.model.to(device) | ||
image = image.to(device) | ||
|
||
with torch.no_grad(): | ||
outputs = model(image[None], captions=[caption]) | ||
|
||
prediction_logits = outputs["pred_logits"].cpu().sigmoid()[0] # prediction_logits.shape = (nq, 256) | ||
prediction_boxes = outputs["pred_boxes"].cpu()[0] # prediction_boxes.shape = (nq, 4) | ||
|
||
mask = prediction_logits.max(dim=1)[0] > box_threshold | ||
logits = prediction_logits[mask] # logits.shape = (n, 256) | ||
boxes = prediction_boxes[mask] # boxes.shape = (n, 4) | ||
|
||
tokenizer = model.tokenizer | ||
tokenized = tokenizer(caption) | ||
|
||
phrases = [ | ||
get_phrases_from_posmap(logit > text_threshold, tokenized, tokenizer).replace('.', '') | ||
for logit | ||
in logits | ||
] | ||
|
||
return boxes, logits.max(dim=1)[0], phrases | ||
|
||
|
||
default_groundingdino = GroundingDinoModel().predict_with_caption |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import sys | ||
|
||
import modules.config | ||
import numpy as np | ||
import torch | ||
from extras.GroundingDINO.util.inference import default_groundingdino | ||
from extras.sam.predictor import SamPredictor | ||
from rembg import remove, new_session | ||
from segment_anything import sam_model_registry | ||
from segment_anything.utils.amg import remove_small_regions | ||
|
||
|
||
class SAMOptions: | ||
def __init__(self, | ||
# GroundingDINO | ||
dino_prompt: str = '', | ||
dino_box_threshold=0.3, | ||
dino_text_threshold=0.25, | ||
dino_erode_or_dilate=0, | ||
dino_debug=False, | ||
|
||
# SAM | ||
max_detections=2, | ||
model_type='vit_b' | ||
): | ||
self.dino_prompt = dino_prompt | ||
self.dino_box_threshold = dino_box_threshold | ||
self.dino_text_threshold = dino_text_threshold | ||
self.dino_erode_or_dilate = dino_erode_or_dilate | ||
self.dino_debug = dino_debug | ||
self.max_detections = max_detections | ||
self.model_type = model_type | ||
|
||
|
||
def optimize_masks(masks: torch.Tensor) -> torch.Tensor: | ||
""" | ||
removes small disconnected regions and holes | ||
""" | ||
fine_masks = [] | ||
for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w] | ||
fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0]) | ||
masks = np.stack(fine_masks, axis=0)[:, np.newaxis] | ||
return torch.from_numpy(masks) | ||
|
||
|
||
def generate_mask_from_image(image: np.ndarray, mask_model: str = 'sam', extras=None, | ||
sam_options: SAMOptions | None = SAMOptions) -> tuple[np.ndarray | None, int | None, int | None, int | None]: | ||
dino_detection_count = 0 | ||
sam_detection_count = 0 | ||
sam_detection_on_mask_count = 0 | ||
|
||
if image is None: | ||
return None, dino_detection_count, sam_detection_count, sam_detection_on_mask_count | ||
|
||
if extras is None: | ||
extras = {} | ||
|
||
if 'image' in image: | ||
image = image['image'] | ||
|
||
if mask_model != 'sam' or sam_options is None: | ||
result = remove( | ||
image, | ||
session=new_session(mask_model, **extras), | ||
only_mask=True, | ||
**extras | ||
) | ||
|
||
return result, dino_detection_count, sam_detection_count, sam_detection_on_mask_count | ||
|
||
detections, boxes, logits, phrases = default_groundingdino( | ||
image=image, | ||
caption=sam_options.dino_prompt, | ||
box_threshold=sam_options.dino_box_threshold, | ||
text_threshold=sam_options.dino_text_threshold | ||
) | ||
|
||
H, W = image.shape[0], image.shape[1] | ||
boxes = boxes * torch.Tensor([W, H, W, H]) | ||
boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2 | ||
boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2] | ||
|
||
sam_checkpoint = modules.config.download_sam_model(sam_options.model_type) | ||
sam = sam_model_registry[sam_options.model_type](checkpoint=sam_checkpoint) | ||
|
||
sam_predictor = SamPredictor(sam) | ||
final_mask_tensor = torch.zeros((image.shape[0], image.shape[1])) | ||
dino_detection_count = boxes.size(0) | ||
|
||
if dino_detection_count > 0: | ||
sam_predictor.set_image(image) | ||
|
||
if sam_options.dino_erode_or_dilate != 0: | ||
for index in range(boxes.size(0)): | ||
assert boxes.size(1) == 4 | ||
boxes[index][0] -= sam_options.dino_erode_or_dilate | ||
boxes[index][1] -= sam_options.dino_erode_or_dilate | ||
boxes[index][2] += sam_options.dino_erode_or_dilate | ||
boxes[index][3] += sam_options.dino_erode_or_dilate | ||
|
||
if sam_options.dino_debug: | ||
from PIL import ImageDraw, Image | ||
debug_dino_image = Image.new("RGB", (image.shape[1], image.shape[0]), color="black") | ||
draw = ImageDraw.Draw(debug_dino_image) | ||
for box in boxes.numpy(): | ||
draw.rectangle(box.tolist(), fill="white") | ||
return np.array(debug_dino_image), dino_detection_count, sam_detection_count, sam_detection_on_mask_count | ||
|
||
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2]) | ||
masks, _, _ = sam_predictor.predict_torch( | ||
point_coords=None, | ||
point_labels=None, | ||
boxes=transformed_boxes, | ||
multimask_output=False, | ||
) | ||
|
||
masks = optimize_masks(masks) | ||
sam_detection_count = len(masks) | ||
if sam_options.max_detections == 0: | ||
sam_options.max_detections = sys.maxsize | ||
sam_objects = min(len(logits), sam_options.max_detections) | ||
for obj_ind in range(sam_objects): | ||
mask_tensor = masks[obj_ind][0] | ||
final_mask_tensor += mask_tensor | ||
sam_detection_on_mask_count += 1 | ||
|
||
final_mask_tensor = (final_mask_tensor > 0).to('cpu').numpy() | ||
mask_image = np.dstack((final_mask_tensor, final_mask_tensor, final_mask_tensor)) * 255 | ||
mask_image = np.array(mask_image, dtype=np.uint8) | ||
return mask_image, dino_detection_count, sam_detection_count, sam_detection_on_mask_count |
Oops, something went wrong.