#include <CL/cl.h>

#include <string.h>
#include <emi.h>

#include <sstream>
#include <fstream>
#include <string>
#include <iostream>
#include <cassert>

#include <stdlib.h>

#define OCL_ERRCK_VAR(var) \
  { if (var != CL_SUCCESS) { fprintf(stderr, "OpenCL Error (%s: %d): %d\n", __FILE__, __LINE__, var); exit(1); } }
  
#define OCL_ERRCK_RETVAL(s) \
  { cl_int clerr = (s);\
    if (clerr != CL_SUCCESS) { fprintf(stderr, "OpenCL Error (%s: %d): %d\n", __FILE__, __LINE__, clerr); exit(1); } }

void clinfo() {
  cl_uint num_platforms;
  cl_device_id clDevice; //hack
  OCL_ERRCK_RETVAL(
    clGetPlatformIDs(/*num_entries=*/0, /*platforms=*/NULL, &num_platforms));
  assert(0 < num_platforms && "No platforms found!");
  cl_platform_id *platforms = new cl_platform_id[num_platforms];
  OCL_ERRCK_RETVAL(
    clGetPlatformIDs(num_platforms, platforms, /*num_platforms=*/NULL));
  std::stringstream ss;
  ss << "# Found " << num_platforms << " OpenCL platform" << (num_platforms == 1 ?  "":"s") << "\n";
  char platform_name[1024];
  char platform_version[1024];
  char device_name[1024];
  char device_vendor[1024];
  cl_uint num_cores;
  cl_uint clk_freq;
  cl_long global_mem_size;
  cl_ulong local_mem_size;
  for (int i=0; i<(int)num_platforms; i++) {
    OCL_ERRCK_RETVAL(
      clGetPlatformInfo(platforms[i], CL_PLATFORM_NAME, sizeof(platform_name), platform_name, /*param_value_size_ret=*/NULL));
    OCL_ERRCK_RETVAL(
      clGetPlatformInfo(platforms[i], CL_PLATFORM_VERSION, sizeof(platform_version), platform_version, /*param_value_size_ret=*/NULL));
    cl_uint num_devices;
    OCL_ERRCK_RETVAL(
      clGetDeviceIDs(platforms[i], CL_DEVICE_TYPE_ALL, /*num_entries=*/0, /*devices=*/NULL, &num_devices));
    ss << "# Platform " << i << "\n";
    ss << "# Name: " << platform_name << "\n";
    ss << "# Version: " << platform_version << "\n";
    ss << "# Number of devices: " << num_devices << "\n";

    cl_device_id *devices = new cl_device_id[num_devices];
    OCL_ERRCK_RETVAL(
      clGetDeviceIDs(platforms[i], CL_DEVICE_TYPE_ALL, num_devices, devices, /*num_devices=*/NULL));
    for (int j=0; j<(int)num_devices; j++) {
      OCL_ERRCK_RETVAL(
          clGetDeviceInfo(devices[j], CL_DEVICE_NAME, sizeof(device_name), device_name, /*param_value_size_ret=*/NULL));
      OCL_ERRCK_RETVAL(
          clGetDeviceInfo(devices[j], CL_DEVICE_VENDOR, sizeof(device_vendor), device_vendor, /*param_value_size_ret=*/NULL));
      OCL_ERRCK_RETVAL(
          clGetDeviceInfo(devices[j], CL_DEVICE_MAX_COMPUTE_UNITS, sizeof(num_cores), &num_cores, /*param_value_size_ret=*/NULL));
      OCL_ERRCK_RETVAL(
          clGetDeviceInfo(devices[j], CL_DEVICE_MAX_CLOCK_FREQUENCY, sizeof(clk_freq), &clk_freq, /*param_value_size_ret=*/NULL));
      OCL_ERRCK_RETVAL(
          clGetDeviceInfo(devices[j], CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(global_mem_size), &global_mem_size, /*param_value_size_ret=*/NULL));
      OCL_ERRCK_RETVAL(
          clGetDeviceInfo(devices[j], CL_DEVICE_LOCAL_MEM_SIZE, sizeof(local_mem_size), &local_mem_size, /*param_value_size_ret=*/NULL));

      ss << "# Device " << j << "\n";
      ss << "# \tName: " << device_name << "\n";
      ss << "# \tVendor: " << device_vendor << "\n";
      ss << "# \tCompute units: " << num_cores << "\n";
      ss << "# \tClock frequency: " << clk_freq << " MHz\n";
      ss << "# \tGlobal memory: " << (global_mem_size>>30) << "GB\n";
      ss << "# \tLocal memory: " <<  (local_mem_size>>10) << "KB\n";
    }
    delete[] devices;
  }
  delete[] platforms;
  std::cout << ss.str();
}
int PLATFORM = 0;
int DEVICE = 0;
char *DEVICE_NAME = NULL;
bool SET_DEVICE_FROM_NAME = false;
bool SUBSTITUTION = true;
int EMI_BLOCK = 0;
bool OPTIMISATIONS = true;
int SNAPSHOT = -1;
char *LOAD_SNAPSHOT = NULL;
bool RUN_SINGLE_ITERATION = false;
bool CLINFO = false;
bool PRINT_KERNEL = false;
bool EMIVERBOSE = false;
bool NO_WHILE = false;
void parseExtraArgs(int argc, char** argv)  {
  for (int i=0; i<argc; i++) {
    if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) {
      std::cout << "--clinfo                           Print info about all platforms and devices" << std::endl;
      std::cout << "--platform N --device M            Select a given platform and device" << std::endl;
      std::cout << "--name NAME                        Ensure the selected device name includes the string NAME" << std::endl;
      std::cout << "--emi_block N                      Inject emi block N" << std::endl;
      std::cout << "--substitution [0|1]               Turn substitutions on/off (on by default)" << std::endl;
      std::cout << "--optimisations [0|1]              Turn compile optimisations on/off (on by default)" << std::endl;
      std::cout << "--print_kernel                     Print kernel before compilation" << std::endl;
      std::cout << "--no_while                         Wrap EMI blocks in for loop instead of while" << std::endl;
      std::cout << "--set_device_from_name             Attempt to find matching platform/device for NAME" << std::endl;
      exit(0);
    }
    if (strcmp(argv[i], "--platform") == 0) {
      PLATFORM = atoi(argv[i+1]);
      i++;
    }
    else if (strcmp(argv[i], "--device") == 0) {
      DEVICE = atoi(argv[i+1]);
      i++;
    }
    else if (strcmp(argv[i], "--name") == 0) {
      DEVICE_NAME = argv[i+1];
      i++;
    }
    else if (strcmp(argv[i], "--set_device_from_name") == 0) {
      SET_DEVICE_FROM_NAME = true;
    }
    else if (strcmp(argv[i], "--substitution") == 0) {
      SUBSTITUTION = (bool) atoi(argv[i+1]);
      i++;
    }
    else if (strcmp(argv[i], "--emi_block") == 0) {
      EMI_BLOCK = atoi(argv[i+1]);
      i++;
    }
    else if (strcmp(argv[i], "--optimisations") == 0) {
      OPTIMISATIONS = atoi(argv[i+1]);
      i++;
    }
    else if (strcmp(argv[i], "--snapshot") == 0) {
      SNAPSHOT = atoi(argv[i+1]);
      i++;
    }
    else if (strcmp(argv[i], "--load_snapshot") == 0) {
      LOAD_SNAPSHOT = argv[i+1];
    }
    else if (strcmp(argv[i], "--single") == 0) {
      RUN_SINGLE_ITERATION = true;
    }
    else if (strcmp(argv[i], "--clinfo") == 0) {
      clinfo();
      exit(0);
    }
    else if (strcmp(argv[i], "--print_kernel") == 0) {
      PRINT_KERNEL = true;
    }
    else if (strcmp(argv[i], "--verbose") == 0) {
      EMIVERBOSE = true;
    }
    else if (strcmp(argv[i], "--no_while") == 0) {
      NO_WHILE = true;
    }
    // else ignored
  }
}

