Skip to content
Snippets Groups Projects
Verified Commit 1312642f authored by Jean-Marc Valin's avatar Jean-Marc Valin
Browse files

More DRED refactoring

progressive decode, avoid storing DRED decoder state
parent 906ee4b2
No related branches found
No related tags found
No related merge requests found
......@@ -522,7 +522,11 @@ OPUS_EXPORT OpusDRED *opus_dred_create(int *error);
OPUS_EXPORT void opus_dred_destroy(OpusDRED *dec);
OPUS_EXPORT int opus_dred_parse(OpusDecoder *st, const unsigned char *data,
OPUS_EXPORT int opus_dred_parse(OpusDRED *dred, const unsigned char *data, opus_int32 len, opus_int32 max_dred_samples, opus_int32 sampling_rate, int defer_processing);
OPUS_EXPORT int opus_dred_process(OpusDRED *dred);
OPUS_EXPORT int opus_decoder_dred_parse(OpusDecoder *st, const unsigned char *data,
opus_int32 len, int offset) OPUS_ARG_NONNULL(1);
/** Parse an opus packet into one or more frames.
......
......@@ -40,7 +40,6 @@
int opus_dred_init(OpusDRED *dec)
{
memset(dec, 0, sizeof(*dec));
dec->rdovae_dec = DRED_rdovae_create_decoder();
return OPUS_OK;
}
......@@ -74,52 +73,47 @@ OpusDRED *opus_dred_create(int *error)
void opus_dred_destroy(OpusDRED *dec)
{
DRED_rdovae_destroy_decoder(dec->rdovae_dec);
}
int dred_decode_redundancy_package(OpusDRED *dec, float *features, const opus_uint8 *bytes, int num_bytes, int min_feature_frames)
int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int min_feature_frames)
{
const opus_uint16 *p0 = DRED_rdovae_get_p0_pointer();
const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer();
const opus_uint16 *r = DRED_rdovae_get_r_pointer();
ec_dec ec;
int q_level;
int i;
int offset;
float state[DRED_STATE_DIM];
float latents[DRED_LATENT_DIM];
/* since features are decoded in quadruples, it makes no sense to go with an uneven number of redundancy frames */
celt_assert(DRED_NUM_REDUNDANCY_FRAMES % 2 == 0);
/* decode initial state and initialize RDOVAE decoder */
ec_dec_init(&ec, (unsigned char*)bytes, num_bytes);
dred_decode_state(&ec, state);
DRED_rdovae_dec_init_states(dec->rdovae_dec, state);
/* decode newest to oldest and store oldest to newest */
for (i = 0; i < IMIN(DRED_NUM_REDUNDANCY_FRAMES, (min_feature_frames+1)/2); i += 2)
{
/* FIXME: Figure out how to avoid missing a last frame that would take up < 8 bits. */
if (8*num_bytes - ec_tell(&ec) <= 7)
break;
q_level = (int) round(DRED_ENC_Q0 + 1.f * (DRED_ENC_Q1 - DRED_ENC_Q0) * i / (DRED_NUM_REDUNDANCY_FRAMES - 2));
offset = q_level * DRED_LATENT_DIM;
dred_decode_latents(
&ec,
latents,
quant_scales + offset,
r + offset,
p0 + offset
);
offset = 2 * i * DRED_NUM_FEATURES;
DRED_rdovae_decode_qframe(
dec->rdovae_dec,
features + offset,
latents);
}
return 2*i;
const opus_uint16 *p0 = DRED_rdovae_get_p0_pointer();
const opus_uint16 *quant_scales = DRED_rdovae_get_quant_scales_pointer();
const opus_uint16 *r = DRED_rdovae_get_r_pointer();
ec_dec ec;
int q_level;
int i;
int offset;
/* since features are decoded in quadruples, it makes no sense to go with an uneven number of redundancy frames */
celt_assert(DRED_NUM_REDUNDANCY_FRAMES % 2 == 0);
/* decode initial state and initialize RDOVAE decoder */
ec_dec_init(&ec, (unsigned char*)bytes, num_bytes);
dred_decode_state(&ec, dec->state);
/* decode newest to oldest and store oldest to newest */
for (i = 0; i < IMIN(DRED_NUM_REDUNDANCY_FRAMES, (min_feature_frames+1)/2); i += 2)
{
/* FIXME: Figure out how to avoid missing a last frame that would take up < 8 bits. */
if (8*num_bytes - ec_tell(&ec) <= 7)
break;
q_level = (int) round(DRED_ENC_Q0 + 1.f * (DRED_ENC_Q1 - DRED_ENC_Q0) * i / (DRED_NUM_REDUNDANCY_FRAMES - 2));
offset = q_level * DRED_LATENT_DIM;
dred_decode_latents(
&ec,
&dec->latents[(i/2)*DRED_LATENT_DIM],
quant_scales + offset,
r + offset,
p0 + offset
);
offset = 2 * i * DRED_NUM_FEATURES;
}
dec->process_stage = 1;
dec->nb_latents = i/2;
return i/2;
}
......@@ -31,10 +31,12 @@
#include "entcode.h"
struct OpusDRED {
RDOVAEDec *rdovae_dec;
float fec_features[2*DRED_NUM_REDUNDANCY_FRAMES*DRED_NUM_FEATURES];
int nb_fec_frames;
float state[DRED_STATE_DIM];
float latents[(DRED_NUM_REDUNDANCY_FRAMES/2)*DRED_LATENT_DIM];
int nb_latents;
int process_stage;
};
int dred_decode_redundancy_package(OpusDRED *dec, float *features, const opus_uint8 *bytes, int num_bytes, int min_feature_frames);
int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int min_feature_frames);
......@@ -1071,19 +1071,17 @@ int opus_decoder_get_nb_samples(const OpusDecoder *dec,
return opus_packet_get_nb_samples(packet, len, dec->Fs);
}
int opus_dred_parse(OpusDecoder *st, const unsigned char *data,
opus_int32 len, int offset)
{
#ifdef ENABLE_NEURAL_FEC
static int dred_find_payload(const unsigned char *data, opus_int32 len, const unsigned char **payload)
{
const unsigned char *data0;
int len0;
const unsigned char *payload = NULL;
opus_int32 payload_len;
*payload = NULL;
int frame = 0;
int ret;
const unsigned char *frames[48];
opus_int16 size[48];
/* Get the padding section of the packet. */
ret = opus_packet_parse_impl(data, len, 0, NULL, frames, size, NULL, NULL, &data0, &len0);
if (ret < 0)
......@@ -1119,12 +1117,22 @@ int opus_dred_parse(OpusDecoder *st, const unsigned char *data,
/* Check that temporary extension type and version match.
This check will be removed once extension is finalized. */
if (curr_payload_len > 2 && curr_payload[0] == 'D' && curr_payload[1] == DRED_VERSION) {
payload = curr_payload+2;
payload_len = curr_payload_len-2;
break;
*payload = curr_payload+2;
return curr_payload_len-2;
}
}
}
*payload = NULL;
return 0;
}
#endif
int opus_decoder_dred_parse(OpusDecoder *st, const unsigned char *data, opus_int32 len, int offset)
{
#ifdef ENABLE_NEURAL_FEC
const unsigned char *payload;
opus_int32 payload_len;
payload_len = dred_find_payload(data, len, &payload);
if (payload != NULL)
{
int min_feature_frames;
......@@ -1132,10 +1140,26 @@ int opus_dred_parse(OpusDecoder *st, const unsigned char *data,
silk_dec = (silk_decoder_state*)((char*)st+st->silk_dec_offset);
/*printf("Found: %p of size %d\n", payload, payload_len);*/
min_feature_frames = IMIN(2 + offset, 2*DRED_NUM_REDUNDANCY_FRAMES);
st->nb_fec_frames = dred_decode_redundancy_package(&st->dred_decoder, st->fec_features, payload, payload_len, min_feature_frames);
dred_ec_decode(&st->dred_decoder, payload, payload_len, min_feature_frames);
opus_dred_process(&st->dred_decoder);
OPUS_COPY(st->fec_features, st->dred_decoder.fec_features, 4*st->dred_decoder.nb_latents*DRED_NUM_FEATURES);
st->nb_fec_frames = 4*st->dred_decoder.nb_latents;
lpcnet_plc_fec_clear(silk_dec->sPLC.lpcnet);
return st->nb_fec_frames;
}
#endif
return 0;
}
int opus_dred_process(OpusDRED *dred)
{
DRED_rdovae_decode_all(dred->fec_features, dred->state, dred->latents, dred->nb_latents);
dred->process_stage = 2;
return OPUS_OK;
}
int opus_decoder_dred_output(OpusDecoder *st, OpusDRED *dred, int dred_offset, opus_int16 *pcm, int frame_size)
{
return OPUS_OK;
}
......@@ -801,7 +801,7 @@ int main(int argc, char *argv[])
opus_decoder_ctl(dec, OPUS_GET_LAST_PACKET_DURATION(&output_samples));
dred_input = lost_count*output_samples*100/sampling_rate;
/* Only decode the amount we need to fill in the gap. */
opus_dred_parse(dec, data, len, IMIN(100, IMAX(0, dred_input)));
opus_decoder_dred_parse(dec, data, len, IMIN(100, IMAX(0, dred_input)));
}
/* FIXME: Figure out how to trigger the decoder when the last packet of the file is lost. */
for (fr=0;fr<run_decoder;fr++) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment