Skip to content
Snippets Groups Projects
Unverified Commit 7df2c67b authored by Jan Buethe's avatar Jan Buethe
Browse files

fixes in osce python code

parent 3499d0aa
No related branches found
No related tags found
No related merge requests found
......@@ -177,10 +177,7 @@ class NoLACE(NNSBase):
def feature_transform(self, f, layer):
f0 = f.permute(0, 2, 1)
f = F.pad(f0, [1, 0])
if self.residual_in_feature_transform:
f = torch.tanh(layer(f) + f0)
else:
f = torch.tanh(layer(f))
f = torch.tanh(layer(f))
return f.permute(0, 2, 1)
def forward(self, x, features, periods, numbits, debug=False):
......
......@@ -92,7 +92,7 @@ class SilkFeatureNetPL(nn.Module):
def flop_count(self, rate=200):
count = 0
for conv in [self.conv1, self.conv2] if self.repeat_upsamp else [self.conv1, self.conv2, self.tconv]:
for conv in self.conv1, self.conv2, self.tconv:
count += _conv1d_flop_count(conv, rate)
count += 2 * (3 * self.gru.input_size * self.gru.hidden_size + 3 * self.gru.hidden_size * self.gru.hidden_size) * rate
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment