Skip to content
Snippets Groups Projects
Unverified Commit ec04a94e authored by Jan Buethe's avatar Jan Buethe
Browse files

bugfix in SilkFeatureNetPL

parent 5f8201c7
Branches opus-ng-lace-integration5
No related tags found
1 merge request!97OSCE_MAX_RNN_NEURONS fix
Pipeline #5070 passed
......@@ -66,18 +66,17 @@ class SilkFeatureNetPL(nn.Module):
self.conv1 = norm(nn.Conv1d(feature_dim, self.hidden_feature_dim, 1))
self.conv2 = norm(nn.Conv1d(4 * self.hidden_feature_dim, num_channels, 2))
self.tconv = norm(nn.ConvTranspose1d(num_channels, num_channels, 4, 4))
gru_input_dim = num_channels + self.repeat_upsamp_dim if self.repeat_upsamp else num_channels
self.gru = norm(norm(nn.GRU(gru_input_dim, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
self.gru = norm(norm(nn.GRU(num_channels, num_channels, batch_first=True), name='weight_hh_l0'), name='weight_ih_l0')
if softquant:
self.conv2 = soft_quant(self.conv2)
if not self.repeat_upsamp: self.tconv = soft_quant(self.tconv)
self.tconv = soft_quant(self.tconv)
self.gru = soft_quant(self.gru, names=['weight_hh_l0', 'weight_ih_l0'])
if sparsify:
mark_for_sparsification(self.conv2, (sparsification_density[0], [8, 4]))
if not self.repeat_upsamp: mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
mark_for_sparsification(self.tconv, (sparsification_density[1], [8, 4]))
mark_for_sparsification(
self.gru,
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment