From c76756e18a8a04bdcbbb4462770f423587725b26 Mon Sep 17 00:00:00 2001 From: Jean-Marc Valin <jmvalin@amazon.com> Date: Sun, 18 Jul 2021 02:24:21 -0400 Subject: [PATCH] Adding sparse training for GRU B inputs --- dnn/training_tf2/lpcnet.py | 51 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/dnn/training_tf2/lpcnet.py b/dnn/training_tf2/lpcnet.py index 368f2750d..e4346c3ec 100644 --- a/dnn/training_tf2/lpcnet.py +++ b/dnn/training_tf2/lpcnet.py @@ -116,6 +116,57 @@ class Sparsify(Callback): #print(thresh, np.mean(mask)) w[1] = p layer.set_weights(w) + +class SparsifyGRUB(Callback): + def __init__(self, t_start, t_end, interval, grua_units, density): + super(SparsifyGRUB, self).__init__() + self.batch = 0 + self.t_start = t_start + self.t_end = t_end + self.interval = interval + self.final_density = density + self.grua_units = grua_units + + def on_batch_end(self, batch, logs=None): + #print("batch number", self.batch) + self.batch += 1 + if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end): + #print("don't constrain"); + pass + else: + #print("constrain"); + layer = self.model.get_layer('gru_b') + w = layer.get_weights() + p = w[0] + N = p.shape[0] + M = p.shape[1]//3 + for k in range(3): + density = self.final_density[k] + if self.batch < self.t_end: + r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start) + density = 1 - (1-self.final_density[k])*(1 - r*r*r) + A = p[:, k*M:(k+1)*M] + #This is needed because of the CuDNNGRU strange weight ordering + A = np.reshape(A, (M, N)) + A = np.transpose(A, (1, 0)) + N2 = self.grua_units + A2 = A[:N2, :] + L=np.reshape(A2, (N2//4, 4, M//8, 8)) + S=np.sum(L*L, axis=-1) + S=np.sum(S, axis=1) + SS=np.sort(np.reshape(S, (-1,))) + thresh = SS[round(M*N2//32*(1-density))] + mask = (S>=thresh).astype('float32'); + mask = np.repeat(mask, 4, axis=0) + mask = np.repeat(mask, 8, axis=1) + A = np.concatenate([A2*mask, A[N2:,:]], axis=0) + #This is needed because of the CuDNNGRU strange weight ordering + A = np.transpose(A, (1, 0)) + A = np.reshape(A, (N, M)) + p[:, k*M:(k+1)*M] = A + #print(thresh, np.mean(mask)) + w[0] = p + layer.set_weights(w) class PCMInit(Initializer): -- GitLab