/* -*- c++ -*- */
/*
 * Copyright 2005 Free Software Foundation, Inc.
 * 
 * This file is part of GNU Radio
 * 
 * GNU Radio is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2, or (at your option)
 * any later version.
 * 
 * GNU Radio is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with GNU Radio; see the file COPYING.  If not, write to
 * the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
 * Boston, MA 02111-1307, USA.
 */

#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif

#include <iostream>
#include <string>
#include <fstream>
#include <unistd.h>
#include <stdlib.h>
#include <getopt.h>
#include <boost/scoped_array.hpp>
#include <gr_complex.h>
#include <gr_fxpt_nco.h>
#include "time_series.h"

#include <gri_fft.h>


gr_complex
complex_conj_dotprod(const gr_complex *a, const gr_complex *b, unsigned n)
{
  gr_complex acc = 0;
  for (unsigned i = 0; i < n; i++)
    acc += a[i] * conj(b[i]);

  return acc;
}

int
get_efficient_fft_size(unsigned int min_fft_size,bool only_powers_of_two=false)
{
  /*
  FFTW has efficient implementations for fftsizes which are factors of small primes.
  The best are powers of two, but factors of small primes also work well.
  Algorithm:
  First find first power of two bigger then the requested size
  Then divide this by 32
  Then try the smalles one of the factors below (which only use the prime factors 2, 3 and 5)
  powers of 2, 3 and 5
  //factors of 16 and below do not need checking because 16 is a power of two so it would have been found alreadyin first part
  //2  = 2
  //3  = 3
  //4  = 2*2
  //5  = 5
  //6  = 2*3
  //8  = 2*2*2
  //9  = 3*3
  //10 = 2*5
  //12 = 2*2*3
  //15 = 3*5
  //16 = 2*2*2*2
  18 = 2*3*3
  20 = 2*2*5
  24 = 2*2*2*3
  25 = 5*5
  27 = 3*3*3
  30 = 2*3*5
  32 = 2*2*2*2*2
  */
  unsigned int fftsize=2;
  while(fftsize<min_fft_size)
    fftsize=fftsize*2;
  /*now fftsize is the first power of two bigger then the requested size */
  if((fftsize >=32) && (!only_powers_of_two))
  {
    unsigned int fftsize_pow2_div32=fftsize/32;
    unsigned int small_factor[]={18,20,24,25,27,30,32};//combinations of powers of 2, 3 and 5
    //TODO CHECK speeddifference between powers of two and powers of small primes.
    //For instance, Does it pay of to use 30, or 27  in stead of 32
    unsigned int i=0;
    while(fftsize<min_fft_size)
    {
      fftsize=fftsize_pow2_div32*small_factor[i];
      i++;
    }
  }
  fprintf(stderr, "fftsize   = %i\n", fftsize);
  return fftsize;
/*Simple algorithm, only factors of 2
  int fftsize=2;
  while(((unsigned)fftsize)<min_fft_size)
    fftsize=fftsize*2;
  fprintf(stderr, "fftsize   = %i\n", fftsize);
  return fftsize;
*/

}

/*!
 * \brief frequency translate src by normalized_freq
 *
 * \param dst			destination
 * \param src			source
 * \param n			length of src and dst in samples
 * \param normalized_freq	[-1/2, +1/2]
 */
void
freq_xlate(gr_complex *dst, const gr_complex *src, unsigned n, float normalized_freq)
{
  gr_fxpt_nco	nco;
  nco.set_freq(2 * M_PI * normalized_freq);

  for (unsigned int i = 0; i < n; i++){
    gr_complex	phasor(nco.cos(), nco.sin());
    dst[i] = src[i] * phasor;
    nco.step();
  }
}

inline void
write_float(FILE *output, float x)
{
  fwrite(&x, sizeof(x), 1, output);
}

void
write_floats(FILE *output, float *x,int n)
{
  fwrite(x, sizeof(x[0]), n, output);
}

// write 8-float header
static void
write_header (FILE *output, float ndoppler_bins, float min_doppler, float max_doppler)
{
  write_float(output, ndoppler_bins);
  write_float(output, min_doppler);
  write_float(output, max_doppler);
  write_float(output, 0);
  write_float(output, 0);
  write_float(output, 0);
  write_float(output, 0);
  write_float(output, 0);
}

