From 1711e971655439c02423ceeec58e3078d2eee5a3 Mon Sep 17 00:00:00 2001
From: Jan Buethe <jbuethe@amazon.de>
Date: Mon, 6 May 2024 14:11:59 +0200
Subject: [PATCH] fixed enable_binary_blob option for CWriter

---
 dnn/torch/lossgen/export_lossgen.py           |  2 +-
 .../wexchange/c_export/c_writer.py            | 77 +++++++++----------
 .../wexchange/c_export/common.py              |  2 +-
 3 files changed, 40 insertions(+), 41 deletions(-)

diff --git a/dnn/torch/lossgen/export_lossgen.py b/dnn/torch/lossgen/export_lossgen.py
index 1f7df957d..da63118f0 100644
--- a/dnn/torch/lossgen/export_lossgen.py
+++ b/dnn/torch/lossgen/export_lossgen.py
@@ -52,7 +52,7 @@ def c_export(args, model):
 
     message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
 
-    writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen')
+    writer = CWriter(os.path.join(args.output_dir, "lossgen_data"), message=message, model_struct_name='LossGen', enable_binary_blob=False)
     writer.header.write(
 f"""
 #include "opus_types.h"
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
index 2745f3371..74c8b5552 100644
--- a/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/c_writer.py
@@ -120,50 +120,49 @@ f"""
     def _finalize_header(self):
 
         # create model type
-        if self.enable_binary_blob:
-            if self.add_typedef:
-                self.header.write(f"\ntypedef struct {{")
-            else:
-                self.header.write(f"\nstruct {self.model_struct_name} {{")
-            for name, data in self.layer_dict.items():
-                layer_type = data[0]
-                self.header.write(f"\n    {layer_type} {name};")
-            if self.add_typedef:
-                self.header.write(f"\n}} {self.model_struct_name};\n")
-            else:
-                self.header.write(f"\n}};\n")
-
-            init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
-            self.header.write(f"\n{init_prototype};\n")
+        if self.add_typedef:
+            self.header.write(f"\ntypedef struct {{")
+        else:
+            self.header.write(f"\nstruct {self.model_struct_name} {{")
+        for name, data in self.layer_dict.items():
+            layer_type = data[0]
+            self.header.write(f"\n    {layer_type} {name};")
+        if self.add_typedef:
+            self.header.write(f"\n}} {self.model_struct_name};\n")
+        else:
+            self.header.write(f"\n}};\n")
+
+        init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
+        self.header.write(f"\n{init_prototype};\n")
 
         self.header.write(f"\n#endif /* {self.header_guard} */\n")
 
     def _finalize_source(self):
 
-        if self.enable_binary_blob:
-            # create weight array
-            if len(set(self.weight_arrays)) != len(self.weight_arrays):
-                raise ValueError("error: detected duplicates in weight arrays")
-            self.source.write("\n#ifndef USE_WEIGHTS_FILE\n")
-            self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n")
-            for name in self.weight_arrays:
-                self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n")
-                self.source.write(f'    {{"{name}",  WEIGHTS_{name}_TYPE, sizeof({name}), {name}}},\n')
-                self.source.write(f"#endif\n")
-            self.source.write("    {NULL, 0, 0, NULL}\n")
-            self.source.write("};\n")
-
-            self.source.write("#endif /* USE_WEIGHTS_FILE */\n")
-
-            # create init function definition
-            init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
-            self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n")
-            self.source.write(f"{init_prototype} {{\n")
-            for name, data in self.layer_dict.items():
-                self.source.write(f"    if ({data[1]}) return 1;\n")
-            self.source.write("    return 0;\n")
-            self.source.write("}\n")
-            self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n")
+
+        # create weight array
+        if len(set(self.weight_arrays)) != len(self.weight_arrays):
+            raise ValueError("error: detected duplicates in weight arrays")
+        if self.enable_binary_blob: self.source.write("\n#ifndef USE_WEIGHTS_FILE\n")
+        self.source.write(f"const WeightArray {self.model_struct_name.lower()}_arrays[] = {{\n")
+        for name in self.weight_arrays:
+            self.source.write(f"#ifdef WEIGHTS_{name}_DEFINED\n")
+            self.source.write(f'    {{"{name}",  WEIGHTS_{name}_TYPE, sizeof({name}), {name}}},\n')
+            self.source.write(f"#endif\n")
+        self.source.write("    {NULL, 0, 0, NULL}\n")
+        self.source.write("};\n")
+
+        if self.enable_binary_blob: self.source.write("#endif /* USE_WEIGHTS_FILE */\n")
+
+        # create init function definition
+        init_prototype = f"int init_{self.model_struct_name.lower()}({self.model_struct_name} *model, const WeightArray *arrays)"
+        if self.enable_binary_blob: self.source.write("\n#ifndef DUMP_BINARY_WEIGHTS\n")
+        self.source.write(f"{init_prototype} {{\n")
+        for name, data in self.layer_dict.items():
+            self.source.write(f"    if ({data[1]}) return 1;\n")
+        self.source.write("    return 0;\n")
+        self.source.write("}\n")
+        if self.enable_binary_blob:self.source.write("#endif /* DUMP_BINARY_WEIGHTS */\n")
 
 
     def close(self):
diff --git a/dnn/torch/weight-exchange/wexchange/c_export/common.py b/dnn/torch/weight-exchange/wexchange/c_export/common.py
index 039edd9b1..b96e0d6c1 100644
--- a/dnn/torch/weight-exchange/wexchange/c_export/common.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py
@@ -54,7 +54,7 @@ f'''
 #ifndef USE_WEIGHTS_FILE
 '''
         )
-        writer.weight_arrays.append(name)
+    writer.weight_arrays.append(name)
 
     if reshape_8x4:
         vector = vector.reshape((vector.shape[0]//4, 4, vector.shape[1]//8, 8))
-- 
GitLab