From f9aee675dcf60970dc0cdfc99cc7ac3b79b54e38 Mon Sep 17 00:00:00 2001
From: Jan Buethe <jbuethe@amazon.de>
Date: Sat, 22 Jul 2023 13:31:22 -0700
Subject: [PATCH] added ShapeNet and ShapeUp48 models

---
 dnn/torch/osce/models/shape_net.py            | 154 ++++++++++++++++++
 dnn/torch/osce/models/shape_up_48.py          | 150 +++++++++++++++++
 dnn/torch/osce/utils/layers/silk_upsampler.py | 138 ++++++++++++++++
 dnn/torch/osce/utils/layers/td_shaper.py      | 129 +++++++++++++++
 4 files changed, 571 insertions(+)
 create mode 100644 dnn/torch/osce/models/shape_net.py
 create mode 100644 dnn/torch/osce/models/shape_up_48.py
 create mode 100644 dnn/torch/osce/utils/layers/silk_upsampler.py
 create mode 100644 dnn/torch/osce/utils/layers/td_shaper.py

diff --git a/dnn/torch/osce/models/shape_net.py b/dnn/torch/osce/models/shape_net.py
new file mode 100644
index 000000000..ebb0a678b
--- /dev/null
+++ b/dnn/torch/osce/models/shape_net.py
@@ -0,0 +1,154 @@
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+import numpy as np
+
+from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
+from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
+from utils.layers.td_shaper import TDShaper
+from utils.complexity import _conv1d_flop_count
+
+from models.nns_base import NNSBase
+from models.silk_feature_net_pl import SilkFeatureNetPL
+from models.silk_feature_net import SilkFeatureNet
+from .scale_embedding import ScaleEmbedding
+
+class ShapeNet(NNSBase):
+    """ Adaptive Noise Re-Shaping """
+    FRAME_SIZE=80
+
+    def __init__(self,
+                 num_features=47,
+                 pitch_embedding_dim=64,
+                 cond_dim=256,
+                 pitch_max=257,
+                 kernel_size=15,
+                 preemph=0.85,
+                 skip=91,
+                 comb_gain_limit_db=-6,
+                 global_gain_limits_db=[-6, 6],
+                 conv_gain_limits_db=[-6, 6],
+                 numbits_range=[50, 650],
+                 numbits_embedding_dim=8,
+                 hidden_feature_dim=64,
+                 partial_lookahead=True,
+                 norm_p=2,
+                 avg_pool_k=4):
+
+        super().__init__(skip=skip, preemph=preemph)
+
+
+        self.num_features           = num_features
+        self.cond_dim               = cond_dim
+        self.pitch_max              = pitch_max
+        self.pitch_embedding_dim    = pitch_embedding_dim
+        self.kernel_size            = kernel_size
+        self.preemph                = preemph
+        self.skip                   = skip
+        self.numbits_range          = numbits_range
+        self.numbits_embedding_dim  = numbits_embedding_dim
+        self.hidden_feature_dim     = hidden_feature_dim
+        self.partial_lookahead      = partial_lookahead
+
+        # pitch embedding
+        self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
+
+        # numbits embedding
+        self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
+
+        # feature net
+        if partial_lookahead:
+            self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
+        else:
+            self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
+
+        # comb filters
+        left_pad = self.kernel_size // 2
+        right_pad = self.kernel_size - 1 - left_pad
+        self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
+        self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p)
+
+        # spectral shaping
+        self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
+
+        # non-linear transforms
+        self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
+        self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
+        self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k)
+
+        # combinators
+        self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
+        self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
+        self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
+
+        # feature transforms
+        self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2)
+        self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2)
+        self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2)
+        self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2)
+        self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2)
+
+
+    def flop_count(self, rate=16000, verbose=False):
+
+        frame_rate = rate / self.FRAME_SIZE
+
+        # feature net
+        feature_net_flops = self.feature_net.flop_count(frame_rate)
+        comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate)
+        af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate)
+        feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate)
+                         + _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate))
+
+        if verbose:
+            print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
+            print(f"comb filters: {comb_flops / 1e6} MFLOPS")
+            print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
+            print(f"feature transforms: {feature_flops / 1e6} MFLOPS")
+
+        return feature_net_flops + comb_flops + af_flops + feature_flops
+
+    def feature_transform(self, f, layer):
+        f = f.permute(0, 2, 1)
+        f = F.pad(f, [1, 0])
+        f = torch.tanh(layer(f))
+        return f.permute(0, 2, 1)
+
+    def forward(self, x, features, periods, numbits, debug=False):
+
+        periods         = periods.squeeze(-1)
+        pitch_embedding = self.pitch_embedding(periods)
+        numbits_embedding = self.numbits_embedding(numbits).flatten(2)
+
+        full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
+        cf = self.feature_net(full_features)
+
+        y = self.cf1(x, cf, periods, debug=debug)
+        cf = self.feature_transform(cf, self.post_cf1)
+
+        y = self.cf2(y, cf, periods, debug=debug)
+        cf = self.feature_transform(cf, self.post_cf2)
+
+        y = self.af1(y, cf, debug=debug)
+        cf = self.feature_transform(cf, self.post_af1)
+
+        y1 = y[:, 0:1, :]
+        y2 = self.tdshape1(y[:, 1:2, :], cf)
+        y = torch.cat((y1, y2), dim=1)
+        y = self.af2(y, cf, debug=debug)
+        cf = self.feature_transform(cf, self.post_af2)
+
+        y1 = y[:, 0:1, :]
+        y2 = self.tdshape2(y[:, 1:2, :], cf)
+        y = torch.cat((y1, y2), dim=1)
+        y = self.af3(y, cf, debug=debug)
+        cf = self.feature_transform(cf, self.post_af3)
+
+        y1 = y[:, 0:1, :]
+        y2 = self.tdshape3(y[:, 1:2, :], cf)
+        y = torch.cat((y1, y2), dim=1)
+        y = self.af4(y, cf, debug=debug)
+
+        return y
\ No newline at end of file
diff --git a/dnn/torch/osce/models/shape_up_48.py b/dnn/torch/osce/models/shape_up_48.py
new file mode 100644
index 000000000..0e11b5800
--- /dev/null
+++ b/dnn/torch/osce/models/shape_up_48.py
@@ -0,0 +1,150 @@
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+import numpy as np
+
+from utils.layers.silk_upsampler import SilkUpsampler
+from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
+from utils.layers.td_shaper import TDShaper
+from utils.layers.deemph import Deemph
+from utils.misc import freeze_model
+
+from models.nns_base import NNSBase
+from models.silk_feature_net_pl import SilkFeatureNetPL
+from models.silk_feature_net import SilkFeatureNet
+from .scale_embedding import ScaleEmbedding
+
+
+
+class ShapeUp48(NNSBase):
+    FRAME_SIZE16k=80
+
+    def __init__(self,
+                 num_features=47,
+                 pitch_embedding_dim=64,
+                 cond_dim=256,
+                 pitch_max=257,
+                 kernel_size=15,
+                 preemph=0.85,
+                 skip=288,
+                 conv_gain_limits_db=[-6, 6],
+                 numbits_range=[50, 650],
+                 numbits_embedding_dim=8,
+                 hidden_feature_dim=64,
+                 partial_lookahead=True,
+                 norm_p=2,
+                 target_fs=48000,
+                 noise_amplitude=0,
+                 prenet=None,
+                 avg_pool_k=4):
+
+        super().__init__(skip=skip, preemph=preemph)
+
+
+        self.num_features           = num_features
+        self.cond_dim               = cond_dim
+        self.pitch_max              = pitch_max
+        self.pitch_embedding_dim    = pitch_embedding_dim
+        self.kernel_size            = kernel_size
+        self.preemph                = preemph
+        self.skip                   = skip
+        self.numbits_range          = numbits_range
+        self.numbits_embedding_dim  = numbits_embedding_dim
+        self.hidden_feature_dim     = hidden_feature_dim
+        self.partial_lookahead      = partial_lookahead
+        self.frame_size48           = int(self.FRAME_SIZE16k * target_fs / 16000 + .1)
+        self.frame_size32           = self.FRAME_SIZE16k * 2
+        self.noise_amplitude        = noise_amplitude
+        self.prenet                 = prenet
+
+        # freeze prenet if given
+        if prenet is not None:
+            freeze_model(self.prenet)
+            try:
+                self.deemph = Deemph(prenet.preemph)
+            except:
+                print("[warning] prenet model is expected to have preemph attribute")
+                self.deemph = Deemph(0)
+
+
+
+        # upsampler
+        self.upsampler = SilkUpsampler()
+
+        # pitch embedding
+        self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim)
+
+        # numbits embedding
+        self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True)
+
+        # feature net
+        if partial_lookahead:
+            self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim)
+        else:
+            self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim)
+
+        # non-linear transforms
+        self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k)
+        self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k)
+
+        # spectral shaping
+        self.af_noise = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=[-30, 0], norm_p=norm_p)
+        self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
+        self.af2 = LimitedAdaptiveConv1d(3, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
+        self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p)
+
+
+    def flop_count(self, rate=16000, verbose=False):
+
+        frame_rate = rate / self.FRAME_SIZE16k
+
+        # feature net
+        feature_net_flops = self.feature_net.flop_count(frame_rate)
+        af_flops = self.af1.flop_count(rate) + self.af2.flop_count(2 * rate) + self.af3.flop_count(3 * rate)
+
+        if verbose:
+            print(f"feature net: {feature_net_flops / 1e6} MFLOPS")
+            print(f"adaptive conv: {af_flops / 1e6} MFLOPS")
+
+        return feature_net_flops + af_flops
+
+    def forward(self, x, features, periods, numbits, debug=False):
+
+        if self.prenet is not None:
+            with torch.no_grad():
+                x = self.prenet(x, features, periods, numbits)
+                x = self.deemph(x)
+
+
+
+        periods         = periods.squeeze(-1)
+        pitch_embedding = self.pitch_embedding(periods)
+        numbits_embedding = self.numbits_embedding(numbits).flatten(2)
+
+        full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1)
+        cf = self.feature_net(full_features)
+
+        y32 = self.upsampler.hq_2x_up(x)
+
+        noise = self.noise_amplitude * torch.randn_like(y32)
+        noise = self.af_noise(noise, cf)
+
+        y32 = self.af1(y32, cf, debug=debug)
+
+        y32_1 = y32[:, 0:1, :]
+        y32_2 = self.tdshape1(y32[:, 1:2, :], cf)
+        y32 = torch.cat((y32_1, y32_2, noise), dim=1)
+
+        y32 = self.af2(y32, cf, debug=debug)
+
+        y48 = self.upsampler.interpolate_3_2(y32)
+
+        y48_1 = y48[:, 0:1, :]
+        y48_2 = self.tdshape2(y48[:, 1:2, :], cf)
+        y48 = torch.cat((y48_1, y48_2), dim=1)
+
+        y48 = self.af3(y48, cf, debug=debug)
+
+        return y48
diff --git a/dnn/torch/osce/utils/layers/silk_upsampler.py b/dnn/torch/osce/utils/layers/silk_upsampler.py
new file mode 100644
index 000000000..d5f396ed2
--- /dev/null
+++ b/dnn/torch/osce/utils/layers/silk_upsampler.py
@@ -0,0 +1,138 @@
+""" This module implements the SILK upsampler from 16kHz to 24 or 48 kHz """
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+import numpy as np
+
+frac_fir = np.array(
+    [
+        [189, -600, 617, 30567, 2996, -1375, 425, -46],
+        [117, -159, -1070, 29704, 5784, -2143, 611, -71],
+        [52, 221, -2392, 28276, 8798, -2865, 773, -91],
+        [-4, 529, -3350, 26341, 11950, -3487, 896, -103],
+        [-48, 758, -3956, 23973, 15143, -3957, 967, -107],
+        [-80, 905, -4235, 21254, 18278, -4222, 972, -99],
+        [-99, 972, -4222, 18278, 21254, -4235, 905, -80],
+        [-107, 967, -3957, 15143, 23973, -3956, 758, -48],
+        [-103, 896, -3487, 11950, 26341, -3350, 529, -4],
+        [-91, 773, -2865, 8798, 28276, -2392, 221, 52],
+        [-71, 611, -2143, 5784, 29704, -1070, -159, 117],
+        [-46, 425, -1375, 2996, 30567, 617, -600, 189]
+    ],
+    dtype=np.float32
+) / 2**15
+
+
+hq_2x_up_c_even = [x / 2**16 for x in [1746, 14986, 39083 - 65536]]
+hq_2x_up_c_odd  = [x / 2**16 for x in [6854, 25769, 55542 - 65536]]
+
+
+def get_impz(coeffs, n):
+    s = 3*[0]
+    y = np.zeros(n)
+    x = 1
+
+    for i in range(n):
+        Y = x - s[0]
+        X = Y * coeffs[0]
+        tmp1 = s[0] + X
+        s[0] = x + X
+
+        Y = tmp1 - s[1]
+        X = Y * coeffs[1]
+        tmp2 = s[1] + X
+        s[1] = tmp1 + X
+
+        Y = tmp2 - s[2]
+        X = Y * (1 + coeffs[2])
+        tmp3 = s[2] + X
+        s[2] = tmp2 + X
+
+        y[i] = tmp3
+        x = 0
+
+    return y
+
+
+
+class SilkUpsampler(nn.Module):
+    SUPPORTED_TARGET_RATES = {24000, 48000}
+    SUPPORTED_SOURCE_RATES = {16000}
+    def __init__(self,
+                 fs_in=16000,
+                 fs_out=48000):
+
+        super().__init__()
+        self.fs_in = fs_in
+        self.fs_out = fs_out
+
+        if fs_in not in self.SUPPORTED_SOURCE_RATES:
+            raise ValueError(f'SilkUpsampler currently only supports upsampling from {self.SUPPORTED_SOURCE_RATES} Hz')
+
+
+        if fs_out not in self.SUPPORTED_TARGET_RATES:
+            raise ValueError(f'SilkUpsampler currently only supports upsampling to {self.SUPPORTED_TARGET_RATES} Hz')
+
+
+        # hq 2x upsampler as FIR approximation
+        hq_2x_up_even = get_impz(hq_2x_up_c_even, 128)[::-1].copy()
+        hq_2x_up_odd  = get_impz(hq_2x_up_c_odd , 128)[::-1].copy()
+
+        self.hq_2x_up_even = nn.Parameter(torch.from_numpy(hq_2x_up_even).float().view(1, 1, -1), requires_grad=False)
+        self.hq_2x_up_odd  = nn.Parameter(torch.from_numpy(hq_2x_up_odd ).float().view(1, 1, -1), requires_grad=False)
+        self.hq_2x_up_padding = [127, 0]
+
+        # interpolation filters
+        frac_01_24 = frac_fir[0]
+        frac_17_24 = frac_fir[8]
+        frac_09_24 = frac_fir[4]
+
+        self.frac_01_24 = nn.Parameter(torch.from_numpy(frac_01_24).view(1, 1, -1), requires_grad=False)
+        self.frac_17_24 = nn.Parameter(torch.from_numpy(frac_17_24).view(1, 1, -1), requires_grad=False)
+        self.frac_09_24 = nn.Parameter(torch.from_numpy(frac_09_24).view(1, 1, -1), requires_grad=False)
+
+        self.stride = 1 if fs_out == 48000 else 2
+
+    def hq_2x_up(self, x):
+
+        num_channels = x.size(1)
+
+        weight_even = torch.repeat_interleave(self.hq_2x_up_even, num_channels, 0)
+        weight_odd  = torch.repeat_interleave(self.hq_2x_up_odd , num_channels, 0)
+
+        x_pad  = F.pad(x, self.hq_2x_up_padding)
+        y_even = F.conv1d(x_pad, weight_even, groups=num_channels)
+        y_odd  = F.conv1d(x_pad, weight_odd , groups=num_channels)
+
+        y = torch.cat((y_even.unsqueeze(-1), y_odd.unsqueeze(-1)), dim=-1).flatten(2)
+
+        return y
+
+    def interpolate_3_2(self, x):
+
+        num_channels = x.size(1)
+
+        weight_01_24 = torch.repeat_interleave(self.frac_01_24, num_channels, 0)
+        weight_17_24 = torch.repeat_interleave(self.frac_17_24, num_channels, 0)
+        weight_09_24 = torch.repeat_interleave(self.frac_09_24, num_channels, 0)
+
+        x_pad = F.pad(x, [8, 0])
+        y_01_24     = F.conv1d(x_pad, weight_01_24, stride=2, groups=num_channels)
+        y_17_24     = F.conv1d(x_pad, weight_17_24, stride=2, groups=num_channels)
+        y_09_24_sh1 = F.conv1d(torch.roll(x_pad, -1, -1), weight_09_24, stride=2, groups=num_channels)
+
+
+        y = torch.cat(
+            (y_01_24.unsqueeze(-1), y_17_24.unsqueeze(-1), y_09_24_sh1.unsqueeze(-1)),
+            dim=-1).flatten(2)
+
+        return y[..., :-3]
+
+    def forward(self, x):
+
+        y_2x = self.hq_2x_up(x)
+        y_3x = self.interpolate_3_2(y_2x)
+
+        return y_3x[:, :, ::self.stride]
diff --git a/dnn/torch/osce/utils/layers/td_shaper.py b/dnn/torch/osce/utils/layers/td_shaper.py
new file mode 100644
index 000000000..2ab12bad6
--- /dev/null
+++ b/dnn/torch/osce/utils/layers/td_shaper.py
@@ -0,0 +1,129 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from utils.complexity import _conv1d_flop_count
+
+class TDShaper(nn.Module):
+    COUNTER = 1
+
+    def __init__(self,
+                 feature_dim,
+                 frame_size=160,
+                 avg_pool_k=4,
+                 innovate=False
+    ):
+        """
+
+        Parameters:
+        -----------
+
+
+        feature_dim : int
+            dimension of input features
+
+        frame_size : int
+            frame size
+
+        avg_pool_k : int, optional
+            kernel size and stride for avg pooling
+
+        padding : List[int, int]
+
+        """
+
+        super().__init__()
+
+
+        self.feature_dim    = feature_dim
+        self.frame_size     = frame_size
+        self.avg_pool_k     = avg_pool_k
+        self.innovate       = innovate
+
+        assert frame_size % avg_pool_k == 0
+        self.env_dim = frame_size // avg_pool_k + 1
+
+        # feature transform
+        self.feature_alpha1 = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
+        self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2)
+
+        if self.innovate:
+            self.feature_alpha1b = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
+            self.feature_alpha1c = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2)
+
+            self.feature_alpha2b = nn.Conv1d(frame_size, frame_size, 2)
+            self.feature_alpha2c = nn.Conv1d(frame_size, frame_size, 2)
+
+
+    def flop_count(self, rate):
+
+        frame_rate = rate / self.frame_size
+
+        shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size
+
+        if self.innovate:
+            inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size
+        else:
+            inno_flops = 0
+
+        return shape_flops + inno_flops
+
+    def envelope_transform(self, x):
+
+        x = torch.abs(x)
+        x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k)
+        x = torch.log(x + .5**16)
+
+        x = x.reshape(x.size(0), -1, self.env_dim - 1)
+        avg_x = torch.mean(x, -1, keepdim=True)
+
+        x = torch.cat((x - avg_x, avg_x), dim=-1)
+
+        return x
+
+    def forward(self, x, features, debug=False):
+        """ innovate signal parts with temporal shaping
+
+
+        Parameters:
+        -----------
+        x : torch.tensor
+            input signal of shape (batch_size, 1, num_samples)
+
+        features : torch.tensor
+            frame-wise features of shape (batch_size, num_frames, feature_dim)
+
+        """
+
+        batch_size = x.size(0)
+        num_frames = features.size(1)
+        num_samples = x.size(2)
+        frame_size = self.frame_size
+
+        # generate temporal envelope
+        tenv = self.envelope_transform(x)
+
+        # feature path
+        f = torch.cat((features, tenv), dim=-1)
+        f = F.pad(f.permute(0, 2, 1), [1, 0])
+        alpha = F.leaky_relu(self.feature_alpha1(f), 0.2)
+        alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0])))
+        alpha = alpha.permute(0, 2, 1)
+
+        if self.innovate:
+            inno_alpha = F.leaky_relu(self.feature_alpha1b(f), 0.2)
+            inno_alpha = torch.exp(self.feature_alpha2b(F.pad(inno_alpha, [1, 0])))
+            inno_alpha = inno_alpha.permute(0, 2, 1)
+
+            inno_x = F.leaky_relu(self.feature_alpha1c(f), 0.2)
+            inno_x = torch.tanh(self.feature_alpha2c(F.pad(inno_x, [1, 0])))
+            inno_x = inno_x.permute(0, 2, 1)
+
+        # signal path
+        y = x.reshape(batch_size, num_frames, -1)
+        y = alpha * y
+
+        if self.innovate:
+            y = y + inno_alpha * inno_x
+
+        return y.reshape(batch_size, 1, num_samples)
-- 
GitLab