Skip to content
Snippets Groups Projects
Commit e1181bca authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

oops, fix band loss

parent c8cbfa7e
No related branches found
No related tags found
No related merge requests found
......@@ -99,8 +99,8 @@ def plc_loss(alpha=1.0):
mask = y_true[:,:,-1:]
y_true = y_true[:,:,:-1]
e = (y_true - y_pred)*mask
e_bands = tf.signal.idct(e, norm='ortho')
l1_loss = K.mean(K.abs(e) + alpha*K.abs(e_bands))
e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
l1_loss = K.mean(K.abs(e)) + alpha*K.mean(K.abs(e_bands))
return l1_loss
return loss
......@@ -118,7 +118,7 @@ def plc_band_loss():
mask = y_true[:,:,-1:]
y_true = y_true[:,:,:-1]
e = (y_true - y_pred)*mask
e_bands = tf.signal.idct(e, norm='ortho')
e_bands = tf.signal.idct(e[:,:,:-2], norm='ortho')
l1_loss = K.mean(K.abs(e_bands))
return l1_loss
return L1_band_loss
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment