From cbb9f535c2e702681f884a95987ff3cfafb3b08e Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@amazon.com>
Date: Fri, 16 Jun 2023 17:06:55 -0400
Subject: [PATCH] Add support the DRED in CELT

---
 Makefile.am         |   8 +--
 celt/celt.h         |   8 +++
 celt/celt_decoder.c | 165 ++++++++++++++++++++++++++++++++++++++++++--
 src/opus_decoder.c  |   8 ++-
 4 files changed, 179 insertions(+), 10 deletions(-)

diff --git a/Makefile.am b/Makefile.am
index d7c955941..a3ac442ca 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -197,7 +197,7 @@ celt_tests_test_unit_cwrs32_SOURCES = celt/tests/test_unit_cwrs32.c
 celt_tests_test_unit_cwrs32_LDADD = $(LIBM)
 
 celt_tests_test_unit_dft_SOURCES = celt/tests/test_unit_dft.c
-celt_tests_test_unit_dft_LDADD = $(CELT_OBJ) $(NE10_LIBS) $(LIBM)
+celt_tests_test_unit_dft_LDADD = $(CELT_OBJ) $(LPCNET_OBJ) $(NE10_LIBS) $(LIBM)
 if OPUS_ARM_EXTERNAL_ASM
 celt_tests_test_unit_dft_LDADD += libarmasm.la
 endif
@@ -209,19 +209,19 @@ celt_tests_test_unit_laplace_SOURCES = celt/tests/test_unit_laplace.c
 celt_tests_test_unit_laplace_LDADD = $(LIBM)
 
 celt_tests_test_unit_mathops_SOURCES = celt/tests/test_unit_mathops.c
-celt_tests_test_unit_mathops_LDADD = $(CELT_OBJ) $(NE10_LIBS) $(LIBM)
+celt_tests_test_unit_mathops_LDADD = $(CELT_OBJ) $(LPCNET_OBJ) $(NE10_LIBS) $(LIBM)
 if OPUS_ARM_EXTERNAL_ASM
 celt_tests_test_unit_mathops_LDADD += libarmasm.la
 endif
 
 celt_tests_test_unit_mdct_SOURCES = celt/tests/test_unit_mdct.c
-celt_tests_test_unit_mdct_LDADD = $(CELT_OBJ) $(NE10_LIBS) $(LIBM)
+celt_tests_test_unit_mdct_LDADD = $(CELT_OBJ) $(LPCNET_OBJ) $(NE10_LIBS) $(LIBM)
 if OPUS_ARM_EXTERNAL_ASM
 celt_tests_test_unit_mdct_LDADD += libarmasm.la
 endif
 
 celt_tests_test_unit_rotation_SOURCES = celt/tests/test_unit_rotation.c
-celt_tests_test_unit_rotation_LDADD = $(CELT_OBJ) $(NE10_LIBS) $(LIBM)
+celt_tests_test_unit_rotation_LDADD = $(CELT_OBJ) $(LPCNET_OBJ) $(NE10_LIBS) $(LIBM)
 if OPUS_ARM_EXTERNAL_ASM
 celt_tests_test_unit_rotation_LDADD += libarmasm.la
 endif
diff --git a/celt/celt.h b/celt/celt.h
index 24b6b2b52..2f5197bbe 100644
--- a/celt/celt.h
+++ b/celt/celt.h
@@ -41,6 +41,7 @@
 #include "entenc.h"
 #include "entdec.h"
 #include "arch.h"
