From f0ec990dba011b2862c60a8903954d782cb92d19 Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@amazon.com>
Date: Mon, 2 Oct 2023 02:23:41 -0400
Subject: [PATCH] Switching to neural pitch estimator

Remove old pitch estimator and retrain all models
---
 autogen.sh                         |   2 +-
 dnn/lpcnet_enc.c                   | 113 +++--------------------------
 dnn/lpcnet_private.h               |  10 +--
 dnn/nnet.c                         |  27 ++++---
 dnn/nnet.h                         |  10 ++-
 dnn/parse_lpcnet_weights.c         |  25 +++++++
 dnn/pitchdnn.c                     |  31 +++++---
 dnn/pitchdnn.h                     |   6 +-
 dnn/torch/neural-pitch/training.py |   6 ++
 dnn/torch/rdovae/rdovae/rdovae.py  |   2 +-
 lpcnet_headers.mk                  |   4 +-
 lpcnet_sources.mk                  |   2 +
 silk/dred_config.h                 |   2 +-
 13 files changed, 103 insertions(+), 137 deletions(-)

diff --git a/autogen.sh b/autogen.sh
index 7f9f83bcb..b0ab990fb 100755
--- a/autogen.sh
+++ b/autogen.sh
@@ -9,7 +9,7 @@ set -e
 srcdir=`dirname $0`
 test -n "$srcdir" && cd "$srcdir"
 
-dnn/download_model.sh 27663d3
+dnn/download_model.sh da7f4c6
 
 echo "Updating build configuration files, please wait...."
 
