From 585de8e4671f8994b80451c36268dd1c7dd51269 Mon Sep 17 00:00:00 2001
From: Jan Buethe <jbuethe@amazon.de>
Date: Wed, 26 Oct 2022 10:15:39 +0000
Subject: [PATCH] changed data types for r, dead_zone, quant_scale and p0 to
 opus_uint16

---
 dnn/dred_rdovae.c               |  8 ++++----
 dnn/include/dred_rdovae.h       |  8 ++++----
 dnn/training_tf2/dump_rdovae.py | 24 ++++++++++++------------
 3 files changed, 20 insertions(+), 20 deletions(-)

diff --git a/dnn/dred_rdovae.c b/dnn/dred_rdovae.c
index 82454277e..610891061 100644
--- a/dnn/dred_rdovae.c
+++ b/dnn/dred_rdovae.c
@@ -99,22 +99,22 @@ void DRED_rdovae_decode_qframe(RDOVAEDec *h, float *qframe, const float *z)
 }
 
 
-const opus_int16 * DRED_rdovae_get_p0_pointer(void)
+const opus_uint16 * DRED_rdovae_get_p0_pointer(void)
 {
     return &dred_p0_q15[0];
 }
 
-const opus_int16 * DRED_rdovae_get_dead_zone_pointer(void)
+const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void)
 {
     return &dred_dead_zone_q10[0];
 }
 
-const opus_int16 * DRED_rdovae_get_r_pointer(void)
+const opus_uint16 * DRED_rdovae_get_r_pointer(void)
 {
     return &dred_r_q15[0];
 }
 
-const opus_int16 * DRED_rdovae_get_quant_scales_pointer(void)
+const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void)
 {
     return &dred_quant_scales_q8[0];
 }
\ No newline at end of file
diff --git a/dnn/include/dred_rdovae.h b/dnn/include/dred_rdovae.h
index 4997f4bd9..bc4211a26 100644
--- a/dnn/include/dred_rdovae.h
+++ b/dnn/include/dred_rdovae.h
@@ -51,7 +51,7 @@ void DRED_rdovae_dec_init_states(RDOVAEDec *h, const float * initial_state);
 
 void DRED_rdovae_decode_qframe(RDOVAEDec *h, float *qframe, const float * z);
 
-const opus_int16 * DRED_rdovae_get_p0_pointer(void);
-const opus_int16 * DRED_rdovae_get_dead_zone_pointer(void);
-const opus_int16 * DRED_rdovae_get_r_pointer(void);
-const opus_int16 * DRED_rdovae_get_quant_scales_pointer(void);
+const opus_uint16 * DRED_rdovae_get_p0_pointer(void);
+const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void);
+const opus_uint16 * DRED_rdovae_get_r_pointer(void);
+const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void);
diff --git a/dnn/training_tf2/dump_rdovae.py b/dnn/training_tf2/dump_rdovae.py
index 14f55f2b4..692d4884a 100644
--- a/dnn/training_tf2/dump_rdovae.py
+++ b/dnn/training_tf2/dump_rdovae.py
@@ -92,22 +92,22 @@ def dump_statistical_model(qembedding, f, fh):
     p0              = tf.math.sigmoid(w[:, 4 * N : 5 * N]).numpy()
     p0              = 1 - r ** (0.5 + 0.5 * p0)
 
-    quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.int16)
-    dead_zone_q10   = np.round(dead_zone * 2**10).astype(np.int16)
-    r_q15           = np.round(r * 2**15).astype(np.int16)
-    p0_q15          = np.round(p0 * 2**15).astype(np.int16)
+    quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
+    dead_zone_q10   = np.round(dead_zone * 2**10).astype(np.uint16)
+    r_q15           = np.round(r * 2**15).astype(np.uint16)
+    p0_q15          = np.round(p0 * 2**15).astype(np.uint16)
 
-    printVector(f, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_int16', static=False)
-    printVector(f, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_int16', static=False)
-    printVector(f, r_q15, 'dred_r_q15', dtype='opus_int16', static=False)
-    printVector(f, p0_q15, 'dred_p0_q15', dtype='opus_int16', static=False)
+    printVector(f, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
+    printVector(f, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
+    printVector(f, r_q15, 'dred_r_q15', dtype='opus_uint16', static=False)
+    printVector(f, p0_q15, 'dred_p0_q15', dtype='opus_uint16', static=False)
 
     fh.write(
 f"""
-extern const opus_int16 dred_quant_scales_q8[{levels * N}];
-extern const opus_int16 dred_dead_zone_q10[{levels * N}];
-extern const opus_int16 dred_r_q15[{levels * N}];
-extern const opus_int16 dred_p0_q15[{levels * N}];
+extern const opus_uint16 dred_quant_scales_q8[{levels * N}];
+extern const opus_uint16 dred_dead_zone_q10[{levels * N}];
+extern const opus_uint16 dred_r_q15[{levels * N}];
+extern const opus_uint16 dred_p0_q15[{levels * N}];
 
 """
     )
-- 
GitLab