/*
 * Try to set platform-id and device-id based on the device name.
 * Returns bool if successful match is found or not.
 */
bool setPlatformDeviceFromDeviceName() {
  cl_uint num_platforms;
  unsigned p, d;
  OCL_ERRCK_RETVAL(
    clGetPlatformIDs(0, NULL, &num_platforms));
  cl_platform_id *platforms = new cl_platform_id[num_platforms];
  assert(platforms);
  OCL_ERRCK_RETVAL(clGetPlatformIDs(num_platforms, platforms, NULL));
  bool match = false;
  for (p=0; !match && p<num_platforms; ++p) {
    cl_platform_id platform = platforms[p];
    cl_uint num_devices;
    OCL_ERRCK_RETVAL(clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, 0, NULL, &num_devices));
    cl_device_id *devices = new cl_device_id[num_devices];
    assert(devices);
    OCL_ERRCK_RETVAL(clGetDeviceIDs(platform, CL_DEVICE_TYPE_ALL, num_devices, devices, NULL));
    for (d=0; !match && d<num_devices; ++d) {
      cl_device_id device = devices[d];
      char name[65536];
      size_t size;
      OCL_ERRCK_RETVAL(clGetDeviceInfo(device, CL_DEVICE_NAME, sizeof(name), name, &size));
      assert(size < sizeof(name));
      if (EMIVERBOSE) {
        printf("# At platform %d and device %d with name [%s]]\n", p, d, name);
      }
      if (strstr(name, DEVICE_NAME) != NULL) {
        match = true;
        if ((PLATFORM != p) || (DEVICE != d)) {
          if (EMIVERBOSE) {
            printf("# Set platform %d and device %d to match %s\n", p, d, DEVICE_NAME);
          }
          PLATFORM = p;
          DEVICE = d;
        }
      }
    }
    delete[] devices;
  }
  delete[] platforms;
  return match;
}

