From 00580a63aa4d0f796bd9b11c7514fce7c09dcd94 Mon Sep 17 00:00:00 2001
From: Jan Buethe <jbuethe@amazon.de>
Date: Fri, 22 Sep 2023 11:39:22 +0200
Subject: [PATCH] bugfix

---
 dnn/torch/osce/models/lavoce.py     | 5 +++--
 dnn/torch/osce/models/lavoce_400.py | 2 +-
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/dnn/torch/osce/models/lavoce.py b/dnn/torch/osce/models/lavoce.py
index 47b6d1e9a..fcfdc8bfa 100644
--- a/dnn/torch/osce/models/lavoce.py
+++ b/dnn/torch/osce/models/lavoce.py
@@ -89,6 +89,7 @@ class LaVoce(nn.Module):
         self.kernel_size            = kernel_size
         self.preemph                = preemph
         self.pulses                 = pulses
+        self.ftrans_k               = ftrans_k
 
         assert self.FEATURE_FRAME_SIZE % self.FRAME_SIZE == 0
         self.upsamp_factor =  self.FEATURE_FRAME_SIZE // self.FRAME_SIZE
@@ -145,7 +146,7 @@ class LaVoce(nn.Module):
             f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
 
             if self.pulses:
-                alpha = torch.cos(f)
+                alpha = torch.cos(f).view(batch_size, 1, 1)
                 chunk_sin = torch.sin(f  * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
                 pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
                 pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
@@ -186,7 +187,7 @@ class LaVoce(nn.Module):
 
     def feature_transform(self, f, layer):
         f = f.permute(0, 2, 1)
-        f = F.pad(f, [1, 0])
+        f = F.pad(f, [self.ftrans_k - 1, 0])
         f = torch.tanh(layer(f))
         return f.permute(0, 2, 1)
 
diff --git a/dnn/torch/osce/models/lavoce_400.py b/dnn/torch/osce/models/lavoce_400.py
index e9a543cff..fe8263beb 100644
--- a/dnn/torch/osce/models/lavoce_400.py
+++ b/dnn/torch/osce/models/lavoce_400.py
@@ -130,7 +130,7 @@ class LaVoce400(nn.Module):
             f = (2.0 * torch.pi / periods[:, sframe]).unsqueeze(-1)
 
             if self.pulses:
-                alpha = torch.cos(f)
+                alpha = torch.cos(f).view(batch_size, 1, 1)
                 chunk_sin = torch.sin(f  * progression + phase0).view(batch_size, 1, self.FRAME_SIZE)
                 pulse_a = torch.relu(chunk_sin - alpha) / (1 - alpha)
                 pulse_b = torch.relu(-chunk_sin - alpha) / (1 - alpha)
-- 
GitLab