+#include "lpcnet.h"
 
 #ifdef __cplusplus
 extern "C" {
@@ -149,6 +150,13 @@ int celt_decoder_get_size(int channels);
 
 int celt_decoder_init(CELTDecoder *st, opus_int32 sampling_rate, int channels);
 
+int celt_decode_with_ec_dred(CELTDecoder * OPUS_RESTRICT st, const unsigned char *data,
+      int len, opus_val16 * OPUS_RESTRICT pcm, int frame_size, ec_dec *dec, int accum
+#ifdef ENABLE_NEURAL_FEC
+      ,LPCNetPLCState *lpcnet
+#endif
+      );
+
 int celt_decode_with_ec(OpusCustomDecoder * OPUS_RESTRICT st, const unsigned char *data,
       int len, opus_val16 * OPUS_RESTRICT pcm, int frame_size, ec_dec *dec, int accum);
 
diff --git a/celt/celt_decoder.c b/celt/celt_decoder.c
index 9d2552ac0..197c94528 100644
--- a/celt/celt_decoder.c
+++ b/celt/celt_decoder.c
@@ -50,6 +50,8 @@
 #include <stdarg.h>
 #include "celt_lpc.h"
 #include "vq.h"
+#include "lpcnet.h"
+#include "lpcnet/src/lpcnet_private.h"
 
 /* The maximum pitch lag to allow in the pitch-based PLC. It's possible to save
    CPU time in the PLC pitch search by making this smaller than MAX_PERIOD. The
@@ -69,6 +71,9 @@
 /**********************************************************************/
 #define DECODE_BUFFER_SIZE 2048
 
+#define PLC_UPDATE_FRAMES 4
+#define PLC_UPDATE_SAMPLES (PLC_UPDATE_FRAMES*FRAME_SIZE)
+
 /** Decoder state
  @brief Decoder state
  */
@@ -102,6 +107,12 @@ struct OpusCustomDecoder {
 
    celt_sig preemph_memD[2];
 
+#ifdef ENABLE_NEURAL_FEC
+   opus_int16 plc_pcm[PLC_UPDATE_SAMPLES];
+   int plc_fill;
+   float plc_preemphasis_mem;
+#endif
+
    celt_sig _decode_mem[1]; /* Size = channels*(DECODE_BUFFER_SIZE+mode->overlap) */
    /* opus_val16 lpc[],  Size = channels*CELT_LPC_ORDER */
    /* opus_val16 oldEBands[], Size = 2*mode->nbEBands */
@@ -537,7 +548,62 @@ static void prefilter_and_fold(CELTDecoder * OPUS_RESTRICT st, int N)
    } while (++c<CC);
 }
 
-static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM)
+#ifdef ENABLE_NEURAL_FEC
+
+#define SINC_ORDER 48
+/* h=cos(pi/2*abs(sin([-24:24]/48*pi*23./24)).^2);
+   b=sinc([-24:24]/3*1.02).*h;
+   b=b/sum(b); */
+static const float sinc_filter[SINC_ORDER+1] = {
+    4.2931e-05f, -0.000190293f, -0.000816132f, -0.000637162f, 0.00141662f, 0.00354764f, 0.00184368f, -0.00428274f,
+    -0.00856105f, -0.0034003f, 0.00930201f, 0.0159616f, 0.00489785f, -0.0169649f, -0.0259484f, -0.00596856f,
+    0.0286551f, 0.0405872f, 0.00649994f, -0.0509284f, -0.0716655f, -0.00665212f,  0.134336f,  0.278927f,
+    0.339995f,  0.278927f,  0.134336f, -0.00665212f, -0.0716655f, -0.0509284f, 0.00649994f, 0.0405872f,
+    0.0286551f, -0.00596856f, -0.0259484f, -0.0169649f, 0.00489785f, 0.0159616f, 0.00930201f, -0.0034003f,
+    -0.00856105f, -0.00428274f, 0.00184368f, 0.00354764f, 0.00141662f, -0.000637162f, -0.000816132f, -0.000190293f,
+    4.2931e-05f
+};
+
+void update_plc_state(LPCNetPLCState *lpcnet, celt_sig *decode_mem[2], int CC)
+{
+   int i;
+   int tmp_read_post, tmp_fec_skip;
+   int offset;
+   celt_sig buf48k[DECODE_BUFFER_SIZE];
+   opus_int16 buf16k[PLC_UPDATE_SAMPLES];
+   if (CC == 1) OPUS_COPY(buf48k, decode_mem[0], DECODE_BUFFER_SIZE);
+   else {
+      for (i=0;i<DECODE_BUFFER_SIZE;i++) {
+         buf48k[i] = .5*(decode_mem[0][i] + decode_mem[1][i]);
+      }
+   }
+   /* Down-sample the last 40 ms. */
+   for (i=1;i<DECODE_BUFFER_SIZE;i++) buf48k[i] += PREEMPHASIS*buf48k[i-1];
+   offset = DECODE_BUFFER_SIZE-SINC_ORDER-1 - 3*(PLC_UPDATE_SAMPLES-1);
+   celt_assert(3*(PLC_UPDATE_SAMPLES-1) + SINC_ORDER + offset == DECODE_BUFFER_SIZE-1);
+   for (i=0;i<PLC_UPDATE_SAMPLES;i++) {
+      int j;
+      float sum = 0;
+      for (j=0;j<SINC_ORDER+1;j++) {
+         sum += buf48k[3*i + j + offset]*sinc_filter[j];
+      }
+      buf16k[i] = sum;
+   }
+   tmp_read_post = lpcnet->fec_read_pos;
+   tmp_fec_skip = lpcnet->fec_skip;
+   for (i=0;i<PLC_UPDATE_FRAMES;i++) {
+      lpcnet_plc_update(lpcnet, &buf16k[FRAME_SIZE*i]);
+   }
+   lpcnet->fec_read_pos = tmp_read_post;
+   lpcnet->fec_skip = tmp_fec_skip;
+}
+#endif
+
+static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM
+#ifdef ENABLE_NEURAL_FEC
+      ,LPCNetPLCState *lpcnet
+#endif
+      )
 {
    int c;
    int i;
@@ -572,7 +638,11 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM)
 
    loss_duration = st->loss_duration;
    start = st->start;
