Skip to content
Snippets Groups Projects
Commit 3a632721 authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

Fixes a few training issues

parent d9e062e0
No related branches found
No related tags found
No related merge requests found
......@@ -149,7 +149,7 @@ double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamp
netOut[i] = tansig_approx(sum);
error[i] = out[i] - netOut[i];
rms += error[i]*error[i];
*error_rate += fabs(error[i])>.5;
*error_rate += fabs(error[i])>1;
}
/* Back-propagate error */
for (i=0;i<outDim;i++)
......@@ -301,22 +301,22 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam
}
float mean_rate = 0, min_rate = 1e10;
rms = sqrt(rms/(outDim*nbSamples));
rms = (rms/(outDim*nbSamples));
error_rate = (error_rate/(outDim*nbSamples));
fprintf (stderr, "%f (%f) ", error_rate, last_rms);
for (i=0;i<W0_size;i++)
{
if (W0_oldgrad[i]*W0_grad[i] >= 0)
if (W0_oldgrad[i]*W0_grad[i] > 0)
W0_rate[i] *= 1.01;
else
else if (W0_oldgrad[i]*W0_grad[i] < 0)
W0_rate[i] *= .9;
mean_rate += W0_rate[i];
if (W0_rate[i] < min_rate)
min_rate = W0_rate[i];
if (W0_rate[i] < 1e-15)
W0_rate[i] = 1e-15;
if (W0_rate[i] > 1)
W0_rate[i] = 1;
if (W0_rate[i] > .01)
W0_rate[i] = .01;
W0_oldgrad[i] = W0_grad[i];
W0_old2[i] = W0_old[i];
W0_old[i] = W0[i];
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment