From d15be43af425013e27ef872ac672700e0b642ac1 Mon Sep 17 00:00:00 2001
From: Jean-Marc Valin <jmvalin@amazon.com>
Date: Sun, 23 Jul 2023 18:11:15 -0400
Subject: [PATCH] Make bias/subias/diag/scale optional

---
 dnn/parse_lpcnet_weights.c | 24 +++++++++++++++++-------
 1 file changed, 17 insertions(+), 7 deletions(-)

diff --git a/dnn/parse_lpcnet_weights.c b/dnn/parse_lpcnet_weights.c
index 833f972fd..0f2def8b6 100644
--- a/dnn/parse_lpcnet_weights.c
+++ b/dnn/parse_lpcnet_weights.c
@@ -124,16 +124,22 @@ int linear_init(LinearLayer *layer, const WeightArray *arrays,
   int nb_inputs,
   int nb_outputs)
 {
-  int total_blocks;
-  if ((layer->bias = find_array_check(arrays, bias, nb_outputs*sizeof(layer->bias[0]))) == NULL) return 1;
-  if ((layer->subias = find_array_check(arrays, subias, nb_outputs*sizeof(layer->subias[0]))) == NULL) return 1;
+  layer->bias = NULL;
+  layer->subias = NULL;
   layer->weights = NULL;
   layer->float_weights = NULL;
   layer->weights_idx = NULL;
-  if (weights_idx != NULL) {
-    if ((layer->weights_idx = find_idx_check(arrays, weights_idx, nb_outputs, nb_inputs, &total_blocks)) == NULL) return 1;
+  layer->diag = NULL;
+  layer->scale = NULL;
+  if (bias != NULL) {
+    if ((layer->bias = find_array_check(arrays, bias, nb_outputs*sizeof(layer->bias[0]))) == NULL) return 1;
+  }
+  if (subias != NULL) {
+    if ((layer->subias = find_array_check(arrays, subias, nb_outputs*sizeof(layer->subias[0]))) == NULL) return 1;
   }
   if (weights_idx != NULL) {
+    int total_blocks;
+    if ((layer->weights_idx = find_idx_check(arrays, weights_idx, nb_outputs, nb_inputs, &total_blocks)) == NULL) return 1;
     if (weights != NULL) {
       if ((layer->weights = find_array_check(arrays, weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->weights[0]))) == NULL) return 1;
     }
@@ -148,8 +154,12 @@ int linear_init(LinearLayer *layer, const WeightArray *arrays,
       if ((layer->float_weights = find_array_check(arrays, float_weights, nb_inputs*nb_outputs*sizeof(layer->float_weights[0]))) == NULL) return 1;
     }
   }
-  if ((layer->diag = find_array_check(arrays, diag, nb_outputs*sizeof(layer->diag[0]))) == NULL) return 1;
-  if ((layer->scale = find_array_check(arrays, scale, nb_outputs*sizeof(layer->scale[0]))) == NULL) return 1;
+  if (diag != NULL) {
+    if ((layer->diag = find_array_check(arrays, diag, nb_outputs*sizeof(layer->diag[0]))) == NULL) return 1;
+  }
+  if (weights != NULL) {
+    if ((layer->scale = find_array_check(arrays, scale, nb_outputs*sizeof(layer->scale[0]))) == NULL) return 1;
+  }
   layer->nb_inputs = nb_inputs;
   layer->nb_outputs = nb_outputs;
   return 0;
-- 
GitLab