-
Notifications
You must be signed in to change notification settings - Fork 16
/
Cumsum.lua
executable file
·56 lines (40 loc) · 1.34 KB
/
Cumsum.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
local Cumsum, parent = torch.class('nn.Cumsum','nn.Module')
--[[
The input is always b*2*h*w. the first channel is y axis and the second one is x axis
]]
function Cumsum:__init()
parent.__init(self)
end
function Cumsum:updateOutput(input)
local yxz=input:split(1,2)
self.output:resizeAs(input)
for i = 1,input:size(2) do
self.output[{{},i}]:cumsum(yxz[i]:squeeze(), i+1) --sum along each dimension
end
return self.output
end
function Cumsum:updateGradInput(input, gradOutput)
if not gradOutput:isContiguous() then
self._gradOutput = self._gradOutput or gradOutput.new()
self._gradOutput:resizeAs(gradOutput):copy(gradOutput)
gradOutput = self._gradOutput
end
dim = {}
for i = 1,input:nDimension()-2 do
dim[i] = input:size(i+2)
end
local yxGrad = gradOutput:split(1,2)
self.gradInput:resizeAs(input)
for i = 1,input:size(2) do
local Grad = yxGrad[i]:squeeze()
local GradFlip = Grad:index(i+1,torch.linspace(dim[i],1,dim[i]):long())
local GradFlipCum = GradFlip:cumsum(i+1)
local GradCum = GradFlipCum:index(i+1,torch.linspace(dim[i],1,dim[i]):long())
self.gradInput[{{},i}]:copy(GradCum)
end
return self.gradInput
end
function Cumsum:clearState()
nn.utils.clear(self, '_gradOutput')
return parent.clearState(self)
end