diff --git a/av1/encoder/context_tree.c b/av1/encoder/context_tree.c
index bd9a7d0252911d045f23d0a1c58045486634af00..969444df4be30f8413895f7aab77a66a30a56693 100644
--- a/av1/encoder/context_tree.c
+++ b/av1/encoder/context_tree.c
@@ -22,19 +22,14 @@ static const BLOCK_SIZE square[MAX_SB_SIZE_LOG2 - 1] = {
 #endif  // CONFIG_EXT_PARTITION
 };
 
-static void alloc_mode_context(AV1_COMMON *cm, int num_4x4_blk,
+static void alloc_mode_context(AV1_COMMON *cm, int num_pix,
 #if CONFIG_EXT_PARTITION_TYPES
                                PARTITION_TYPE partition,
 #endif
                                PICK_MODE_CONTEXT *ctx) {
-  const int num_blk = (num_4x4_blk < 4 ? 4 : num_4x4_blk);
-  const int num_pix = num_blk * tx_size_2d[0];
   int i;
-#if CONFIG_CB4X4 && CONFIG_VAR_TX
-  ctx->num_4x4_blk = num_blk / 4;
-#else
+  const int num_blk = num_pix / 16;
   ctx->num_4x4_blk = num_blk;
-#endif
 
 #if CONFIG_EXT_PARTITION_TYPES
   ctx->partition = partition;
@@ -110,72 +105,53 @@ static void free_mode_context(PICK_MODE_CONTEXT *ctx) {
 #endif  // CONFIG_MRC_TX
 }
 
-static void alloc_tree_contexts(AV1_COMMON *cm, PC_TREE *tree,
-                                int num_4x4_blk) {
+static void alloc_tree_contexts(AV1_COMMON *cm, PC_TREE *tree, int num_pix) {
 #if CONFIG_EXT_PARTITION_TYPES
-  alloc_mode_context(cm, num_4x4_blk, PARTITION_NONE, &tree->none);
-  alloc_mode_context(cm, num_4x4_blk / 2, PARTITION_HORZ, &tree->horizontal[0]);
-  alloc_mode_context(cm, num_4x4_blk / 2, PARTITION_VERT, &tree->vertical[0]);
-  alloc_mode_context(cm, num_4x4_blk / 2, PARTITION_VERT, &tree->horizontal[1]);
-  alloc_mode_context(cm, num_4x4_blk / 2, PARTITION_VERT, &tree->vertical[1]);
+  alloc_mode_context(cm, num_pix, PARTITION_NONE, &tree->none);
+  alloc_mode_context(cm, num_pix / 2, PARTITION_HORZ, &tree->horizontal[0]);
+  alloc_mode_context(cm, num_pix / 2, PARTITION_VERT, &tree->vertical[0]);
+  alloc_mode_context(cm, num_pix / 2, PARTITION_VERT, &tree->horizontal[1]);
+  alloc_mode_context(cm, num_pix / 2, PARTITION_VERT, &tree->vertical[1]);
 
-  alloc_mode_context(cm, num_4x4_blk / 4, PARTITION_HORZ_A,
-                     &tree->horizontala[0]);
-  alloc_mode_context(cm, num_4x4_blk / 4, PARTITION_HORZ_A,
-                     &tree->horizontala[1]);
-  alloc_mode_context(cm, num_4x4_blk / 2, PARTITION_HORZ_A,
-                     &tree->horizontala[2]);
-  alloc_mode_context(cm, num_4x4_blk / 2, PARTITION_HORZ_B,
-                     &tree->horizontalb[0]);
-  alloc_mode_context(cm, num_4x4_blk / 4, PARTITION_HORZ_B,
-                     &tree->horizontalb[1]);
-  alloc_mode_context(cm, num_4x4_blk / 4, PARTITION_HORZ_B,
-                     &tree->horizontalb[2]);
-  alloc_mode_context(cm, num_4x4_blk / 4, PARTITION_VERT_A,
-                     &tree->verticala[0]);
-  alloc_mode_context(cm, num_4x4_blk / 4, PARTITION_VERT_A,
-                     &tree->verticala[1]);
-  alloc_mode_context(cm, num_4x4_blk / 2, PARTITION_VERT_A,
-                     &tree->verticala[2]);
-  alloc_mode_context(cm, num_4x4_blk / 2, PARTITION_VERT_B,
-                     &tree->verticalb[0]);
-  alloc_mode_context(cm, num_4x4_blk / 4, PARTITION_VERT_B,
-                     &tree->verticalb[1]);
-  alloc_mode_context(cm, num_4x4_blk / 4, PARTITION_VERT_B,
-                     &tree->verticalb[2]);
+  alloc_mode_context(cm, num_pix / 4, PARTITION_HORZ_A, &tree->horizontala[0]);
+  alloc_mode_context(cm, num_pix / 4, PARTITION_HORZ_A, &tree->horizontala[1]);
+  alloc_mode_context(cm, num_pix / 2, PARTITION_HORZ_A, &tree->horizontala[2]);
+  alloc_mode_context(cm, num_pix / 2, PARTITION_HORZ_B, &tree->horizontalb[0]);
+  alloc_mode_context(cm, num_pix / 4, PARTITION_HORZ_B, &tree->horizontalb[1]);
+  alloc_mode_context(cm, num_pix / 4, PARTITION_HORZ_B, &tree->horizontalb[2]);
+  alloc_mode_context(cm, num_pix / 4, PARTITION_VERT_A, &tree->verticala[0]);
+  alloc_mode_context(cm, num_pix / 4, PARTITION_VERT_A, &tree->verticala[1]);
+  alloc_mode_context(cm, num_pix / 2, PARTITION_VERT_A, &tree->verticala[2]);
+  alloc_mode_context(cm, num_pix / 2, PARTITION_VERT_B, &tree->verticalb[0]);
+  alloc_mode_context(cm, num_pix / 4, PARTITION_VERT_B, &tree->verticalb[1]);
+  alloc_mode_context(cm, num_pix / 4, PARTITION_VERT_B, &tree->verticalb[2]);
   for (int i = 0; i < 4; ++i) {
-    alloc_mode_context(cm, num_4x4_blk / 4, PARTITION_HORZ_4,
+    alloc_mode_context(cm, num_pix / 4, PARTITION_HORZ_4,
                        &tree->horizontal4[i]);
-    alloc_mode_context(cm, num_4x4_blk / 4, PARTITION_HORZ_4,
-                       &tree->vertical4[i]);
+    alloc_mode_context(cm, num_pix / 4, PARTITION_HORZ_4, &tree->vertical4[i]);
   }
 #if CONFIG_SUPERTX
-  alloc_mode_context(cm, num_4x4_blk, PARTITION_HORZ,
-                     &tree->horizontal_supertx);
-  alloc_mode_context(cm, num_4x4_blk, PARTITION_VERT, &tree->vertical_supertx);
-  alloc_mode_context(cm, num_4x4_blk, PARTITION_SPLIT, &tree->split_supertx);
-  alloc_mode_context(cm, num_4x4_blk, PARTITION_HORZ_A,
-                     &tree->horizontala_supertx);
-  alloc_mode_context(cm, num_4x4_blk, PARTITION_HORZ_B,
-                     &tree->horizontalb_supertx);
-  alloc_mode_context(cm, num_4x4_blk, PARTITION_VERT_A,
-                     &tree->verticala_supertx);
-  alloc_mode_context(cm, num_4x4_blk, PARTITION_VERT_B,
-                     &tree->verticalb_supertx);
+  alloc_mode_context(cm, num_pix, PARTITION_HORZ, &tree->horizontal_supertx);
+  alloc_mode_context(cm, num_pix, PARTITION_VERT, &tree->vertical_supertx);
+  alloc_mode_context(cm, num_pix, PARTITION_SPLIT, &tree->split_supertx);
+  alloc_mode_context(cm, num_pix, PARTITION_HORZ_A, &tree->horizontala_supertx);
+  alloc_mode_context(cm, num_pix, PARTITION_HORZ_B, &tree->horizontalb_supertx);
+  alloc_mode_context(cm, num_pix, PARTITION_VERT_A, &tree->verticala_supertx);
+  alloc_mode_context(cm, num_pix, PARTITION_VERT_B, &tree->verticalb_supertx);
 #endif  // CONFIG_SUPERTX
 #else
-  alloc_mode_context(cm, num_4x4_blk, &tree->none);
-  alloc_mode_context(cm, num_4x4_blk / 2, &tree->horizontal[0]);
-  alloc_mode_context(cm, num_4x4_blk / 2, &tree->vertical[0]);
+  alloc_mode_context(cm, num_pix, &tree->none);
+  alloc_mode_context(cm, num_pix / 2, &tree->horizontal[0]);
+  alloc_mode_context(cm, num_pix / 2, &tree->vertical[0]);
 #if CONFIG_SUPERTX
-  alloc_mode_context(cm, num_4x4_blk, &tree->horizontal_supertx);
-  alloc_mode_context(cm, num_4x4_blk, &tree->vertical_supertx);
-  alloc_mode_context(cm, num_4x4_blk, &tree->split_supertx);
+  alloc_mode_context(cm, num_pix, &tree->horizontal_supertx);
+  alloc_mode_context(cm, num_pix, &tree->vertical_supertx);
+  alloc_mode_context(cm, num_pix, &tree->split_supertx);
 #endif
 
-  if (num_4x4_blk > 4) {
-    alloc_mode_context(cm, num_4x4_blk / 2, &tree->horizontal[1]);
-    alloc_mode_context(cm, num_4x4_blk / 2, &tree->vertical[1]);
+  if (num_pix > 16) {
+    alloc_mode_context(cm, num_pix / 2, &tree->horizontal[1]);
+    alloc_mode_context(cm, num_pix / 2, &tree->vertical[1]);
   } else {
     memset(&tree->horizontal[1], 0, sizeof(tree->horizontal[1]));
     memset(&tree->vertical[1], 0, sizeof(tree->vertical[1]));