#include <alloca.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include <stdlib.h>
#include <error.h>
#include <string.h>

#define SIMD_SIZE 8
typedef double vd __attribute__ ((vector_size (SIMD_SIZE*sizeof(double))));
typedef long vl __attribute__ ((vector_size (SIMD_SIZE*sizeof(long))));

typedef struct {
  double x;
  double y;
} point;

double sqr(double x)
{
  return x*x;
}

vd vsqr(vd x)
{
  return x*x;
}

double dist(point cities[], int i, int j) {
  return sqrt(sqr(cities[i].x-cities[j].x)+
	      sqr(cities[i].y-cities[j].y));
}

double DistSqrd(point cities[], int i, int j) {
  return (sqr(cities[i].x-cities[j].x)+
	  sqr(cities[i].y-cities[j].y));
}

void swap(point *p, point *q)
{
  point tmp=*p;
  *p = *q;
  *q = tmp;
}

void swapd(double *p, double *q)
{
  double tmp=*p;
  *p = *q;
  *q = tmp;
}

void tsp(point cities[], point tour[], int ncities)
{
  long i,k;
  double CloseDist;
  long ClosePt;
  vl baseidx;
  double *x, *y;
  void *x1, *y1;
  int err;
  size_t size=(sizeof(double)*(ncities+SIMD_SIZE)&~(sizeof(double)*SIMD_SIZE-1))
#ifdef ONLY_8_BYTE_ALIGNED
    +sizeof(double)
#endif
    ;
  if ((err=posix_memalign(&x1,sizeof(double)*SIMD_SIZE,size)))
    error(1,err,"tsp");
  x = ((double *)x1)+size/sizeof(double)-ncities;
  if ((err=posix_memalign(&y1,sizeof(double)*SIMD_SIZE,size)))
    error(1,err,"tsp");
  y = ((double *)y1)+size/sizeof(double)-ncities;
  
  for (i=1; i<ncities; i++) {
    x[i]=cities[i-1].x;
    y[i]=cities[i-1].y;
  }
  x[0] = cities[ncities-1].x;
  y[0] = cities[ncities-1].y;
  tour[0] = cities[ncities-1];
  for (k=0; k<SIMD_SIZE; k++)
    baseidx[k]=k;
  
  for (i=1; i<ncities; i++) {
    double ThisX = x[i-1];
    double ThisY = y[i-1];
    vd closedist;
    vl closeidx;
    vd dist;
    long j;
    // printf("(%g,%g)\n",ThisX,ThisY);
    /* GNU C allows assigning the scalar to the vector, GNU C++ does
       not; clang implements the same */
    closedist = (vd){DBL_MAX,DBL_MAX,DBL_MAX,DBL_MAX,DBL_MAX,DBL_MAX,DBL_MAX,DBL_MAX,};
    closeidx = (vl){-1,-1,-1,-1,-1,-1,-1,-1,};
    for (j = ncities-SIMD_SIZE; j >= i; j -= SIMD_SIZE) {
#ifdef VECTOR_LOADS_ARE_64_BYTE_ALIGNED
      /* does not buy any speedup on Rocket Lake */
      vd vx=*(vd *)&x[j];
      vd vy=*(vd *)&y[j];
#else
      vd vx, vy;
      memcpy(&vx,&x[j],sizeof(vx));
      memcpy(&vy,&y[j],sizeof(vy));
#endif
      dist = vsqr(vx-ThisX)+vsqr(vy-ThisY);
      /* GNU C does not allow ? : on vectors, GNU C++ does; clang
         implements the same */
      closeidx =  dist<closedist ? (baseidx+j) : closeidx;
      closedist = dist<closedist ? dist : closedist;
    }
    for (k=SIMD_SIZE-1; j+k>=i; k--) {
      dist[k] = sqr(x[k+j]-ThisX)+sqr(y[k+j]-ThisY);
      if (dist[k]<closedist[k]) {
        closedist[k] = dist[k];
        closeidx[k] = j+k;
        // printf("%ld %g  ",closeidx[k],closedist[k]);
      }
    }
    // printf("a\n");
    CloseDist = closedist[0];
    ClosePt =   closeidx[0];
    for (k=1; k<SIMD_SIZE; k++) {
      if (closedist[k]<CloseDist) {
        CloseDist = closedist[k];
        ClosePt = closeidx[k];
        // printf("%ld %g  ",ClosePt,CloseDist);
      }
    }
    // printf("b\n");
    // printf("swap %ld <-> %ld\n",i,ClosePt);
    swapd(&x[i],&x[ClosePt]);
    swapd(&y[i],&y[ClosePt]);
    tour[i].x = x[i];
    tour[i].y = y[i];
  }
  free(x1);
  free(y1);
}

int main(int argc, char *argv[])
{
  int i, ncities;
  point *cities;
  point *tour;
  FILE *psfile;
  double sumdist = 0.0;

  if (argc!=2) {
    fprintf(stderr, "usage: %s <ncities>", argv[0]);
    exit(1);
  }
  ncities = atoi(argv[1]);
  cities = (point *)alloca(ncities*sizeof(point));
  tour = (point *)alloca(ncities*sizeof(point));
  for (i=0; i<ncities; i++) {
    cities[i].x = ((double)(random()))/(double)(1U<<31);
    cities[i].y = ((double)(random()))/(double)(1U<<31);
  }
  tsp(cities,tour,ncities);
  psfile = fopen("tsp.eps","w");
  fprintf(psfile, "%%!PS-Adobe-2.0 EPSF-1.2\n%%%%BoundingBox: 0 0 300 300\n");
  fprintf(psfile, "1 setlinejoin\n0 setlinewidth\n");
  fprintf(psfile, "%f %f moveto\n",
	  300.0*tour[0].x, 300.0*tour[0].y);
  for (i=1; i<ncities; i++) {
    fprintf(psfile, "%f %f lineto\n",
	    300.0*tour[i].x, 300.0*tour[i].y);
    sumdist += dist(tour, i-1, i);
  }
  fprintf(psfile,"stroke\n");
  printf("sumdist = %f\n", sumdist);
  exit(0);
}