+#ifdef ENABLE_NEURAL_FEC
+   noise_based = start != 0 || (lpcnet->fec_fill_pos == 0 && (st->skip_plc || loss_duration >= 80));
+#else
    noise_based = loss_duration >= 40 || start != 0 || st->skip_plc;
+#endif
    if (noise_based)
    {
       /* Noise-based PLC/CNG */
@@ -645,6 +715,9 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM)
 
       if (loss_duration == 0)
       {
+#ifdef ENABLE_NEURAL_FEC
+         update_plc_state(lpcnet, decode_mem, C);
+#endif
          st->last_pitch_index = pitch_index = celt_plc_pitch_search(decode_mem, C, st->arch);
       } else {
          pitch_index = st->last_pitch_index;
@@ -834,6 +907,67 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM)
          }
 
       } while (++c<C);
+
+#ifdef ENABLE_NEURAL_FEC
+      {
+         float overlap_mem;
+         int samples_needed16k;
+         int ignored = 0;
+         celt_sig *buf;
+         VARDECL(float, buf_copy);
+         buf = decode_mem[0];
+         ALLOC(buf_copy, C*overlap, float);
+         c=0; do {
+            OPUS_COPY(buf_copy+c*overlap, &decode_mem[c][DECODE_BUFFER_SIZE-N], overlap);
+         } while (++c<C);
+
+         /* Need enough samples from the PLC to cover the frame size, resampling delay,
+            and the overlap at the end. */
+         samples_needed16k = (N+SINC_ORDER+overlap)/3;
+         if (loss_duration == 0) {
+            /* Ignore the first 8 samples due to the update resampling delay. */
+            ignored = SINC_ORDER/6;
+            samples_needed16k += ignored;
+            st->plc_fill = 0;
+         }
+         while (st->plc_fill < samples_needed16k) {
+            lpcnet_plc_conceal(lpcnet, &st->plc_pcm[st->plc_fill]);
+            st->plc_fill += FRAME_SIZE;
+         }
+         /* Resample to 48 kHz. */
+         for (i=0;i<(N+overlap)/3;i++) {
+            int j;
+            float sum;
+            for (sum=0, j=0;j<17;j++) sum += 3*st->plc_pcm[ignored+i+j]*sinc_filter[3*j];
+            buf[DECODE_BUFFER_SIZE-N+3*i] = sum;
+            for (sum=0, j=0;j<16;j++) sum += 3*st->plc_pcm[ignored+i+j+1]*sinc_filter[3*j+2];
+            buf[DECODE_BUFFER_SIZE-N+3*i+1] = sum;
+            for (sum=0, j=0;j<16;j++) sum += 3*st->plc_pcm[ignored+i+j+1]*sinc_filter[3*j+1];
+            buf[DECODE_BUFFER_SIZE-N+3*i+1] = sum;
+         }
+         OPUS_MOVE(st->plc_pcm, &st->plc_pcm[N/3+ignored], st->plc_fill-N/3-ignored);
+         st->plc_fill -= N/3+ignored;
+         for (i=0;i<N;i++) {
+            float tmp = buf[DECODE_BUFFER_SIZE-N+i];
+            buf[DECODE_BUFFER_SIZE-N+i] -= PREEMPHASIS*st->plc_preemphasis_mem;
+            st->plc_preemphasis_mem = tmp;
+         }
+         overlap_mem = st->plc_preemphasis_mem;
+         for (i=0;i<overlap;i++) {
+            float tmp = buf[DECODE_BUFFER_SIZE+i];
+            buf[DECODE_BUFFER_SIZE+i] -= PREEMPHASIS*overlap_mem;
+            overlap_mem = tmp;
+         }
+         /* For now, we just do mono PLC. */
+         if (C==2) OPUS_COPY(decode_mem[1], decode_mem[0], DECODE_BUFFER_SIZE+overlap);
+         c=0; do {
+            /* Cross-fade with 48-kHz non-neural PLC for the first 2.5 ms to avoid a discontinuity. */
+            if (loss_duration == 0) {
+               for (i=0;i<overlap;i++) decode_mem[c][DECODE_BUFFER_SIZE-N+i] = (1-window[i])*buf_copy[c*overlap+i] + (window[i])*decode_mem[c][DECODE_BUFFER_SIZE-N+i];
+            }
+         } while (++c<C);
+      }
+#endif
       st->prefilter_and_fold = 1;
    }
 