void calculate_wiener_filter_spectrum(gr_complex * result_spectrum, int N,gr_complex *ref_spectrum /*estimated blur*/,double sigma /*noise power*/, double alpha=1.0 /*noise scale factor*/)
{
  //all inputs and outputs are in the freq domain
  gr_complex *Hf=ref_spectrum;/*estimated blur*/
  const float sigma_2=alpha*sigma*sigma;

  for(int i=0;i<N;i++)
  {
    //this is wrong: result[i] = conj(Hf[i])/(abs(Hf[i]*Hf[i])+sigma_2);
    //correct but slow: result[i] = conj( Hf[i])/( abs(Hf[i])*abs(Hf[i]) +sigma_2 );
    result_spectrum[i] = conj(Hf[i])/(Hf[i]*conj(Hf[i])+sigma_2);
  }

  return;
}

void apply_wiener_filter_spectrum(gr_complex * result_spectrum, int N,gr_complex *ref_spectrum /*estimated blur*/,gr_complex *signal_spectrum,double sigma /*noise power*/,  double alpha=1.0 /*noise scale factor*/)
{
  //all inputs and outputs are in the freq domain
  gr_complex *Hf=ref_spectrum;/*estimated blur*/
  const float sigma_2=alpha*sigma*sigma;

  for(int i=0;i<N;i++)
  {
    //this is wrong: result[i] = conj(Hf[i])/(abs(Hf[i]*Hf[i])+sigma_2);
    //correct but slow: result[i] = conj( Hf[i])/( abs(Hf[i])*abs(Hf[i]) +sigma_2 );
    result_spectrum[i] = signal_spectrum[i]*conj(Hf[i])/(Hf[i]*conj(Hf[i])+sigma_2);
  }

  return;
}

void calculate_crosscorrelate_spectrum(gr_complex * result_spectrum, int N,gr_complex *ref_spectrum /*spectrum to correlate with*/,double sigma=0.0 /*noise power, not used*/, double alpha=1.0 /*noise scale factor, not used*/)
{
  //all inputs and outputs are in the freq domain
  gr_complex *Hf=ref_spectrum;/*expected spectrum to correlate with*/
  for(int i=0;i<N;i++)
  {
    result_spectrum[i] = conj(Hf[i]);
  }
  return;
}

void apply_crosscorrelate_spectrum(gr_complex * result_spectrum, int N,gr_complex *ref_spectrum /*spectrum to correlate with*/,gr_complex *signal_spectrum,double sigma=0.0 /*noise power, not used*/, double alpha=1.0 /*noise scale factor, not used*/)
{
  //all inputs and outputs are in the freq domain
  gr_complex *Hf=ref_spectrum;/*expected spectrum to correlate with*/
  for(int i=0;i<N;i++)
  {
    result_spectrum[i] = signal_spectrum[i]*conj(Hf[i]);
  }
  return;
}



void calculate_convolute_spectrum(gr_complex * result_spectrum, int N,gr_complex *ref_spectrum /*spectrum to convulute with*/,double sigma=0.0 /*noise power, not used*/, double alpha=1.0 /*noise scale factor, not used*/)
{
  //all inputs and outputs are in the freq domain
  //this is a no op, only a copy. It is here to have the same API for the different ways to combine two spectra
  gr_complex *Hf=ref_spectrum;/*spectrum to convolute with*/
  for(int i=0;i<N;i++)
  {
    result_spectrum[i] = Hf[i];
  }
  return;
}

void apply_convolute_spectrum(gr_complex * result_spectrum, int N,gr_complex *ref_spectrum /*spectrum to convolute with*/,gr_complex *signal_spectrum,double sigma=0.0 /*noise power, not used*/, double alpha=1.0 /*noise scale factor, not used*/)
{
  //all inputs and outputs are in the freq domain
  gr_complex *Hf=ref_spectrum;/*spectrum to convolute with*/
  for(int i=0;i<N;i++)
  {
    result_spectrum[i] = signal_spectrum[i]*Hf[i];
  }
  return;
}

