From e51d3da901703b7c004f4538446aafa7997dafbd Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@amazon.com>
Date: Mon, 16 Oct 2023 23:01:17 -0400
Subject: [PATCH] Fix tests

---
 Makefile.am  | 10 ++++++----
 dnn/fargan.h |  1 +
 dnn/nnet.c   |  3 ++-
 3 files changed, 9 insertions(+), 5 deletions(-)

diff --git a/Makefile.am b/Makefile.am
index d892eebef..e2081a768 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -141,7 +141,6 @@ TESTS = celt/tests/test_unit_cwrs32 \
         silk/tests/test_unit_LPC_inv_pred_gain \
         tests/test_opus_api \
         tests/test_opus_decode \
-        tests/test_opus_dred \
         tests/test_opus_encode \
         tests/test_opus_extensions \
         tests/test_opus_padding \
@@ -176,9 +175,6 @@ tests_test_opus_decode_LDADD = libopus.la $(NE10_LIBS) $(LIBM)
 tests_test_opus_padding_SOURCES = tests/test_opus_padding.c tests/test_opus_common.h
 tests_test_opus_padding_LDADD = libopus.la $(NE10_LIBS) $(LIBM)
 
-tests_test_opus_dred_SOURCES = tests/test_opus_dred.c tests/test_opus_common.h
-tests_test_opus_dred_LDADD = libopus.la $(NE10_LIBS) $(LIBM)
-
 CELT_OBJ = $(CELT_SOURCES:.c=.lo)
 SILK_OBJ = $(SILK_SOURCES:.c=.lo)
 LPCNET_OBJ = $(LPCNET_SOURCES:.c=.lo)
@@ -254,6 +250,12 @@ dump_data_LDADD = $(LPCNET_OBJ) $(CELT_OBJ) $(LIBM)
 dump_weights_blob_SOURCES = dnn/write_lpcnet_weights.c
 dump_weights_blob_LDADD = $(LIBM)
 dump_weights_blob_CFLAGS = $(AM_CFLAGS) -DDUMP_BINARY_WEIGHTS
+endif
+if ENABLE_DRED
+TESTS += tests/test_opus_dred
+tests_test_opus_dred_SOURCES = tests/test_opus_dred.c tests/test_opus_common.h
+tests_test_opus_dred_LDADD = libopus.la $(NE10_LIBS) $(LIBM)
+
 endif
 endif
 
diff --git a/dnn/fargan.h b/dnn/fargan.h
index df5751ba3..1031c0054 100644
--- a/dnn/fargan.h
+++ b/dnn/fargan.h
@@ -41,6 +41,7 @@
 #define SIG_NET_INPUT_SIZE (FARGAN_COND_SIZE+2*FARGAN_SUBFRAME_SIZE+4)
 #define SIG_NET_FWC0_STATE_SIZE (2*SIG_NET_INPUT_SIZE)
 
+#define FARGAN_MAX_RNN_NEURONS SIG_NET_GRU1_OUT_SIZE
 typedef struct {
   FARGAN model;
   int arch;
diff --git a/dnn/nnet.c b/dnn/nnet.c
index 7f4658e0b..97ac74f32 100644
--- a/dnn/nnet.c
+++ b/dnn/nnet.c
@@ -38,6 +38,7 @@
 #include "nnet.h"
 #include "dred_rdovae_constants.h"
 #include "plc_data.h"
+#include "fargan.h"
 #include "os_support.h"
 
 #ifdef NO_OPTIMIZATIONS
@@ -108,7 +109,7 @@ void compute_generic_dense(const LinearLayer *layer, float *output, const float
    compute_activation(output, output, layer->nb_outputs, activation);
 }
 
-#define MAX_RNN_NEURONS_ALL IMAX(PLC_MAX_RNN_NEURONS, DRED_MAX_RNN_NEURONS)
+#define MAX_RNN_NEURONS_ALL IMAX(IMAX(FARGAN_MAX_RNN_NEURONS, PLC_MAX_RNN_NEURONS), DRED_MAX_RNN_NEURONS)
 
 
 void compute_generic_gru(const LinearLayer *input_weights, const LinearLayer *recurrent_weights, float *state, const float *in)
-- 
GitLab