-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSampler.lua
More file actions
27 lines (22 loc) · 812 Bytes
/
Sampler.lua
File metadata and controls
27 lines (22 loc) · 812 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
-- Based on JoinTable module
require 'nn'
local Sampler, parent = torch.class('nn.Sampler', 'nn.Module')
function Sampler:__init(dim)
parent.__init(self)
self.dim = dim
end
function Sampler:updateOutput(input)
self.eps = torch.randn(input:size(1), self.dim):type(input:type())
self.output = self.output or self.output.new()
self.output:resizeAs(self.eps):copy(self.eps)
self.output:cmul(torch.expand(input, input:size(1), self.dim))
return self.output
end
function Sampler:updateGradInput(input, gradOutput)
self.gradInput = self.gradInput or input.new()
self.gradInput:resizeAs(input)
local gi = torch.cmul(self.eps, torch.expand(input, input:size(1), self.dim))
gi:mul(0.5):cmul(gradOutput)
self.gradInput:copy(torch.sum(gi, 2))
return self.gradInput
end