diff --git a/src/analysis.c b/src/analysis.c index 54005d3a30607df37647830191165440dac6ca22..34ea5107345f989eb2786b3bdfcc73aac0e0017a 100644 --- a/src/analysis.c +++ b/src/analysis.c @@ -141,7 +141,6 @@ static inline float fast_atan2f(float y, float x) { void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int len) { -#if 1 int pos; int curr_lookahead; float psum; @@ -184,31 +183,6 @@ void tonality_get_info(TonalityAnalysisState *tonal, AnalysisInfo *info_out, int /*printf("%f %f\n", psum, info_out->music_prob);*/ info_out->music_prob = psum; -#else - /* If data not available, return invalid */ - if (tonal->read_pos==tonal->write_pos) - { - info_out->valid=0; - return; - } - - OPUS_COPY(info_out, &tonal->info[tonal->read_pos], 1); - tonal->read_subframe += len/480; - while (tonal->read_subframe>=4) - { - tonal->read_subframe -= 4; - tonal->read_pos++; - } - if (tonal->read_pos>=DETECT_SIZE) - tonal->read_pos-=DETECT_SIZE; - if (tonal->read_pos == tonal->write_pos) - { - tonal->read_pos = tonal->write_pos-1; - if (tonal->read_pos<0) - tonal->read_pos=DETECT_SIZE-1; - tonal->read_subframe = 3; - } -#endif } void tonality_analysis(TonalityAnalysisState *tonal, AnalysisInfo *info_out, const CELTMode *celt_mode, const void *x, int len, int offset, int C, int lsb_depth, downmix_func downmix) @@ -234,7 +208,7 @@ void tonality_analysis(TonalityAnalysisState *tonal, AnalysisInfo *info_out, con float slope=0; float frame_stationarity; float relativeE; - float frame_prob; + float frame_probs[2]; float alpha, alphaE, alphaE2; float frame_loudness; float bandwidth_mask; @@ -494,32 +468,34 @@ void tonality_analysis(TonalityAnalysisState *tonal, AnalysisInfo *info_out, con features[24] = tonal->lowECount; #ifndef FIXED_POINT - mlp_process(&net, features, &frame_prob); - frame_prob = .5f*(frame_prob+1); + mlp_process(&net, features, frame_probs); + frame_probs[0] = .5f*(frame_probs[0]+1); /* Curve fitting between the MLP probability and the actual probability */ - frame_prob = .01f + 1.21f*frame_prob*frame_prob - .23f*(float)pow(frame_prob, 10); + frame_probs[0] = .01f + 1.21f*frame_probs[0]*frame_probs[0] - .23f*(float)pow(frame_probs[0], 10); + frame_probs[1] = .5*frame_probs[1]+.5; + frame_probs[0] = frame_probs[1]*frame_probs[0] + (1-frame_probs[1])*.5; - /*printf("%f\n", frame_prob);*/ + /*printf("%f %f ", frame_probs[0], frame_probs[1]);*/ { float tau, beta; float p0, p1; float max_certainty; /* One transition every 3 minutes */ - tau = .00005f; - beta = .1f; + tau = .00005f*frame_probs[1]; + beta = .05f; max_certainty = .01f+1.f/(20.f+.5f*tonal->last_transition); max_certainty = 0; p0 = (1-tonal->music_prob)*(1-tau) + tonal->music_prob *tau; p1 = tonal->music_prob *(1-tau) + (1-tonal->music_prob)*tau; - p0 *= (float)pow(1-frame_prob, beta); - p1 *= (float)pow(frame_prob, beta); + p0 *= (float)pow(1-frame_probs[0], beta); + p1 *= (float)pow(frame_probs[0], beta); tonal->music_prob = MAX16(max_certainty, MIN16(1-max_certainty, p1/(p0+p1))); info->music_prob = tonal->music_prob; - info->music_prob = frame_prob; + info->music_prob = frame_probs[0]; float psum=1e-20; - float speech0 = (float)pow(1-frame_prob, beta); - float music0 = (float)pow(frame_prob, beta); + float speech0 = (float)pow(1-frame_probs[0], beta); + float music0 = (float)pow(frame_probs[0], beta); if (tonal->count==1) { tonal->pspeech[0]=.5; @@ -550,7 +526,7 @@ void tonality_analysis(TonalityAnalysisState *tonal, AnalysisInfo *info_out, con for (i=1;i<DETECT_SIZE;i++) psum += tonal->pspeech[i]; - /*printf("%f %f %f\n", frame_prob, info->music_prob, psum);*/ + /*printf("%f\n", psum);*/ } if (tonal->last_music != (tonal->music_prob>.5f)) tonal->last_transition=0; diff --git a/src/mlp_data.c b/src/mlp_data.c index 5c13ca408df136c107d06e5773bea7b41e5032d4..9085b85faa3f95754258f850cb25f28c751c5313 100644 --- a/src/mlp_data.c +++ b/src/mlp_data.c @@ -3,74 +3,103 @@ #include "mlp.h" -/* RMS error was 0.179835, seed was 1322103961 */ +/* RMS error was 0.138320, seed was 1361535663 */ -static const float weights[271] = { +static const float weights[422] = { /* hidden layer */ -1.55597f, -0.0739792f, -0.0646761f, -0.099531f, -0.0794943f, -0.0180174f, -0.0391354f, 0.0508224f, -0.0160169f, -0.0773263f, --0.0300002f, -0.0865361f, 0.124477f, -0.28648f, -0.0860702f, --0.518949f, -0.0873341f, -0.235393f, -0.907833f, -0.383573f, -0.535388f, -0.57944f, 0.98116f, 0.8482f, 1.12426f, --3.23721f, -0.647072f, -0.0265139f, 0.0711052f, -0.00125666f, --0.0396181f, -0.44282f, -0.510495f, -0.201865f, 0.0134336f, --0.167205f, -0.155406f, 0.00041678f, -0.00468705f, -0.0233224f, -0.264279f, -0.301375f, 0.00234895f, 0.0144741f, -0.137535f, -0.200323f, 0.0192027f, 3.19818f, 2.03495f, 0.705517f, --4.6025f, -0.11485f, -0.792716f, 0.150714f, 0.10608f, -0.240633f, 0.0690698f, 0.0695297f, 0.124819f, 0.0501433f, -0.0460952f, 0.147639f, 0.10327f, 0.158007f, 0.113714f, -0.0276191f, 0.0680749f, -0.130012f, 0.0796126f, 0.133067f, -0.51495f, 0.747578f, -0.128742f, 5.98112f, -1.16698f, --0.276492f, -1.73549f, -3.90234f, 2.01489f, -0.040118f, --0.113002f, -0.146751f, -0.113569f, 0.0534873f, 0.0989832f, -0.0872875f, 0.049266f, 0.0367557f, -0.00889148f, -0.0648461f, --0.00190352f, 0.0143773f, 0.0259364f, -0.0592133f, -0.0672924f, -0.1399f, -0.0987886f, -0.347402f, 0.101326f, -0.0680876f, -0.469186f, 0.246922f, 10.4017f, 3.44846f, -0.662725f, --0.0328208f, -0.0561274f, -0.0167744f, 0.00044282f, -0.0457645f, --0.0408314f, -0.013113f, -0.0373873f, -0.0474122f, -0.0273745f, --0.0308505f, 0.000582959f, -0.0421135f, 0.464859f, 0.196842f, -0.320538f, 0.0435528f, -0.200168f, 0.266475f, -0.0853727f, -1.20397f, 0.711542f, -1.04397f, -1.47759f, 1.26768f, -0.446958f, 0.266477f, -0.30802f, 0.28431f, -0.118541f, -0.00836345f, 0.0689026f, -0.0137996f, -0.0395417f, 0.26982f, --0.206255f, 0.16066f, 0.114757f, 0.359587f, -0.106503f, --0.0948534f, 0.175358f, -0.122966f, -0.0056675f, 0.483848f, --0.134916f, -0.427567f, -0.140172f, -1.0866f, -2.73921f, -0.549843f, 0.17685f, 0.0010675f, -0.00137386f, 0.0884424f, --0.0698736f, -0.00174136f, 0.0718775f, -0.0396849f, 0.0448056f, -0.0577853f, -0.0372353f, 0.134599f, 0.0260656f, 0.140322f, -0.22704f, -0.020568f, -0.0142424f, -0.21723f, -0.997704f, --0.884573f, -0.163495f, 2.33617f, 0.224142f, 0.19635f, --0.957387f, 0.144678f, 1.47035f, -0.00700498f, -0.0472309f, --0.0137848f, -0.0189145f, 0.00856479f, 0.0316965f, 0.00613373f, -0.00209807f, 0.00270964f, -0.0490206f, 0.0105712f, -0.0465045f, --0.0381532f, -0.0985268f, -0.108297f, 0.0146409f, -0.0040718f, --0.0698572f, -0.380568f, -0.230479f, 3.98917f, 0.457652f, --1.02355f, -7.4435f, -0.475314f, 1.61743f, 0.0254017f, --0.00791293f, 0.047217f, 0.0220995f, -0.0304311f, 0.0052168f, --0.0404054f, -0.0230293f, 0.00169229f, -0.0138178f, 0.0043137f, --0.0598088f, -0.133601f, 0.0555138f, -0.177358f, -0.159856f, --0.137281f, 0.108051f, -0.305973f, 0.393775f, 0.0747287f, -0.783993f, -0.875086f, 1.06862f, 0.340519f, -0.352681f, --0.0830912f, -0.100017f, 0.0729085f, -0.00829403f, 0.027489f, --0.0779597f, 0.082286f, -0.164181f, -0.41519f, 0.00282335f, --0.29573f, 0.125571f, 0.726935f, 0.392137f, 0.491348f, -0.0723196f, -0.0259758f, -0.0636332f, -0.452384f, -0.000225974f, --2.34001f, 2.45211f, -0.544628f, 5.62944f, -3.44507f, +-0.0941125f, -0.302976f, -0.603555f, -0.19393f, -0.185983f, +-0.601617f, -0.0465317f, -0.114563f, -0.103599f, -0.618938f, +-0.317859f, -0.169949f, -0.0702885f, 0.148065f, 0.409524f, +0.548432f, 0.367649f, -0.494393f, 0.764306f, -1.83957f, +0.170849f, 12.786f, -1.08848f, -1.27284f, -16.2606f, +24.1773f, -5.57454f, -0.17276f, -0.163388f, -0.224421f, +-0.0948944f, -0.0728695f, -0.26557f, -0.100283f, -0.0515459f, +-0.146142f, -0.120674f, -0.180655f, 0.12857f, 0.442138f, +-0.493735f, 0.167767f, 0.206699f, -0.197567f, 0.417999f, +1.50364f, -0.773341f, -10.0401f, 0.401872f, 2.97966f, +15.2165f, -1.88905f, -1.19254f, 0.0285397f, -0.00405139f, +0.0707565f, 0.00825699f, -0.0927269f, -0.010393f, -0.00428882f, +-0.00489743f, -0.0709731f, -0.00255992f, 0.0395619f, 0.226424f, +0.0325231f, 0.162175f, -0.100118f, 0.485789f, 0.12697f, +0.285937f, 0.0155637f, 0.10546f, 3.05558f, 1.15059f, +-1.00904f, -1.83088f, 3.31766f, -3.42516f, -0.119135f, +-0.0405654f, 0.00690068f, 0.0179877f, -0.0382487f, 0.00597941f, +-0.0183611f, 0.00190395f, -0.144322f, -0.0435671f, 0.000990594f, +0.221087f, 0.142405f, 0.484066f, 0.404395f, 0.511955f, +-0.237255f, 0.241742f, 0.35045f, -0.699428f, 10.3993f, +2.6507f, -2.43459f, -4.18838f, 1.05928f, 1.71067f, +0.00667811f, -0.0721335f, -0.0397346f, 0.0362704f, -0.11496f, +-0.0235776f, 0.0082161f, -0.0141741f, -0.0329699f, -0.0354253f, +0.00277404f, -0.290654f, -1.14767f, -0.319157f, -0.686544f, +0.36897f, 0.478899f, 0.182579f, -0.411069f, 0.881104f, +-4.60683f, 1.4697f, 0.335845f, -1.81905f, -30.1699f, +5.55225f, 0.0019508f, -0.123576f, -0.0727332f, -0.0641597f, +-0.0534458f, -0.108166f, -0.0937368f, -0.0697883f, -0.0275475f, +-0.192309f, -0.110074f, 0.285375f, -0.405597f, 0.0926724f, +-0.287881f, -0.851193f, -0.099493f, -0.233764f, -1.2852f, +1.13611f, 3.12168f, -0.0699f, -1.86216f, 2.65292f, +-7.31036f, 2.44776f, -0.00111802f, -0.0632786f, -0.0376296f, +-0.149851f, 0.142963f, 0.184368f, 0.123433f, 0.0756158f, +0.117312f, 0.0933395f, 0.0692163f, 0.0842592f, 0.0704683f, +0.0589963f, 0.0942205f, -0.448862f, 0.0262677f, 0.270352f, +-0.262317f, 0.172586f, 2.00227f, -0.159216f, 0.038422f, +10.2073f, 4.15536f, -2.3407f, -0.0550265f, 0.00964792f, +-0.141336f, 0.0274501f, 0.0343921f, -0.0487428f, 0.0950172f, +-0.00775017f, -0.0372492f, -0.00548121f, -0.0663695f, 0.0960506f, +-0.200008f, -0.0412827f, 0.58728f, 0.0515787f, 0.337254f, +0.855024f, 0.668371f, -0.114904f, -3.62962f, -0.467477f, +-0.215472f, 2.61537f, 0.406117f, -1.36373f, 0.0425394f, +0.12208f, 0.0934502f, 0.123055f, 0.0340935f, -0.142466f, +0.035037f, -0.0490666f, 0.0733208f, 0.0576672f, 0.123984f, +-0.0517194f, -0.253018f, 0.590565f, 0.145849f, 0.315185f, +0.221534f, -0.149081f, 0.216161f, -0.349575f, 24.5664f, +-0.994196f, 0.614289f, -18.7905f, -2.83277f, -0.716801f, +-0.347201f, 0.479515f, -0.246027f, 0.0758683f, 0.137293f, +-0.17781f, 0.118751f, -0.00108329f, -0.237334f, 0.355732f, +-0.12991f, -0.0547627f, -0.318576f, -0.325524f, 0.180494f, +-0.0625604f, 0.141219f, 0.344064f, 0.37658f, -0.591772f, +5.8427f, -0.38075f, 0.221894f, -1.41934f, -1.87943e+06f, +1.34114f, 0.0283355f, -0.0447856f, -0.0211466f, -0.0256927f, +0.0139618f, 0.0207934f, -0.0107666f, 0.0110969f, 0.0586069f, +-0.0253545f, -0.0328433f, 0.11872f, -0.216943f, 0.145748f, +0.119808f, -0.0915211f, -0.120647f, -0.0787719f, -0.143644f, +-0.595116f, -1.152f, -1.25335f, -1.17092f, 4.34023f, +-975268.f, -1.37033f, -0.0401123f, 0.210602f, -0.136656f, +0.135962f, -0.0523293f, 0.0444604f, 0.0143928f, 0.00412666f, +-0.0193003f, 0.218452f, -0.110204f, -2.02563f, 0.918238f, +-2.45362f, 1.19542f, -0.061362f, -1.92243f, 0.308111f, +0.49764f, 0.912356f, 0.209272f, -2.34525f, 2.19326f, +-6.47121f, 1.69771f, -0.725123f, 0.0118929f, 0.0377944f, +0.0554003f, 0.0226452f, -0.0704421f, -0.0300309f, 0.0122978f, +-0.0041782f, -0.0686612f, 0.0313115f, 0.039111f, 0.364111f, +-0.0945548f, 0.0229876f, -0.17414f, 0.329795f, 0.114714f, +0.30022f, 0.106997f, 0.132355f, 5.79932f, 0.908058f, +-0.905324f, -3.3561f, 0.190647f, 0.184211f, -0.673648f, +0.231807f, -0.0586222f, 0.230752f, -0.438277f, 0.245857f, +-0.17215f, 0.0876383f, -0.720512f, 0.162515f, 0.0170571f, +0.101781f, 0.388477f, 1.32931f, 1.08548f, -0.936301f, +-2.36958f, -6.71988f, -3.44376f, 2.13818f, 14.2318f, +4.91459f, -3.09052f, -9.69191f, -0.768234f, 1.79604f, +0.0549653f, 0.163399f, 0.0797025f, 0.0343933f, -0.0555876f, +-0.00505673f, 0.0187258f, 0.0326628f, 0.0231486f, 0.15573f, +0.0476223f, -0.254824f, 1.60155f, -0.801221f, 2.55496f, +0.737629f, -1.36249f, -0.695463f, -2.44301f, -1.73188f, +3.95279f, 1.89068f, 0.486087f, -11.3343f, 3.9416e+06f, /* output layer */ --3.13835f, 0.994751f, 0.444901f, 1.59518f, 1.23665f, -3.37012f, -1.34606f, 1.99131f, 1.33476f, 1.3885f, -1.12559f, }; +-0.381439, 0.12115, -0.906927, 2.93878, 1.6388, +0.882811, 0.874344, 1.21726, -0.874545, 0.321706, +0.785055, 0.946558, -0.575066, -3.46553, 0.884905, +0.0924047, -9.90712, 0.391338, 0.160103, -2.04954, +4.1455, 0.0684029, -0.144761, -0.285282, 0.379244, +-1.1584, -0.0277241, -9.85, -4.82386, 3.71333, +3.87308, 3.52558, }; -static const int topo[3] = {25, 10, 1}; +static const int topo[3] = {25, 15, 2}; const MLP net = { - 3, - topo, - weights + 3, + topo, + weights }; - diff --git a/src/mlp_train.c b/src/mlp_train.c index 5fbbff082f96d631dcfe430dc8bf98755cb8503e..2e9568ba4e15b7174716bc3644899fba56064d62 100644 --- a/src/mlp_train.c +++ b/src/mlp_train.c @@ -106,6 +106,7 @@ MLPTrain * mlp_init(int *topo, int nbLayers, float *inputs, float *outputs, int } #define MAX_NEURONS 100 +#define MAX_OUT 10 double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamples, double *W0_grad, double *W1_grad, double *error_rate) { @@ -120,7 +121,8 @@ double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamp double netOut[MAX_NEURONS]; double error[MAX_NEURONS]; - *error_rate = 0; + for (i=0;i<outDim;i++) + error_rate[i] = 0; topo = net->topo; inDim = net->topo[0]; hiddenDim = net->topo[1]; @@ -153,7 +155,7 @@ double compute_gradient(MLPTrain *net, float *inputs, float *outputs, int nbSamp netOut[i] = tansig_approx(sum); error[i] = out[i] - netOut[i]; rms += error[i]*error[i]; - *error_rate += fabs(error[i])>1; + error_rate[i] += fabs(error[i])>1; /*error[i] = error[i]/(1+fabs(error[i]));*/ } /* Back-propagate error */ @@ -194,7 +196,7 @@ struct GradientArg { double *W0_grad; double *W1_grad; double rms; - double error_rate; + double error_rate[MAX_OUT]; }; void *gradient_thread_process(void *_arg) @@ -213,7 +215,7 @@ void *gradient_thread_process(void *_arg) sem_wait(&sem_begin[arg->id]); if (arg->done) break; - arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, &arg->error_rate); + arg->rms = compute_gradient(arg->net, arg->inputs, arg->outputs, arg->nbSamples, arg->W0_grad, arg->W1_grad, arg->error_rate); sem_post(&sem_end[arg->id]); } fprintf(stderr, "done\n"); @@ -295,7 +297,7 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam for (e=0;e<nbEpoch;e++) { double rms=0; - double error_rate = 0; + double error_rate[2] = {0,0}; for (i=0;i<NB_THREADS;i++) { sem_post(&sem_begin[i]); @@ -306,7 +308,8 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam { sem_wait(&sem_end[i]); rms += args[i].rms; - error_rate += args[i].error_rate; + error_rate[0] += args[i].error_rate[0]; + error_rate[1] += args[i].error_rate[1]; for (j=0;j<W0_size;j++) W0_grad[j] += args[i].W0_grad[j]; for (j=0;j<W1_size;j++) @@ -315,8 +318,9 @@ float mlp_train_backprop(MLPTrain *net, float *inputs, float *outputs, int nbSam float mean_rate = 0, min_rate = 1e10; rms = (rms/(outDim*nbSamples)); - error_rate = (error_rate/(outDim*nbSamples)); - fprintf (stderr, "%f (%f %f) ", error_rate, rms, best_rms); + error_rate[0] = (error_rate[0]/(nbSamples)); + error_rate[1] = (error_rate[1]/(nbSamples)); + fprintf (stderr, "%f %f (%f %f) ", error_rate[0], error_rate[1], rms, best_rms); if (rms < best_rms) { best_rms = rms; @@ -445,6 +449,7 @@ int main(int argc, char **argv) outputs = malloc(nbOutputs*nbSamples*sizeof(*outputs)); seed = time(NULL); + /*seed = 1361480659;*/ fprintf (stderr, "Seed is %u\n", seed); srand(seed); build_tansig_table();