From 183a820212381f6c447b5a7c9b92b34fa01c629b Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@jmvalin.ca>
Date: Sat, 3 Feb 2024 02:51:08 -0500
Subject: [PATCH] Refactoring: store all states

---
 silk/dred_encoder.c | 4 ++--
 silk/dred_encoder.h | 3 +--
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/silk/dred_encoder.c b/silk/dred_encoder.c
index 378e8a1cc..23a697435 100644
--- a/silk/dred_encoder.c
+++ b/silk/dred_encoder.c
@@ -95,6 +95,7 @@ static void dred_process_frame(DREDEnc *enc, int arch)
     celt_assert(enc->loaded);
     /* shift latents buffer */
     OPUS_MOVE(enc->latents_buffer + DRED_LATENT_DIM, enc->latents_buffer, (DRED_MAX_FRAMES - 1) * DRED_LATENT_DIM);
+    OPUS_MOVE(enc->state_buffer + DRED_STATE_DIM, enc->state_buffer, (DRED_MAX_FRAMES - 1) * DRED_STATE_DIM);
 
     /* calculate LPCNet features */
     lpcnet_compute_single_frame_features_float(&enc->lpcnet_enc_state, enc->input_buffer, feature_buffer, arch);
@@ -212,7 +213,6 @@ void dred_compute_latents(DREDEnc *enc, const float *pcm, int frame_size, int ex
             /* 15 ms (6*2.5 ms) is the ideal offset for DRED because it corresponds to our vocoder look-ahead. */
             if (enc->dred_offset < 6) {
                 enc->dred_offset += 8;
-                OPUS_COPY(enc->initial_state, enc->state_buffer, DRED_STATE_DIM);
             } else {
                 enc->latent_offset++;
             }
@@ -277,7 +277,7 @@ int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunk
     state_qoffset = q0*DRED_STATE_DIM;
     dred_encode_latents(
         &ec_encoder,
-        enc->initial_state,
+        &enc->state_buffer[enc->latent_offset*DRED_STATE_DIM],
         dred_state_quant_scales_q8 + state_qoffset,
         dred_state_dead_zone_q8 + state_qoffset,
         dred_state_r_q8 + state_qoffset,
diff --git a/silk/dred_encoder.h b/silk/dred_encoder.h
index 136f52ced..3831f355b 100644
--- a/silk/dred_encoder.h
+++ b/silk/dred_encoder.h
@@ -53,8 +53,7 @@ typedef struct {
     int latent_offset;
     float latents_buffer[DRED_MAX_FRAMES * DRED_LATENT_DIM];
     int latents_buffer_fill;
-    float state_buffer[DRED_STATE_DIM];
-    float initial_state[DRED_STATE_DIM];
+    float state_buffer[DRED_MAX_FRAMES * DRED_STATE_DIM];
     float resample_mem[RESAMPLING_ORDER + 1];
 } DREDEnc;
 
-- 
GitLab