From 368c72374e14109a0af72567e4190b8e7bef302c Mon Sep 17 00:00:00 2001
From: Jeff Petkau <jpet@chromium.org>
Date: Thu, 13 Jun 2013 12:16:58 -0700
Subject: [PATCH] Change the encryption feature to use a callback for
 decryption.

This allows code calling the library can choose an arbitrary
encryption algorithm.

Decoder control parameter VP8_SET_DECRYPT_KEY is renamed to
VP8D_SET_DECRYPTOR, and now takes an small config struct instead
of just a byte array.

Change-Id: I0462b3388d8d45057e4f79a6b6777fe713dc546e
---
 test/vp8_boolcoder_test.cc | 35 ++++++++++------
 test/vp8_decrypt_test.cc   | 60 +++++++++++++++-------------
 vp8/decoder/dboolhuff.c    | 24 ++++++-----
 vp8/decoder/dboolhuff.h    | 24 +++++------
 vp8/decoder/decodframe.c   | 79 ++++++++++++++++---------------------
 vp8/decoder/onyxd_int.h    |  3 +-
 vp8/vp8_dx_iface.c         | 81 +++++++++++++++++++-------------------
 vpx/vp8dx.h                | 21 +++++++---
 8 files changed, 173 insertions(+), 154 deletions(-)

