Skip to content

Commit

Permalink
Start implementing interpolate
Browse files Browse the repository at this point in the history
  • Loading branch information
manu12121999 authored Dec 17, 2024
1 parent 8ab8ad2 commit 474c942
Showing 1 changed file with 67 additions and 3 deletions.
70 changes: 67 additions & 3 deletions ctrl_c_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,9 @@ def f_cat(iterable: list, dim: int) -> list:
class Tensor:
# Wrapper to use linalg operations on lists (of lists) (e.g. matmuls) in a nicer way

def __init__(self, elems: (list, int)):
def __init__(self, elems: (list, int, float)):
if not isinstance(elems, (list, int, float)): # e.g. numpy array or torch tensor
elems = elems.tolist()
self.elems = elems

# calculate number of dimensions
Expand Down Expand Up @@ -829,6 +831,19 @@ def im2col(x_pad):
print("Conv2d took in total", time.time() - start_time, " of which Matmul took", end_mat - start_mat)
return res

class Conv2dTranspose(Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride=1, padding=0, bias=True):
super().__init__()
self.stride = stride
self.padding = padding
self.kernel_size = kernel_size
self.out_channels = out_channels
self.weight = Tensor.fill(shape=(out_channels, in_channels, kernel_size, kernel_size), number=0.0)
self.bias = Tensor.fill(shape=(out_channels, ), number=0.0 if bias else 0)

def forward(self, x: Tensor):
raise NotImplementedError

class BatchNorm2d(Module):
def __init__(self, num_features, eps=1e-05, *args, **kwargs):
self.weight = Tensor.random_float((num_features,))
Expand Down Expand Up @@ -936,6 +951,55 @@ def backward(self, dout: Tensor):

# missing: weight_init, ConvTranspose2d, AvgPool2d, Softmax, InstanceNorm2d, LayerNorm2d, Losses


class F:
@staticmethod
def interpolate(input, size=None, scale_factor=None, mode='nearest'):
B, C, H, W = input.shape
assert mode in ['nearest', 'bilinear']

def get_new_size_and_scale_factors(size, scale_factor):
if size is None:
if isinstance(scale_factor, tuple):
scale_factor_h, scale_factor_w = scale_factor[-2:]
else: # scale_factor is float
scale_factor_h = scale_factor
scale_factor_w = scale_factor
new_H = int(math.floor(H * scale_factor_h))
new_W = int(math.floor(W * scale_factor_w))
else:
if isinstance(size, int):
size = (size, size)
new_H, new_W = size
scale_factor_h = new_H / H
scale_factor_w = new_W / W
return new_H, new_W, scale_factor_h, scale_factor_w

new_H, new_W, scale_factor_h, scale_factor_w = get_new_size_and_scale_factors(size, scale_factor)
output_tensor = Tensor.fill((B, C, new_H, new_W), 0.0)
if mode == 'nearest':
for new_h in range(new_H):
for new_w in range(new_W):
old_h = int(new_h // scale_factor_h)
old_w = int(new_w // scale_factor_w)
output_tensor[:, :, new_h, new_w] = input[:, :, old_h, old_w]
elif mode == 'bilinear':
for new_h in range(new_H):
for new_w in range(new_W):
old_h_low = int(new_h // scale_factor_h)
old_w_low = int(new_w // scale_factor_w)
old_h_high = min(old_h_low + 1, H-1)
old_w_high = min(old_w_low + 1, W-1)
frac_h = math.fmod(new_h / scale_factor_h, 1)
frac_w = math.fmod(new_w / scale_factor_w, 1)
p_top_left = (1 - frac_h) * (1 - frac_w) * (input[:, :, old_h_low, old_w_low])
p_top_right = (1 - frac_h) * frac_w * (input[:, :, old_h_low, old_w_high])
p_bottom_left = frac_h * (1 - frac_w) * (input[:, :, old_h_high, old_w_low])
p_bottom_right = frac_h * frac_w * (input[:, :, old_h_high, old_w_high])
output_tensor[:, :, new_h, new_w] = p_top_left + p_top_right + p_bottom_left + p_bottom_right
return output_tensor


class PthUnpickler(pickle.Unpickler):
def __init__(self, picklefile, zipfile, name):
self.zipfile = zipfile
Expand Down Expand Up @@ -1081,7 +1145,7 @@ def png_decompress(data_bytes, width, height, n_channels):
return lines

@staticmethod
def read_png(path, resize=None, dimorder="HWC", num_channels=3, to_float=True, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)):
def read_png(path, resize=None, dimorder="HWC", num_channels=3, to_float=False, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)):
with open(path, "br") as f:
image_bytes = f.read()
data = b''
Expand All @@ -1107,7 +1171,7 @@ def read_png(path, resize=None, dimorder="HWC", num_channels=3, to_float=True, m

if chunk_type == b'IEND':
start = None
lines = ImageIO.png_decompress(data, width, height, color_type_bytes)
lines = ImageIO.png_decompress(data, width, height, color_type_bytes)
t = Tensor(lines)

if t.shape[2] > num_channels:
Expand Down

0 comments on commit 474c942

Please sign in to comment.