Skip to content

Commit

Permalink
Update mask_decoder.py
Browse files Browse the repository at this point in the history
  • Loading branch information
BlackBoyZeus committed Sep 4, 2024
1 parent f18d54c commit faea0eb
Showing 1 changed file with 91 additions and 59 deletions.
150 changes: 91 additions & 59 deletions samkeras (1)/Sam/mask_decoder.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,63 @@
# /modeling/sam/mask_decoder.py
# sam2_tfkeras/modeling/sam/mask_decoder.py

from typing import List, Optional, Tuple, Type

import tensorflow as tf
from tensorflow.keras import layers

from sam2.modeling.sam2_utils import LayerNorm2d, MLP
from sam2_tfkeras.modeling.sam2_utils import LayerNorm2d, MLP
from sam2_tfkeras.modeling.sam.transformer import TwoWayTransformer
from ncps.tf import CfC

class MaskDecoder(layers.Layer):
"""
Predicts masks given an image and prompt embeddings, using a
transformer architecture.
This implementation incorporates a CfC (Closed-form Continuous-time) layer
to enhance the model's ability to capture temporal dynamics.
"""

def __init__(
self,
transformer_dim: int,
transformer: layers.Layer, # TensorFlow transformer layer
num_multimask_outputs: int = 3,
activation: Type[layers.Layer] = layers.Activation('gelu'), # Keras Activation
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
use_high_res_features: bool = False,
iou_prediction_use_sigmoid=False,
dynamic_multimask_via_stability=False,
dynamic_multimask_stability_delta=0.05,
dynamic_multimask_stability_thresh=0.98,
pred_obj_scores: bool = False,
pred_obj_scores_mlp: bool = False,
use_multimask_token_for_obj_ptr: bool = False,
*,
transformer_dim: int, # Channel dimension of the transformer
transformer: layers.Layer, # TensorFlow transformer layer
num_multimask_outputs: int = 3, # Number of masks to predict for disambiguation
activation: Type[layers.Layer] = layers.Activation('gelu'), # Activation function
iou_head_depth: int = 3, # Depth of the MLP for mask quality prediction
iou_head_hidden_dim: int = 256, # Hidden dimension of the MLP for mask quality
use_high_res_features: bool = False, # Whether to use high-res features
iou_prediction_use_sigmoid=False, # Whether to use sigmoid for IoU prediction
dynamic_multimask_via_stability=False, # Dynamic multimask selection
dynamic_multimask_stability_delta=0.05, # Delta for stability score
dynamic_multimask_stability_thresh=0.98, # Threshold for stability score
pred_obj_scores: bool = False, # Whether to predict object scores
pred_obj_scores_mlp: bool = False, # Whether to use MLP for object score prediction
use_multimask_token_for_obj_ptr: bool = False, # Use multimask token for object pointer
cfc_units: int = 128, # Number of units in the CfC layer
mixed_memory: bool = True, # Use mixed memory in CfC
cfc_mode: str = "default", # Mode for the CfC layer
) -> None:
super(MaskDecoder, self).__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer

self.num_multimask_outputs = num_multimask_outputs

# Embeddings for IoU and mask tokens
self.iou_token = layers.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = layers.Embedding(self.num_mask_tokens, transformer_dim)

# Object score prediction
self.pred_obj_scores = pred_obj_scores
if self.pred_obj_scores:
self.obj_score_token = layers.Embedding(1, transformer_dim)
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr

# Upsampling
# --- Upsampling Layers ---
self.output_upscaling = tf.keras.Sequential([
layers.Conv2DTranspose(
filters=transformer_dim // 4,
Expand All @@ -49,26 +66,29 @@ def __init__(
padding='same'
),
LayerNorm2d(transformer_dim // 4),
activation,
activation,
layers.Conv2DTranspose(
filters=transformer_dim // 8,
kernel_size=2,
strides=2,
padding='same'
),
activation,
activation,
])

# High-resolution features
self.use_high_res_features = use_high_res_features
if use_high_res_features:
self.conv_s0 = layers.Conv2D(transformer_dim // 8, kernel_size=1, strides=1, padding='same')
self.conv_s1 = layers.Conv2D(transformer_dim // 4, kernel_size=1, strides=1, padding='same')

# Output hypernetworks
self.output_hypernetworks_mlps = [
MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
for _ in range(self.num_mask_tokens)
]

# IoU prediction head
self.iou_prediction_head = MLP(
transformer_dim,
iou_head_hidden_dim,
Expand All @@ -77,29 +97,42 @@ def __init__(
sigmoid_output=iou_prediction_use_sigmoid,
)

# Object score prediction head
if self.pred_obj_scores:
self.pred_obj_score_head = layers.Dense(1)
if pred_obj_scores_mlp:
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)

# Dynamic multimask selection parameters
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh

# --- CfC Layer Initialization ---
self.cfc_layer = CfC(units=cfc_units, mixed_memory=mixed_memory, mode=cfc_mode)

def call(
self,
image_embeddings: tf.Tensor,
image_pe: tf.Tensor,
sparse_prompt_embeddings: tf.Tensor,
dense_prompt_embeddings: tf.Tensor,
multimask_output: bool,
repeat_image: bool,
high_res_features: Optional[List[tf.Tensor]] = None,
image_embeddings: tf.Tensor, # Image embeddings from the encoder
image_pe: tf.Tensor, # Positional encodings for the image
sparse_prompt_embeddings: tf.Tensor, # Sparse prompt embeddings
dense_prompt_embeddings: tf.Tensor, # Dense prompt embeddings
multimask_output: bool, # Whether to output multiple masks
repeat_image: bool, # Whether to repeat the image embeddings
high_res_features: Optional[List[tf.Tensor]] = None, # High-resolution features (optional)
training=False,
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
"""
Predict masks given image and prompt embeddings.
Returns:
masks: Predicted masks
iou_pred: Predicted IoU scores for the masks
sam_tokens_out: SAM tokens for the mask output
object_score_logits: Logits for object scores (if enabled)
"""

# Predict masks using helper function
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
Expand All @@ -110,31 +143,34 @@ def call(
training=training
)

# Select the appropriate masks based on output settings
if multimask_output:
masks = masks[:, 1:, :, :]
iou_pred = iou_pred[:, 1:]
elif self.dynamic_multimask_via_stability and not training:
elif self.dynamic_multimask_via_stability and not training:
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
else:
masks = masks[:, 0:1, :, :]
iou_pred = iou_pred[:, 0:1]

# Select SAM output tokens based on configuration
if multimask_output and self.use_multimask_token_for_obj_ptr:
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
sam_tokens_out = mask_tokens_out[:, 1:]
else:
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
sam_tokens_out = mask_tokens_out[:, 0:1]

# Return the predicted masks, IoU scores, SAM tokens, and object score logits
return masks, iou_pred, sam_tokens_out, object_score_logits

def predict_masks(
self,
image_embeddings: tf.Tensor,
image_pe: tf.Tensor,
sparse_prompt_embeddings: tf.Tensor,
dense_prompt_embeddings: tf.Tensor,
repeat_image: bool,
high_res_features: Optional[List[tf.Tensor]] = None,
training=False,
image_embeddings: tf.Tensor,
image_pe: tf.Tensor,
sparse_prompt_embeddings: tf.Tensor,
dense_prompt_embeddings: tf.Tensor,
repeat_image: bool,
high_res_features: Optional[List[tf.Tensor]] = None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor]:
"""Predicts masks. See 'call' for more details."""
# Concatenate output tokens
Expand All @@ -156,52 +192,61 @@ def predict_masks(
output_tokens = tf.tile(tf.expand_dims(output_tokens, axis=0), [tf.shape(sparse_prompt_embeddings)[0], 1, 1])
tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=1)

# Expand per-image data in batch direction to be per-mask
# Expand image embeddings if needed
if repeat_image:
src = tf.repeat(image_embeddings, repeats=tf.shape(tokens)[0], axis=0)
else:
tf.debugging.assert_equal(tf.shape(image_embeddings)[0], tf.shape(tokens)[0])
src = image_embeddings
src = src + dense_prompt_embeddings

tf.debugging.assert_equal(tf.shape(image_pe)[0], 1, message="image_pe should have size 1 in batch dim (from `get_dense_pe()`)")
# Repeat image positional encodings
tf.debugging.assert_equal(tf.shape(image_pe)[0], 1, message="image_pe should have size 1 in batch dim")
pos_src = tf.repeat(image_pe, repeats=tf.shape(tokens)[0], axis=0)
b, c, h, w = tf.shape(src)

# Run the transformer
# Get shape of source tensor
b, c, h, w = tf.shape(src)

# --- Run the Transformer ---
hs, src = self.transformer(src, pos_src, tokens, training=training)

# --- Apply CfC Layer ---
hs = self.cfc_layer(hs, training=training)

# Extract IoU token and mask tokens
iou_token_out = hs[:, s, :]
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]

# Upscale mask embeddings and predict masks using the mask tokens
src = tf.reshape(tf.transpose(src, perm=[0, 2, 1]), (b, c, h, w))
# Upscale mask embeddings
src = tf.reshape(tf.transpose(src, perm=[0, 2, 1]), (b, c, h, w))
if not self.use_high_res_features:
upscaled_embedding = self.output_upscaling(src)
else:
# Assuming self.output_upscaling is a Sequential model
# Apply upscaling with high-resolution features
dc1, ln1, act1, dc2, act2 = self.output_upscaling.layers
feat_s0, feat_s1 = high_res_features
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)

# Predict masks using hypernetworks
hyper_in_list = []
for i in range(self.num_mask_tokens):
hyper_in_list.append(
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
)
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = tf.stack(hyper_in_list, axis=1)
b, c, h, w = tf.shape(upscaled_embedding)
masks = tf.reshape(tf.linalg.matmul(hyper_in, tf.reshape(upscaled_embedding, (b, c, h * w))), (b, -1, h, w))

# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)

# Generate object score logits (if enabled)
if self.pred_obj_scores:
assert s == 1
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
else:
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
object_score_logits = 10.0 * tf.ones_like(iou_pred[:, :1])

# Return predicted masks, IoU predictions, mask tokens, and object score logits
return masks, iou_pred, mask_tokens_out, object_score_logits

def _get_stability_scores(self, mask_logits):
Expand Down Expand Up @@ -255,16 +300,3 @@ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
return mask_logits_out, iou_scores_out


'''
Explanation and Adaptations:
TensorFlow Transformer: The transformer argument in the constructor is now expected to be a TensorFlow/Keras implementation of the transformer (refer to modeling/sam/transformer.py for this implementation).
MLP Class: The MLP class is used from sam2_utils.py (You'll need to implement this separately).
Upsampling: Uses layers.Conv2DTranspose for upsampling operations.
Keras Activation: Utilizes layers.Activation for applying activation functions.
TensorFlow Equivalents: Employs TensorFlow equivalents like tf.concat, tf.tile, tf.expand_dims, tf.repeat, tf.reshape, tf.transpose, tf.math.reduce_sum, tf.cast, tf.where, tf.gather_nd, tf.broadcast_to, etc.
Training Argument: The call and predict_masks methods include the training argument for controlling behaviors during training.
Challenges:
Transformer Implementation: You will need to have a working TensorFlow/Keras implementation of the TwoWayTransformer in the modeling/sam/transformer.py file.
Complex Number Operations: The RoPE attention in the transformer might require special handling for complex numbers in TensorFlow.'''

0 comments on commit faea0eb

Please sign in to comment.