From 393d463fdd6d17f3fef126a019102694ad89e2f7 Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@jmvalin.ca>
Date: Sat, 17 Feb 2024 14:20:44 -0500
Subject: [PATCH] Add lossgen_demo

Also skip the first loss values being generated since they're
biased towards "not lost" due to the initialization.
---
 Makefile.am        |  8 ++++++++
 dnn/lossgen.c      | 15 ++++++++++++++-
 dnn/lossgen.h      |  1 +
 dnn/lossgen_demo.c | 22 ++++++++++++++++++++++
 4 files changed, 45 insertions(+), 1 deletion(-)
 create mode 100644 dnn/lossgen_demo.c

diff --git a/Makefile.am b/Makefile.am
index f2631f727..a20d0a395 100644
--- a/Makefile.am
+++ b/Makefile.am
@@ -308,8 +308,16 @@ endif
 if ENABLE_DRED
 TESTS += tests/test_opus_dred
 endif
+
+if ENABLE_LOSSGEN
+noinst_PROGRAMS += lossgen_demo
+lossgen_demo_SOURCES = dnn/lossgen_demo.c $(LOSSGEN_SOURCES)
+lossgen_demo_LDADD = $(LIBM)
 endif
 
+endif
+
+
 EXTRA_DIST = opus.pc.in \
              opus-uninstalled.pc.in \
              opus.m4 \
diff --git a/dnn/lossgen.c b/dnn/lossgen.c
index f38763a9a..4730e6973 100644
--- a/dnn/lossgen.c
+++ b/dnn/lossgen.c
@@ -70,7 +70,7 @@ void compute_generic_dense_lossgen(const LinearLayer *layer, float *output, cons
 }
 
 
-int sample_loss(
+static int sample_loss_impl(
     LossGenState *st,
     float percent_loss)
 {
@@ -90,6 +90,19 @@ int sample_loss(
   return loss;
 }
 
+int sample_loss(
+    LossGenState *st,
+    float percent_loss)
+{
+   /* Due to GRU being initialized with zeros, the first packets aren't quite random,
+      so we skip them. */
+   if (!st->used) {
+      int i;
+      for (i=0;i<100;i++) sample_loss_impl(st, percent_loss);
+      st->used = 1;
+   }
+   return sample_loss_impl(st, percent_loss);
+}
 
 void lossgen_init(LossGenState *st)
 {
diff --git a/dnn/lossgen.h b/dnn/lossgen.h
index 0ac0860fe..06b771de2 100644
--- a/dnn/lossgen.h
+++ b/dnn/lossgen.h
@@ -15,6 +15,7 @@ typedef struct {
   float gru1_state[LOSSGEN_GRU1_STATE_SIZE];
   float gru2_state[LOSSGEN_GRU2_STATE_SIZE];
   int last_loss;
+  int used;
 } LossGenState;
 
 
diff --git a/dnn/lossgen_demo.c b/dnn/lossgen_demo.c
new file mode 100644
index 000000000..bad7bdc32
--- /dev/null
+++ b/dnn/lossgen_demo.c
@@ -0,0 +1,22 @@
+#include <stdio.h>
+#include <stdlib.h>
+#include "lossgen.h"
+int main(int argc, char **argv)
+{
+   LossGenState st;
+   long num_packets;
+   long i;
+   float percent;
+   if (argc != 3) {
+      fprintf(stderr, "usage: %s <percent_loss> <nb packets>\n", argv[0]);
+      return 1;
+   }
+   lossgen_init(&st);
+   percent = atof(argv[1]);
+   num_packets = atol(argv[2]);
+   /*printf("loss: %f %d\n", percent, num_packets);*/
+   for (i=0;i<num_packets;i++) {
+      printf("%d\n", sample_loss(&st, percent*0.01f));
+   }
+   return 0;
+}
-- 
GitLab