Skip to content
Snippets Groups Projects
Commit c76756e1 authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

Adding sparse training for GRU B inputs

parent 8bdbbfa1
No related branches found
No related tags found
No related merge requests found
......@@ -116,6 +116,57 @@ class Sparsify(Callback):
#print(thresh, np.mean(mask))
w[1] = p
layer.set_weights(w)
class SparsifyGRUB(Callback):
def __init__(self, t_start, t_end, interval, grua_units, density):
super(SparsifyGRUB, self).__init__()
self.batch = 0
self.t_start = t_start
self.t_end = t_end
self.interval = interval
self.final_density = density
self.grua_units = grua_units
def on_batch_end(self, batch, logs=None):
#print("batch number", self.batch)
self.batch += 1
if self.batch < self.t_start or ((self.batch-self.t_start) % self.interval != 0 and self.batch < self.t_end):
#print("don't constrain");
pass
else:
#print("constrain");
layer = self.model.get_layer('gru_b')
w = layer.get_weights()
p = w[0]
N = p.shape[0]
M = p.shape[1]//3
for k in range(3):
density = self.final_density[k]
if self.batch < self.t_end:
r = 1 - (self.batch-self.t_start)/(self.t_end - self.t_start)
density = 1 - (1-self.final_density[k])*(1 - r*r*r)
A = p[:, k*M:(k+1)*M]
#This is needed because of the CuDNNGRU strange weight ordering
A = np.reshape(A, (M, N))
A = np.transpose(A, (1, 0))
N2 = self.grua_units
A2 = A[:N2, :]
L=np.reshape(A2, (N2//4, 4, M//8, 8))
S=np.sum(L*L, axis=-1)
S=np.sum(S, axis=1)
SS=np.sort(np.reshape(S, (-1,)))
thresh = SS[round(M*N2//32*(1-density))]
mask = (S>=thresh).astype('float32');
mask = np.repeat(mask, 4, axis=0)
mask = np.repeat(mask, 8, axis=1)
A = np.concatenate([A2*mask, A[N2:,:]], axis=0)
#This is needed because of the CuDNNGRU strange weight ordering
A = np.transpose(A, (1, 0))
A = np.reshape(A, (N, M))
p[:, k*M:(k+1)*M] = A
#print(thresh, np.mean(mask))
w[0] = p
layer.set_weights(w)
class PCMInit(Initializer):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment