Unverified Commit 92739d88 authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

New features (#9)

Also, adding an error^4 term to the loss function
parent c60a7634
...@@ -563,7 +563,7 @@ int main(int argc, char **argv) { ...@@ -563,7 +563,7 @@ int main(int argc, char **argv) {
float E=0; float E=0;
if (++gain_change_count > 101*300) { if (++gain_change_count > 101*300) {
speech_gain = pow(10., (-40+(rand()%60))/20.); speech_gain = pow(10., (-40+(rand()%60))/20.);
noise_gain = pow(10., (-30+(rand()%40))/20.); noise_gain = pow(10., (-30+(rand()%50))/20.);
if (rand()%10==0) noise_gain = 0; if (rand()%10==0) noise_gain = 0;
noise_gain *= speech_gain; noise_gain *= speech_gain;
if (rand()%10==0) speech_gain = 0; if (rand()%10==0) speech_gain = 0;
......
This diff is collapsed.
...@@ -38,7 +38,7 @@ def msse(y_true, y_pred): ...@@ -38,7 +38,7 @@ def msse(y_true, y_pred):
return K.mean(mymask(y_true) * K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1) return K.mean(mymask(y_true) * K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1)
def mycost(y_true, y_pred): def mycost(y_true, y_pred):
return K.mean(mymask(y_true) * (K.square(K.sqrt(y_pred) - K.sqrt(y_true)) + 0.01*K.binary_crossentropy(y_pred, y_true)), axis=-1) return K.mean(mymask(y_true) * (10*K.square(K.square(K.sqrt(y_pred) - K.sqrt(y_true))) + K.square(K.sqrt(y_pred) - K.sqrt(y_true)) + 0.01*K.binary_crossentropy(y_pred, y_true)), axis=-1)
def my_accuracy(y_true, y_pred): def my_accuracy(y_true, y_pred):
return K.mean(2*K.abs(y_true-0.5) * K.equal(y_true, K.round(y_pred)), axis=-1) return K.mean(2*K.abs(y_true-0.5) * K.equal(y_true, K.round(y_pred)), axis=-1)
...@@ -82,7 +82,7 @@ model.compile(loss=[mycost, my_crossentropy], ...@@ -82,7 +82,7 @@ model.compile(loss=[mycost, my_crossentropy],
batch_size = 32 batch_size = 32
print('Loading data...') print('Loading data...')
with h5py.File('denoise_data6.h5', 'r') as hf: with h5py.File('denoise_data9.h5', 'r') as hf:
all_data = hf['data'][:] all_data = hf['data'][:]
print('done.') print('done.')
...@@ -113,4 +113,4 @@ model.fit(x_train, [y_train, vad_train], ...@@ -113,4 +113,4 @@ model.fit(x_train, [y_train, vad_train],
batch_size=batch_size, batch_size=batch_size,
epochs=120, epochs=120,
validation_split=0.1) validation_split=0.1)
model.save("newweights6c.hdf5") model.save("newweights9i.hdf5")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment