/* This code is in the public domain.  It is provided with no strings
   attached, and absolutely no guarantees.
*/

#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <setjmp.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/wait.h>

#ifndef FD_SET
typedef int fd_set;
#define FD_SET(fd,set) *set |= (1 << (fd))
#define FD_CLR(fd,set) *set &= ~(1 << (fd))
#define FD_ISSET(fd,set) (*set & (1 << (fd)))
#define FD_ZERO(set) *set = 0
#endif

#ifndef EINTR
#define EINTR EAGAIN
#endif

#define CHUNK_TIME 300         /* how big to make a chunk */
#define WAIT_TIME 300          /* how long to make slaves wait */
#define DEAD_TIME 5000         /* how long before a slave is dead */
#define PERIOD_TIME 120        /* how often to do periodic stuff */
#define REASSIGN_TIME 1800     /* how long to wait before redoing a chunk */
#define SELECT_TIME 60         /* how long select should wait */
#define MAX_SLAVES 300         /* maximum number of slaves */
#define PROTO_TIME 20          /* timeout for protocol */

struct {
  char pattern [49];
  char goal [17];
} globstruct = {
  {
    0x0e, 0x89, 0x94, 0xb8, 0xbf, 0x0e, 0xb9, 0x2e,
    0x50, 0x44, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00,
    '0',
    0xaf, 0x84, 0xa7, 0x79, 0xf8, 0x13, 0x69, 0x20,
    0x25, 0x9b, 0x53, 0xa0, 0x60, 0xae, 0x75, 0x51,
    0xbf, 0xeb, 0x90, 0xf8, 0x2c, 0x0c, 0xe1, 0xea,
    0x18, 0xac, 0x11, 0x4c, 0x83, 0x14, 0x21, 0xb6,
  },
  {
    0x74,
    0x0f, 0x32, 0x88, 0xcf, 0x1a, 0x22, 0xcd, 0x0a,
    0xa4, 0x48, 0x76, 0xc1, 0x38, 0x12, 0x85, 0xa5,
  }
};
char *glob = (char *) &globstruct;
long globlen = sizeof (globstruct);

void usage (void)
{
  fprintf (stderr, "usage: master [-v|-V] socket begin end\n");
  fprintf (stderr, "   or: master [-v|-V] socket -r\n");
  exit (2);
}

unsigned long start_point, end_point;
unsigned long cur_point;
char ckpname [6] = "ckp.0";
int verbose = 0;
int mysock;
unsigned long start_date;
int found = 0;

/* display messages depending on verbosity level */
#define Verb1 if (verbose >= 1) printf
#define Verb2 if (verbose >= 2) printf

void fatal (char *syscall)
{
  perror (syscall);
  exit (3);
}

void log (char *msg, unsigned long clientaddr)
{
  int i;
  unsigned long adr = ntohl (clientaddr);

  fprintf (stderr, "%d.%d.%d.%d: %s\n", (adr >> 24) & 255, (adr >> 16) & 255,
	   (adr >> 8) & 255, adr & 255, msg);
}

unsigned long get_date (void)
{
  struct timeval tv;
  struct timezone tz;
  int res;

  res = gettimeofday (&tv, &tz);
  if (res != 0) fatal ("gettimeofday");
  return tv.tv_sec;
}

/* read the checkpoint files; find the most recent one */
void recover (void)
{
  char i;
  FILE *f;
  unsigned long maxlow = 0;
  unsigned long maxhigh, low, high;
  int found = 0;

  for (i = '0'; i <= '9'; i++){
    ckpname [4] = i;
    f = fopen (ckpname, "r");
    if (f == NULL) continue;
    low = 0; high = 0;
    fscanf (f, "%lx %lx\n", &low, &high);
    Verb2 ("recover: file %c; low = %lx; high = %lx\n", i, low, high);
    if (low >= maxlow){
      maxlow = low;
      maxhigh = high;
    }
    fclose (f);
    found = 1;
  }
  if (!found) fatal ("cannot find any checkpoint file");
  start_point = maxlow;
  end_point = maxhigh;
  Verb1 ("recover: start = %lx; end = %lx\n", start_point, end_point);
}

void init_socket (int portnum)
{
  struct sockaddr_in myaddr;
  int service_num = htons (portnum);
  int opt;
  int s;

  s = socket (AF_INET, SOCK_STREAM, 0);
  if (s == -1) fatal ("socket");
  if (s > 8 * sizeof (int)) fatal ("socket fd too high");
  opt = 1;
  setsockopt (s, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof (opt));
  myaddr.sin_family = AF_INET;
  myaddr.sin_port = service_num;
  myaddr.sin_addr.s_addr = INADDR_ANY;
  if (bind (s, &myaddr, sizeof (myaddr)) == -1) fatal ("bind");
  if (listen (s, 8) == -1) fatal ("listen");
  mysock = s;
}

