From 0dc559f060db0d62d95f424e3fd26a5f673b2f6b Mon Sep 17 00:00:00 2001 From: Jan Buethe <jbuethe@amazon.de> Date: Wed, 24 Apr 2024 12:17:51 +0200 Subject: [PATCH] added some bwe-related stuff --- dnn/torch/osce/losses/td_lowpass.py | 34 +++++++++++++++++++++++++++++ dnn/torch/osce/silk_16_to_48.py | 28 ++++++++++++++++++++++++ dnn/torch/osce/utils/layers/fir.py | 27 +++++++++++++++++++++++ 3 files changed, 89 insertions(+) create mode 100644 dnn/torch/osce/losses/td_lowpass.py create mode 100644 dnn/torch/osce/silk_16_to_48.py create mode 100644 dnn/torch/osce/utils/layers/fir.py diff --git a/dnn/torch/osce/losses/td_lowpass.py b/dnn/torch/osce/losses/td_lowpass.py new file mode 100644 index 000000000..af422fb55 --- /dev/null +++ b/dnn/torch/osce/losses/td_lowpass.py @@ -0,0 +1,34 @@ +import torch +import scipy.signal + + +from utils.layers.fir import FIR + +class TDLowpass(torch.nn.Module): + def __init__(self, numtaps, cutoff, power=2): + super().__init__() + + self.b = scipy.signal.firwin(numtaps, cutoff) + self.weight = torch.from_numpy(self.b).float().view(1, 1, -1) + self.power = power + + def forward(self, y_true, y_pred): + + assert len(y_true.shape) == 3 and len(y_pred.shape) == 3 + + diff = y_true - y_pred + diff_lp = torch.nn.functional.conv1d(diff, self.weight) + + loss = torch.mean(torch.abs(diff_lp ** self.power)) + + return loss, diff_lp + + def get_freqz(self): + freq, response = scipy.signal.freqz(self.b) + + return freq, response + + + + + \ No newline at end of file diff --git a/dnn/torch/osce/silk_16_to_48.py b/dnn/torch/osce/silk_16_to_48.py new file mode 100644 index 000000000..e59b6cc84 --- /dev/null +++ b/dnn/torch/osce/silk_16_to_48.py @@ -0,0 +1,28 @@ +import argparse + +from scipy.io import wavfile +import torch +import numpy as np + +from utils.layers.silk_upsampler import SilkUpsampler + +parser = argparse.ArgumentParser() +parser.add_argument("input", type=str, help="input wave file") +parser.add_argument("output", type=str, help="output wave file") + +if __name__ == "__main__": + args = parser.parse_args() + + fs, x = wavfile.read(args.input) + + # being lazy for now + assert fs == 16000 and x.dtype == np.int16 + + x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1) + + upsampler = SilkUpsampler() + y = upsampler(x) + + y = y.squeeze().numpy().astype(np.int16) + + wavfile.write(args.output, 48000, y[13:]) \ No newline at end of file diff --git a/dnn/torch/osce/utils/layers/fir.py b/dnn/torch/osce/utils/layers/fir.py new file mode 100644 index 000000000..7eeb3e4e5 --- /dev/null +++ b/dnn/torch/osce/utils/layers/fir.py @@ -0,0 +1,27 @@ +import numpy as np +import scipy.signal +import torch +from torch import nn +import torch.nn.functional as F + + +class FIR(nn.Module): + def __init__(self, numtaps, bands, desired, fs=2): + super().__init__() + + if numtaps % 2 == 0: + print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}") + numtaps += 1 + + a = scipy.signal.firls(numtaps, bands, desired, fs=fs) + + self.weight = torch.from_numpy(a.astype(np.float32)) + + def forward(self, x): + num_channels = x.size(1) + + weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0) + + y = F.conv1d(x, weight, groups=num_channels) + + return y \ No newline at end of file -- GitLab