
#include <iostream>
#include <string>
#include <algorithm>
#include <numeric> // std::reduce is in c++-17
#include <chrono>

#include <thrust/host_vector.h>
#include <thrust/device_vector.h>

#include "cub/cub/cub.cuh"

#include "adding-up-bond-restraint.hh"

__inline__ __device__
float warpReduceSum(int val) {
  for (int offset = warpSize/2; offset > 0; offset /= 2) 
    val += __shfl_down(val, offset);
  return val;
}


__inline__ __device__
float blockReduceSum(float val) {

  static __shared__ int shared[32]; // Shared mem for 32 partial sums
  int lane = threadIdx.x % warpSize;
  int wid = threadIdx.x / warpSize;

  val = warpReduceSum(val);     // Each warp performs partial reduction

  if (lane==0) shared[wid]=val; // Write reduced value to shared memory

  __syncthreads();              // Wait for all partial reductions

  //read from shared memory only if that warp existed
  val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;

  if (wid==0) val = warpReduceSum(val); //Final reduce within first warp

  return val;
}

__global__
void deviceReduceKernel(float *in, float* out, int N) {
  float sum = 0.0;
  //reduce multiple elements per thread
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; 
       i < N; 
       i += blockDim.x * gridDim.x) {
     printf("deviceReduceKernel adding %f to %f\n", in[i], sum);
     sum += in[i];
  }
  sum = blockReduceSum(sum);
  if (threadIdx.x==0) {
     printf("deviceReduceKernel set out[blockIdx.x] %d to %f\n", blockIdx.x, sum);
     out[blockIdx.x]=sum;
  }
}


void deviceReduce(float *in, float* out, int N) { // docs

  int threads = 512;
  int blocks = min((N + threads - 1) / threads, 1024);

  printf("deviceReduce blocks threads %d %d\n", blocks, threads);

  deviceReduceKernel<<<blocks, threads>>>(in, out, N);
  deviceReduceKernel<<<1, 1024>>>(out, out, blocks);
}

struct saxpy_functor {
    const float a;
    saxpy_functor(float _a) : a(_a) {}
    __host__ __device__
        float operator()(const float& x, const float& y) const { 
            return a * x + y;
        }
};

void saxpy_fast(float A, thrust::device_vector<float>& X, thrust::device_vector<float>& Y) {
    // Y <- A * X + Y
    thrust::transform(X.begin(), X.end(), Y.begin(), Y.begin(), saxpy_functor(A));
}

struct my_calc_functor {
    const float *atom_positions;
    my_calc_functor(const float *_ap) {atom_positions = _ap;}
    __host__ __device__ float operator()(const coot::bond_restraint &p) const { return p.distortion(atom_positions); }
};

struct std_calc_functor {
    const float *atom_positions;
    std_calc_functor(const float *_ap) { atom_positions = _ap;}
    float operator()(const coot::bond_restraint &p) const { return p.distortion(atom_positions); }
};


