diff --git a/dnn/ceps_vq_train.c b/dnn/ceps_vq_train.c index 8a3b5e4767bcc6e12cdc75021646cea73c0c831d..ed54239ad48f00fd4b9aea6fddf2d3d3d64c3469 100644 --- a/dnn/ceps_vq_train.c +++ b/dnn/ceps_vq_train.c @@ -6,9 +6,12 @@ #include <math.h> #define MIN(a,b) ((a)<(b)?(a):(b)) -#define COEF 0.75f +#define COEF 0.0f #define MAX_ENTRIES 16384 +#define MULTI 4 +#define MULTI_MASK (MULTI-1) + void compute_weights(const float *x, float *w, int ndim) { int i; @@ -48,6 +51,31 @@ int find_nearest(const float *codebook, int nb_entries, const float *x, int ndim return nearest; } +int find_nearest_multi(const float *codebook, int nb_entries, const float *x, int ndim, float *dist) +{ + int i, j; + float min_dist = 1e15; + int nearest = 0; + + for (i=0;i<nb_entries;i++) + { + int offset; + float dist=0; + offset = (i&MULTI_MASK)*ndim; + for (j=0;j<ndim;j++) + dist += (x[offset+j]-codebook[i*ndim+j])*(x[offset+j]-codebook[i*ndim+j]); + if (dist<min_dist) + { + min_dist = dist; + nearest = i; + } + } + if (dist) + *dist = min_dist; + return nearest; +} + + int find_nearest_weighted(const float *codebook, int nb_entries, float *x, const float *w, int ndim) { int i, j; @@ -203,6 +231,45 @@ void update(float *data, int nb_vectors, float *codebook, int nb_entries, int nd //fprintf(stderr, "%f / %d\n", 1./w2, nb_entries); } +void update_multi(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim) +{ + int i,j; + int count[nb_entries]; + int nearest[nb_vectors]; + double err=0; + + for (i=0;i<nb_entries;i++) + count[i] = 0; + + for (i=0;i<nb_vectors;i++) + { + float dist; + nearest[i] = find_nearest_multi(codebook, nb_entries, data+MULTI*i*ndim, ndim, &dist); + err += dist; + } + printf("RMS error = %f\n", sqrt(err/nb_vectors/ndim)); + for (i=0;i<nb_entries*ndim;i++) + codebook[i] = 0; + + for (i=0;i<nb_vectors;i++) + { + int n = nearest[i]; + count[n]++; + for (j=0;j<ndim;j++) + codebook[n*ndim+j] += data[(MULTI*i + (n&MULTI_MASK))*ndim+j]; + } + + float w2=0; + for (i=0;i<nb_entries;i++) + { + for (j=0;j<ndim;j++) + codebook[i*ndim+j] *= (1./count[i]); + w2 += (count[i]/(float)nb_vectors)*(count[i]/(float)nb_vectors); + } + //fprintf(stderr, "%f / %d\n", 1./w2, nb_entries); +} + + void update_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim) { int i,j; @@ -271,6 +338,38 @@ void vq_train(float *data, int nb_vectors, float *codebook, int nb_entries, int update(data, nb_vectors, codebook, e, ndim); } +void vq_train_multi(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim) +{ + int i, j, e; + for (e=0;e<MULTI;e++) { + for (j=0;j<ndim;j++) + codebook[e*ndim+j] = 0; + for (i=0;i<nb_vectors;i++) + for (j=0;j<ndim;j++) + codebook[e*ndim+j] += data[(MULTI*i+e)*ndim+j]; + for (j=0;j<ndim;j++) { + float delta = .01*(rand()/(float)RAND_MAX-.5); + codebook[e*ndim+j] *= (1./nb_vectors); + codebook[e*ndim+j] += delta; + } + } + e = MULTI; + for (j=0;j<10;j++) + update_multi(data, nb_vectors, codebook, e, ndim); + + while (e < nb_entries) + { + split(codebook, e, ndim); + e<<=1; + fprintf(stderr, "%d\n", e); + for (j=0;j<4;j++) + update_multi(data, nb_vectors, codebook, e, ndim); + } + for (j=0;j<ndim*2;j++) + update_multi(data, nb_vectors, codebook, e, ndim); +} + + void vq_train_weighted(float *data, float *weight, int nb_vectors, float *codebook, int nb_entries, int ndim) { int i, j, e; @@ -303,8 +402,9 @@ void vq_train_weighted(float *data, float *weight, int nb_vectors, float *codebo int main(int argc, char **argv) { int i,j; - int nb_vectors, nb_entries, ndim, ndim0, total_dim; - float *data, *pred, *codebook, *codebook2; + int nb_vectors, nb_entries, nb_entries2, ndim, ndim0, total_dim; + float *data, *pred, *multi_data, *multi_data2, *qdata; + float *codebook, *codebook2, *codebook_diff2, *codebook_diff4; float *delta; double err; FILE *fout; @@ -314,11 +414,17 @@ int main(int argc, char **argv) total_dim = atoi(argv[2]); nb_vectors = atoi(argv[3]); nb_entries = 1<<atoi(argv[4]); + nb_entries2 = 64; data = malloc((nb_vectors*ndim+total_dim)*sizeof(*data)); + qdata = malloc((nb_vectors*ndim+total_dim)*sizeof(*qdata)); pred = malloc(nb_vectors*ndim0*sizeof(*pred)); + multi_data = malloc(MULTI*nb_vectors*ndim*sizeof(*multi_data)); + multi_data2 = malloc(MULTI*nb_vectors*ndim*sizeof(*multi_data)); codebook = malloc(nb_entries*ndim0*sizeof(*codebook)); codebook2 = malloc(nb_entries*ndim0*sizeof(*codebook2)); + codebook_diff4 = malloc(nb_entries*ndim*sizeof(*codebook_diff4)); + codebook_diff2 = malloc(nb_entries2*ndim*sizeof(*codebook_diff2)); for (i=0;i<nb_vectors;i++) { @@ -348,8 +454,10 @@ int main(int argc, char **argv) for (i=0;i<nb_vectors;i++) { int nearest = find_nearest(codebook, nb_entries, &pred[i*ndim0], ndim0, NULL); + qdata[i*ndim+j] = data[i*ndim+j]; for (j=0;j<ndim0;j++) { + qdata[i*ndim+j+1] = codebook[nearest*ndim0+j]; delta[i*ndim0+j] = pred[i*ndim0+j] - codebook[nearest*ndim0+j]; err += delta[i*ndim0+j]*delta[i*ndim0+j]; } @@ -366,12 +474,44 @@ int main(int argc, char **argv) n1 = find_nearest(codebook2, nb_entries, &delta[i*ndim0], ndim0, NULL); for (j=0;j<ndim0;j++) { - delta[i*ndim0+j] = delta[i*ndim0+j] - codebook2[n1*ndim0+j]; + qdata[i*ndim+j+1] += codebook2[n1*ndim0+j]; + //delta[i*ndim0+j] = delta[i*ndim0+j] - codebook2[n1*ndim0+j]; + delta[i*ndim0+j] = qdata[i*ndim+j+1] - data[i*ndim+j+1]; err += delta[i*ndim0+j]*delta[i*ndim0+j]; } } fprintf(stderr, "Cepstrum RMS error after stage 2: %f)\n", sqrt(err/nb_vectors/ndim)); - + + for (i=0;i<nb_vectors-4;i++) + { + for (j=0;j<ndim0;j++) + multi_data[MULTI*i*ndim+j] = data[(i+1)*ndim+j+1] - .5*(qdata[i*ndim+j+1]+qdata[(i+2)*ndim+j+1]); + for (j=0;j<ndim0;j++) + multi_data[(MULTI*i+1)*ndim+j] = data[(i+1)*ndim+j+1] - .5*(qdata[i*ndim+j+1]+qdata[(i+2)*ndim+j+1]); + for (j=0;j<ndim0;j++) + multi_data[(MULTI*i+2)*ndim+j] = data[(i+1)*ndim+j+1] - qdata[i*ndim+j+1]; + for (j=0;j<ndim0;j++) + multi_data[(MULTI*i+3)*ndim+j] = data[(i+1)*ndim+j+1] - qdata[(i+2)*ndim+j+1]; + } + + for (i=0;i<nb_vectors-4;i++) + { + for (j=0;j<ndim0;j++) + multi_data2[MULTI*i*ndim+j] = data[(i+2)*ndim+j+1] - .5*(qdata[i*ndim+j+1]+qdata[(i+4)*ndim+j+1]); + for (j=0;j<ndim0;j++) + multi_data2[(MULTI*i+1)*ndim+j] = data[(i+2)*ndim+j+1] - .5*(qdata[i*ndim+j+1]+qdata[(i+4)*ndim+j+1]); + for (j=0;j<ndim0;j++) + multi_data2[(MULTI*i+2)*ndim+j] = data[(i+2)*ndim+j+1] - qdata[i*ndim+j+1]; + for (j=0;j<ndim0;j++) + multi_data2[(MULTI*i+3)*ndim+j] = data[(i+2)*ndim+j+1] - qdata[(i+4)*ndim+j+1]; + } + + vq_train_multi(multi_data2, nb_vectors-4, codebook_diff4, nb_entries, ndim); + + printf("done\n"); + vq_train_multi(multi_data, nb_vectors-4, codebook_diff2, 64, ndim); + + fout = fopen("ceps_codebooks.c", "w"); fprintf(fout, "/* This file is automatically generated */\n\n"); fprintf(fout, "float ceps_codebook1[%d*%d] = {\n",nb_entries, ndim0); @@ -385,7 +525,6 @@ int main(int argc, char **argv) fprintf(fout, "};\n\n"); fprintf(fout, "float ceps_codebook2[%d*%d] = {\n",nb_entries, ndim0); - for (i=0;i<nb_entries;i++) { for (j=0;j<ndim0;j++) @@ -393,6 +532,24 @@ int main(int argc, char **argv) fprintf(fout, "\n"); } fprintf(fout, "};\n\n"); + + fprintf(fout, "float ceps_codebook_diff4[%d*%d] = {\n",nb_entries, ndim); + for (i=0;i<nb_entries;i++) + { + for (j=0;j<ndim;j++) + fprintf(fout, "%f, ", codebook_diff4[i*ndim+j]); + fprintf(fout, "\n"); + } + fprintf(fout, "};\n\n"); + + fprintf(fout, "float ceps_codebook_diff2[%d*%d] = {\n",nb_entries2, ndim); + for (i=0;i<nb_entries2;i++) + { + for (j=0;j<ndim;j++) + fprintf(fout, "%f, ", codebook_diff2[i*ndim+j]); + fprintf(fout, "\n"); + } + fprintf(fout, "};\n\n"); fclose(fout); return 0;