From c76756e18a8a04bdcbbb4462770f423587725b26 Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@amazon.com>
Date: Sun, 18 Jul 2021 02:24:21 -0400
Subject: [PATCH] Adding sparse training for GRU B inputs

---
 dnn/training_tf2/lpcnet.py | 51 ++++++++++++++++++++++++++++++++++++++
 1 file changed, 51 insertions(+)

diff --git a/dnn/training_tf2/lpcnet.py b/dnn/training_tf2/lpcnet.py
index 368f2750d..e4346c3ec 100644
--- a/dnn/training_tf2/lpcnet.py
+++ b/dnn/training_tf2/lpcnet.py
@@ -116,6 +116,57 @@ class Sparsify(Callback):
                 #print(thresh, np.mean(mask))
             w[1] = p
             layer.set_weights(w)
+
+class SparsifyGRUB(Callback):
+    def __init__(self, t_start, t_end, interval, grua_units, density):
+        super(SparsifyGRUB, self).__init__()
+        self.batch = 0
+        self.t_start = t_start
+        self.t_end = t_end
+        self.interval = interval
+        self.final_density = density
+        self.grua_units = grua_units
+
+    def on_batch_end(self, batch, logs=None):
+        #print("batch number", self.batch)
+        self.batch += 1
+        if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end):
+            #print("don't constrain");
+            pass
+        else:
+            #print("constrain");
+            layer = self.model.get_layer('gru_b')
+            w = layer.get_weights()
+            p = w[0]
+            N = p.shape[0]
+            M = p.shape[1]//3
+            for k in range(3):
+                density = self.final_density[k]
+                if self.batch < self.t_end:
+                    r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
+                    density = 1 - (1-self.final_density[k])*(1 - r*r*r)
+                A = p[:, k*M:(k+1)*M]
+                #This is needed because of the CuDNNGRU strange weight ordering
+                A = np.reshape(A, (M, N))
+                A = np.transpose(A, (1, 0))
+                N2 = self.grua_units
+                A2 = A[:N2, :]
+                L=np.reshape(A2, (N2//4, 4, M//8, 8))
+                S=np.sum(L*L, axis=-1)
+                S=np.sum(S, axis=1)
+                SS=np.sort(np.reshape(S, (-1,)))
+                thresh = SS[round(M*N2//32*(1-density))]
+                mask = (S>=thresh).astype('float32');
+                mask = np.repeat(mask, 4, axis=0)
+                mask = np.repeat(mask, 8, axis=1)
+                A = np.concatenate([A2*mask, A[N2:,:]], axis=0)
+                #This is needed because of the CuDNNGRU strange weight ordering
+                A = np.transpose(A, (1, 0))
+                A = np.reshape(A, (N, M))
+                p[:, k*M:(k+1)*M] = A
+                #print(thresh, np.mean(mask))
+            w[0] = p
+            layer.set_weights(w)
             
 
 class PCMInit(Initializer):
-- 
GitLab