diff --git a/celt/kiss_fft.c b/celt/kiss_fft.c
index d37f4ed4e4537c43f2cbd6b13ee6e72e4ac8a57b..e67309b293b9ed7cea004651f6cf496cb0ae7542 100644
--- a/celt/kiss_fft.c
+++ b/celt/kiss_fft.c
@@ -600,7 +600,7 @@ void opus_fft_free(const kiss_fft_state *cfg)
 
 #endif /* CUSTOM_MODES */
 
-void opus_fft(const kiss_fft_state *st,const kiss_fft_cpx *fin,kiss_fft_cpx *fout)
+void opus_fft_impl(const kiss_fft_state *st,kiss_fft_cpx *fout)
 {
     int m2, m;
     int p;
@@ -608,34 +608,10 @@ void opus_fft(const kiss_fft_state *st,const kiss_fft_cpx *fin,kiss_fft_cpx *fou
     int fstride[MAXFACTORS];
     int i;
     int shift;
-#ifdef FIXED_POINT
-    /* FIXME: This should eventually just go in the state. */
-    opus_val16 scale;
-    int scale_shift;
-    scale_shift = celt_ilog2(st->nfft);
-    if (st->nfft == 1<<scale_shift)
-       scale = Q15ONE;
-    else
-       scale = (1073741824+st->nfft/2)/st->nfft>>(15-scale_shift);
-#endif
 
     /* st->shift can be -1 */
     shift = st->shift>0 ? st->shift : 0;
 
-    celt_assert2 (fin != fout, "In-place FFT not supported");
-    /* Bit-reverse the input */
-    for (i=0;i<st->nfft;i++)
-    {
-       kiss_fft_cpx x = fin[i];
-#ifdef FIXED_POINT
-       fout[st->bitrev[i]].r = SHR32(MULT16_32_Q15(scale, x.r), scale_shift);
-       fout[st->bitrev[i]].i = SHR32(MULT16_32_Q15(scale, x.i), scale_shift);
-#else
-       fout[st->bitrev[i]].r = st->scale*x.r;
-       fout[st->bitrev[i]].i = st->scale*x.i;
-#endif
-    }
-
     fstride[0] = 1;
     L=0;
     do {
@@ -672,6 +648,36 @@ void opus_fft(const kiss_fft_state *st,const kiss_fft_cpx *fin,kiss_fft_cpx *fou
     }
 }
 
+void opus_fft(const kiss_fft_state *st,const kiss_fft_cpx *fin,kiss_fft_cpx *fout)
+{
+   int i;
+#ifdef FIXED_POINT
+   /* FIXME: This should eventually just go in the state. */
+   opus_val16 scale;
+   int scale_shift;
+   scale_shift = celt_ilog2(st->nfft);
+   if (st->nfft == 1<<scale_shift)
+      scale = Q15ONE;
+   else
+      scale = (1073741824+st->nfft/2)/st->nfft>>(15-scale_shift);
+#endif
+
+   celt_assert2 (fin != fout, "In-place FFT not supported");
+   /* Bit-reverse the input */
+   for (i=0;i<st->nfft;i++)
+   {
+      kiss_fft_cpx x = fin[i];
+#ifdef FIXED_POINT
+      fout[st->bitrev[i]].r = SHR32(MULT16_32_Q15(scale, x.r), scale_shift);
+      fout[st->bitrev[i]].i = SHR32(MULT16_32_Q15(scale, x.i), scale_shift);
+#else
+      fout[st->bitrev[i]].r = st->scale*x.r;
+      fout[st->bitrev[i]].i = st->scale*x.i;
+#endif
+   }
+   opus_fft_impl(st, fout);
+}
+
 void opus_ifft_impl(const kiss_fft_state *st,kiss_fft_cpx *fout)
 {
    int m2, m;
diff --git a/celt/kiss_fft.h b/celt/kiss_fft.h
index ee5aae55ff5ddb855615faf09384d8573afe366c..67ac5f375e8b922d682af49bfd57e7161ba2ca64 100644
--- a/celt/kiss_fft.h
+++ b/celt/kiss_fft.h
@@ -130,6 +130,7 @@ kiss_fft_state *opus_fft_alloc(int nfft,void * mem,size_t * lenmem);
 void opus_fft(const kiss_fft_state *cfg,const kiss_fft_cpx *fin,kiss_fft_cpx *fout);
 void opus_ifft(const kiss_fft_state *cfg,const kiss_fft_cpx *fin,kiss_fft_cpx *fout);
 
+void opus_fft_impl(const kiss_fft_state *st,kiss_fft_cpx *fout);
 void opus_ifft_impl(const kiss_fft_state *st,kiss_fft_cpx *fout);
 
 void opus_fft_free(const kiss_fft_state *cfg);
diff --git a/celt/mdct.c b/celt/mdct.c
index b4209195780e159e4eef60cbf8ad485654a23e80..a6bd6b4cf219b4e2920facf1c84303af724c598c 100644
--- a/celt/mdct.c
+++ b/celt/mdct.c
@@ -109,14 +109,25 @@ void clt_mdct_forward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scalar
    int N, N2, N4;
    kiss_twiddle_scalar sine;
    VARDECL(kiss_fft_scalar, f);
-   VARDECL(kiss_fft_scalar, f2);
+   VARDECL(kiss_fft_cpx, f2);
+   const kiss_fft_state *st = l->kfft[shift];
+#ifdef FIXED_POINT
+   /* FIXME: This should eventually just go in the state. */
+   opus_val16 scale;
+   int scale_shift;
+   scale_shift = celt_ilog2(st->nfft);
+   if (st->nfft == 1<<scale_shift)
+      scale = Q15ONE;
+   else
+      scale = (1073741824+st->nfft/2)/st->nfft>>(15-scale_shift);
+#endif
    SAVE_STACK;
    N = l->n;
    N >>= shift;
    N2 = N>>1;
    N4 = N>>2;
    ALLOC(f, N2, kiss_fft_scalar);
-   ALLOC(f2, N2, kiss_fft_scalar);
+   ALLOC(f2, N2, kiss_fft_cpx);
    /* sin(x) ~= x here */
 #ifdef FIXED_POINT
    sine = TRIG_UPSCALE*(QCONST16(0.7853981f, 15)+N2)/N;
@@ -170,24 +181,33 @@ void clt_mdct_forward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scalar
       const kiss_twiddle_scalar *t = &l->trig[0];
       for(i=0;i<N4;i++)
       {
+         kiss_fft_cpx yc;
          kiss_fft_scalar re, im, yr, yi;
-         re = yp[0];
-         im = yp[1];
+         re = *yp++;
+         im = *yp++;
          yr = -S_MUL(re,t[i<<shift])  -  S_MUL(im,t[(N4-i)<<shift]);
          yi = -S_MUL(im,t[i<<shift])  +  S_MUL(re,t[(N4-i)<<shift]);
          /* works because the cos is nearly one */
-         *yp++ = yr + S_MUL(yi,sine);
-         *yp++ = yi - S_MUL(yr,sine);
+         yc.r = yr + S_MUL(yi,sine);
+         yc.i = yi - S_MUL(yr,sine);
+#ifdef FIXED_POINT
+         yc.r = SHR32(MULT16_32_Q15(scale, yc.r), scale_shift);
+         yc.i = SHR32(MULT16_32_Q15(scale, yc.i), scale_shift);
+#else
+         yc.r *= st->scale;
+         yc.i *= st->scale;
+#endif
+         f2[st->bitrev[i]] = yc;
       }
    }
 
    /* N/4 complex FFT, down-scales by 4/N */
-   opus_fft(l->kfft[shift], (kiss_fft_cpx *)f, (kiss_fft_cpx *)f2);
+   opus_fft_impl(st, f2);
 
    /* Post-rotate */
    {
       /* Temp pointers to make it really clear to the compiler what we're doing */
-      const kiss_fft_scalar * OPUS_RESTRICT fp = f2;
+      const kiss_fft_cpx * OPUS_RESTRICT fp = f2;
       kiss_fft_scalar * OPUS_RESTRICT yp1 = out;
       kiss_fft_scalar * OPUS_RESTRICT yp2 = out+stride*(N2-1);
       const kiss_twiddle_scalar *t = &l->trig[0];
@@ -195,12 +215,12 @@ void clt_mdct_forward(const mdct_lookup *l, kiss_fft_scalar *in, kiss_fft_scalar
       for(i=0;i<N4;i++)
       {
          kiss_fft_scalar yr, yi;
-         yr = S_MUL(fp[1],t[(N4-i)<<shift]) + S_MUL(fp[0],t[i<<shift]);
-         yi = S_MUL(fp[0],t[(N4-i)<<shift]) - S_MUL(fp[1],t[i<<shift]);
+         yr = S_MUL(fp->i,t[(N4-i)<<shift]) + S_MUL(fp->r,t[i<<shift]);
+         yi = S_MUL(fp->r,t[(N4-i)<<shift]) - S_MUL(fp->i,t[i<<shift]);
          /* works because the cos is nearly one */
          *yp1 = yr - S_MUL(yi,sine);
          *yp2 = yi + S_MUL(yr,sine);;
-         fp += 2;
+         fp++;
          yp1 += 2*stride;
          yp2 -= 2*stride;
       }