From 0dc559f060db0d62d95f424e3fd26a5f673b2f6b Mon Sep 17 00:00:00 2001
From: Jan Buethe <jbuethe@amazon.de>
Date: Wed, 24 Apr 2024 12:17:51 +0200
Subject: [PATCH] added some bwe-related stuff

---
 dnn/torch/osce/losses/td_lowpass.py | 34 +++++++++++++++++++++++++++++
 dnn/torch/osce/silk_16_to_48.py     | 28 ++++++++++++++++++++++++
 dnn/torch/osce/utils/layers/fir.py  | 27 +++++++++++++++++++++++
 3 files changed, 89 insertions(+)
 create mode 100644 dnn/torch/osce/losses/td_lowpass.py
 create mode 100644 dnn/torch/osce/silk_16_to_48.py
 create mode 100644 dnn/torch/osce/utils/layers/fir.py

diff --git a/dnn/torch/osce/losses/td_lowpass.py b/dnn/torch/osce/losses/td_lowpass.py
new file mode 100644
index 000000000..af422fb55
--- /dev/null
+++ b/dnn/torch/osce/losses/td_lowpass.py
@@ -0,0 +1,34 @@
+import torch
+import scipy.signal
+
+
+from utils.layers.fir import FIR
+
+class TDLowpass(torch.nn.Module):
+    def __init__(self, numtaps, cutoff, power=2):
+        super().__init__()
+        
+        self.b = scipy.signal.firwin(numtaps, cutoff)
+        self.weight = torch.from_numpy(self.b).float().view(1, 1, -1)
+        self.power = power
+        
+    def forward(self, y_true, y_pred):
+        
+        assert len(y_true.shape) == 3 and len(y_pred.shape) == 3
+        
+        diff = y_true - y_pred
+        diff_lp = torch.nn.functional.conv1d(diff, self.weight)
+        
+        loss = torch.mean(torch.abs(diff_lp ** self.power))
+        
+        return loss, diff_lp
+    
+    def get_freqz(self):
+        freq, response = scipy.signal.freqz(self.b)
+        
+        return freq, response
+        
+        
+        
+        
+        
\ No newline at end of file
diff --git a/dnn/torch/osce/silk_16_to_48.py b/dnn/torch/osce/silk_16_to_48.py
new file mode 100644
index 000000000..e59b6cc84
--- /dev/null
+++ b/dnn/torch/osce/silk_16_to_48.py
@@ -0,0 +1,28 @@
+import argparse
+
+from scipy.io import wavfile
+import torch
+import numpy as np
+
+from utils.layers.silk_upsampler import SilkUpsampler
+
+parser = argparse.ArgumentParser()
+parser.add_argument("input", type=str, help="input wave file")
+parser.add_argument("output", type=str, help="output wave file")
+
+if __name__ == "__main__":
+    args = parser.parse_args()
+    
+    fs, x = wavfile.read(args.input)
+
+    # being lazy for now
+    assert fs == 16000 and x.dtype == np.int16
+    
+    x = torch.from_numpy(x.astype(np.float32)).view(1, 1, -1)
+    
+    upsampler = SilkUpsampler()
+    y = upsampler(x)
+    
+    y = y.squeeze().numpy().astype(np.int16)
+    
+    wavfile.write(args.output, 48000, y[13:])
\ No newline at end of file
diff --git a/dnn/torch/osce/utils/layers/fir.py b/dnn/torch/osce/utils/layers/fir.py
new file mode 100644
index 000000000..7eeb3e4e5
--- /dev/null
+++ b/dnn/torch/osce/utils/layers/fir.py
@@ -0,0 +1,27 @@
+import numpy as np
+import scipy.signal
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class FIR(nn.Module):
+    def __init__(self, numtaps, bands, desired, fs=2):
+        super().__init__()
+        
+        if numtaps % 2 == 0:
+            print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
+            numtaps += 1
+        
+        a = scipy.signal.firls(numtaps, bands, desired, fs=fs)
+        
+        self.weight = torch.from_numpy(a.astype(np.float32))
+        
+    def forward(self, x):
+        num_channels = x.size(1)
+        
+        weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)
+        
+        y = F.conv1d(x, weight, groups=num_channels)
+        
+        return y
\ No newline at end of file
-- 
GitLab