Skip to content

Commit

Permalink
vit_small_ds to backend-agnostic in Keras 3
Browse files Browse the repository at this point in the history
  • Loading branch information
SuryanarayanaY committed Dec 15, 2023
1 parent e60722d commit a751ebe
Showing 1 changed file with 11 additions and 18 deletions.
29 changes: 11 additions & 18 deletions examples/vision/vit_small_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,9 @@
"""
## Setup
"""
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import math
import numpy as np
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
from keras import layers
Expand All @@ -55,11 +51,6 @@
SEED = 42
keras.utils.set_random_seed(SEED)

# TF imports required for this tutorial
from tensorflow import image as tf_image
from tensorflow import range as tf_range
from tensorflow import constant as tf_constant


"""
## Prepare the data
Expand Down Expand Up @@ -199,13 +190,15 @@ def crop_shift_pad(self, images, mode):
shift_width = self.half_patch

# Crop the shifted images and pad them
crop = tf_image.crop_to_bounding_box(
images,
offset_height=crop_height,
offset_width=crop_width,
target_height=self.image_size - self.half_patch,
target_width=self.image_size - self.half_patch,
)
target_height = self.image_size - self.half_patch
target_width = self.image_size - self.half_patch
crop = images[
:,
crop_height : crop_height + target_height,
crop_width : crop_width + target_width,
:,
]

shift_pad = ops.image.pad_images(
crop,
top_padding=shift_height,
Expand Down Expand Up @@ -311,7 +304,7 @@ def __init__(
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
self.positions = tf_range(start=0, limit=self.num_patches, delta=1)
self.positions = ops.arange(start=0, stop=self.num_patches, step=1)

def call(self, encoded_patches):
encoded_positions = self.position_embedding(self.positions)
Expand Down Expand Up @@ -460,7 +453,7 @@ def __init__(
self.total_steps = total_steps
self.warmup_learning_rate = warmup_learning_rate
self.warmup_steps = warmup_steps
self.pi = tf_constant(np.pi)
self.pi = ops.array(np.pi)

def __call__(self, step):
if self.total_steps < self.warmup_steps:
Expand Down

0 comments on commit a751ebe

Please sign in to comment.