void initCL(cl_platform_id &clPlatform, cl_device_id &clDevice, cl_context &clContext, cl_command_queue &clCommandQueue) {
  // fix up platform/device choice if --set_device_from_name specified
  if (SET_DEVICE_FROM_NAME) {
    if (DEVICE_NAME == NULL) {
      std::cout << "Must give '-n NAME' to use --set_device_from_name\n" << std::endl;
      exit(1);
    }
    if (!setPlatformDeviceFromDeviceName()) {
      std::cout << "# No matching platform/device found for name " << DEVICE_NAME << std::endl;
      exit(1);
    }
  }

  if (EMIVERBOSE) {
    std::cout << "# Will use platform " << PLATFORM << " and device " << DEVICE << std::endl;
    std::cout << "# Substitution is " << SUBSTITUTION << std::endl;
  }
  cl_int clStatus;
  // get platform
  cl_uint num_platforms;
  OCL_ERRCK_RETVAL(
    clGetPlatformIDs(/*num_entries=*/0, /*platforms=*/NULL, &num_platforms));
  assert(0 < num_platforms && "No platforms found!");
  cl_platform_id *platforms = new cl_platform_id[num_platforms];
  OCL_ERRCK_RETVAL(
    clGetPlatformIDs(num_platforms, platforms, /*num_platforms=*/NULL));
  clPlatform = platforms[PLATFORM];
  // get device on platform
  cl_uint num_devices;
  OCL_ERRCK_RETVAL(
    clGetDeviceIDs(clPlatform, CL_DEVICE_TYPE_ALL, /*num_entries=*/0, /*devices=*/NULL, &num_devices));
  assert(0 < num_devices && "No devices found!");
  cl_device_id *devices = new cl_device_id[num_devices];
  OCL_ERRCK_RETVAL(
    clGetDeviceIDs(clPlatform, CL_DEVICE_TYPE_ALL, num_devices, devices, /*num_devices=*/NULL));
  clDevice = devices[DEVICE];
  // setup context and command queue
  clContext = clCreateContext(/*properties=*/NULL, /*num_devices=*/1, &clDevice, NULL, NULL, &clStatus);
  OCL_ERRCK_VAR(clStatus);
  clCommandQueue = clCreateCommandQueue(clContext,clDevice,CL_QUEUE_PROFILING_ENABLE,&clStatus);
  OCL_ERRCK_VAR(clStatus);

  delete[] devices;
  delete[] platforms;

  // Check to requsted device name matches
  if (DEVICE_NAME != NULL) {
    char name[65536];
    size_t size;
    OCL_ERRCK_RETVAL(clGetDeviceInfo(clDevice, CL_DEVICE_NAME, sizeof(name), name, &size));
    assert(size < sizeof(name));
    if (strstr(name, DEVICE_NAME) == NULL) {
      std::cout << "Given name [" << DEVICE_NAME << "] not found in device name [" << name << "]" << std::endl;
      exit(1);
    }
  }
}

