Skip to content
Snippets Groups Projects
Verified Commit 77594bf1 authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

Dumping RDOVAE stats from XML

parent 222662da
No related branches found
No related tags found
No related merge requests found
Pipeline #4191 failed
......@@ -48,8 +48,27 @@ from rdovae import RDOVAE
from wexchange.torch import dump_torch_weights
from wexchange.c_export import CWriter, print_vector
def dump_statistical_model(writer, w, name):
def print_xml(xmlout, val, param, anchor, name):
xmlout.write(
f"""
<table anchor="{anchor}_{name}">
<name>{param} values for {name}</name>
<thead>
<tr><th>k</th><th>Q0</th><th>Q1</th><th>Q2</th><th>Q3</th><th>Q4</th><th>Q5</th><th>Q6</th><th>Q7</th><th>Q8</th><th>Q9</th><th>Q10</th><th>Q11</th><th>Q12</th><th>Q13</th><th>Q14</th><th>Q15</th></tr>
</thead>
<tbody>
""")
for k in range(val.shape[1]):
xmlout.write(f" <tr><th>{k}</th>")
for j in range(val.shape[0]):
xmlout.write(f"<th>{val[j][k]}</th>")
xmlout.write("</tr>\n")
xmlout.write(
f"""
</tbody>
</table>
""")
def dump_statistical_model(writer, w, name, xmlout):
levels = w.shape[0]
print("printing statistical model")
......@@ -78,6 +97,11 @@ def dump_statistical_model(writer, w, name):
print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)
print_xml(xmlout, quant_scales_q8, "Scale", "scale", name)
print_xml(xmlout, dead_zone_q8, "Dead zone", "deadzone", name)
print_xml(xmlout, r_q8, "Decay (r)", "decay", name)
print_xml(xmlout, p0_q8, "P(0)", "p0", name)
writer.header.write(
f"""
extern const opus_uint8 dred_{name}_quant_scales_q8[{levels * N}];
......@@ -98,6 +122,7 @@ def c_export(args, model):
dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message, model_struct_name='RDOVAEDec')
stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats_data"), message=message, enable_binary_blob=False)
constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True, enable_binary_blob=False)
xmlout = open("stats.xml", "w")
# some custom includes
for writer in [enc_writer, dec_writer]:
......@@ -130,8 +155,8 @@ f"""
levels = qembedding.shape[0]
qembedding = torch.reshape(qembedding, (levels, 6, -1))
latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent')
state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state')
latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent', xmlout)
state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state', xmlout)
padded_latent_dim = (latent_dim+7)//8*8
latent_pad = padded_latent_dim - latent_dim;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment