Skip to content

Commit

Permalink
Ignoring NN.modules arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
manu12121999 committed Dec 20, 2024
1 parent 474c942 commit 407c012
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions ctrl_c_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,8 @@ class nn:
class Module:

def __init__(self, *args, **kwargs):
if args != () or kwargs != {}:
print('Warning, ignoring', args, kwargs)
self.cache = None

def __call__(self, x: Tensor):
Expand Down Expand Up @@ -694,6 +696,7 @@ def forward(self, x: Tensor):
start = time.time()
self.cache = x
res = x.matmul_T_2d(self.weight) + self.bias
print("Linear took ", time.time() - start)
return res

def backward(self, dout: Tensor):
Expand Down Expand Up @@ -766,14 +769,16 @@ def backward(self, dout: Tensor):
return dout

class Conv2d(Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride=1, padding=0, bias=True):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride=1, padding=0, bias=True, *args, **kwargs):
if args != () or kwargs != {}:
print('Warning, Conv2dTranspose ignoring', args, kwargs)
super().__init__()
self.stride = stride
self.padding = padding
self.kernel_size = kernel_size
self.out_channels = out_channels
self.weight = Tensor.fill(shape=(out_channels, in_channels, kernel_size, kernel_size), number=0.0)
self.bias = Tensor.fill(shape=(out_channels, ), number=0.0 if bias else 0)
self.bias = Tensor.fill(shape=(out_channels, ), number=0.0 if bias else 0.0)

def forward(self, x: Tensor):
return self.forward_gemm(x)
Expand Down Expand Up @@ -832,7 +837,9 @@ def im2col(x_pad):
return res

class Conv2dTranspose(Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride=1, padding=0, bias=True):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride=1, padding=0, bias=True, *args, **kwargs):
if args != () or kwargs != {}:
print('Warning, Conv2dTranspose ignoring', args, kwargs)
super().__init__()
self.stride = stride
self.padding = padding
Expand All @@ -846,14 +853,16 @@ def forward(self, x: Tensor):

class BatchNorm2d(Module):
def __init__(self, num_features, eps=1e-05, *args, **kwargs):
self.weight = Tensor.random_float((num_features,))
self.bias = Tensor.random_float((num_features,))
self.running_mean = Tensor.zeros((num_features,))
self.running_var = Tensor.ones((num_features,))
self.num_batches_tracked = Tensor([0])
if args != () or kwargs != {}:
print('Warning, BatchNorm2d ignoring', args, kwargs)
self.weight = Tensor.fill((num_features,), 0.0)
self.bias = Tensor.fill((num_features,), 0.0)
self.running_mean = Tensor.fill((num_features,), 0.0)
self.running_var = Tensor.fill((num_features,), 1.0)
self.num_batches_tracked = Tensor([0.0])
self.C = num_features
self.eps = eps
super().__init__(*args, **kwargs)
super().__init__()

def forward(self, x: Tensor):
start_time = time.time()
Expand All @@ -869,7 +878,9 @@ def forward(self, x: Tensor):
return y

class MaxPool2d(Module):
def __init__(self, kernel_size=2, stride=2, padding=0):
def __init__(self, kernel_size=2, stride=2, padding=0, *args, **kwargs):
if args != () or kwargs != {}:
print('Warning, MaxPool2d ignoring', args, kwargs)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
Expand Down Expand Up @@ -919,6 +930,15 @@ class Dropout(Module):
def forward(self, x: Tensor):
return x

class Upsample(Module):
def __init__(self, scale_factor, *args, **kwargs):
super().__init__()
print('Warning, ignoring', args, kwargs)
self.scale_factor = scale_factor

def forward(self, x: Tensor):
return F.interpolate(x, scale_factor=self.scale_factor)

class AbstractLoss:
def __init__(self):
self.cache = None
Expand Down

0 comments on commit 407c012

Please sign in to comment.