typedef struct elem {
  struct elem *next;
  unsigned long start;
  unsigned long end;
  unsigned long date;
} *elem;

struct elem run_head = { NULL, 0, 0, 0 };
struct elem todo_head = { NULL, 0, 0, 0 };
struct elem done_head = { NULL, 0, 0, 0 };
elem run_list = &run_head;
elem todo_list = &todo_head;
elem done_list = &done_head;

elem new (void)
{
  elem result = (elem) malloc (sizeof (struct elem));
  if (result == NULL) fatal ("malloc");
  return result;
}

elem first (elem list)
{
  return list->next;
}

elem find (elem list, unsigned long st, unsigned long nd)
{
  elem prev, cur;
  
  prev = list;
  while (1){
    cur = prev->next;
    if (cur == NULL || (st == cur->start && nd == cur->end)) break;
    prev = cur;
  }
  return cur;
}

void insert (elem list, elem e)
{
  elem prev, cur;

  prev = list;
  while (1){
    cur = prev->next;
    if (cur == NULL || cur->date >= e->date){
      e->next = cur;
      prev->next = e;
      return;
    }
    prev = cur;
  }
}

void delete (elem list, elem e)
{
  elem prev, cur;

  prev = list;
  while (1){
    cur = prev->next;
    if (cur == NULL) return;
    if (cur == e){
      prev->next = cur->next;
      return;
    }
    prev = cur;
  }
}

#if 0
void enumerate (elem list, void (*f)(elem e))
{
  elem prev, cur;

  prev = list;
  while (1){
    cur = prev->next;
    if (cur == NULL) return;
    (*f) (cur);
    prev = cur;
  }
}
#endif

struct machine {
  unsigned long addr;
  unsigned long date;
  unsigned long speed;
  unsigned long nproc;
  char status;
};

struct machine slaves [MAX_SLAVES];
int num_slaves;

void update_slave (unsigned long addr, unsigned long date, char status,
		   unsigned long speed, unsigned long nproc)
{
  int i;
  for (i = 0; i < num_slaves; i++){
    if (slaves [i].addr == addr){
      if (date != 0) slaves [i].date = date;
      if (status != '?') slaves [i].status = status;
      if (speed != 0) slaves [i].speed = speed;
      if (nproc != 0) slaves [i].nproc = nproc;
      return;
    }
  }
  if (num_slaves >= MAX_SLAVES) return;
  slaves [num_slaves].addr = addr;
  slaves [num_slaves].date = date;
  slaves [num_slaves].status = status;
  slaves [num_slaves].speed = speed;
  slaves [num_slaves].nproc = nproc;
  ++num_slaves;
}

enum chunk_reply { Interval, Wait, Done };

int get_chunk (unsigned long clientaddr, unsigned long speed,
	       unsigned long *p_start, unsigned long *p_end)
{
  float tg = (float) speed * CHUNK_TIME / 4096;         /* 5 mn */
  unsigned long togive;
  elem e = first (todo_list);

  if (tg < 1) tg = 1;
  if (tg > end_point - start_point) tg = end_point - start_point;
  togive = tg;
  update_slave (clientaddr, 0, '?', speed, 0);
  if (e != NULL){
    Verb2 ("first chunk to do: from %lx to %lx\n", e->start, e->end);
    if (e->end - e->start > togive){
      elem e2 = new ();
      e2->start = e->start;
      e2->end = e->start + togive;
      e->start = e2->end;
      e->date = e->start;          /* to get the sorting right */
      if (e2->end > end_point || e2->end <= e2->start){
	e2->end = end_point;
	delete (todo_list, e);
	free (e);
      }
      e = e2;
    }else{
      delete (todo_list, e);
    }
    /* e is what we want to give to the process */
    *p_start = e->start;
    *p_end = e->end;
    e->date = get_date ();
    insert (run_list, e);
    return Interval;
  }else{
    /* no more work to do */
    if (first (run_list) == NULL){
      return Done;
    }else{
      return Wait;
    }
  }
}

void get_status (unsigned long clientaddr, char status, unsigned long nproc)
{
  update_slave (clientaddr, get_date (), status, 0, nproc);
}

