From 8d43b185b2bb44841cb9b13e6dc1f4f30da0db87 Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@amazon.com>
Date: Wed, 18 Oct 2023 17:40:54 -0400
Subject: [PATCH] Support OPUS_SET_COMPLEXITY() on decoder side

Controls whether deep PLC is enabled
---
 celt/celt.h         |  2 +-
 celt/celt_decoder.c | 43 ++++++++++++++++++++++++++++++++-----------
 silk/PLC.c          | 19 +++++++++++++------
 silk/control.h      |  3 +++
 silk/dec_API.c      |  1 +
 silk/structs.h      |  1 +
 src/opus_decoder.c  | 26 +++++++++++++++++++++++++-
 src/opus_demo.c     |  8 ++++++++
 8 files changed, 84 insertions(+), 19 deletions(-)

diff --git a/celt/celt.h b/celt/celt.h
index b5bafa373..86909239a 100644
--- a/celt/celt.h
+++ b/celt/celt.h
@@ -152,7 +152,7 @@ int celt_decoder_init(CELTDecoder *st, opus_int32 sampling_rate, int channels);
 
 int celt_decode_with_ec_dred(CELTDecoder * OPUS_RESTRICT st, const unsigned char *data,
       int len, opus_val16 * OPUS_RESTRICT pcm, int frame_size, ec_dec *dec, int accum
-#ifdef ENABLE_DRED
+#ifdef ENABLE_DEEP_PLC
       ,LPCNetPLCState *lpcnet
 #endif
       );
diff --git a/celt/celt_decoder.c b/celt/celt_decoder.c
index dd7279766..f2d480282 100644
--- a/celt/celt_decoder.c
+++ b/celt/celt_decoder.c
@@ -90,6 +90,7 @@ struct OpusCustomDecoder {
    int start, end;
    int signalling;
    int disable_inv;
+   int complexity;
    int arch;
 
    /* Everything beyond this point gets cleared on a reset */
@@ -110,7 +111,7 @@ struct OpusCustomDecoder {
 
    celt_sig preemph_memD[2];
 
-#ifdef ENABLE_DRED
+#ifdef ENABLE_DEEP_PLC
    opus_int16 plc_pcm[PLC_UPDATE_SAMPLES];
    int plc_fill;
    float plc_preemphasis_mem;
@@ -551,7 +552,7 @@ static void prefilter_and_fold(CELTDecoder * OPUS_RESTRICT st, int N)
    } while (++c<CC);
 }
 
-#ifdef ENABLE_DRED
+#ifdef ENABLE_DEEP_PLC
 
 #define SINC_ORDER 48
 /* h=cos(pi/2*abs(sin([-24:24]/48*pi*23./24)).^2);
@@ -603,7 +604,7 @@ void update_plc_state(LPCNetPLCState *lpcnet, celt_sig *decode_mem[2], int CC)
 #endif
 
 static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM
-#ifdef ENABLE_DRED
+#ifdef ENABLE_DEEP_PLC
       ,LPCNetPLCState *lpcnet
 #endif
       )
@@ -641,7 +642,7 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM
 
    loss_duration = st->loss_duration;
    start = st->start;
-#ifdef ENABLE_DRED
+#ifdef ENABLE_DEEP_PLC
    noise_based = start != 0 || (lpcnet->fec_fill_pos == 0 && (st->skip_plc || loss_duration >= 80));
 #else
    noise_based = loss_duration >= 40 || start != 0 || st->skip_plc;
@@ -718,7 +719,7 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM
 
       if (loss_duration == 0)
       {
-#ifdef ENABLE_DRED
+#ifdef ENABLE_DEEP_PLC
          update_plc_state(lpcnet, decode_mem, C);
 #endif
          st->last_pitch_index = pitch_index = celt_plc_pitch_search(decode_mem, C, st->arch);
@@ -911,8 +912,8 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM
 
       } while (++c<C);
 
-#ifdef ENABLE_DRED
-      {
+#ifdef ENABLE_DEEP_PLC
+      if (st->complexity >= 5 || lpcnet->fec_fill_pos > 0) {
          float overlap_mem;
          int samples_needed16k;
          int ignored = 0;
@@ -982,7 +983,7 @@ static void celt_decode_lost(CELTDecoder * OPUS_RESTRICT st, int N, int LM
 
 int celt_decode_with_ec_dred(CELTDecoder * OPUS_RESTRICT st, const unsigned char *data,
       int len, opus_val16 * OPUS_RESTRICT pcm, int frame_size, ec_dec *dec, int accum
-#ifdef ENABLE_DRED
+#ifdef ENABLE_DEEP_PLC
       ,LPCNetPLCState *lpcnet
 #endif
       )
@@ -1103,7 +1104,7 @@ int celt_decode_with_ec_dred(CELTDecoder * OPUS_RESTRICT st, const unsigned char
    if (data == NULL || len<=1)
    {
       celt_decode_lost(st, N, LM
-#ifdef ENABLE_DRED
+#ifdef ENABLE_DEEP_PLC
       , lpcnet
 #endif
                       );
@@ -1111,7 +1112,7 @@ int celt_decode_with_ec_dred(CELTDecoder * OPUS_RESTRICT st, const unsigned char
       RESTORE_STACK;
       return frame_size/st->downsample;
    }
-#ifdef ENABLE_DRED
+#ifdef ENABLE_DEEP_PLC
    else {
       /* FIXME: This is a bit of a hack just to make sure opus_decode_native() knows we're no longer in PLC. */
       if (lpcnet) lpcnet->blend = 0;
@@ -1365,7 +1366,7 @@ int celt_decode_with_ec(CELTDecoder * OPUS_RESTRICT st, const unsigned char *dat
       int len, opus_val16 * OPUS_RESTRICT pcm, int frame_size, ec_dec *dec, int accum)
 {
    return celt_decode_with_ec_dred(st, data, len, pcm, frame_size, dec, accum
-#ifdef ENABLE_DRED
+#ifdef ENABLE_DEEP_PLC
        , NULL
 #endif
        );
@@ -1443,6 +1444,26 @@ int opus_custom_decoder_ctl(CELTDecoder * OPUS_RESTRICT st, int request, ...)
    va_start(ap, request);
    switch (request)
    {
+      case OPUS_SET_COMPLEXITY_REQUEST:
+      {
+          opus_int32 value = va_arg(ap, opus_int32);
+          if(value<0 || value>10)
+          {
+             goto bad_arg;
+          }
+          st->complexity = value;
+      }
+      break;
+      case OPUS_GET_COMPLEXITY_REQUEST:
+      {
+          opus_int32 *value = va_arg(ap, opus_int32*);
+          if (!value)
+          {
+             goto bad_arg;
+          }
+          *value = st->complexity;
+      }
+      break;
       case CELT_SET_START_BAND_REQUEST:
       {
          opus_int32 value = va_arg(ap, opus_int32);
diff --git a/silk/PLC.c b/silk/PLC.c
index 8a5e4f3b3..1e5248232 100644
--- a/silk/PLC.c
+++ b/silk/PLC.c
@@ -398,12 +398,19 @@ static OPUS_INLINE void silk_PLC_conceal(
     }
 #ifdef ENABLE_DEEP_PLC
     if ( lpcnet != NULL && psDec->sPLC.fs_kHz == 16 ) {
-        for( k = 0; k < psDec->nb_subfr; k += 2 ) {
-            lpcnet_plc_conceal( lpcnet, frame + k * psDec->subfr_length );
-        }
-        /* We *should* be able to copy only from psDec->frame_length-MAX_LPC_ORDER, i.e. the last MAX_LPC_ORDER samples. */
-        for( i = 0; i < psDec->frame_length; i++ ) {
-            sLPC_Q14_ptr[ MAX_LPC_ORDER + i ] = (int)floor(.5 + frame[ i ] * (float)(1 << 24) / prevGain_Q10[ 1 ] );
+        int run_deep_plc = psDec->sPLC.enable_deep_plc || lpcnet->fec_fill_pos != 0;
+        if( run_deep_plc ) {
+            for( k = 0; k < psDec->nb_subfr; k += 2 ) {
+                lpcnet_plc_conceal( lpcnet, frame + k * psDec->subfr_length );
+            }
+            /* We *should* be able to copy only from psDec->frame_length-MAX_LPC_ORDER, i.e. the last MAX_LPC_ORDER samples. */
+            for( i = 0; i < psDec->frame_length; i++ ) {
+                sLPC_Q14_ptr[ MAX_LPC_ORDER + i ] = (int)floor(.5 + frame[ i ] * (float)(1 << 24) / prevGain_Q10[ 1 ] );
+            }
+        } else {
+          for( k = 0; k < psDec->nb_subfr; k += 2 ) {
+              lpcnet_plc_update( lpcnet, frame + k * psDec->subfr_length );
+          }
         }
     }
 #endif
diff --git a/silk/control.h b/silk/control.h
index 8d2392d33..d30d114c7 100644
--- a/silk/control.h
+++ b/silk/control.h
@@ -144,6 +144,9 @@ typedef struct {
 
     /* O:   Pitch lag of previous frame (0 if unvoiced), measured in samples at 48 kHz      */
     opus_int prevPitchLag;
+
+    /* I:   Enable Deep PLC                                                                 */
+    opus_int enable_deep_plc;
 } silk_DecControlStruct;
 
 #ifdef __cplusplus
diff --git a/silk/dec_API.c b/silk/dec_API.c
index 090d6089a..a29ecc73c 100644
--- a/silk/dec_API.c
+++ b/silk/dec_API.c
@@ -281,6 +281,7 @@ opus_int silk_Decode(                                   /* O    Returns error co
         has_side = !psDec->prev_decode_only_middle
               || (decControl->nChannelsInternal == 2 && lostFlag == FLAG_DECODE_LBRR && channel_state[1].LBRR_flags[ channel_state[1].nFramesDecoded ] == 1 );
     }
+    channel_state[ 0 ].sPLC.enable_deep_plc = decControl->enable_deep_plc;
     /* Call decoder for one frame */
     for( n = 0; n < decControl->nChannelsInternal; n++ ) {
         if( n == 0 || has_side ) {
diff --git a/silk/structs.h b/silk/structs.h
index 22e79f701..709d3557f 100644
--- a/silk/structs.h
+++ b/silk/structs.h
@@ -253,6 +253,7 @@ typedef struct {
     opus_int                    fs_kHz;
     opus_int                    nb_subfr;
     opus_int                    subfr_length;
+    opus_int                    enable_deep_plc;
 } silk_PLC_struct;
 
 /* Struct for CNG */
diff --git a/src/opus_decoder.c b/src/opus_decoder.c
index 1bdb6ad4e..6f8e517a0 100644
--- a/src/opus_decoder.c
+++ b/src/opus_decoder.c
@@ -63,6 +63,7 @@ struct OpusDecoder {
    opus_int32   Fs;          /** Sampling rate (at the API level) */
    silk_DecControlStruct DecControl;
    int          decode_gain;
+   int          complexity;
    int          arch;
 #ifdef ENABLE_DEEP_PLC
     LPCNetPLCState lpcnet;
@@ -142,6 +143,7 @@ int opus_decoder_init(OpusDecoder *st, opus_int32 Fs, int channels)
    silk_dec = (char*)st+st->silk_dec_offset;
    celt_dec = (CELTDecoder*)((char*)st+st->celt_dec_offset);
    st->stream_channels = st->channels = channels;
+   st->complexity = 0;
 
    st->Fs = Fs;
    st->DecControl.API_sampleRate = st->Fs;
@@ -404,6 +406,7 @@ static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
            st->DecControl.internalSampleRate = 16000;
         }
      }
+     st->DecControl.enable_deep_plc = st->complexity >= 5;
 
      lost_flag = data == NULL ? 1 : 2 * !!decode_fec;
      decoded_samples = 0;
@@ -537,7 +540,7 @@ static int opus_decode_frame(OpusDecoder *st, const unsigned char *data,
       /* Decode CELT */
       celt_ret = celt_decode_with_ec_dred(celt_dec, decode_fec ? NULL : data,
                                      len, pcm, celt_frame_size, &dec, celt_accum
-#ifdef ENABLE_DRED
+#ifdef ENABLE_DEEP_PLC
                                      , &st->lpcnet
 #endif
                                      );
@@ -911,6 +914,27 @@ int opus_decoder_ctl(OpusDecoder *st, int request, ...)
       *value = st->bandwidth;
    }
    break;
+   case OPUS_SET_COMPLEXITY_REQUEST:
+   {
+       opus_int32 value = va_arg(ap, opus_int32);
+       if(value<0 || value>10)
+       {
+          goto bad_arg;
+       }
+       st->complexity = value;
+       celt_decoder_ctl(celt_dec, OPUS_SET_COMPLEXITY(value));
+   }
+   break;
+   case OPUS_GET_COMPLEXITY_REQUEST:
+   {
+       opus_int32 *value = va_arg(ap, opus_int32*);
+       if (!value)
+       {
+          goto bad_arg;
+       }
+       *value = st->complexity;
+   }
+   break;
    case OPUS_GET_FINAL_RANGE_REQUEST:
    {
       opus_uint32 *value = va_arg(ap, opus_uint32*);
diff --git a/src/opus_demo.c b/src/opus_demo.c
index 66c7492a9..c5f6250fc 100644
--- a/src/opus_demo.c
+++ b/src/opus_demo.c
@@ -126,6 +126,7 @@ static opus_uint32 char_to_int(unsigned char ch[4])
 }
 
 #define check_encoder_option(decode_only, opt) do {if (decode_only) {fprintf(stderr, "option %s is only for encoding\n", opt); goto failure;}} while(0)
+#define check_decoder_option(encode_only, opt) do {if (encode_only) {fprintf(stderr, "option %s is only for decoding\n", opt); goto failure;}} while(0)
 
 static const int silk8_test[][4] = {
       {MODE_SILK_ONLY, OPUS_BANDWIDTH_NARROWBAND, 960*3, 1},
@@ -273,6 +274,7 @@ int main(int argc, char *argv[])
     int use_vbr;
     int max_payload_bytes;
     int complexity;
+    int dec_complexity;
     int use_inbandfec;
     int use_dtx;
     int forcechannels;
@@ -391,6 +393,7 @@ int main(int argc, char *argv[])
     use_vbr = 1;
     max_payload_bytes = MAX_PACKET;
     complexity = 10;
+    dec_complexity = 0;
     use_inbandfec = 0;
     forcechannels = OPUS_AUTO;
     use_dtx = 0;
@@ -456,6 +459,10 @@ int main(int argc, char *argv[])
             check_encoder_option(decode_only, "-complexity");
             complexity = atoi( argv[ args + 1 ] );
             args += 2;
+        } else if( strcmp( argv[ args ], "-dec_complexity" ) == 0 ) {
+            check_decoder_option(encode_only, "-dec_complexity");
+            dec_complexity = atoi( argv[ args + 1 ] );
+            args += 2;
         } else if( strcmp( argv[ args ], "-inbandfec" ) == 0 ) {
             use_inbandfec = 1;
             args++;
@@ -616,6 +623,7 @@ int main(int argc, char *argv[])
           fprintf(stderr, "Cannot create decoder: %s\n", opus_strerror(err));
           goto failure;
        }
+       opus_decoder_ctl(dec, OPUS_SET_COMPLEXITY(dec_complexity));
     }
     switch(bandwidth)
     {
-- 
GitLab