@@ -843,8 +977,12 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM)
    RESTORE_STACK;
 }
 
-int celt_decode_with_ec(CELTDecoder * OPUS_RESTRICT st, const unsigned char *data,
-      int len, opus_val16 * OPUS_RESTRICT pcm, int frame_size, ec_dec *dec, int accum)
+int celt_decode_with_ec_dred(CELTDecoder * OPUS_RESTRICT st, const unsigned char *data,
+      int len, opus_val16 * OPUS_RESTRICT pcm, int frame_size, ec_dec *dec, int accum
+#ifdef ENABLE_NEURAL_FEC
+      ,LPCNetPLCState *lpcnet
+#endif
+      )
 {
    int c, i, N;
    int spread_decision;
@@ -961,11 +1099,21 @@ int celt_decode_with_ec(CELTDecoder * OPUS_RESTRICT st, const unsigned char *dat
 
    if (data == NULL || len<=1)
    {
-      celt_decode_lost(st, N, LM);
+      celt_decode_lost(st, N, LM
+#ifdef ENABLE_NEURAL_FEC
+      , lpcnet
+#endif
+                      );
       deemphasis(out_syn, pcm, N, CC, st->downsample, mode->preemph, st->preemph_memD, accum);
       RESTORE_STACK;
       return frame_size/st->downsample;
    }
+#ifdef ENABLE_NEURAL_FEC
+   else {
+      /* FIXME: This is a bit of a hack just to make sure opus_decode_native() knows we're no longer in PLC. */
+      lpcnet->blend = 0;
+   }
+#endif
 
    /* Check if there are at least two packets received consecutively before
     * turning on the pitch-based PLC */
@@ -1210,6 +1358,15 @@ int celt_decode_with_ec(CELTDecoder * OPUS_RESTRICT st, const unsigned char *dat
    return frame_size/st->downsample;
 }
 
+int celt_decode_with_ec(CELTDecoder * OPUS_RESTRICT st, const unsigned char *data,
+      int len, opus_val16 * OPUS_RESTRICT pcm, int frame_size, ec_dec *dec, int accum)
+{
+   return celt_decode_with_ec_dred(st, data, len, pcm, frame_size, dec, accum
+#ifdef ENABLE_NEURAL_FEC
+       , NULL
+#endif
+       );
+}
 
 #ifdef CUSTOM_MODES
 
diff --git a/src/opus_decoder.c b/src/opus_decoder.c
index 965e9f0a8..c7d62e5f7 100644
--- a/src/opus_decoder.c
+++ b/src/opus_decoder.c
@@ -532,8 +532,12 @@ static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
       if (mode != st->prev_mode && st->prev_mode > 0 && !st->prev_redundancy)
          MUST_SUCCEED(celt_decoder_ctl(celt_dec, OPUS_RESET_STATE));
       /* Decode CELT */
-      celt_ret = celt_decode_with_ec(celt_dec, decode_fec ? NULL : data,
-                                     len, pcm, celt_frame_size, &dec, celt_accum);
+      celt_ret = celt_decode_with_ec_dred(celt_dec, decode_fec ? NULL : data,
+                                     len, pcm, celt_frame_size, &dec, celt_accum
+#ifdef ENABLE_NEURAL_FEC
+                                     , &st->lpcnet
+#endif
+                                     );
    } else {
       unsigned char silence[2] = {0xFF, 0xFF};
       if (!celt_accum)
-- 
GitLab