diff --git a/celt/bands.c b/celt/bands.c
index 44e9074dfbe376b5925998143776b69074b5ad49..5088ee85efc78099255699e0e766c6598eacafce 100644
--- a/celt/bands.c
+++ b/celt/bands.c
@@ -1336,7 +1336,7 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
       const celt_ener *bandE, int *pulses, int shortBlocks, int spread,
       int dual_stereo, int intensity, int *tf_res, opus_int32 total_bits,
       opus_int32 balance, ec_ctx *ec, int LM, int codedBands,
-      opus_uint32 *seed, int arch)
+      opus_uint32 *seed, int complexity, int arch)
 {
    int i;
    opus_int32 remaining_bits;
@@ -1344,6 +1344,8 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
    celt_norm * OPUS_RESTRICT norm, * OPUS_RESTRICT norm2;
    VARDECL(celt_norm, _norm);
    VARDECL(celt_norm, _lowband_scratch);
+   VARDECL(celt_norm, X_save);
+   VARDECL(celt_norm, Y_save);
    int resynth_alloc;
    celt_norm *lowband_scratch;
    int B;
@@ -1352,10 +1354,11 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
    int update_lowband = 1;
    int C = Y_ != NULL ? 2 : 1;
    int norm_offset;
+   int theta_rdo = encode && Y_!=NULL && !dual_stereo && complexity>=8;
 #ifdef RESYNTH
    int resynth = 1;
 #else
-   int resynth = !encode;
+   int resynth = !encode || theta_rdo;
 #endif
    struct band_ctx ctx;
    SAVE_STACK;
@@ -1381,6 +1384,8 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
       lowband_scratch = _lowband_scratch;
    else
       lowband_scratch = X_+M*eBands[m->nbEBands-1];
+   ALLOC(X_save, resynth_alloc, celt_norm);
+   ALLOC(Y_save, resynth_alloc, celt_norm);
 
    lowband_offset = 0;
    ctx.bandE = bandE;
@@ -1392,6 +1397,7 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
    ctx.spread = spread;
    ctx.arch = arch;
    ctx.resynth = resynth;
+   ctx.theta_round = 0;
    for (i=start;i<end;i++)
    {
       opus_int32 tell;
@@ -1441,7 +1447,7 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
             Y = norm;
          lowband_scratch = NULL;
       }
-      if (i==end-1)
+      if (last && !theta_rdo)
          lowband_scratch = NULL;
 
       /* Get a conservative estimate of the collapse_mask's for the bands we're
@@ -1489,10 +1495,51 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
       } else {
          if (Y!=NULL)
          {
-            ctx.theta_round = 0;
-            x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
-                  effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
-                  last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, x_cm|y_cm);
+            if (theta_rdo)
+            {
+               ec_ctx ec_save;
+               struct band_ctx ctx_save;
+               opus_val32 dist0, dist1;
+               unsigned cm;
+               /* Make a copy. */
+               cm = x_cm|y_cm;
+               ec_save = *ec;
+               ctx_save = ctx;
+               OPUS_COPY(X_save, X, N);
+               OPUS_COPY(Y_save, Y, N);
+               /* Encode and round down. */
+               ctx.theta_round = -1;
+               x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
+                     effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
+                     last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, cm);
+               dist0 = celt_inner_prod(X_save, X, N, arch) + celt_inner_prod(Y_save, Y, N, arch);
+               /* Restore */
+               *ec = ec_save;
+               ctx = ctx_save;
+               OPUS_COPY(X, X_save, N);
+               OPUS_COPY(Y, Y_save, N);
+               /* Encode and round up. */
+               ctx.theta_round = 1;
+               x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
+                     effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
+                     last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, cm);
+               dist1 = celt_inner_prod(X_save, X, N, arch) + celt_inner_prod(Y_save, Y, N, arch);
+               /* Restore */
+               *ec = ec_save;
+               ctx = ctx_save;
+               OPUS_COPY(X, X_save, N);
+               OPUS_COPY(Y, Y_save, N);
+               /* Encode with best choice. */
+               ctx.theta_round = dist0 >= dist1 ? -1 : 1;
+               x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
+                     effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
+                     last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, cm);
+            } else {
+               ctx.theta_round = 0;
+               x_cm = quant_band_stereo(&ctx, X, Y, N, b, B,
+                     effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
+                     last?NULL:norm+M*eBands[i]-norm_offset, lowband_scratch, x_cm|y_cm);
+            }
          } else {
             x_cm = quant_band(&ctx, X, N, b, B,
                   effective_lowband != -1 ? norm+effective_lowband : NULL, LM,
diff --git a/celt/bands.h b/celt/bands.h
index e8bef4bad0e0b254d91062c36b7aae25d31dc78a..41d080f6709a00d0c5e57e584c781cb118867642 100644
--- a/celt/bands.h
+++ b/celt/bands.h
@@ -105,7 +105,7 @@ void quant_all_bands(int encode, const CELTMode *m, int start, int end,
       const celt_ener *bandE, int *pulses, int shortBlocks, int spread,
       int dual_stereo, int intensity, int *tf_res, opus_int32 total_bits,
       opus_int32 balance, ec_ctx *ec, int M, int codedBands, opus_uint32 *seed,
-      int arch);
+      int complexity, int arch);
 
 void anti_collapse(const CELTMode *m, celt_norm *X_,
       unsigned char *collapse_masks, int LM, int C, int size, int start,
diff --git a/celt/celt_decoder.c b/celt/celt_decoder.c
index 90fee158798b630fa252b3cd69e28040132e9baa..4ab89339e2a602997651acffb75fb55fa12ce280 100644
--- a/celt/celt_decoder.c
+++ b/celt/celt_decoder.c
@@ -1005,7 +1005,7 @@ int celt_decode_with_ec(CELTDecoder * OPUS_RESTRICT st, const unsigned char *dat
 
    quant_all_bands(0, mode, start, end, X, C==2 ? X+N : NULL, collapse_masks,
          NULL, pulses, shortBlocks, spread_decision, dual_stereo, intensity, tf_res,
-         len*(8<<BITRES)-anti_collapse_rsv, balance, dec, LM, codedBands, &st->rng, st->arch);
+         len*(8<<BITRES)-anti_collapse_rsv, balance, dec, LM, codedBands, &st->rng, 0, st->arch);
 
    if (anti_collapse_rsv > 0)
    {
diff --git a/celt/celt_encoder.c b/celt/celt_encoder.c
index 4473b37bf6774430bd0d6dbd6ff55868f8d7d55c..1a031d9e1417adcb0d505b42861fbb630bd39fd7 100644
--- a/celt/celt_encoder.c
+++ b/celt/celt_encoder.c
@@ -2079,7 +2079,7 @@ int celt_encode_with_ec(CELTEncoder * OPUS_RESTRICT st, const opus_val16 * pcm,
    quant_all_bands(1, mode, start, end, X, C==2 ? X+N : NULL, collapse_masks,
          bandE, pulses, shortBlocks, st->spread_decision,
          dual_stereo, st->intensity, tf_res, nbCompressedBytes*(8<<BITRES)-anti_collapse_rsv,
-         balance, enc, LM, codedBands, &st->rng, st->arch);
+         balance, enc, LM, codedBands, &st->rng, st->complexity, st->arch);
 
    if (anti_collapse_rsv > 0)
    {