From faea0eb0a606ec29230f68c8c18b8fec0c222159 Mon Sep 17 00:00:00 2001 From: Pegasi Assocation <128257630+BlackBoyZeus@users.noreply.github.com> Date: Wed, 4 Sep 2024 12:50:00 -0700 Subject: [PATCH] Update mask_decoder.py --- samkeras (1)/Sam/mask_decoder.py | 150 +++++++++++++++++++------------ 1 file changed, 91 insertions(+), 59 deletions(-) diff --git a/samkeras (1)/Sam/mask_decoder.py b/samkeras (1)/Sam/mask_decoder.py index 25ef2252..ea84a3e2 100644 --- a/samkeras (1)/Sam/mask_decoder.py +++ b/samkeras (1)/Sam/mask_decoder.py @@ -1,29 +1,44 @@ # /modeling/sam/mask_decoder.py +# sam2_tfkeras/modeling/sam/mask_decoder.py from typing import List, Optional, Tuple, Type import tensorflow as tf from tensorflow.keras import layers -from sam2.modeling.sam2_utils import LayerNorm2d, MLP +from sam2_tfkeras.modeling.sam2_utils import LayerNorm2d, MLP +from sam2_tfkeras.modeling.sam.transformer import TwoWayTransformer +from ncps.tf import CfC class MaskDecoder(layers.Layer): + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + This implementation incorporates a CfC (Closed-form Continuous-time) layer + to enhance the model's ability to capture temporal dynamics. + """ + def __init__( self, - transformer_dim: int, - transformer: layers.Layer, # TensorFlow transformer layer - num_multimask_outputs: int = 3, - activation: Type[layers.Layer] = layers.Activation('gelu'), # Keras Activation - iou_head_depth: int = 3, - iou_head_hidden_dim: int = 256, - use_high_res_features: bool = False, - iou_prediction_use_sigmoid=False, - dynamic_multimask_via_stability=False, - dynamic_multimask_stability_delta=0.05, - dynamic_multimask_stability_thresh=0.98, - pred_obj_scores: bool = False, - pred_obj_scores_mlp: bool = False, - use_multimask_token_for_obj_ptr: bool = False, + *, + transformer_dim: int, # Channel dimension of the transformer + transformer: layers.Layer, # TensorFlow transformer layer + num_multimask_outputs: int = 3, # Number of masks to predict for disambiguation + activation: Type[layers.Layer] = layers.Activation('gelu'), # Activation function + iou_head_depth: int = 3, # Depth of the MLP for mask quality prediction + iou_head_hidden_dim: int = 256, # Hidden dimension of the MLP for mask quality + use_high_res_features: bool = False, # Whether to use high-res features + iou_prediction_use_sigmoid=False, # Whether to use sigmoid for IoU prediction + dynamic_multimask_via_stability=False, # Dynamic multimask selection + dynamic_multimask_stability_delta=0.05, # Delta for stability score + dynamic_multimask_stability_thresh=0.98, # Threshold for stability score + pred_obj_scores: bool = False, # Whether to predict object scores + pred_obj_scores_mlp: bool = False, # Whether to use MLP for object score prediction + use_multimask_token_for_obj_ptr: bool = False, # Use multimask token for object pointer + cfc_units: int = 128, # Number of units in the CfC layer + mixed_memory: bool = True, # Use mixed memory in CfC + cfc_mode: str = "default", # Mode for the CfC layer ) -> None: super(MaskDecoder, self).__init__() self.transformer_dim = transformer_dim @@ -31,16 +46,18 @@ def __init__( self.num_multimask_outputs = num_multimask_outputs + # Embeddings for IoU and mask tokens self.iou_token = layers.Embedding(1, transformer_dim) self.num_mask_tokens = num_multimask_outputs + 1 self.mask_tokens = layers.Embedding(self.num_mask_tokens, transformer_dim) + # Object score prediction self.pred_obj_scores = pred_obj_scores if self.pred_obj_scores: self.obj_score_token = layers.Embedding(1, transformer_dim) self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr - # Upsampling + # --- Upsampling Layers --- self.output_upscaling = tf.keras.Sequential([ layers.Conv2DTranspose( filters=transformer_dim // 4, @@ -49,26 +66,29 @@ def __init__( padding='same' ), LayerNorm2d(transformer_dim // 4), - activation, + activation, layers.Conv2DTranspose( filters=transformer_dim // 8, kernel_size=2, strides=2, padding='same' ), - activation, + activation, ]) + # High-resolution features self.use_high_res_features = use_high_res_features if use_high_res_features: self.conv_s0 = layers.Conv2D(transformer_dim // 8, kernel_size=1, strides=1, padding='same') self.conv_s1 = layers.Conv2D(transformer_dim // 4, kernel_size=1, strides=1, padding='same') + # Output hypernetworks self.output_hypernetworks_mlps = [ MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens) ] + # IoU prediction head self.iou_prediction_head = MLP( transformer_dim, iou_head_hidden_dim, @@ -77,29 +97,42 @@ def __init__( sigmoid_output=iou_prediction_use_sigmoid, ) + # Object score prediction head if self.pred_obj_scores: self.pred_obj_score_head = layers.Dense(1) if pred_obj_scores_mlp: self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + # Dynamic multimask selection parameters self.dynamic_multimask_via_stability = dynamic_multimask_via_stability self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + # --- CfC Layer Initialization --- + self.cfc_layer = CfC(units=cfc_units, mixed_memory=mixed_memory, mode=cfc_mode) + def call( self, - image_embeddings: tf.Tensor, - image_pe: tf.Tensor, - sparse_prompt_embeddings: tf.Tensor, - dense_prompt_embeddings: tf.Tensor, - multimask_output: bool, - repeat_image: bool, - high_res_features: Optional[List[tf.Tensor]] = None, + image_embeddings: tf.Tensor, # Image embeddings from the encoder + image_pe: tf.Tensor, # Positional encodings for the image + sparse_prompt_embeddings: tf.Tensor, # Sparse prompt embeddings + dense_prompt_embeddings: tf.Tensor, # Dense prompt embeddings + multimask_output: bool, # Whether to output multiple masks + repeat_image: bool, # Whether to repeat the image embeddings + high_res_features: Optional[List[tf.Tensor]] = None, # High-resolution features (optional) training=False, ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: """ Predict masks given image and prompt embeddings. + + Returns: + masks: Predicted masks + iou_pred: Predicted IoU scores for the masks + sam_tokens_out: SAM tokens for the mask output + object_score_logits: Logits for object scores (if enabled) """ + + # Predict masks using helper function masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( image_embeddings=image_embeddings, image_pe=image_pe, @@ -110,31 +143,34 @@ def call( training=training ) + # Select the appropriate masks based on output settings if multimask_output: masks = masks[:, 1:, :, :] iou_pred = iou_pred[:, 1:] - elif self.dynamic_multimask_via_stability and not training: + elif self.dynamic_multimask_via_stability and not training: masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) else: masks = masks[:, 0:1, :, :] iou_pred = iou_pred[:, 0:1] + # Select SAM output tokens based on configuration if multimask_output and self.use_multimask_token_for_obj_ptr: - sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + sam_tokens_out = mask_tokens_out[:, 1:] else: - sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + sam_tokens_out = mask_tokens_out[:, 0:1] + # Return the predicted masks, IoU scores, SAM tokens, and object score logits return masks, iou_pred, sam_tokens_out, object_score_logits def predict_masks( self, - image_embeddings: tf.Tensor, - image_pe: tf.Tensor, - sparse_prompt_embeddings: tf.Tensor, - dense_prompt_embeddings: tf.Tensor, - repeat_image: bool, - high_res_features: Optional[List[tf.Tensor]] = None, - training=False, + image_embeddings: tf.Tensor, + image_pe: tf.Tensor, + sparse_prompt_embeddings: tf.Tensor, + dense_prompt_embeddings: tf.Tensor, + repeat_image: bool, + high_res_features: Optional[List[tf.Tensor]] = None, + training=False, ) -> Tuple[tf.Tensor, tf.Tensor]: """Predicts masks. See 'call' for more details.""" # Concatenate output tokens @@ -156,7 +192,7 @@ def predict_masks( output_tokens = tf.tile(tf.expand_dims(output_tokens, axis=0), [tf.shape(sparse_prompt_embeddings)[0], 1, 1]) tokens = tf.concat((output_tokens, sparse_prompt_embeddings), axis=1) - # Expand per-image data in batch direction to be per-mask + # Expand image embeddings if needed if repeat_image: src = tf.repeat(image_embeddings, repeats=tf.shape(tokens)[0], axis=0) else: @@ -164,44 +200,53 @@ def predict_masks( src = image_embeddings src = src + dense_prompt_embeddings - tf.debugging.assert_equal(tf.shape(image_pe)[0], 1, message="image_pe should have size 1 in batch dim (from `get_dense_pe()`)") + # Repeat image positional encodings + tf.debugging.assert_equal(tf.shape(image_pe)[0], 1, message="image_pe should have size 1 in batch dim") pos_src = tf.repeat(image_pe, repeats=tf.shape(tokens)[0], axis=0) - b, c, h, w = tf.shape(src) - # Run the transformer + # Get shape of source tensor + b, c, h, w = tf.shape(src) + + # --- Run the Transformer --- hs, src = self.transformer(src, pos_src, tokens, training=training) + + # --- Apply CfC Layer --- + hs = self.cfc_layer(hs, training=training) + + # Extract IoU token and mask tokens iou_token_out = hs[:, s, :] mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] - # Upscale mask embeddings and predict masks using the mask tokens - src = tf.reshape(tf.transpose(src, perm=[0, 2, 1]), (b, c, h, w)) + # Upscale mask embeddings + src = tf.reshape(tf.transpose(src, perm=[0, 2, 1]), (b, c, h, w)) if not self.use_high_res_features: upscaled_embedding = self.output_upscaling(src) else: - # Assuming self.output_upscaling is a Sequential model + # Apply upscaling with high-resolution features dc1, ln1, act1, dc2, act2 = self.output_upscaling.layers feat_s0, feat_s1 = high_res_features upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + # Predict masks using hypernetworks hyper_in_list = [] for i in range(self.num_mask_tokens): - hyper_in_list.append( - self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) - ) + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) hyper_in = tf.stack(hyper_in_list, axis=1) b, c, h, w = tf.shape(upscaled_embedding) masks = tf.reshape(tf.linalg.matmul(hyper_in, tf.reshape(upscaled_embedding, (b, c, h * w))), (b, -1, h, w)) # Generate mask quality predictions iou_pred = self.iou_prediction_head(iou_token_out) + + # Generate object score logits (if enabled) if self.pred_obj_scores: assert s == 1 object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) else: - # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 object_score_logits = 10.0 * tf.ones_like(iou_pred[:, :1]) + # Return predicted masks, IoU predictions, mask tokens, and object score logits return masks, iou_pred, mask_tokens_out, object_score_logits def _get_stability_scores(self, mask_logits): @@ -255,16 +300,3 @@ def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): return mask_logits_out, iou_scores_out - ''' - Explanation and Adaptations: - -TensorFlow Transformer: The transformer argument in the constructor is now expected to be a TensorFlow/Keras implementation of the transformer (refer to modeling/sam/transformer.py for this implementation). -MLP Class: The MLP class is used from sam2_utils.py (You'll need to implement this separately). -Upsampling: Uses layers.Conv2DTranspose for upsampling operations. -Keras Activation: Utilizes layers.Activation for applying activation functions. -TensorFlow Equivalents: Employs TensorFlow equivalents like tf.concat, tf.tile, tf.expand_dims, tf.repeat, tf.reshape, tf.transpose, tf.math.reduce_sum, tf.cast, tf.where, tf.gather_nd, tf.broadcast_to, etc. -Training Argument: The call and predict_masks methods include the training argument for controlling behaviors during training. -Challenges: - -Transformer Implementation: You will need to have a working TensorFlow/Keras implementation of the TwoWayTransformer in the modeling/sam/transformer.py file. -Complex Number Operations: The RoPE attention in the transformer might require special handling for complex numbers in TensorFlow.''' \ No newline at end of file