void apply_fft_filter(gr_complex * result /* time domain */,int N,gr_complex * signal/*time domain*/, gr_complex * spectrum /* filter spectrum, freq domain */, gri_fft_complex *fwdfft,gri_fft_complex *invfft)
{
  
  memcpy(fwdfft->get_inbuf(), signal, N *sizeof(gr_complex));
  fwdfft->execute();	// compute fwd xform
  gr_complex * Yf=fwdfft->get_outbuf();
  gr_complex * eXf=invfft->get_inbuf();
  for(int i=0;i<N;i++)
  {
    eXf[i]=spectrum[i]*Yf[i];
  }
  invfft->execute();	// compute fwd xform
  memcpy(result, invfft->get_outbuf(), N *sizeof(gr_complex));
}

double
calculate_average_power(int N,gr_complex * signal/*time or freq domain*/)
{
  //Input can be in freq or time domain
  double result_power=0.0;
  for(int i=0;i<N;i++)
  {
    result_power += (signal[i]*conj(signal[i])).real();
  }
  result_power=result_power/N;
  result_power=sqrt(result_power);//do a sqrt to get the RMS value, leave this step out to get average power
  return result_power;
}

void
main_loop(FILE *output, time_series &ref_ts, time_series &scat0_ts,
	  unsigned nranges, unsigned correlation_window_size,
	  float min_doppler, float max_doppler, int ndoppler_bins,double sigma=-1.0)
{
  fprintf(stderr, "ndoppler_bins = %10d\n", ndoppler_bins);
  fprintf(stderr, "min_doppler   = %10f\n", min_doppler);
  fprintf(stderr, "max_doppler   = %10f\n", max_doppler);


  int fftsize=correlation_window_size;//Should this multiplied by two because of the pos/neg freqs in the fft and we only use the positive ones?
                                      //Most efficient when this number can be factored into small primes
                                      //TODO maybe increase to the next power of two if it is not.
                                      //Then also pad with zero's or use some other window
  // float scale_factor = 1.0/correlation_window_size;	// FIXME, not sure this is right
  //float scale_factor = 1.0;				// FIXME, not sure this is right
  float scale_factor = 1.0;///fftsize;

  boost::scoped_array<gr_complex>  shifted(new gr_complex[correlation_window_size]);

  const gr_complex *ref = (const gr_complex *) ref_ts.seek(0, correlation_window_size);
  const gr_complex *scat0 = (const gr_complex *) scat0_ts.seek(0, correlation_window_size);
  // gr_complex shifted[correlation_window_size];		// doppler shifted reference

  //The FFt's
  /*FFTW computes an unnormalized transform, in that there is no coefficient in front of the summation in the DFT. In other words, applying the forward and then the backward transform will multiply the input by n.

From above, an FFTW_FORWARD transform corresponds to a sign of -1 in the exponent of the DFT. Note also that we use the standard “in-order” output ordering—the k-th output corresponds to the frequency k/n (or k/T, where T is your total sampling period). For those who like to think in terms of positive and negative frequencies, this means that the positive frequencies are stored in the first half of the output and the negative frequencies are stored in backwards order in the second half of the output. (The frequency -k/n is the same as the frequency (n-k)/n.) 
  */
  gri_fft_complex	  *fwdfft_shifted_ref;		// forward "plan"
  gri_fft_complex	  *fwdfft_scat;		// forward "plan"
  gri_fft_complex	  *invfft;		// inverse "plan"

  fwdfft_shifted_ref = new gri_fft_complex(fftsize, true);
  fwdfft_scat = new gri_fft_complex(fftsize, true);
  invfft = new gri_fft_complex(fftsize, false);
  /*fft implementation pseudocode:
     fft_ref=fft(ref_ts)
     loop over dopppler
       shifted=translatefreq(ref,current_doppler)
       fft_shifted=fft(shifted)
       loop over freqs
         fft_out[freq]=conjugatemult(fft_shifted[current_freq],fft_ref[current_freq])
       out[current_doppler]=ifft(fft_out)
   */
  float doppler_incr = 0;
  if (ndoppler_bins == 1){
    min_doppler = 0;
    max_doppler = 0;
  }
  else
    doppler_incr = (max_doppler - min_doppler) / (ndoppler_bins - 1);

  write_header(output, ndoppler_bins, min_doppler, max_doppler);
  int j = 0;
  for (j = 0 /*correlation_window_size*/; j < fftsize; j++)
    fwdfft_scat->get_inbuf()[j] = 0;
  memcpy(&fwdfft_scat->get_inbuf()[0], scat0, correlation_window_size * 1 *sizeof(gr_complex));
  //for (j = correlation_window_size; j < fftsize; j++)
  //  fwdfft_scat->get_inbuf()[j] = 0;
  fwdfft_scat->execute();	// compute fwd xform
  //Estimate the noise power by taking the power of the scat signal
  //This is only correct if the scat signal has the direct signal completely removed
  double avg_scat_power=calculate_average_power(fftsize,fwdfft_scat->get_outbuf());
  printf("avg_scat_power = %f\n",avg_scat_power);

  if(sigma<0.0)
  {
    double sigma=avg_scat_power;//1380.0;
    printf("sigma is set to %f which is the avg_scat_power\n",sigma);
  } else
  {
   //sigma=30.0;//1400.0;
    printf("noise power sigma is manually overridden with value%f\n",sigma);
  }

  for (j = 0; j < fftsize; j++)
      fwdfft_shifted_ref->get_inbuf()[j] = 0;//Should only need to do this once
  float *out=new float[nranges*ndoppler_bins];
  for (int nd = 0; nd < ndoppler_bins; nd++){
    float fdop = min_doppler + doppler_incr * nd;
    //fprintf(stderr, "fdop = %10g\n", fdop);
    freq_xlate(&fwdfft_shifted_ref->get_inbuf()[0], ref, correlation_window_size*1, fdop);	// generated doppler shifted reference
    //fft_shifted=fft(shifted)

    fwdfft_shifted_ref->execute();	// compute fwd xform
    if(0==nd)
    {
      double avg_ref_power=calculate_average_power(fftsize,fwdfft_shifted_ref->get_outbuf());
      printf("avg_ref_power= %f\n",avg_ref_power);
      printf("avg_scat_power/avg_ref_power= %f\n",avg_scat_power/avg_ref_power);
    }
    gr_complex *a = fwdfft_shifted_ref->get_outbuf();
    gr_complex *b = fwdfft_scat->get_outbuf();
    gr_complex *c = invfft->get_inbuf();

//    for (j = 0; j < fftsize; j++)	// cross-correlate in the freq domain
//      c[j] = a[j] * conj(b[j]);
    apply_wiener_filter_spectrum(c, fftsize,a /*estimated blur*/,b/*spectrum of signal*/, sigma /*noise power*/,  1.0 /*noise scale factor*/);

    invfft->execute();
    gr_complex *unscaled_out=invfft->get_outbuf();
    int fft_bin;
    for (j = 0; j < (int)nranges; j++)
    {
        if (j<=0)
          fft_bin=j;
        else
          fft_bin=correlation_window_size-j;
        out[j*ndoppler_bins+nd]=norm(unscaled_out[fft_bin]) * scale_factor;
        //out[j*ndoppler_bins+nd]=norm(unscaled_out[j]) * scale_factor;
    }
  }
  write_floats(output,out,ndoppler_bins*nranges);
  delete [] out;
  delete fwdfft_shifted_ref;
  delete fwdfft_scat;
  delete invfft;
  /*unsigned long long ro = 0;	// reference offset
  unsigned long long so = 0;	// scatter offset

  for (unsigned int n = 0; n < nranges; n++){
    if (0){
      fprintf(stdout, "n =  %6d\n", n);
      fprintf(stdout, "ro = %6lld\n", ro);
      fprintf(stdout, "so = %6lld\n", so);
    }
    const gr_complex *ref = (const gr_complex *) ref_ts.seek(ro, correlation_window_size);
    const gr_complex *scat0 = (const gr_complex *) scat0_ts.seek(so, correlation_window_size);
    if (ref == 0 || scat0 == 0)
      return;

    for (int nd = 0; nd < ndoppler_bins; nd++){
      float fdop = min_doppler + doppler_incr * nd;
      //fprintf(stderr, "fdop = %10g\n", fdop);
      freq_xlate(&shifted[0], ref, correlation_window_size, fdop);	// generated doppler shifted reference

      gr_complex ccor = complex_conj_dotprod(&shifted[0], scat0, correlation_window_size);
      float out = norm(ccor) * scale_factor;

      // fprintf(output, "%12g\n", out);
      write_float(output, out);
    }

    so += 1;
  }*/
}