void compileProgram(const char *fname, const char *kernelname, cl_device_id clDevice, cl_context clContext, cl_program &clProgram, cl_kernel &clKernel, std::string extraflags) {
  std::ifstream f(fname, std::ios::in);
  std::stringstream sstr;
  sstr << f.rdbuf();
  std::string str = sstr.str();
  size_t pos = 0;
  std::string old("EMI_BLOCK");
  std::stringstream blockss;
  blockss << EMI_BLOCK;
  std::string news = blockss.str();
  while ((pos = str.find(old, pos)) != std::string::npos) {
    str.replace(pos, old.length(), news);
    pos += old.length();
  }
  if (PRINT_KERNEL) {
    std::cout << str;
  }
  cl_int clStatus;
  const char *clSource = str.c_str();
  clProgram = clCreateProgramWithSource(clContext, 1, (const char **)&clSource, NULL, &clStatus);
  OCL_ERRCK_VAR(clStatus);

  std::stringstream flags;
  flags << extraflags;
  flags << " -I src";
  if (!SUBSTITUTION) {
    flags << " -D NO_SUBSTITUTION";
  }
  if (!OPTIMISATIONS) {
    flags << " -cl-opt-disable";
  }
  flags << " -D EMI_BLOCK=" << EMI_BLOCK;
  if (NO_WHILE)  {
    flags << " -D NO_WHILE_TRUE";
  }
  std::cout << "# OpenCL compiler flags are [" << flags.str() << "]" << std::endl;
  clStatus = clBuildProgram(clProgram,1,&clDevice,flags.str().c_str(),NULL,NULL);

  if (clStatus != CL_SUCCESS) {
    char *build_log;
    size_t ret_val_size;
    clGetProgramBuildInfo(clProgram, clDevice, CL_PROGRAM_BUILD_LOG, 0, NULL, &ret_val_size);  
    build_log = (char *)malloc(ret_val_size+1);
    clGetProgramBuildInfo(clProgram, clDevice, CL_PROGRAM_BUILD_LOG, ret_val_size, build_log, NULL);
    // there's no information in the reference whether the string is 0 terminated or not
    build_log[ret_val_size] = '\0';
    printf("%s\n", build_log );
    free(build_log);
    exit(1);
  }

  clKernel = clCreateKernel(clProgram,kernelname,&clStatus);
  OCL_ERRCK_VAR(clStatus);
}
#define EMI_DATA_LEN 256
void initEMI(cl_context clContext, cl_command_queue clCommandQueue, cl_mem &d_emi_data) {
  int *emi_data = new int[EMI_DATA_LEN];
  for (int i=0; i<EMI_DATA_LEN; i++) {
    emi_data[i] = i;
  }
  cl_int clStatus;
  d_emi_data = clCreateBuffer(clContext,CL_MEM_READ_WRITE,EMI_DATA_LEN*sizeof(int),NULL,&clStatus);
  OCL_ERRCK_VAR(clStatus);
  OCL_ERRCK_RETVAL(clEnqueueWriteBuffer(clCommandQueue,d_emi_data,CL_TRUE,0,EMI_DATA_LEN*sizeof(int),emi_data,0,NULL,NULL));
  delete[] emi_data;
}
