diff --git a/celt/mdct.c b/celt/mdct.c
index 10ec8026b946dac71d9ae3a8b2067aca4832bed0..15c7ffd7fed61bbd0abfc1681ae19a486e124b69 100644
--- a/celt/mdct.c
+++ b/celt/mdct.c
@@ -119,15 +119,20 @@ void clt_mdct_forward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scalar
    VARDECL(kiss_fft_cpx, f2);
    const kiss_fft_state *st = l->kfft[shift];
    const kiss_twiddle_scalar *trig;
+   opus_val16 scale;
 #ifdef FIXED_POINT
    /* FIXME: This should eventually just go in the state. */
-   opus_val16 scale;
    int scale_shift;
    scale_shift = celt_ilog2(st->nfft);
    if (st->nfft == 1<<scale_shift)
       scale = Q15ONE;
    else
       scale = (1073741824+st->nfft/2)/st->nfft>>(15-scale_shift);
+   /* Allows us to scale with MULT16_32_Q16(), which is faster than
+      MULT16_32_Q15() on ARM. */
+   scale_shift--;
+#else
+   scale = st->scale;
 #endif
    SAVE_STACK;
 
@@ -195,28 +200,19 @@ void clt_mdct_forward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scalar
          kiss_fft_scalar re, im, yr, yi;
          t0 = t[i];
          t1 = t[N4+i];
-#ifdef FIXED_POINT
-         t0 = MULT16_16_P15(t0, scale);
-         t1 = MULT16_16_P15(t1, scale);
-#else
-         t0 *= st->scale;
-         t1 *= st->scale;
-#endif
          re = *yp++;
          im = *yp++;
          yr = -S_MUL(re,t0)  +  S_MUL(im,t1);
          yi = -S_MUL(im,t0)  -  S_MUL(re,t1);
          yc.r = yr;
          yc.i = yi;
-#ifdef FIXED_POINT
-         yc.r = SHR32(yc.r, scale_shift);
-         yc.i = SHR32(yc.i, scale_shift);
-#endif
+         yc.r = PSHR32(MULT16_32_Q16(scale, yc.r), scale_shift);
+         yc.i = PSHR32(MULT16_32_Q16(scale, yc.i), scale_shift);
          f2[st->bitrev[i]] = yc;
       }
    }
 
-   /* N/4 complex FFT, down-scales by 4/N */
+   /* N/4 complex FFT, does not downscale anymore */
    opus_fft_impl(st, f2);
 
    /* Post-rotate */