static void
usage(const char *argv0)
{
  const char *progname;
  const char *t = std::strrchr(argv0, '/');
  if (t != 0)
    progname = t + 1;
  else
    progname = argv0;
    
  fprintf(stderr, "usage: %s [options] ref_file scatter_file\n", progname);
  fprintf(stderr, "    -o OUTPUT_FILENAME [default=mdvh-xambi.out]\n");
  fprintf(stderr, "    -m MIN_RANGE [default=0]\n");
  fprintf(stderr, "    -M MAX_RANGE [default=300]\n");
  fprintf(stderr, "    -w CORRELATION_WINDOW_SIZE [default=2048]\n");
  fprintf(stderr, "    -s NSAMPLES_TO_SKIP [default=0]\n");
  fprintf(stderr, "    -d max_doppler (normalized: [0, +1/2)) [default=.0012]\n");
  fprintf(stderr, "    -n ndoppler_bins [default=31]\n");
}

int
main(int argc, char **argv)
{
  int	ch;
  int min_range =    0;
  int max_range =  300;
  const char *ref_filename = 0;
  const char *scatter_filename = 0;
  const char *output_filename = "mdvh-xambi.out";
  unsigned int correlation_window_size = 2048;
  long long int nsamples_to_skip = 0;
  double max_doppler = 0.0012;
  int ndoppler_bins = 31;
  double sigma=-1.0;


  while ((ch = getopt(argc, argv, "m:M:ho:w:s:S:d:n:")) != -1){
    switch (ch){
    case 'm':
      min_range = strtol(optarg, 0, 0);
      break;

    case 'M':
      max_range = strtol(optarg, 0, 0);
      break;

    case 'w':
      correlation_window_size = strtol(optarg, 0, 0);
      if (correlation_window_size <= 1){
	usage(argv[0]);
	fprintf(stderr, "    correlation_window_size must be >= 1\n");
	exit(1);
      }
      break;

    case 'o':
      output_filename = optarg;
      break;
      
    case 's':
      nsamples_to_skip = (long long) strtof(optarg, 0);
      if (nsamples_to_skip < 0){
	usage(argv[0]);
	fprintf(stderr, "    nsamples_to_skip must be >= 0\n");
	exit(1);
      }
      break;

    case 'S':
      sigma = (double) strtod(optarg, 0);
      if (sigma < 0.0){
	//usage(argv[0]);
	fprintf(stderr, "    sigma must be >= 0.0\n");
	fprintf(stderr, "    since you set sigma <0.0 the avg power of the scat signal will be used for sigma\n");
	//exit(1);
      }
      break;

    case 'd':
      max_doppler = strtof(optarg, 0);
      if (max_doppler < 0 || max_doppler >= 0.5){
	usage(argv[0]);
	fprintf(stderr, "    max_doppler must be in [0, 0.5)\n");
	exit(1);
      }
      break;

    case 'n':
      ndoppler_bins = strtol(optarg, 0, 0);
      if (ndoppler_bins < 1){
	usage(argv[0]);
	fprintf(stderr, "    ndoppler_bins must >= 1\n");
	exit(1);
      }
      break;

    case '?':
    case 'h':
    default:
      usage(argv[0]);
      exit(1);
    }
  } // while getopt

  if (argc - optind != 2){
    usage(argv[0]);
    exit(1);
  }

  if (max_range < min_range){
    usage(argv[0]);
    fprintf(stderr, "    max_range must be >= min_range\n");
    exit(1);
  }
  unsigned int nranges = max_range - min_range + 1;

  ref_filename = argv[optind++];
  scatter_filename = argv[optind++];

  FILE *output = fopen(output_filename, "wb");
  if (output == 0){
    perror(output_filename);
    exit(1);
  }

  unsigned long long ref_starting_offset = 0;
  unsigned long long scatter_starting_offset = 0;

  if (min_range < 0){
    ref_starting_offset = -min_range;
    scatter_starting_offset = 0;
  }
  else {
    ref_starting_offset = 0;
    scatter_starting_offset = min_range;
  }

  ref_starting_offset += nsamples_to_skip;
  scatter_starting_offset += nsamples_to_skip;

  try {
    time_series ref(sizeof(gr_complex), ref_filename, ref_starting_offset, 0);
    time_series scat0(sizeof(gr_complex), scatter_filename, scatter_starting_offset, 0);

    main_loop(output, ref, scat0, nranges, correlation_window_size,
	      -max_doppler, max_doppler, ndoppler_bins,sigma);
  }
  catch (std::string &s){
    std::cerr << s << std::endl;
    exit(1);
  }

  return 0;
}