diff --git a/dnn/lpcnet_enc.c b/dnn/lpcnet_enc.c
index 7205ddb5c..c51a68524 100644
--- a/dnn/lpcnet_enc.c
+++ b/dnn/lpcnet_enc.c
@@ -51,6 +51,7 @@ int lpcnet_encoder_get_size() {
 int lpcnet_encoder_init(LPCNetEncState *st) {
   memset(st, 0, sizeof(*st));
   st->exc_mem = lin2ulaw(0.f);
+  pitchdnn_init(&st->pitchdnn);
   return 0;
 }
 
@@ -98,7 +99,6 @@ void compute_frame_features(LPCNetEncState *st, const float *in) {
   float Ex[NB_BANDS];
   float xcorr[PITCH_MAX_PERIOD];
   float ener0;
-  int sub;
   float ener;
   /* [b,a]=ellip(2, 2, 20, 1200/8000); */
   static const float lp_b[2] = {-0.84946f, 1.f};
@@ -163,110 +163,21 @@ void compute_frame_features(LPCNetEncState *st, const float *in) {
     }
     /*printf("\n");*/
   }
-  /* Cross-correlation on half-frames. */
-  for (sub=0;sub<2;sub++) {
-    int off = sub*FRAME_SIZE/2;
-    double ener1;
-    celt_pitch_xcorr(&st->exc_buf[PITCH_MAX_PERIOD+off], st->exc_buf+off, xcorr, FRAME_SIZE/2, PITCH_MAX_PERIOD, st->arch);
-    ener0 = celt_inner_prod_c(&st->exc_buf[PITCH_MAX_PERIOD+off], &st->exc_buf[PITCH_MAX_PERIOD+off], FRAME_SIZE/2);
-    ener1 = celt_inner_prod_c(&st->exc_buf[off], &st->exc_buf[off], FRAME_SIZE/2-1);
-    st->frame_weight[sub] = ener0;
-    /*printf("%f\n", st->frame_weight[sub]);*/
-    for (i=0;i<PITCH_MAX_PERIOD;i++) {
-      ener1 += st->exc_buf[i+off+FRAME_SIZE/2-1]*st->exc_buf[i+off+FRAME_SIZE/2-1];
-      ener = 1 + ener0 + ener1;
-      st->xc[sub][i] = 2*xcorr[i] / ener;
-      ener1 -= st->exc_buf[i+off]*st->exc_buf[i+off];
-    }
-    if (1) {
-      /* Upsample correlation by 3x and keep the max. */
-      float interpolated[PITCH_MAX_PERIOD]={0};
-      /* interp=sinc([-3:3]+1/3).*(.5+.5*cos(pi*[-3:3]/4.5)); interp=interp/sum(interp); */
-      static const float interp[7] = {0.026184f, -0.098339f, 0.369938f, 0.837891f, -0.184969f, 0.070242f, -0.020947f};
-      for (i=4;i<PITCH_MAX_PERIOD-4;i++) {
-        float val1=0, val2=0;
-        int j;
-        for (j=0;j<7;j++) {
-          val1 += st->xc[sub][i-3+j]*interp[j];
-          val2 += st->xc[sub][i+3-j]*interp[j];
-          interpolated[i] = MAX16(st->xc[sub][i], MAX16(val1, val2));
-        }
-      }
-      for (i=4;i<PITCH_MAX_PERIOD-4;i++) {
-        st->xc[sub][i] = interpolated[i];
-      }
-    }
-#if 0
-    for (i=0;i<PITCH_MAX_PERIOD;i++)
-      printf("%f ", st->xc[sub][i]);
-    printf("\n");
-#endif
-  }
+  st->dnn_pitch = compute_pitchdnn(&st->pitchdnn, st->if_features, st->xcorr_features);
 }
 
 void process_single_frame(LPCNetEncState *st, FILE *ffeat) {
-  int i;
-  int sub;
-  int best_i;
-  int best[4];
-  int pitch_prev[2][PITCH_MAX_PERIOD];
   float frame_corr;
-  float frame_weight_sum = 1e-15f;
-  for(sub=0;sub<2;sub++) frame_weight_sum += st->frame_weight[sub];
-  for(sub=0;sub<2;sub++) st->frame_weight[sub] *= (2.f/frame_weight_sum);
-  for(sub=0;sub<2;sub++) {
-    float max_path_all = -1e15f;
-    best_i = 0;
-    for (i=0;i<PITCH_MAX_PERIOD-2*PITCH_MIN_PERIOD;i++) {
-      float xc_half = MAX16(MAX16(st->xc[sub][(PITCH_MAX_PERIOD+i)/2], st->xc[sub][(PITCH_MAX_PERIOD+i+2)/2]), st->xc[sub][(PITCH_MAX_PERIOD+i-1)/2]);
-      if (st->xc[sub][i] < xc_half*1.1f) st->xc[sub][i] *= .8f;
-    }
-    for (i=0;i<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;i++) {
-      int j;
-      float max_prev;
-      max_prev = st->pitch_max_path_all - 6.f;
-      pitch_prev[sub][i] = st->best_i;
-      for (j=IMAX(-4, -i);j<=4 && i+j<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;j++) {
-        if (st->pitch_max_path[0][i+j] - .02f*abs(j)*abs(j) > max_prev) {
-          max_prev = st->pitch_max_path[0][i+j] - .02f*abs(j)*abs(j);
-          pitch_prev[sub][i] = i+j;
-        }
-      }
-      st->pitch_max_path[1][i] = max_prev + st->frame_weight[sub]*st->xc[sub][i];
-      if (st->pitch_max_path[1][i] > max_path_all) {
-        max_path_all = st->pitch_max_path[1][i];
-        best_i = i;
-      }
-    }
-    /* Renormalize. */
-    for (i=0;i<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;i++) st->pitch_max_path[1][i] -= max_path_all;
-    /*for (i=0;i<PITCH_MAX_PERIOD-PITCH_MIN_PERIOD;i++) printf("%f ", st->pitch_max_path[1][i]);
-    printf("\n");*/
-    OPUS_COPY(&st->pitch_max_path[0][0], &st->pitch_max_path[1][0], PITCH_MAX_PERIOD);
-    st->pitch_max_path_all = max_path_all;
-    st->best_i = best_i;
-  }
-  best_i = st->best_i;
-  frame_corr = 0;
-  /* Backward pass. */
-  for (sub=1;sub>=0;sub--) {
-    best[2+sub] = PITCH_MAX_PERIOD-best_i;
-    frame_corr += st->frame_weight[sub]*st->xc[sub][best_i];
-    best_i = pitch_prev[sub][best_i];
-  }
-  frame_corr /= 2;
-  if (0) {
-    float xy, xx, yy;
-    int pitch = (best[2]+best[3])/2;
-    xx = celt_inner_prod_c(&st->lp_buf[PITCH_MAX_PERIOD], &st->lp_buf[PITCH_MAX_PERIOD], FRAME_SIZE);
-    yy = celt_inner_prod_c(&st->lp_buf[PITCH_MAX_PERIOD-pitch], &st->lp_buf[PITCH_MAX_PERIOD-pitch], FRAME_SIZE);
-    xy = celt_inner_prod_c(&st->lp_buf[PITCH_MAX_PERIOD], &st->lp_buf[PITCH_MAX_PERIOD-pitch], FRAME_SIZE);
-    //printf("%f %f\n", frame_corr, xy/sqrt(1e-15+xx*yy));
-    frame_corr = xy/sqrt(1+xx*yy);
-    //frame_corr = MAX32(0, xy/sqrt(1+xx*yy));
-    frame_corr = log(1.f+exp(5.f*frame_corr))/log(1+exp(5.f));
-  }
-  st->features[NB_BANDS] = .01f*(IMAX(66, IMIN(510, best[2]+best[3]))-200);
+  float xy, xx, yy;
+  /*int pitch = (best[2]+best[3])/2;*/
+  int pitch = (int)floor(.5+256./pow(2.f,((1./60.)*((st->dnn_pitch+1.5)*60))));
+  xx = celt_inner_prod_c(&st->lp_buf[PITCH_MAX_PERIOD], &st->lp_buf[PITCH_MAX_PERIOD], FRAME_SIZE);
+  yy = celt_inner_prod_c(&st->lp_buf[PITCH_MAX_PERIOD-pitch], &st->lp_buf[PITCH_MAX_PERIOD-pitch], FRAME_SIZE);
+  xy = celt_inner_prod_c(&st->lp_buf[PITCH_MAX_PERIOD], &st->lp_buf[PITCH_MAX_PERIOD-pitch], FRAME_SIZE);
+  /*printf("%f %f\n", frame_corr, xy/sqrt(1e-15+xx*yy));*/
+  frame_corr = xy/sqrt(1+xx*yy);
+  frame_corr = log(1.f+exp(5.f*frame_corr))/log(1+exp(5.f));
+  st->features[NB_BANDS] = st->dnn_pitch;
   st->features[NB_BANDS + 1] = frame_corr-.5f;
   if (ffeat) {
     fwrite(st->features, sizeof(float), NB_TOTAL_FEATURES, ffeat);
diff --git a/dnn/lpcnet_private.h b/dnn/lpcnet_private.h
index 597b487d6..c533cbaf9 100644
--- a/dnn/lpcnet_private.h
+++ b/dnn/lpcnet_private.h
@@ -7,9 +7,8 @@
 #include "nnet_data.h"
 #include "plc_data.h"
 #include "kiss99.h"
+#include "pitchdnn.h"
 
-#define PITCH_MIN_PERIOD 32
-#define PITCH_MAX_PERIOD 256
 
 #define PITCH_FRAME_SIZE 320
 #define PITCH_BUF_SIZE (PITCH_MAX_PERIOD+PITCH_FRAME_SIZE)
@@ -44,22 +43,19 @@ struct LPCNetState {
 };
 
 struct LPCNetEncState{
+  PitchDNNState pitchdnn;
   int arch;
   float analysis_mem[OVERLAP_SIZE];
   float mem_preemph;
   kiss_fft_cpx prev_if[PITCH_IF_MAX_FREQ];
   float if_features[PITCH_IF_FEATURES];
   float xcorr_features[PITCH_MAX_PERIOD - PITCH_MIN_PERIOD];
+  float dnn_pitch;
   float pitch_mem[LPC_ORDER];
   float pitch_filt;
-  float xc[2][PITCH_MAX_PERIOD+1];
-  float frame_weight[2];
   float exc_buf[PITCH_BUF_SIZE];
   float lp_buf[PITCH_BUF_SIZE];
   float lp_mem[4];
-  float pitch_max_path[2][PITCH_MAX_PERIOD];
-  float pitch_max_path_all;
-  int best_i;
   float last_gain;
   int last_period;
   float lpc[LPC_ORDER];
diff --git a/dnn/nnet.c b/dnn/nnet.c
index d5ef904ec..0587ff5b2 100644
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -408,22 +408,22 @@ void compute_conv1d(const Conv1DLayer *layer, float *output, float *mem, const f
    storing the output as [ out_channels x len2 ].
    We assume that the output dimension along the ksize1 axis is 1,
    i.e. processing one frame at a time. */
-void conv2d_float(float *out, const float *weights, int in_channels, int out_channels, int ktime, int kheight, const float *in, int len2)
+void conv2d_float(float *out, const float *weights, int in_channels, int out_channels, int ktime, int kheight, const float *in, int height, int hstride)
 {
    int i;
    int in_stride;
-   in_stride = len2+kheight-1;
-   OPUS_CLEAR(out, out_channels*len2);
+   in_stride = height+kheight-1;
    for (i=0;i<out_channels;i++) {
       int m;
+      OPUS_CLEAR(&out[i*hstride], height);
       for (m=0;m<in_channels;m++) {
          int t;
          for (t=0;t<ktime;t++) {
             int h;
             for (h=0;h<kheight;h++) {
                int j;
-               for (j=0;j<len2;j++) {
-                  out[i*len2 + j] += weights[i*in_channels*ktime*kheight + m*ktime*kheight + t*kheight + h] *
+               for (j=0;j<height;j++) {
+                  out[i*hstride + j] += weights[i*in_channels*ktime*kheight + m*ktime*kheight + t*kheight + h] *
                                      in[t*in_channels*in_stride + m*in_stride + j + h];
                }
             }
@@ -432,26 +432,31 @@ void conv2d_float(float *out, const float *weights, int in_channels, int out_cha
    }
 }
 
-#define MAX_CONV2D_INPUTS 2048
+#define MAX_CONV2D_INPUTS 8192
 
-void compute_conv2d(const Conv2dLayer *conv, float *out, float *mem, const float *in, int len2, int activation)
+void compute_conv2d(const Conv2dLayer *conv, float *out, float *mem, const float *in, int height, int hstride, int activation)
 {
    int i;
    const float *bias;
    float in_buf[MAX_CONV2D_INPUTS];
    int time_stride;
    celt_assert(in != out);
-   time_stride = conv->in_channels*(len2+conv->kheight);
+   time_stride = conv->in_channels*(height+conv->kheight-1);
    celt_assert(conv->ktime*time_stride <= MAX_CONV2D_INPUTS);
    OPUS_COPY(in_buf, mem, (conv->ktime-1)*time_stride);
    OPUS_COPY(&in_buf[(conv->ktime-1)*time_stride], in, time_stride);
    OPUS_COPY(mem, &in_buf[time_stride], (conv->ktime-1)*time_stride);
    bias = conv->bias;
-   conv2d_float(out, conv->float_weights, conv->in_channels, conv->out_channels, conv->ktime, conv->kheight, in_buf, len2);
+   conv2d_float(out, conv->float_weights, conv->in_channels, conv->out_channels, conv->ktime, conv->kheight, in_buf, height, hstride);
    if (bias != NULL) {
-      for (i=0;i<conv->out_channels*len2;i++) out[i] += bias[i];
+     for (i=0;i<conv->out_channels;i++) {
+       int j;
+       for (j=0;j<height;j++) out[i*hstride+j] += bias[i];
+     }
+   }
+   for (i=0;i<conv->out_channels;i++) {
+     compute_activation(&out[i*hstride], &out[i*hstride], height, activation);
    }
-   compute_activation(out, out, conv->out_channels*len2, activation);
 }
 
 
diff --git a/dnn/nnet.h b/dnn/nnet.h
index 9ed20b028..c379a90f7 100644
--- a/dnn/nnet.h
+++ b/dnn/nnet.h
@@ -189,6 +189,14 @@ int linear_init(LinearLayer *layer, const WeightArray *arrays,
   int nb_inputs,
   int nb_outputs);
 
+int conv2d_init(Conv2dLayer *layer, const WeightArray *arrays,
+  const char *bias,
+  const char *float_weights,
+  int in_channels,
+  int out_channels,
+  int ktime,
+  int kheight);
+
 int mdense_init(MDenseLayer *layer, const WeightArray *arrays,
   const char *bias,
   const char *input_weights,
@@ -234,7 +242,7 @@ int conv1d_init(Conv1DLayer *layer, const WeightArray *arrays,
   int nb_neurons,
   int activation);
 
-void compute_conv2d(const Conv2dLayer *conv, float *out, float *mem, const float *in, int len2, int activation);
+void compute_conv2d(const Conv2dLayer *conv, float *out, float *mem, const float *in, int height, int hstride, int activation);
 
 int embedding_init(EmbeddingLayer *layer, const WeightArray *arrays,
   const char *embedding_weights,
diff --git a/dnn/parse_lpcnet_weights.c b/dnn/parse_lpcnet_weights.c
index 7413099b6..9805ec8c7 100644
--- a/dnn/parse_lpcnet_weights.c
+++ b/dnn/parse_lpcnet_weights.c
@@ -272,6 +272,31 @@ int conv1d_init(Conv1DLayer *layer, const WeightArray *arrays,
   return 0;
 }
 
+int conv2d_init(Conv2dLayer *layer, const WeightArray *arrays,
+  const char *bias,
+  const char *float_weights,
+  int in_channels,
+  int out_channels,
+  int ktime,
+  int kheight)
+{
+  int err;
+  layer->bias = NULL;
+  layer->float_weights = NULL;
+  if (bias != NULL) {
+    if ((layer->bias = find_array_check(arrays, bias, out_channels*sizeof(layer->bias[0]))) == NULL) return 1;
+  }
+  if (float_weights != NULL) {
+    layer->float_weights = opt_array_check(arrays, float_weights, in_channels*out_channels*ktime*kheight*sizeof(layer->float_weights[0]), &err);
+    if (err) return 1;
+  }
+  layer->in_channels = in_channels;
+  layer->out_channels = out_channels;
+  layer->ktime = ktime;
+  layer->kheight = kheight;
+  return 0;
+}
+
 int embedding_init(EmbeddingLayer *layer, const WeightArray *arrays,
   const char *embedding_weights,
   int nb_inputs,
diff --git a/dnn/pitchdnn.c b/dnn/pitchdnn.c
index 5a35936ce..02c674442 100644
--- a/dnn/pitchdnn.c
+++ b/dnn/pitchdnn.c
@@ -9,7 +9,7 @@
 #include "lpcnet_private.h"
 
 
-int compute_pitchdnn(
+float compute_pitchdnn(
     PitchDNNState *st,
     const float *if_features,
     const float *xcorr_features
@@ -18,35 +18,40 @@ int compute_pitchdnn(
   float if1_out[DENSE_IF_UPSAMPLER_1_OUT_SIZE];
   float downsampler_in[NB_XCORR_FEATURES + DENSE_IF_UPSAMPLER_2_OUT_SIZE];
   float downsampler_out[DENSE_DOWNSAMPLER_OUT_SIZE];
-  float conv1_tmp1[NB_XCORR_FEATURES + 2] = {0};
-  float conv1_tmp2[NB_XCORR_FEATURES + 2] = {0};
+  float conv1_tmp1[(NB_XCORR_FEATURES + 2)*8] = {0};
+  float conv1_tmp2[(NB_XCORR_FEATURES + 2)*8] = {0};
   float output[DENSE_FINAL_UPSAMPLER_OUT_SIZE];
   int i;
   int pos=0;
   float maxval=-1;
+  float sum=0;
+  float count=0;
   PitchDNN *model = &st->model;
-
   /* IF */
   compute_generic_dense(&model->dense_if_upsampler_1, if1_out, if_features, ACTIVATION_TANH);
   compute_generic_dense(&model->dense_if_upsampler_2, &downsampler_in[NB_XCORR_FEATURES], if1_out, ACTIVATION_TANH);
-
   /* xcorr*/
   OPUS_COPY(&conv1_tmp1[1], xcorr_features, NB_XCORR_FEATURES);
-  compute_conv2d(&model->conv2d_1, &conv1_tmp2[1], st->xcorr_mem1, conv1_tmp1, NB_XCORR_FEATURES, ACTIVATION_TANH);
-  compute_conv2d(&model->conv2d_1, &conv1_tmp1[1], st->xcorr_mem2, conv1_tmp2, NB_XCORR_FEATURES, ACTIVATION_TANH);
-  compute_conv2d(&model->conv2d_1, downsampler_in, st->xcorr_mem3, conv1_tmp1, NB_XCORR_FEATURES, ACTIVATION_TANH);
+  compute_conv2d(&model->conv2d_1, &conv1_tmp2[1], st->xcorr_mem1, conv1_tmp1, NB_XCORR_FEATURES, NB_XCORR_FEATURES+2, ACTIVATION_TANH);
+  compute_conv2d(&model->conv2d_2, &conv1_tmp1[1], st->xcorr_mem2, conv1_tmp2, NB_XCORR_FEATURES, NB_XCORR_FEATURES+2, ACTIVATION_TANH);
+  compute_conv2d(&model->conv2d_3, downsampler_in, st->xcorr_mem3, conv1_tmp1, NB_XCORR_FEATURES, NB_XCORR_FEATURES, ACTIVATION_TANH);
 
   compute_generic_dense(&model->dense_downsampler, downsampler_out, downsampler_in, ACTIVATION_TANH);
   compute_generic_gru(&model->gru_1_input, &model->gru_1_recurrent, st->gru_state, downsampler_out);
   compute_generic_dense(&model->dense_final_upsampler, output, st->gru_state, ACTIVATION_LINEAR);
-
-  for (i=0;i<DENSE_FINAL_UPSAMPLER_OUT_SIZE;i++) {
+  for (i=0;i<180;i++) {
     if (output[i] > maxval) {
       pos = i;
       maxval = output[i];
     }
   }
-  return (1.f/60.f)*pos - 1.5;
+  for (i=IMAX(0, pos-2); i<=IMIN(179, pos+2); i++) {
+    float p = exp(output[i]);
+    sum += p*i;
+    count += p;
+  }
+  /*printf("%d %f\n", pos, sum/count);*/
+  return (1.f/60.f)*(sum/count) - 1.5;
   /*return 256.f/pow(2.f, (1.f/60.f)*i);*/
 }
 
@@ -55,7 +60,11 @@ void pitchdnn_init(PitchDNNState *st)
 {
   int ret;
   OPUS_CLEAR(st, 1);
+#ifndef USE_WEIGHTS_FILE
   ret = init_pitchdnn(&st->model, pitchdnn_arrays);
+#else
+  ret = 0;
+#endif
   celt_assert(ret == 0);
   /* FIXME: perform arch detection. */
 }
diff --git a/dnn/pitchdnn.h b/dnn/pitchdnn.h
index 74eacd77d..3799510da 100644
--- a/dnn/pitchdnn.h
+++ b/dnn/pitchdnn.h
@@ -5,7 +5,9 @@
 typedef struct PitchDNN PitchDNN;
 
 #include "pitchdnn_data.h"
-#include "lpcnet_private.h"
+
+#define PITCH_MIN_PERIOD 32
+#define PITCH_MAX_PERIOD 256
 
 #define NB_XCORR_FEATURES (PITCH_MAX_PERIOD-PITCH_MIN_PERIOD)
 
@@ -21,7 +23,7 @@ typedef struct {
 
 void pitchdnn_init(PitchDNNState *st);
 
-int compute_pitchdnn(
+float compute_pitchdnn(
     PitchDNNState *st,
     const float *if_features,
     const float *xcorr_features
diff --git a/dnn/torch/neural-pitch/training.py b/dnn/torch/neural-pitch/training.py
index e725e57c6..62da1351d 100644
--- a/dnn/torch/neural-pitch/training.py
+++ b/dnn/torch/neural-pitch/training.py
@@ -24,6 +24,7 @@ parser.add_argument('--learning_rate', type=float, help='Learning Rate',default
 parser.add_argument('--epochs', type=int, help='Number of training epochs',default = 50,required = False)
 parser.add_argument('--choice_cel', type=str, help='Choice of Cross Entropy Loss (default or robust)',choices=['default','robust'],default = 'default',required = False)
 parser.add_argument('--prefix', type=str, help="prefix for model export, default: model", default='model')
+parser.add_argument('--initial-checkpoint', type=str, help='initial checkpoint to start training from, default: None', default=None)
 
 
 args = parser.parse_args()
@@ -55,6 +56,11 @@ elif args.data_format == 'xcorr':
 else:
     pitch_nn = PitchDNN(3 * args.freq_keep - 2, 224, args.gru_dim, args.output_dim)
 
+if type(args.initial_checkpoint) != type(None):
+    checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
+    pitch_nn.load_state_dict(checkpoint['state_dict'], strict=False)
+
+
 dataset_training = PitchDNNDataloader(args.features,args.features_pitch,args.confidence_threshold,args.context,args.data_format)
 
 def loss_custom(logits,labels,confidence,choice = 'default',nmax = 192,q = 0.7):
diff --git a/dnn/torch/rdovae/rdovae/rdovae.py b/dnn/torch/rdovae/rdovae/rdovae.py
index b126d4c44..3552cf906 100644
--- a/dnn/torch/rdovae/rdovae/rdovae.py
+++ b/dnn/torch/rdovae/rdovae/rdovae.py
@@ -159,7 +159,7 @@ def distortion_loss(y_true, y_pred, rate_lambda=None):
         raise ValueError('distortion loss is designed to work with 20 features')
 
     ceps_error   = y_pred[..., :18] - y_true[..., :18]
-    pitch_error  = 2 * (y_pred[..., 18:19] - y_true[..., 18:19]) / (2 + y_true[..., 18:19])
+    pitch_error  = 2*(y_pred[..., 18:19] - y_true[..., 18:19])
     corr_error   = y_pred[..., 19:] - y_true[..., 19:]
     pitch_weight = torch.relu(y_true[..., 19:] + 0.5) ** 2
 
diff --git a/lpcnet_headers.mk b/lpcnet_headers.mk
index fc3fc84c5..178792105 100644
--- a/lpcnet_headers.mk
+++ b/lpcnet_headers.mk
@@ -24,4 +24,6 @@ dnn/dred_rdovae_enc.h \
 dnn/dred_rdovae_enc_data.h \
 dnn/dred_rdovae_dec.h \
 dnn/dred_rdovae_dec_data.h \
-dnn/dred_rdovae_stats_data.h
+dnn/dred_rdovae_stats_data.h \
+dnn/pitchdnn.h \
+dnn/pitchdnn_data.h
diff --git a/lpcnet_sources.mk b/lpcnet_sources.mk
index 61cbb1f1c..bb0ec5f18 100644
--- a/lpcnet_sources.mk
+++ b/lpcnet_sources.mk
@@ -18,6 +18,8 @@ dnn/dred_rdovae_enc_data.c \
 dnn/dred_rdovae_dec.c \
 dnn/dred_rdovae_dec_data.c \
 dnn/dred_rdovae_stats_data.c \
+dnn/pitchdnn.c \
+dnn/pitchdnn_data.c \
 silk/dred_encoder.c \
 silk/dred_coding.c \
 silk/dred_decoder.c
diff --git a/silk/dred_config.h b/silk/dred_config.h
index 03125f1b6..46f32a8d9 100644
--- a/silk/dred_config.h
+++ b/silk/dred_config.h
@@ -32,7 +32,7 @@
 #define DRED_EXTENSION_ID 126
 
 /* Remove these two completely once DRED gets an extension number assigned. */
-#define DRED_EXPERIMENTAL_VERSION 4
+#define DRED_EXPERIMENTAL_VERSION 5
 #define DRED_EXPERIMENTAL_BYTES 2
 
 
-- 
GitLab