From f9aee675dcf60970dc0cdfc99cc7ac3b79b54e38 Mon Sep 17 00:00:00 2001 From: Jan Buethe <jbuethe@amazon.de> Date: Sat, 22 Jul 2023 13:31:22 -0700 Subject: [PATCH] added ShapeNet and ShapeUp48 models --- dnn/torch/osce/models/shape_net.py | 154 ++++++++++++++++++ dnn/torch/osce/models/shape_up_48.py | 150 +++++++++++++++++ dnn/torch/osce/utils/layers/silk_upsampler.py | 138 ++++++++++++++++ dnn/torch/osce/utils/layers/td_shaper.py | 129 +++++++++++++++ 4 files changed, 571 insertions(+) create mode 100644 dnn/torch/osce/models/shape_net.py create mode 100644 dnn/torch/osce/models/shape_up_48.py create mode 100644 dnn/torch/osce/utils/layers/silk_upsampler.py create mode 100644 dnn/torch/osce/utils/layers/td_shaper.py diff --git a/dnn/torch/osce/models/shape_net.py b/dnn/torch/osce/models/shape_net.py new file mode 100644 index 000000000..ebb0a678b --- /dev/null +++ b/dnn/torch/osce/models/shape_net.py @@ -0,0 +1,154 @@ + +import torch +from torch import nn +import torch.nn.functional as F + +import numpy as np + +from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d +from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d +from utils.layers.td_shaper import TDShaper +from utils.complexity import _conv1d_flop_count + +from models.nns_base import NNSBase +from models.silk_feature_net_pl import SilkFeatureNetPL +from models.silk_feature_net import SilkFeatureNet +from .scale_embedding import ScaleEmbedding + +class ShapeNet(NNSBase): + """ Adaptive Noise Re-Shaping """ + FRAME_SIZE=80 + + def __init__(self, + num_features=47, + pitch_embedding_dim=64, + cond_dim=256, + pitch_max=257, + kernel_size=15, + preemph=0.85, + skip=91, + comb_gain_limit_db=-6, + global_gain_limits_db=[-6, 6], + conv_gain_limits_db=[-6, 6], + numbits_range=[50, 650], + numbits_embedding_dim=8, + hidden_feature_dim=64, + partial_lookahead=True, + norm_p=2, + avg_pool_k=4): + + super().__init__(skip=skip, preemph=preemph) + + + self.num_features = num_features + self.cond_dim = cond_dim + self.pitch_max = pitch_max + self.pitch_embedding_dim = pitch_embedding_dim + self.kernel_size = kernel_size + self.preemph = preemph + self.skip = skip + self.numbits_range = numbits_range + self.numbits_embedding_dim = numbits_embedding_dim + self.hidden_feature_dim = hidden_feature_dim + self.partial_lookahead = partial_lookahead + + # pitch embedding + self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim) + + # numbits embedding + self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True) + + # feature net + if partial_lookahead: + self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim) + else: + self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) + + # comb filters + left_pad = self.kernel_size // 2 + right_pad = self.kernel_size - 1 - left_pad + self.cf1 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) + self.cf2 = LimitedAdaptiveComb1d(self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, overlap_size=40, use_bias=False, padding=[left_pad, right_pad], max_lag=pitch_max + 1, gain_limit_db=comb_gain_limit_db, global_gain_limits_db=global_gain_limits_db, norm_p=norm_p) + + # spectral shaping + self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + + # non-linear transforms + self.tdshape1 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k) + self.tdshape2 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k) + self.tdshape3 = TDShaper(cond_dim, frame_size=self.FRAME_SIZE, avg_pool_k=avg_pool_k) + + # combinators + self.af2 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af3 = LimitedAdaptiveConv1d(2, 2, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af4 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.FRAME_SIZE, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + + # feature transforms + self.post_cf1 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_cf2 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_af1 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_af2 = nn.Conv1d(cond_dim, cond_dim, 2) + self.post_af3 = nn.Conv1d(cond_dim, cond_dim, 2) + + + def flop_count(self, rate=16000, verbose=False): + + frame_rate = rate / self.FRAME_SIZE + + # feature net + feature_net_flops = self.feature_net.flop_count(frame_rate) + comb_flops = self.cf1.flop_count(rate) + self.cf2.flop_count(rate) + af_flops = self.af1.flop_count(rate) + self.af2.flop_count(rate) + self.af3.flop_count(rate) + self.af4.flop_count(rate) + feature_flops = (_conv1d_flop_count(self.post_cf1, frame_rate) + _conv1d_flop_count(self.post_cf2, frame_rate) + + _conv1d_flop_count(self.post_af1, frame_rate) + _conv1d_flop_count(self.post_af2, frame_rate) + _conv1d_flop_count(self.post_af3, frame_rate)) + + if verbose: + print(f"feature net: {feature_net_flops / 1e6} MFLOPS") + print(f"comb filters: {comb_flops / 1e6} MFLOPS") + print(f"adaptive conv: {af_flops / 1e6} MFLOPS") + print(f"feature transforms: {feature_flops / 1e6} MFLOPS") + + return feature_net_flops + comb_flops + af_flops + feature_flops + + def feature_transform(self, f, layer): + f = f.permute(0, 2, 1) + f = F.pad(f, [1, 0]) + f = torch.tanh(layer(f)) + return f.permute(0, 2, 1) + + def forward(self, x, features, periods, numbits, debug=False): + + periods = periods.squeeze(-1) + pitch_embedding = self.pitch_embedding(periods) + numbits_embedding = self.numbits_embedding(numbits).flatten(2) + + full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1) + cf = self.feature_net(full_features) + + y = self.cf1(x, cf, periods, debug=debug) + cf = self.feature_transform(cf, self.post_cf1) + + y = self.cf2(y, cf, periods, debug=debug) + cf = self.feature_transform(cf, self.post_cf2) + + y = self.af1(y, cf, debug=debug) + cf = self.feature_transform(cf, self.post_af1) + + y1 = y[:, 0:1, :] + y2 = self.tdshape1(y[:, 1:2, :], cf) + y = torch.cat((y1, y2), dim=1) + y = self.af2(y, cf, debug=debug) + cf = self.feature_transform(cf, self.post_af2) + + y1 = y[:, 0:1, :] + y2 = self.tdshape2(y[:, 1:2, :], cf) + y = torch.cat((y1, y2), dim=1) + y = self.af3(y, cf, debug=debug) + cf = self.feature_transform(cf, self.post_af3) + + y1 = y[:, 0:1, :] + y2 = self.tdshape3(y[:, 1:2, :], cf) + y = torch.cat((y1, y2), dim=1) + y = self.af4(y, cf, debug=debug) + + return y \ No newline at end of file diff --git a/dnn/torch/osce/models/shape_up_48.py b/dnn/torch/osce/models/shape_up_48.py new file mode 100644 index 000000000..0e11b5800 --- /dev/null +++ b/dnn/torch/osce/models/shape_up_48.py @@ -0,0 +1,150 @@ + +import torch +from torch import nn +import torch.nn.functional as F + +import numpy as np + +from utils.layers.silk_upsampler import SilkUpsampler +from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d +from utils.layers.td_shaper import TDShaper +from utils.layers.deemph import Deemph +from utils.misc import freeze_model + +from models.nns_base import NNSBase +from models.silk_feature_net_pl import SilkFeatureNetPL +from models.silk_feature_net import SilkFeatureNet +from .scale_embedding import ScaleEmbedding + + + +class ShapeUp48(NNSBase): + FRAME_SIZE16k=80 + + def __init__(self, + num_features=47, + pitch_embedding_dim=64, + cond_dim=256, + pitch_max=257, + kernel_size=15, + preemph=0.85, + skip=288, + conv_gain_limits_db=[-6, 6], + numbits_range=[50, 650], + numbits_embedding_dim=8, + hidden_feature_dim=64, + partial_lookahead=True, + norm_p=2, + target_fs=48000, + noise_amplitude=0, + prenet=None, + avg_pool_k=4): + + super().__init__(skip=skip, preemph=preemph) + + + self.num_features = num_features + self.cond_dim = cond_dim + self.pitch_max = pitch_max + self.pitch_embedding_dim = pitch_embedding_dim + self.kernel_size = kernel_size + self.preemph = preemph + self.skip = skip + self.numbits_range = numbits_range + self.numbits_embedding_dim = numbits_embedding_dim + self.hidden_feature_dim = hidden_feature_dim + self.partial_lookahead = partial_lookahead + self.frame_size48 = int(self.FRAME_SIZE16k * target_fs / 16000 + .1) + self.frame_size32 = self.FRAME_SIZE16k * 2 + self.noise_amplitude = noise_amplitude + self.prenet = prenet + + # freeze prenet if given + if prenet is not None: + freeze_model(self.prenet) + try: + self.deemph = Deemph(prenet.preemph) + except: + print("[warning] prenet model is expected to have preemph attribute") + self.deemph = Deemph(0) + + + + # upsampler + self.upsampler = SilkUpsampler() + + # pitch embedding + self.pitch_embedding = nn.Embedding(pitch_max + 1, pitch_embedding_dim) + + # numbits embedding + self.numbits_embedding = ScaleEmbedding(numbits_embedding_dim, *numbits_range, logscale=True) + + # feature net + if partial_lookahead: + self.feature_net = SilkFeatureNetPL(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim, hidden_feature_dim) + else: + self.feature_net = SilkFeatureNet(num_features + pitch_embedding_dim + 2 * numbits_embedding_dim, cond_dim) + + # non-linear transforms + self.tdshape1 = TDShaper(cond_dim, frame_size=self.frame_size32, avg_pool_k=avg_pool_k) + self.tdshape2 = TDShaper(cond_dim, frame_size=self.frame_size48, avg_pool_k=avg_pool_k) + + # spectral shaping + self.af_noise = LimitedAdaptiveConv1d(1, 1, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=[-30, 0], norm_p=norm_p) + self.af1 = LimitedAdaptiveConv1d(1, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af2 = LimitedAdaptiveConv1d(3, 2, self.kernel_size, cond_dim, frame_size=self.frame_size32, overlap_size=self.frame_size32//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + self.af3 = LimitedAdaptiveConv1d(2, 1, self.kernel_size, cond_dim, frame_size=self.frame_size48, overlap_size=self.frame_size48//2, use_bias=False, padding=[self.kernel_size - 1, 0], gain_limits_db=conv_gain_limits_db, norm_p=norm_p) + + + def flop_count(self, rate=16000, verbose=False): + + frame_rate = rate / self.FRAME_SIZE16k + + # feature net + feature_net_flops = self.feature_net.flop_count(frame_rate) + af_flops = self.af1.flop_count(rate) + self.af2.flop_count(2 * rate) + self.af3.flop_count(3 * rate) + + if verbose: + print(f"feature net: {feature_net_flops / 1e6} MFLOPS") + print(f"adaptive conv: {af_flops / 1e6} MFLOPS") + + return feature_net_flops + af_flops + + def forward(self, x, features, periods, numbits, debug=False): + + if self.prenet is not None: + with torch.no_grad(): + x = self.prenet(x, features, periods, numbits) + x = self.deemph(x) + + + + periods = periods.squeeze(-1) + pitch_embedding = self.pitch_embedding(periods) + numbits_embedding = self.numbits_embedding(numbits).flatten(2) + + full_features = torch.cat((features, pitch_embedding, numbits_embedding), dim=-1) + cf = self.feature_net(full_features) + + y32 = self.upsampler.hq_2x_up(x) + + noise = self.noise_amplitude * torch.randn_like(y32) + noise = self.af_noise(noise, cf) + + y32 = self.af1(y32, cf, debug=debug) + + y32_1 = y32[:, 0:1, :] + y32_2 = self.tdshape1(y32[:, 1:2, :], cf) + y32 = torch.cat((y32_1, y32_2, noise), dim=1) + + y32 = self.af2(y32, cf, debug=debug) + + y48 = self.upsampler.interpolate_3_2(y32) + + y48_1 = y48[:, 0:1, :] + y48_2 = self.tdshape2(y48[:, 1:2, :], cf) + y48 = torch.cat((y48_1, y48_2), dim=1) + + y48 = self.af3(y48, cf, debug=debug) + + return y48 diff --git a/dnn/torch/osce/utils/layers/silk_upsampler.py b/dnn/torch/osce/utils/layers/silk_upsampler.py new file mode 100644 index 000000000..d5f396ed2 --- /dev/null +++ b/dnn/torch/osce/utils/layers/silk_upsampler.py @@ -0,0 +1,138 @@ +""" This module implements the SILK upsampler from 16kHz to 24 or 48 kHz """ + +import torch +from torch import nn +import torch.nn.functional as F + +import numpy as np + +frac_fir = np.array( + [ + [189, -600, 617, 30567, 2996, -1375, 425, -46], + [117, -159, -1070, 29704, 5784, -2143, 611, -71], + [52, 221, -2392, 28276, 8798, -2865, 773, -91], + [-4, 529, -3350, 26341, 11950, -3487, 896, -103], + [-48, 758, -3956, 23973, 15143, -3957, 967, -107], + [-80, 905, -4235, 21254, 18278, -4222, 972, -99], + [-99, 972, -4222, 18278, 21254, -4235, 905, -80], + [-107, 967, -3957, 15143, 23973, -3956, 758, -48], + [-103, 896, -3487, 11950, 26341, -3350, 529, -4], + [-91, 773, -2865, 8798, 28276, -2392, 221, 52], + [-71, 611, -2143, 5784, 29704, -1070, -159, 117], + [-46, 425, -1375, 2996, 30567, 617, -600, 189] + ], + dtype=np.float32 +) / 2**15 + + +hq_2x_up_c_even = [x / 2**16 for x in [1746, 14986, 39083 - 65536]] +hq_2x_up_c_odd = [x / 2**16 for x in [6854, 25769, 55542 - 65536]] + + +def get_impz(coeffs, n): + s = 3*[0] + y = np.zeros(n) + x = 1 + + for i in range(n): + Y = x - s[0] + X = Y * coeffs[0] + tmp1 = s[0] + X + s[0] = x + X + + Y = tmp1 - s[1] + X = Y * coeffs[1] + tmp2 = s[1] + X + s[1] = tmp1 + X + + Y = tmp2 - s[2] + X = Y * (1 + coeffs[2]) + tmp3 = s[2] + X + s[2] = tmp2 + X + + y[i] = tmp3 + x = 0 + + return y + + + +class SilkUpsampler(nn.Module): + SUPPORTED_TARGET_RATES = {24000, 48000} + SUPPORTED_SOURCE_RATES = {16000} + def __init__(self, + fs_in=16000, + fs_out=48000): + + super().__init__() + self.fs_in = fs_in + self.fs_out = fs_out + + if fs_in not in self.SUPPORTED_SOURCE_RATES: + raise ValueError(f'SilkUpsampler currently only supports upsampling from {self.SUPPORTED_SOURCE_RATES} Hz') + + + if fs_out not in self.SUPPORTED_TARGET_RATES: + raise ValueError(f'SilkUpsampler currently only supports upsampling to {self.SUPPORTED_TARGET_RATES} Hz') + + + # hq 2x upsampler as FIR approximation + hq_2x_up_even = get_impz(hq_2x_up_c_even, 128)[::-1].copy() + hq_2x_up_odd = get_impz(hq_2x_up_c_odd , 128)[::-1].copy() + + self.hq_2x_up_even = nn.Parameter(torch.from_numpy(hq_2x_up_even).float().view(1, 1, -1), requires_grad=False) + self.hq_2x_up_odd = nn.Parameter(torch.from_numpy(hq_2x_up_odd ).float().view(1, 1, -1), requires_grad=False) + self.hq_2x_up_padding = [127, 0] + + # interpolation filters + frac_01_24 = frac_fir[0] + frac_17_24 = frac_fir[8] + frac_09_24 = frac_fir[4] + + self.frac_01_24 = nn.Parameter(torch.from_numpy(frac_01_24).view(1, 1, -1), requires_grad=False) + self.frac_17_24 = nn.Parameter(torch.from_numpy(frac_17_24).view(1, 1, -1), requires_grad=False) + self.frac_09_24 = nn.Parameter(torch.from_numpy(frac_09_24).view(1, 1, -1), requires_grad=False) + + self.stride = 1 if fs_out == 48000 else 2 + + def hq_2x_up(self, x): + + num_channels = x.size(1) + + weight_even = torch.repeat_interleave(self.hq_2x_up_even, num_channels, 0) + weight_odd = torch.repeat_interleave(self.hq_2x_up_odd , num_channels, 0) + + x_pad = F.pad(x, self.hq_2x_up_padding) + y_even = F.conv1d(x_pad, weight_even, groups=num_channels) + y_odd = F.conv1d(x_pad, weight_odd , groups=num_channels) + + y = torch.cat((y_even.unsqueeze(-1), y_odd.unsqueeze(-1)), dim=-1).flatten(2) + + return y + + def interpolate_3_2(self, x): + + num_channels = x.size(1) + + weight_01_24 = torch.repeat_interleave(self.frac_01_24, num_channels, 0) + weight_17_24 = torch.repeat_interleave(self.frac_17_24, num_channels, 0) + weight_09_24 = torch.repeat_interleave(self.frac_09_24, num_channels, 0) + + x_pad = F.pad(x, [8, 0]) + y_01_24 = F.conv1d(x_pad, weight_01_24, stride=2, groups=num_channels) + y_17_24 = F.conv1d(x_pad, weight_17_24, stride=2, groups=num_channels) + y_09_24_sh1 = F.conv1d(torch.roll(x_pad, -1, -1), weight_09_24, stride=2, groups=num_channels) + + + y = torch.cat( + (y_01_24.unsqueeze(-1), y_17_24.unsqueeze(-1), y_09_24_sh1.unsqueeze(-1)), + dim=-1).flatten(2) + + return y[..., :-3] + + def forward(self, x): + + y_2x = self.hq_2x_up(x) + y_3x = self.interpolate_3_2(y_2x) + + return y_3x[:, :, ::self.stride] diff --git a/dnn/torch/osce/utils/layers/td_shaper.py b/dnn/torch/osce/utils/layers/td_shaper.py new file mode 100644 index 000000000..2ab12bad6 --- /dev/null +++ b/dnn/torch/osce/utils/layers/td_shaper.py @@ -0,0 +1,129 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from utils.complexity import _conv1d_flop_count + +class TDShaper(nn.Module): + COUNTER = 1 + + def __init__(self, + feature_dim, + frame_size=160, + avg_pool_k=4, + innovate=False + ): + """ + + Parameters: + ----------- + + + feature_dim : int + dimension of input features + + frame_size : int + frame size + + avg_pool_k : int, optional + kernel size and stride for avg pooling + + padding : List[int, int] + + """ + + super().__init__() + + + self.feature_dim = feature_dim + self.frame_size = frame_size + self.avg_pool_k = avg_pool_k + self.innovate = innovate + + assert frame_size % avg_pool_k == 0 + self.env_dim = frame_size // avg_pool_k + 1 + + # feature transform + self.feature_alpha1 = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) + self.feature_alpha2 = nn.Conv1d(frame_size, frame_size, 2) + + if self.innovate: + self.feature_alpha1b = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) + self.feature_alpha1c = nn.Conv1d(self.feature_dim + self.env_dim, frame_size, 2) + + self.feature_alpha2b = nn.Conv1d(frame_size, frame_size, 2) + self.feature_alpha2c = nn.Conv1d(frame_size, frame_size, 2) + + + def flop_count(self, rate): + + frame_rate = rate / self.frame_size + + shape_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1, self.feature_alpha2)]) + 11 * frame_rate * self.frame_size + + if self.innovate: + inno_flops = sum([_conv1d_flop_count(x, frame_rate) for x in (self.feature_alpha1b, self.feature_alpha2b, self.feature_alpha1c, self.feature_alpha2c)]) + 22 * frame_rate * self.frame_size + else: + inno_flops = 0 + + return shape_flops + inno_flops + + def envelope_transform(self, x): + + x = torch.abs(x) + x = F.avg_pool1d(x, self.avg_pool_k, self.avg_pool_k) + x = torch.log(x + .5**16) + + x = x.reshape(x.size(0), -1, self.env_dim - 1) + avg_x = torch.mean(x, -1, keepdim=True) + + x = torch.cat((x - avg_x, avg_x), dim=-1) + + return x + + def forward(self, x, features, debug=False): + """ innovate signal parts with temporal shaping + + + Parameters: + ----------- + x : torch.tensor + input signal of shape (batch_size, 1, num_samples) + + features : torch.tensor + frame-wise features of shape (batch_size, num_frames, feature_dim) + + """ + + batch_size = x.size(0) + num_frames = features.size(1) + num_samples = x.size(2) + frame_size = self.frame_size + + # generate temporal envelope + tenv = self.envelope_transform(x) + + # feature path + f = torch.cat((features, tenv), dim=-1) + f = F.pad(f.permute(0, 2, 1), [1, 0]) + alpha = F.leaky_relu(self.feature_alpha1(f), 0.2) + alpha = torch.exp(self.feature_alpha2(F.pad(alpha, [1, 0]))) + alpha = alpha.permute(0, 2, 1) + + if self.innovate: + inno_alpha = F.leaky_relu(self.feature_alpha1b(f), 0.2) + inno_alpha = torch.exp(self.feature_alpha2b(F.pad(inno_alpha, [1, 0]))) + inno_alpha = inno_alpha.permute(0, 2, 1) + + inno_x = F.leaky_relu(self.feature_alpha1c(f), 0.2) + inno_x = torch.tanh(self.feature_alpha2c(F.pad(inno_x, [1, 0]))) + inno_x = inno_x.permute(0, 2, 1) + + # signal path + y = x.reshape(batch_size, num_frames, -1) + y = alpha * y + + if self.innovate: + y = y + inno_alpha * inno_x + + return y.reshape(batch_size, 1, num_samples) -- GitLab