diff --git a/test/vp8_boolcoder_test.cc b/test/vp8_boolcoder_test.cc
index ab19c3412d..c3a8d12e18 100644
--- a/test/vp8_boolcoder_test.cc
+++ b/test/vp8_boolcoder_test.cc
@@ -27,18 +27,28 @@ extern "C" {
 namespace {
 const int num_tests = 10;
 
-void encrypt_buffer(uint8_t *buffer, int size, const uint8_t *key) {
+// In a real use the 'decrypt_state' parameter will be a pointer to a struct
+// with whatever internal state the decryptor uses. For testing we'll just
+// xor with a constant key, and decrypt_state will point to the start of
+// the original buffer.
+const uint8_t secret_key[16] = {
+  0x01, 0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78,
+  0x89, 0x9a, 0xab, 0xbc, 0xcd, 0xde, 0xef, 0xf0
+};
+
+void encrypt_buffer(uint8_t *buffer, int size) {
   for (int i = 0; i < size; ++i) {
-    buffer[i] ^= key[i % 32];
+    buffer[i] ^= secret_key[i & 15];
   }
 }
 
-const uint8_t secret_key[32] = {
-  234,  32,   2,  3,  4, 230,   6,  11,
-    0, 132,  22, 23, 45,  21, 124, 255,
-    0,  43,  52,  3, 23,  63,  99,   7,
-  120,   8, 252, 84,  4,  83,   6,  13
-};
+void test_decrypt_cb(void *decrypt_state, const uint8_t *input,
+                           uint8_t *output, int count) {
+  int offset = input - (uint8_t *)decrypt_state;
+  for (int i = 0; i < count; i++) {
+    output[i] = input[i] ^ secret_key[(offset + i) & 15];
+  }
+}
 
 }  // namespace
 
@@ -85,12 +95,13 @@ TEST(VP8, TestBitIO) {
         vp8_stop_encode(&bw);
 
         BOOL_DECODER br;
-
 #if CONFIG_DECRYPT
-        encrypt_buffer(bw_buffer, buffer_size, secret_key);
+        encrypt_buffer(bw_buffer, buffer_size);
+        vp8dx_start_decode(&br, bw_buffer, buffer_size,
+                           test_decrypt_cb, (void *)bw_buffer);
+#else
+        vp8dx_start_decode(&br, bw_buffer, buffer_size, NULL, NULL);
 #endif
-
-        vp8dx_start_decode(&br, bw_buffer, buffer_size, bw_buffer, secret_key);
         bit_rnd.Reset(random_seed);
         for (int i = 0; i < bits_to_test; ++i) {
           if (bit_method == 2) {
diff --git a/test/vp8_decrypt_test.cc b/test/vp8_decrypt_test.cc
index ea7b920499..d850f006c0 100644
--- a/test/vp8_decrypt_test.cc
+++ b/test/vp8_decrypt_test.cc
@@ -11,55 +11,61 @@
 #include <cstdio>
 #include <cstdlib>
 #include <string>
+#include <vector>
 #include "third_party/googletest/src/include/gtest/gtest.h"
-#include "test/decode_test_driver.h"
+#include "test/codec_factory.h"
 #include "test/ivf_video_source.h"
 
-#if CONFIG_DECRYPT
-
 namespace {
-
-const uint8_t decrypt_key[32] = {
-  255, 0, 0, 0, 0, 0, 0, 0,
-    0, 0, 0, 0, 0, 0, 0, 0,
-    0, 0, 0, 0, 0, 0, 0, 0,
-    0, 0, 0, 0, 0, 0, 0, 0,
+// In a real use the 'decrypt_state' parameter will be a pointer to a struct
+// with whatever internal state the decryptor uses. For testing we'll just
+// xor with a constant key, and decrypt_state will point to the start of
+// the original buffer.
+const uint8_t test_key[16] = {
+  0x01, 0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78,
+  0x89, 0x9a, 0xab, 0xbc, 0xcd, 0xde, 0xef, 0xf0
 };
 
-}  // namespace
+void encrypt_buffer(const uint8_t *src, uint8_t *dst, int size, int offset = 0) {
+  for (int i = 0; i < size; ++i) {
+    dst[i] = src[i] ^ test_key[(offset + i) & 15];
+  }
+}
 
-namespace libvpx_test {
+void test_decrypt_cb(void *decrypt_state, const uint8_t *input,
+                     uint8_t *output, int count) {
+  encrypt_buffer(input, output, count, input - (uint8_t *)decrypt_state);
+}
 
-TEST(TestDecrypt, NullKey) {
-  vpx_codec_dec_cfg_t cfg = {0};
-  vpx_codec_ctx_t decoder = {0};
-  vpx_codec_err_t res = vpx_codec_dec_init(&decoder, &vpx_codec_vp8_dx_algo,
-                                           &cfg, 0);
-  ASSERT_EQ(VPX_CODEC_OK, res);
+} // namespace
 
-  res = vpx_codec_control(&decoder, VP8_SET_DECRYPT_KEY, NULL);
-  ASSERT_EQ(VPX_CODEC_INVALID_PARAM, res);
-}
+namespace libvpx_test {
 
 TEST(TestDecrypt, DecryptWorks) {
   libvpx_test::IVFVideoSource video("vp80-00-comprehensive-001.ivf");
   video.Init();
 
   vpx_codec_dec_cfg_t dec_cfg = {0};
-  Decoder decoder(dec_cfg, 0);
+  VP8Decoder decoder(dec_cfg, 0);
 
-  // Zero decrypt key (by default)
   video.Begin();
+
+  // no decryption
   vpx_codec_err_t res = decoder.DecodeFrame(video.cxdata(), video.frame_size());
   ASSERT_EQ(VPX_CODEC_OK, res) << decoder.DecodeError();
 
-  // Non-zero decrypt key
+  // decrypt frame
   video.Next();
-  decoder.Control(VP8_SET_DECRYPT_KEY, decrypt_key);
+
+#if CONFIG_DECRYPT
+  std::vector<uint8_t> encrypted(video.frame_size());
+  encrypt_buffer(video.cxdata(), &encrypted[0], video.frame_size());
+  vp8_decrypt_init di = { test_decrypt_cb, &encrypted[0] };
+  decoder.Control(VP8D_SET_DECRYPTOR, &di);
+#endif  // CONFIG_DECRYPT
+
   res = decoder.DecodeFrame(video.cxdata(), video.frame_size());
-  ASSERT_NE(VPX_CODEC_OK, res) << decoder.DecodeError();
+  ASSERT_EQ(VPX_CODEC_OK, res) << decoder.DecodeError();
 }
 
 }  // namespace libvpx_test
-
-#endif  // CONFIG_DECRYPT
diff --git a/vp8/decoder/dboolhuff.c b/vp8/decoder/dboolhuff.c
index aa7a56a021..546fb2d217 100644
--- a/vp8/decoder/dboolhuff.c
+++ b/vp8/decoder/dboolhuff.c
@@ -14,16 +14,16 @@
 int vp8dx_start_decode(BOOL_DECODER *br,
                        const unsigned char *source,
                        unsigned int source_sz,
-                       const unsigned char *origin,
-                       const unsigned char *key)
+                       vp8_decrypt_cb *decrypt_cb,
+                       void *decrypt_state)
 {
     br->user_buffer_end = source+source_sz;
     br->user_buffer     = source;
     br->value    = 0;
     br->count    = -8;
     br->range    = 255;
-    br->origin = origin;
-    br->key = key;
+    br->decrypt_cb = decrypt_cb;
+    br->decrypt_state = decrypt_state;
 
     if (source_sz && !source)
         return 1;
@@ -37,13 +37,20 @@ int vp8dx_start_decode(BOOL_DECODER *br,
 void vp8dx_bool_decoder_fill(BOOL_DECODER *br)
 {
     const unsigned char *bufptr = br->user_buffer;
-    const unsigned char *bufend = br->user_buffer_end;
     VP8_BD_VALUE value = br->value;
     int count = br->count;
     int shift = VP8_BD_VALUE_SIZE - 8 - (count + 8);
-    size_t bits_left = (bufend - bufptr)*CHAR_BIT;
+    size_t bytes_left = br->user_buffer_end - bufptr;
+    size_t bits_left = bytes_left * CHAR_BIT;
     int x = (int)(shift + CHAR_BIT - bits_left);
     int loop_end = 0;
+    unsigned char decrypted[sizeof(VP8_BD_VALUE) + 1];
+
+    if (br->decrypt_cb) {
+        int n = bytes_left > sizeof(decrypted) ? sizeof(decrypted) : bytes_left;
+        br->decrypt_cb(br->decrypt_state, bufptr, decrypted, n);
+        bufptr = decrypted;
+    }
 
     if(x >= 0)
     {
@@ -56,14 +63,13 @@ void vp8dx_bool_decoder_fill(BOOL_DECODER *br)
         while(shift >= loop_end)
         {
             count += CHAR_BIT;
-            value |= ((VP8_BD_VALUE)decrypt_byte(bufptr, br->origin,
-                                                 br->key)) << shift;
+            value |= (VP8_BD_VALUE)*bufptr << shift;
             ++bufptr;
+            ++br->user_buffer;
             shift -= CHAR_BIT;
         }
     }
 
-    br->user_buffer = bufptr;
     br->value = value;
     br->count = count;
 }
diff --git a/vp8/decoder/dboolhuff.h b/vp8/decoder/dboolhuff.h
index 46a4dd60ed..4c0ca1ce73 100644
--- a/vp8/decoder/dboolhuff.h
+++ b/vp8/decoder/dboolhuff.h
@@ -28,17 +28,11 @@ typedef size_t VP8_BD_VALUE;
   Even relatively modest values like 100 would work fine.*/
 #define VP8_LOTS_OF_BITS (0x40000000)
 
-static unsigned char decrypt_byte(const unsigned char *ch,
-                                  const unsigned char *origin,
-                                  const unsigned char *key)
-{
-#if CONFIG_DECRYPT
-    const int offset = (int)(ch - origin);
-    return *ch ^ key[offset % 32];  // VP8_DECRYPT_KEY_SIZE
-#else
-    return *ch;
-#endif
-}
+/*Decrypt n bytes of data from input -> output, using the decrypt_state
+   passed in VP8D_SET_DECRYPTOR.
+*/
+typedef void (vp8_decrypt_cb)(void *decrypt_state, const unsigned char *input,
+                              unsigned char *output, int count);
 
 typedef struct
 {
@@ -47,8 +41,8 @@ typedef struct
     VP8_BD_VALUE         value;
     int                  count;
     unsigned int         range;
-    const unsigned char *origin;
-    const unsigned char *key;
+    vp8_decrypt_cb      *decrypt_cb;
+    void                *decrypt_state;
 } BOOL_DECODER;
 
 DECLARE_ALIGNED(16, extern const unsigned char, vp8_norm[256]);
@@ -56,8 +50,8 @@ DECLARE_ALIGNED(16, extern const unsigned char, vp8_norm[256]);
 int vp8dx_start_decode(BOOL_DECODER *br,
                        const unsigned char *source,
                        unsigned int source_sz,
-                       const unsigned char *origin,
-                       const unsigned char *key);
+                       vp8_decrypt_cb *decrypt_cb,
+                       void *decrypt_state);
 
 void vp8dx_bool_decoder_fill(BOOL_DECODER *br);
 
diff --git a/vp8/decoder/decodframe.c b/vp8/decoder/decodframe.c
index bb727db90b..44c35effe6 100644
--- a/vp8/decoder/decodframe.c
+++ b/vp8/decoder/decodframe.c
@@ -759,11 +759,16 @@ static void decode_mb_rows(VP8D_COMP *pbi)
 
 }
 
-static unsigned int read_partition_size(const unsigned char *cx_size)
+static unsigned int read_partition_size(VP8D_COMP *pbi,
+                                        const unsigned char *cx_size)
 {
-    const unsigned int size =
-        cx_size[0] + (cx_size[1] << 8) + (cx_size[2] << 16);
-    return size;
+    unsigned char temp[3];
+    if (pbi->decrypt_cb)
+    {
+        pbi->decrypt_cb(pbi->decrypt_state, cx_size, temp, 3);
+        cx_size = temp;
+    }
+    return cx_size[0] + (cx_size[1] << 8) + (cx_size[2] << 16);
 }
 
 static int read_is_valid(const unsigned char *start,
@@ -794,7 +799,7 @@ static unsigned int read_available_partition_size(
     if (i < num_part - 1)
     {
         if (read_is_valid(partition_size_ptr, 3, first_fragment_end))
-            partition_size = read_partition_size(partition_size_ptr);
+            partition_size = read_partition_size(pbi, partition_size_ptr);
         else if (pbi->ec_active)
             partition_size = (unsigned int)bytes_left;
         else
@@ -894,8 +899,7 @@ static void setup_token_decoder(VP8D_COMP *pbi,
         if (vp8dx_start_decode(bool_decoder,
                                pbi->fragments.ptrs[partition_idx],
                                pbi->fragments.sizes[partition_idx],
-                               pbi->fragments.ptrs[0],
-                               pbi->decrypt_key))
+                               pbi->decrypt_cb, pbi->decrypt_state))
             vpx_internal_error(&pbi->common.error, VPX_CODEC_MEM_ERROR,
                                "Failed to allocate bool decoder %d",
                                partition_idx);
@@ -986,7 +990,6 @@ int vp8_decode_frame(VP8D_COMP *pbi)
     VP8_COMMON *const pc = &pbi->common;
     MACROBLOCKD *const xd  = &pbi->mb;
     const unsigned char *data = pbi->fragments.ptrs[0];
-    const unsigned char *const origin = data;
     const unsigned char *data_end =  data + pbi->fragments.sizes[0];
     ptrdiff_t first_partition_length_in_bytes;
 
@@ -1019,18 +1022,21 @@ int vp8_decode_frame(VP8D_COMP *pbi)
     }
     else
     {
-        const unsigned char data0 = decrypt_byte(data + 0, origin,
-                                                 pbi->decrypt_key);
-        const unsigned char data1 = decrypt_byte(data + 1, origin,
-                                                 pbi->decrypt_key);
-        const unsigned char data2 = decrypt_byte(data + 2, origin,
-                                                 pbi->decrypt_key);
-
-        pc->frame_type = (FRAME_TYPE)(data0 & 1);
-        pc->version = (data0 >> 1) & 7;
-        pc->show_frame = (data0 >> 4) & 1;
+        unsigned char clear_buffer[10];
+        const unsigned char *clear = data;
+        if (pbi->decrypt_cb)
+        {
+            int n = data_end - data;
+            if (n > 10) n = 10;
+            pbi->decrypt_cb(pbi->decrypt_state, data, clear_buffer, n);
+            clear = clear_buffer;
+        }
+
+        pc->frame_type = (FRAME_TYPE)(clear[0] & 1);
+        pc->version = (clear[0] >> 1) & 7;
+        pc->show_frame = (clear[0] >> 4) & 1;
         first_partition_length_in_bytes =
-            (data0 | (data1 << 8) | (data2 << 16)) >> 5;
+            (clear[0] | (clear[1] << 8) | (clear[2] << 16)) >> 5;
 
         if (!pbi->ec_active &&
             (data + first_partition_length_in_bytes > data_end
@@ -1039,6 +1045,7 @@ int vp8_decode_frame(VP8D_COMP *pbi)
                                "Truncated packet or corrupt partition 0 length");
 
         data += 3;
+        clear += 3;
 
         vp8_setup_version(pc);
 
@@ -1051,13 +1058,7 @@ int vp8_decode_frame(VP8D_COMP *pbi)
              */
             if (!pbi->ec_active || data + 3 < data_end)
             {
-                const unsigned char data0 = decrypt_byte(data + 0, origin,
-                                                         pbi->decrypt_key);
-                const unsigned char data1 = decrypt_byte(data + 1, origin,
-                                                         pbi->decrypt_key);
-                const unsigned char data2 = decrypt_byte(data + 2, origin,
-                                                         pbi->decrypt_key);
-                if (data0 != 0x9d || data1 != 0x01 || data2 != 0x2a)
+                if (clear[0] != 0x9d || clear[1] != 0x01 || clear[2] != 0x2a)
                     vpx_internal_error(&pc->error, VPX_CODEC_UNSUP_BITSTREAM,
                                    "Invalid frame sync code");
             }
@@ -1068,22 +1069,13 @@ int vp8_decode_frame(VP8D_COMP *pbi)
              */
             if (!pbi->ec_active || data + 6 < data_end)
             {
-                const unsigned char data3 = decrypt_byte(data + 3, origin,
-                                                         pbi->decrypt_key);
-                const unsigned char data4 = decrypt_byte(data + 4, origin,
-                                                         pbi->decrypt_key);
-                const unsigned char data5 = decrypt_byte(data + 5, origin,
-                                                         pbi->decrypt_key);
-                const unsigned char data6 = decrypt_byte(data + 6, origin,
-                                                         pbi->decrypt_key);
-
-                pc->Width = (data3 | (data4 << 8)) & 0x3fff;
-                pc->horiz_scale = data4 >> 6;
-                pc->Height = (data5 | (data6 << 8)) & 0x3fff;
-                pc->vert_scale = data6 >> 6;
+                pc->Width = (clear[3] | (clear[4] << 8)) & 0x3fff;
+                pc->horiz_scale = clear[4] >> 6;
+                pc->Height = (clear[5] | (clear[6] << 8)) & 0x3fff;
+                pc->vert_scale = clear[6] >> 6;
             }
             data += 7;
-
+            clear += 7;
         }
         else
         {
@@ -1098,11 +1090,8 @@ int vp8_decode_frame(VP8D_COMP *pbi)
 
     init_frame(pbi);
 
-    if (vp8dx_start_decode(bc,
-                           data,
-                           (unsigned int)(data_end - data),
-                           pbi->fragments.ptrs[0],
-                           pbi->decrypt_key))
+    if (vp8dx_start_decode(bc, data, (unsigned int)(data_end - data),
+                           pbi->decrypt_cb, pbi->decrypt_state))
         vpx_internal_error(&pc->error, VPX_CODEC_MEM_ERROR,
                            "Failed to allocate bool decoder 0");
     if (pc->frame_type == KEY_FRAME) {
diff --git a/vp8/decoder/onyxd_int.h b/vp8/decoder/onyxd_int.h
index c2325ebef8..54a98f7cc3 100644
--- a/vp8/decoder/onyxd_int.h
+++ b/vp8/decoder/onyxd_int.h
@@ -122,7 +122,8 @@ typedef struct VP8D_COMP
     int independent_partitions;
     int frame_corrupt_residual;
 
-    const unsigned char *decrypt_key;
+    vp8_decrypt_cb *decrypt_cb;
+    void *decrypt_state;
 } VP8D_COMP;
 
 int vp8_decode_frame(VP8D_COMP *cpi);
diff --git a/vp8/vp8_dx_iface.c b/vp8/vp8_dx_iface.c
index 45cf3859e3..c826f696d3 100644
--- a/vp8/vp8_dx_iface.c
+++ b/vp8/vp8_dx_iface.c
@@ -29,8 +29,6 @@
 #define VP8_CAP_ERROR_CONCEALMENT (CONFIG_ERROR_CONCEALMENT ? \
                                     VPX_CODEC_CAP_ERROR_CONCEALMENT : 0)
 
-#define VP8_DECRYPT_KEY_SIZE 32
-
 typedef vpx_codec_stream_info_t  vp8_stream_info_t;
 
 /* Structures for handling memory allocations */
@@ -75,7 +73,8 @@ struct vpx_codec_alg_priv
     int                     dbg_color_b_modes_flag;
     int                     dbg_display_mv_flag;
 #endif
-    unsigned char           decrypt_key[VP8_DECRYPT_KEY_SIZE];
+    vp8_decrypt_cb          *decrypt_cb;
+    void                    *decrypt_state;
     vpx_image_t             img;
     int                     img_setup;
     struct frame_buffers    yv12_frame_buffers;
@@ -153,8 +152,6 @@ static vpx_codec_err_t vp8_validate_mmaps(const vp8_stream_info_t *si,
     return res;
 }
 
-static const unsigned char fake_decrypt_key[VP8_DECRYPT_KEY_SIZE] = { 0 };
-
 static void vp8_init_ctx(vpx_codec_ctx_t *ctx, const vpx_codec_mmap_t *mmap)
 {
     int i;
@@ -169,8 +166,8 @@ static void vp8_init_ctx(vpx_codec_ctx_t *ctx, const vpx_codec_mmap_t *mmap)
 
     ctx->priv->alg_priv->mmaps[0] = *mmap;
     ctx->priv->alg_priv->si.sz = sizeof(ctx->priv->alg_priv->si);
-    memcpy(ctx->priv->alg_priv->decrypt_key, fake_decrypt_key,
-           VP8_DECRYPT_KEY_SIZE);
+    ctx->priv->alg_priv->decrypt_cb = NULL;
+    ctx->priv->alg_priv->decrypt_state = NULL;
     ctx->priv->init_flags = ctx->init_flags;
 
     if (ctx->config.dec)
@@ -269,10 +266,11 @@ static vpx_codec_err_t vp8_destroy(vpx_codec_alg_priv_t *ctx)
     return VPX_CODEC_OK;
 }
 
-static vpx_codec_err_t vp8_peek_si_external(const uint8_t         *data,
-                                            unsigned int           data_sz,
+static vpx_codec_err_t vp8_peek_si_internal(const uint8_t *data,
+                                            unsigned int data_sz,
                                             vpx_codec_stream_info_t *si,
-                                            const unsigned char *decrypt_key)
+                                            vp8_decrypt_cb *decrypt_cb,
+                                            void *decrypt_state)
 {
     vpx_codec_err_t res = VPX_CODEC_OK;
 
@@ -288,27 +286,26 @@ static vpx_codec_err_t vp8_peek_si_external(const uint8_t         *data,
          * 4 bytes:- including image width and height in the lowest 14 bits
          *           of each 2-byte value.
          */
-
-        const uint8_t data0 = decrypt_byte(data, data, decrypt_key);
-        si->is_kf = 0;
-        if (data_sz >= 10 && !(data0 & 0x01))  /* I-Frame */
+        uint8_t clear_buffer[10];
+        const uint8_t *clear = data;
+        if (decrypt_cb)
         {
-            const uint8_t data3 = decrypt_byte(data + 3, data, decrypt_key);
-            const uint8_t data4 = decrypt_byte(data + 4, data, decrypt_key);
-            const uint8_t data5 = decrypt_byte(data + 5, data, decrypt_key);
-            const uint8_t data6 = decrypt_byte(data + 6, data, decrypt_key);
-            const uint8_t data7 = decrypt_byte(data + 7, data, decrypt_key);
-            const uint8_t data8 = decrypt_byte(data + 8, data, decrypt_key);
-            const uint8_t data9 = decrypt_byte(data + 9, data, decrypt_key);
+            int n = data_sz > 10 ? 10 : data_sz;
+            decrypt_cb(decrypt_state, data, clear_buffer, n);
+            clear = clear_buffer;
+        }
+        si->is_kf = 0;
 
+        if (data_sz >= 10 && !(clear[0] & 0x01))  /* I-Frame */
+        {
             si->is_kf = 1;
 
             /* vet via sync code */
-            if (data3 != 0x9d || data4 != 0x01 || data5 != 0x2a)
+            if (clear[3] != 0x9d || clear[4] != 0x01 || clear[5] != 0x2a)
                 res = VPX_CODEC_UNSUP_BITSTREAM;
 
-            si->w = (data6 | (data7 << 8)) & 0x3fff;
-            si->h = (data8 | (data9 << 8)) & 0x3fff;
+            si->w = (clear[6] | (clear[7] << 8)) & 0x3fff;
+            si->h = (clear[8] | (clear[9] << 8)) & 0x3fff;
 
             /*printf("w=%d, h=%d\n", si->w, si->h);*/
             if (!(si->h | si->w))
@@ -326,7 +323,7 @@ static vpx_codec_err_t vp8_peek_si_external(const uint8_t         *data,
 static vpx_codec_err_t vp8_peek_si(const uint8_t *data,
                                    unsigned int data_sz,
                                    vpx_codec_stream_info_t *si) {
-    return vp8_peek_si_external(data, data_sz, si, fake_decrypt_key);
+    return vp8_peek_si_internal(data, data_sz, si, NULL, NULL);
 }
 
 static vpx_codec_err_t vp8_get_si(vpx_codec_alg_priv_t    *ctx,
@@ -455,10 +452,8 @@ static vpx_codec_err_t vp8_decode(vpx_codec_alg_priv_t  *ctx,
     w = ctx->si.w;
     h = ctx->si.h;
 
-    res = vp8_peek_si_external(ctx->fragments.ptrs[0],
-                               ctx->fragments.sizes[0],
-                               &ctx->si,
-                               ctx->decrypt_key);
+    res = vp8_peek_si_internal(ctx->fragments.ptrs[0], ctx->fragments.sizes[0],
+                               &ctx->si, ctx->decrypt_cb, ctx->decrypt_state);
 
     if((res == VPX_CODEC_UNSUP_BITSTREAM) && !ctx->si.is_kf)
     {
@@ -532,7 +527,8 @@ static vpx_codec_err_t vp8_decode(vpx_codec_alg_priv_t  *ctx,
             }
 
             res = vp8_create_decoder_instances(&ctx->yv12_frame_buffers, &oxcf);
-            ctx->yv12_frame_buffers.pbi[0]->decrypt_key = ctx->decrypt_key;
+            ctx->yv12_frame_buffers.pbi[0]->decrypt_cb = ctx->decrypt_cb;
+            ctx->yv12_frame_buffers.pbi[0]->decrypt_state = ctx->decrypt_state;
         }
 
         ctx->decoder_init = 1;
@@ -956,17 +952,22 @@ static vpx_codec_err_t vp8_get_frame_corrupted(vpx_codec_alg_priv_t *ctx,
 
 }
 
-
-static vpx_codec_err_t vp8_set_decrypt_key(vpx_codec_alg_priv_t *ctx,
-                                           int ctr_id,
-                                           va_list args)
+static vpx_codec_err_t vp8_set_decryptor(vpx_codec_alg_priv_t *ctx,
+                                         int ctrl_id,
+                                         va_list args)
 {
-    const unsigned char *data = va_arg(args, const unsigned char *);
-    if (data == NULL) {
-        return VPX_CODEC_INVALID_PARAM;
-    }
+    vp8_decrypt_init *init = va_arg(args, vp8_decrypt_init *);
 
-    memcpy(ctx->decrypt_key, data, VP8_DECRYPT_KEY_SIZE);
+    if (init)
+    {
+        ctx->decrypt_cb = init->decrypt_cb;
+        ctx->decrypt_state = init->decrypt_state;
+    }
+    else
+    {
+        ctx->decrypt_cb = NULL;
+        ctx->decrypt_state = NULL;
+    }
     return VPX_CODEC_OK;
 }
 
@@ -982,7 +983,7 @@ vpx_codec_ctrl_fn_map_t vp8_ctf_maps[] =
     {VP8D_GET_LAST_REF_UPDATES,     vp8_get_last_ref_updates},
     {VP8D_GET_FRAME_CORRUPTED,      vp8_get_frame_corrupted},
     {VP8D_GET_LAST_REF_USED,        vp8_get_last_ref_frame},
-    {VP8_SET_DECRYPT_KEY,           vp8_set_decrypt_key},
+    {VP8D_SET_DECRYPTOR,            vp8_set_decryptor},
     { -1, NULL},
 };
 
diff --git a/vpx/vp8dx.h b/vpx/vp8dx.h
index 201df88fe2..97f111d697 100644
--- a/vpx/vp8dx.h
+++ b/vpx/vp8dx.h
@@ -63,11 +63,11 @@ enum vp8_dec_control_id {
    */
   VP8D_GET_LAST_REF_USED,
 
-  /** decryption key to protect encoded data buffer before decoding,
-   *  pointer to 32 byte array which is copied, so the array passed
-   *  does not need to be preserved
+  /** decryption function to decrypt encoded buffer data immediately
+   * before decoding. Takes a vp8_decrypt_init, which contains
+   * a callback function and opaque context pointer.
    */
-  VP8_SET_DECRYPT_KEY,
+  VP8D_SET_DECRYPTOR,
 
   /** For testing. */
   VP9_INVERT_TILE_DECODE_ORDER,
@@ -76,6 +76,17 @@ enum vp8_dec_control_id {
 };
 
 
+/*Decrypt n bytes of data from input -> output, using the decrypt_state
+   passed in VP8D_SET_DECRYPTOR.
+*/
+typedef void (vp8_decrypt_cb)(void *decrypt_state, const unsigned char *input,
+                              unsigned char *output, int count);
+
+typedef struct vp8_decrypt_init {
+    vp8_decrypt_cb *decrypt_cb;
+    void *decrypt_state;
+} vp8_decrypt_init;
+
 /*!\brief VP8 decoder control function parameter type
  *
  * Defines the data types that VP8D control functions take. Note that
@@ -87,7 +98,7 @@ enum vp8_dec_control_id {
 VPX_CTRL_USE_TYPE(VP8D_GET_LAST_REF_UPDATES,   int *)
 VPX_CTRL_USE_TYPE(VP8D_GET_FRAME_CORRUPTED,    int *)
 VPX_CTRL_USE_TYPE(VP8D_GET_LAST_REF_USED,      int *)
-VPX_CTRL_USE_TYPE(VP8_SET_DECRYPT_KEY,         const unsigned char *)
+VPX_CTRL_USE_TYPE(VP8D_SET_DECRYPTOR,          vp8_decrypt_init *)
 VPX_CTRL_USE_TYPE(VP9_INVERT_TILE_DECODE_ORDER, int)
 
 /*! @} - end defgroup vp8_decoder */
-- 
GitLab