void get_result (unsigned long clientaddr, unsigned long st, unsigned long nd,
		 unsigned long len, char *buf)
{
  elem e = find (run_list, st, nd);
  long res;
  FILE *f;

  if (e == NULL) return;
  f = fopen ("log", "a");
  if (f != NULL){
    fprintf (f, "%8.8lx %8.8lx%s\n", st, nd, (len != 0) ? " *" : "");
    fclose (f);
  }
  if (len > 0){
    f = fopen ("results", "a");
    if (f == NULL) fatal ("fopen");
    res = fwrite (buf, 1, len, f);
    if (res != len) fatal ("fwrite");
    fclose (f);
    log ("got result", clientaddr);
    found = 1;
    todo_list->next = NULL;
    cur_point = e->start;
  }
  delete (run_list, e);
  e->date = e->start;
  insert (done_list, e);
  while (1){
    e = first (done_list);
    if (e == NULL || e->start > cur_point) break;
    cur_point = e->end;
    delete (done_list, e);
    free (e);
  }
}

jmp_buf jump_timeout;

void alarm_handler (void)
{
  longjmp (jump_timeout, 1);
}

void protocol (int fd, unsigned long clientaddr)
{
  int c;
  unsigned long param1, param2, param3;
  int res;
  char *buf;
  FILE *in, *out;
  char paramc;

  in = fdopen (fd, "r");
  out = fdopen (fd, "w");
  if (in == NULL || out == NULL) fatal ("fdopen");
  signal (SIGALRM, (void *) alarm_handler);
  if (setjmp (jump_timeout) == 0){
    alarm (PROTO_TIME);
    while (1){
      c = getc (in);
      switch (c){
      case 'g':
	fscanf (in, " %lx", &param1);
	c = getc (in);   /* eat the end-of-line */
	if (c == ' '){
	  fscanf (in, "%lx", &clientaddr);
	  getc (in);
	  clientaddr = htonl (clientaddr);
	}
	Verb2 ("got command: g %lx %lx\n", param1, clientaddr);
	res = get_chunk (clientaddr, param1, &param2, &param3);
	switch (res){
	case Interval:
	  fprintf (out, "I %lx %lx %lx\n", param2, param3, globlen);
	  fwrite (glob, 1, globlen, out);
	  break;
	case Wait:
	  fprintf (out, "W %lx\n", WAIT_TIME);
	  break;
	case Done: default:
	  fprintf (out, "F\n");
	  break;
	}
	fflush (out);
	break;
      case 's':
	fscanf (in, " %c", &paramc);
	c = getc (in);   /* eat the end-of-line */
	param1 = 1;
	if (c == ' '){
	  fscanf (in, "%lx %lx", &clientaddr, &param1);
	  getc (in);
	  clientaddr = htonl (clientaddr);
	}
	Verb2 ("got command: s %c %lx %lx\n", paramc, clientaddr, param1);
	get_status (clientaddr, paramc, param1);
	break;
      case 'r':
	res = fscanf (in, " %lx %lx %lx", &param1, &param2, &param3);
	Verb2 ("got command: r %lx %lx %lx\n", param1, param2, param3);
	getc (in);   /* eat the end-of-line */
	if (res != 3){
	  log ("ill-formed result code", clientaddr);
	  goto protocol_done;
	}
	if (param3 != 0){
	  Verb2 ("got result, size = %ld\n", param3);
	  buf = malloc (param3);
	  if (buf == NULL){
	    log ("out of memory for result string", clientaddr);
	    goto protocol_done;
	  }
	  res = fread (buf, 1, param3, in);
	  if (res != param3){
	    log ("truncated result", clientaddr);
	    goto protocol_done;
	  }
	}
	get_result (clientaddr, param1, param2, param3, buf);
	break;
      case '\n': case '\r': case ' ': break;
      case EOF: goto protocol_done;
      default:
	log ("protocol error", clientaddr);
	goto protocol_done;
      }
    }
  protocol_done: ;
  }else{
    Verb1 ("time-out in protocol\n");
  }
  alarm (0);
  fclose (in);
  fclose (out);
  return;
}

char cur_ckp = '0';
unsigned long last_check = 0;

void periodic (void)
{
  elem e, e2;
  FILE *ckp_file;
  unsigned long total_speed;
  unsigned long now;
  int i;

  now = get_date ();
  if (now - last_check < PERIOD_TIME) return;
  last_check = now;
  while (1){
    e = first (run_list);
    if (e == NULL || now - e->date <= REASSIGN_TIME) break;
    Verb2 ("reassigning interval: %lx..%lx\n", e->start, e->end);
    delete (run_list, e);
    e2 = new ();
    e2->start = e->start;
    e2->end = e->end;
    e2->date = e2->start;
    insert (todo_list, e2);
  }
  ckpname [4] = cur_ckp;
  ++cur_ckp;
  if (cur_ckp > '9') cur_ckp = '0';
  Verb2 ("writing checkpoint to file %s\n", ckpname);
  ckp_file = fopen (ckpname, "w");
  if (ckp_file == NULL){
    log ("cannot write checkpoint file", 0);
  }else{
    fprintf (ckp_file, "%lx %lx\n", cur_point, end_point);
    fclose (ckp_file);
  }
  /* display stats and slave list */
  {
    time_t date = get_date ();
    printf ("*************************** %s", ctime (&date));
  }
  printf ("name                 address         status    speed processes\n");
  total_speed = 0;
  for (i = 0; i < num_slaves; i++){
    char isdead;
    unsigned long a;
    struct hostent *host;
    char *name;
    
    a = ntohl (slaves [i].addr);
    host = gethostbyaddr (&(slaves [i].addr), 4, AF_INET);
    if (host != NULL){
      name = host->h_name;
    }else{
      name = "?";
    }
    isdead = ' ';
    if (now - slaves [i].date > DEAD_TIME) isdead = '*';
    if (slaves [i].status == 'r' && isdead == ' '){
      total_speed += slaves [i].speed * slaves [i].nproc;
    }
    printf ("%-20.20s %3.3d.%3.3d.%3.3d.%3.3d   %c%c   %8d   %3d\n",
	    name, (a>>24)&255, (a>>16)&255, (a>>8)&255, a&255,
	    slaves [i].status, isdead, slaves [i].speed, slaves [i].nproc);
  }

  {
    float avg_speed;
    printf ("\ntotal peak speed = %lu\n", total_speed);
    if (now - start_date != 0){
      avg_speed = 4096.0 * (cur_point - start_point) / (now - start_date);
      printf ("average speed = %.0f\n", avg_speed);
    }
    printf ("current point = %lx\n", cur_point);
    printf ("%.1f%% done;      ", cur_point * 100.0 / end_point);
    if (avg_speed > 1){
      printf ("estimated time to completion = %.1f days\n",
	      4096.0 * (end_point - cur_point) / avg_speed / 86400);
    }else{
      printf ("\n");
    }
    if (found) printf ("**** result found ****\n");
  }
}

void dispatch (void)
{
  fd_set readfds;
  struct timeval timeout;
  int res;
  struct sockaddr_in clientaddr;
  int clientaddrlen;
  int clientsocket;

  while (1){
    Verb1 ("waiting for connection\n");
    FD_ZERO (&readfds);
    FD_SET (mysock, &readfds);
    timeout.tv_sec = SELECT_TIME;
    timeout.tv_usec = 0;
    res = select (8*sizeof(int), (void *) &readfds, NULL, NULL, &timeout);
    if (res == 1 && FD_ISSET (mysock, &readfds)){
      clientaddrlen = sizeof (clientaddr);
      clientsocket = accept (mysock, &clientaddr, &clientaddrlen);
      if (clientsocket == -1){
	log ("accept failed", 0);
	continue;
      }
      Verb2 ("got connection from %lx\n", clientaddr.sin_addr.s_addr);
      protocol (clientsocket, clientaddr.sin_addr.s_addr);
      close (clientsocket);
    }else if (res != 0){
      if (errno != EINTR) fatal ("select");
    }
    Verb2 ("doing periodic action\n");
    periodic ();
  }
}

void init_lists (unsigned long st, unsigned long nd)
{
  elem e = new ();
  e->start = st;
  e->end = nd;
  e->date = st;
  insert (todo_list, e);
}

int main (int argc, char **argv)
{
  int portnum;

  if (argc > 1){
    if (!strcmp (argv [1], "-v")){
      verbose = 1;
      --argc;
      ++argv;
    }else if (!strcmp (argv [1], "-V")){
      verbose = 2;
      --argc;
      ++argv;
    }
  }
  if (argc < 2) usage ();
  portnum = 0;
  sscanf (argv [1], "%d", &portnum);
  if (portnum == 0) usage ();
  if (argc == 3 && !strcmp (argv [2], "-r")){
    recover ();
  }else{
    char i;
    for (i = '0'; i <= '9'; i++){
      ckpname [4] = i;
      unlink (ckpname);
    }
    sscanf (argv [2], "%lx", &start_point);
    sscanf (argv [3], "%lx", &end_point);
  }
  Verb1 ("main: start = %lx; end = %lx\n", start_point, end_point);
  init_socket (portnum);
  init_lists (start_point, end_point);
  start_date = get_date ();
  cur_point = start_point;
  dispatch ();
}