int main(int argc, char **argv) {

   int N_atoms = 4;
   float *atom_positions;
   cudaMallocManaged(&atom_positions, N_atoms*3*sizeof(float));
   for (unsigned int i=0; i<N_atoms; i++) {
       atom_positions[i] = static_cast<float> (i);
   }

   coot::restraint_type_t b   = coot::BOND_RESTRAINT;
   coot::restraint_type_t nbc = coot::NON_BONDED_CONTACT_RESTRAINT;

   coot::bond_restraint ap1(b,   0,1, 2,0.1);
   coot::bond_restraint ap2(nbc, 0,3, 2,0.1);
   coot::bond_restraint ap3(nbc, 2,3, 2,0.1);
   coot::bond_restraint ap4(nbc, 3,1, 2,0.1);
   thrust::host_vector<coot::bond_restraint> vh;
   std::vector<coot::bond_restraint> std_v;

   unsigned int ManyN=3000;
   for(unsigned int i=0; i<ManyN; i++) {
      // make some restraints
      vh.push_back(ap1); vh.push_back(ap2);
      vh.push_back(ap3); vh.push_back(ap4);
      std_v.push_back(ap1); std_v.push_back(ap2);
      std_v.push_back(ap3); std_v.push_back(ap4);
   }

   thrust::device_vector<coot::bond_restraint> dvd = vh;
   std::vector<float> std_results(dvd.size());
   float *d_results;
   cudaMallocManaged(&d_results, dvd.size() * sizeof(float) );
   thrust::device_ptr<float> d_results_ptr(d_results);
   thrust::plus<float> plus;

   auto tp_0 = std::chrono::high_resolution_clock::now();
   thrust::transform(dvd.begin(), dvd.end(), d_results_ptr, my_calc_functor(atom_positions));
   cudaDeviceSynchronize();

   if (false)
      for (unsigned int i=0; i<std_v.size(); i++)
         std::cout << " after first transform " << i << " " << d_results[i]  << std::endl;

   auto tp_1 = std::chrono::high_resolution_clock::now(); // post-transform
   float sum_reduce = thrust::reduce(d_results_ptr, d_results_ptr+dvd.size(), 0.0, plus);
   cudaDeviceSynchronize();
   std::cout << "Result for sum_reduce: value " << sum_reduce << std::endl;
   auto tp_2 = std::chrono::high_resolution_clock::now(); // post reduce

   float sum_tr = thrust::transform_reduce(dvd.begin(), dvd.end(), my_calc_functor(atom_positions), 0.0, plus);
   cudaDeviceSynchronize();
   auto tp_3 = std::chrono::high_resolution_clock::now(); // post transform_reduce

   auto d10 = std::chrono::duration_cast<std::chrono::microseconds>(tp_1 - tp_0).count();
   auto d21 = std::chrono::duration_cast<std::chrono::microseconds>(tp_2 - tp_1).count();
   auto d32 = std::chrono::duration_cast<std::chrono::microseconds>(tp_3 - tp_2).count();
   std::cout << "Timings (microseconds):\n   transform:" << d10 << "\n   reduce: " << d21 << "\n   transform_reduce: " << d32 << std::endl;

   auto tp_4 = std::chrono::high_resolution_clock::now(); // ---- start using deviceReduce ----
   thrust::transform(dvd.begin(), dvd.end(), d_results_ptr, my_calc_functor(atom_positions));
   cudaDeviceSynchronize();

   std::cout << "Now to add thrust transformed values: size " << dvd.size() << std::endl;
   auto tp_5 = std::chrono::high_resolution_clock::now(); // post transform
   size_t d_temp_storage_bytes;
   float* d_temp_storage=NULL;

   float *d_sum_parallel_reduce_p;
   cudaMallocManaged(&d_sum_parallel_reduce_p, sizeof(float));
   cub::DeviceReduce::Sum(d_temp_storage, d_temp_storage_bytes, d_results, d_sum_parallel_reduce_p, dvd.size());
   cudaMalloc(&d_temp_storage, d_temp_storage_bytes);
   cub::DeviceReduce::Sum(d_temp_storage, d_temp_storage_bytes, d_results, d_sum_parallel_reduce_p, dvd.size());

   cudaDeviceSynchronize();
   auto tp_6 = std::chrono::high_resolution_clock::now();

   auto d65 = std::chrono::duration_cast<std::chrono::microseconds>(tp_6 - tp_5).count();
   std::cout << "Timings:\n   cub::Sum()         " << d65 << " microseconds, value " << *d_sum_parallel_reduce_p << std::endl;

   auto tp_7 = std::chrono::high_resolution_clock::now();
   std::transform(std_v.begin(), std_v.end(), std_results.begin(), std_calc_functor(atom_positions));
   auto tp_8 = std::chrono::high_resolution_clock::now();

   float sum_fly = 0;
   // on the fly addition
   for (unsigned int i=0; i<std_v.size(); i++)
      sum_fly += std_v[i].distortion(atom_positions);
   auto tp_9 = std::chrono::high_resolution_clock::now();

   auto d98 = std::chrono::duration_cast<std::chrono::microseconds>(tp_9 - tp_8).count();
   std::cout << "   CPU on-the-fly-sum " << d98 << " microseconds, value " << sum_fly << std::endl;

   cudaFree(atom_positions);
   cudaFree(d_results);
   cudaFree(d_sum_parallel_reduce_p);

   exit(0);
}
