/*
 * tcp_sync.c
 * june 1995
 * joseph aadler and sandeep gupta
 *
 * Copyright (C) 1995 Massachusetts Institute of Technology
 *
 * Permission to use, copy, modify, distribute, and sell this software
 * and its documentation for any purpose is hereby granted without
 * fee, provided that the above copyright notice appear in all copies
 * and that both that copyright notice and this permission notice
 * appear in supporting documentation. The author makes no
 * representations about the suitability of this software for any
 * purpose. It is provided "as is" without express or implied
 * warranty.
 *
 * THE AUTHORS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
 * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN
 * NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR
 * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
 * OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
 * NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 *
 * RCS $Id: tcp_sync.c,v 1.12 1995/08/22 21:27:49 tuna Exp $
 */

/*
 * This file contains synchronization functions specific to the TCP/UNIX
 * version of CRL.  It is included by sync.c.  Among other things,
 * the region barriers, broadcasts, and reduction functions are
 * implemented here.
 */

#include "crl_int.h"

#define Max(a,b) (((a)>(b))?(a):(b))
#define Min(a,b) (((a)>(b))?(b):(a))

/* for buffered broadcasts: */
typedef struct bdcst
{
  struct bdcst *next; /* pointer to next bdcsts structure, or NULL if this
		       * if the last bdcst
		       */
  int lngth;           /* length of broadcast message */
  void *buf;           /* pointer to broadcast message */
} Bdcst;

/* used for the queue of incoming broadcast messages:    */
static Bdcst   *current; /* first broadcast in the queue */
static Bdcst   *last;    /* last broadcast in the queue  */

/* Used for barrier */
static int      barrier_count = 0;
static int      barrier_dropped;

/* Used for reductions */
typedef double (*red_fun)();

static double  *reduction_values;
static red_fun *reduction_functions;
static int      num_responses;


static void socket_broadcast_send(int,void *);
static void socket_broadcast_recv(int,void *);

void init_sync(void)
{
  /*  a variety of different values need to be initialized in order for the
   *  broadcasts and such to function. They are now all listed here:
   */

  /* allocate space for reductions */

  reduction_values   = (double *)safe_malloc(crl_num_nodes * sizeof(double));
  reduction_functions= (red_fun *)safe_malloc(crl_num_nodes * sizeof(red_fun));

  /* initialize the broadcast message queue: */

  current            = (Bdcst *) safe_malloc(sizeof(Bdcst));
  current->next      = NULL;
}

void exit_sync(void)
{
  /* Will put some cleanup code here later */
}

static void bdcst_handler(int sender, int msg_lngth, char *buffer) {
  /* deals with a broadcast message */

  int oldmask;
  oldmask = crl_sigblock(sigmask(SIGIO));

  if (current->next == NULL) { /* queue is empty */
    current->next       = (Bdcst *) safe_malloc(sizeof(Bdcst));
    current->next->next = NULL;
    current->lngth      = msg_lngth;
    current->buf        = safe_malloc(msg_lngth);
    memcpy(current->buf, buffer, msg_lngth);
    last                = current->next;
    
  } else { /* queue is not empty */
    last->next          = (Bdcst *) safe_malloc(sizeof(Bdcst));
    last->lngth         = msg_lngth;
    last->buf           = safe_malloc(msg_lngth);
    memcpy(last->buf, buffer, msg_lngth);
    last                = last->next;
    last->next          = NULL;

  }
  crl_sigsetmask(oldmask);
}

static void socket_broadcast_send(int lngth, void *buf) {
  /* send the message to all listening nodes: */

  int i;

  for (i = 0; i < crl_num_nodes; i++) {
    if (i != crl_self_addr) { 
      tcp_am_send(i, bdcst_handler, lngth, (char *) buf);
    } /* if */
  } /* for */
}

static void socket_broadcast_recv(int lngth, void *buf) {
  /*  sleep until the message arrives,  stick the message into the
   *  appropriate buffer, unless, of course, it's too big
   */

  int    oldmask;
  Bdcst *prev;

  oldmask = crl_sigblock(sigmask(SIGIO));

  while (current->next == NULL) { /* while the queue is empty, */
    sigpause(0);
  }

  /*  once the queue is no longer empty, process only the first broadcast
   *  message in the queue.
   */
  assert(current->lngth <= lngth);
  memcpy(buf, current->buf, lngth);
  safe_free(current->buf);
  prev    = current;
  current = current->next;
  safe_free(prev);

  crl_sigsetmask(oldmask);
}

static void reduction_handler(int sender, int lngth, char *buffer) {
  
  double       arg;
  red_fun      fun;
  char        *hcvt[sizeof(red_fun)];
  red_fun     *rf_ptr;
  char        *dcvt[sizeof(double)];
  double      *d_ptr;

  memcpy(hcvt, (void *) buffer, sizeof(red_fun));
  rf_ptr = (red_fun *) hcvt;
  fun    = *rf_ptr;

  memcpy(dcvt, (void *) buffer+sizeof(red_fun) , sizeof(double));
  d_ptr = (double *) dcvt;
  arg   = *d_ptr;

  reduction_values[sender]    = arg;
  reduction_functions[sender] = fun;
  num_responses++;

}

static double socket_reduce_dgeneral (double arg, red_fun fun) {
  /*  Initial algorithm: node 0 receives all messages, returns the answer to
   *  each node.
   *
   *  fun is a pointer to the desired reduction function; it is packed
   *  into all messages
   */

  int           i;
  int           oldmask;
  double        ans;
  char          msg[sizeof(double)+sizeof(red_fun)];

  if (crl_self_addr == 0) {
    /* wait until all responses are collected from other nodes */
    oldmask = crl_sigblock(sigmask(SIGIO));
    ans     = arg;
    num_responses++;
    while (num_responses < crl_num_nodes) {
      /* block until a message arrives */
      sigpause(0);
    }
    
    for (i = 1; i < crl_num_nodes ; i++) {
      /* check to see that all nodes are doing the same reduction: */
      assert(fun == reduction_functions[i]);
      ans = (fun)(reduction_values[i], ans);
    }

    /*  now that we have an answer, we need to pack it up, and send it
     *  as a broadcast message to all nodes:
     */
    *(double *)msg = ans;
    socket_broadcast_send(sizeof(double), (void *) msg);
    num_responses = 0; /* reset the counter */
    crl_sigsetmask(oldmask);

  } else {
    /*  send out a message to node 0, then wait for node 0 to
     *  broadcast the answer
     */

    *(red_fun *)msg                    = fun;
    *(double *)(msg + sizeof(red_fun)) = arg;

    tcp_am_send(0, reduction_handler, sizeof(double)+sizeof(red_fun) , msg);

    socket_broadcast_recv(sizeof(double), msg);

    ans = *(double *)(msg);
  }
  return ans;
}

static double addfun(double a, double b) {
  return(a + b);
}

double rgn_reduce_dadd (double arg) {
  double ans;
  ans = socket_reduce_dgeneral(arg, addfun);
  return(ans);
}

static double maxfun (double a, double b) {
  return Max(a,b);
}

double rgn_reduce_dmax(double arg) {
  double ans;
  ans = socket_reduce_dgeneral(arg, maxfun);
  return ans;
}

static double minfun (double a, double b) {
  return Min(a,b);
}

double rgn_reduce_dmin (double arg) {
  double ans;
  ans = socket_reduce_dgeneral(arg, minfun);
  return ans;
}

void rgn_bcast_send(int nbytes, void *buf)
{
  socket_broadcast_send(nbytes, buf);
  rgn_barrier();
}

void rgn_bcast_recv(int nbytes, void *buf)
{
  socket_broadcast_recv(nbytes, buf);
  rgn_barrier();
}

static void barrier_begin(int node)
{
#if defined(CRL_DEBUG)
  printf("Barrier reached by node %d\n",node);
  fflush(stdout);
#endif
  barrier_count++;
}

static void barrier_end(int node)
{
#if defined(CRL_DEBUG)
  printf("Barrier ended by node %d\n",node);
  fflush(stdout);
#endif
  barrier_dropped = 1;
}

void rgn_barrier(void)
{
  int           oldmask;
  int           i;
  char          msg[2]="t";

  if(crl_self_addr == 0) {
    oldmask = crl_sigblock(sigmask(SIGIO));

    /* Count self */
    barrier_count++;

#if defined(CRL_DEBUG)
    printf("Barrier count is %d\n",barrier_count);
    fflush(stdout);
#endif

    /* Wait for everybody else */
    while(barrier_count < crl_num_nodes) {
#if defined(CRL_DEBUG)
      printf("Barrier count is %d\n",barrier_count);
      fflush(stdout);
#endif
      sigpause(0);
    }

#if defined(CRL_DEBUG)
    printf("Barrier dropped.  Barrier count is %d\n",barrier_count);
    fflush(stdout);
#endif

    barrier_count = 0;

    /* Tell everybody to continue */
    for(i=1;i<crl_num_nodes;i++) {
#if defined(CRL_DEBUG)
      printf("Told node %d to continue\n",i);
      fflush(stdout);
#endif
      tcp_am_send(i,barrier_end,2,msg);
    }

    crl_sigsetmask(oldmask);
#if defined(CRL_DEBUG)
    printf("Done with barrier\n");
    fflush(stdout);
#endif
  }
  else {
    barrier_dropped = 0;

    /* Send message to node 0 */
#if defined(CRL_DEBUG)
    printf("Sending begin message\n");
    fflush(stdout);
#endif
    tcp_am_send(0,barrier_begin,2,msg);

    /* Wait for everybody else */
#if defined(CRL_DEBUG)
    printf("Waiting for others\n");
    fflush(stdout);
#endif
    oldmask = crl_sigblock(sigmask(SIGIO));

    while(!barrier_dropped)
      sigpause(0);

    crl_sigsetmask(oldmask);
#if defined(CRL_DEBUG)
    printf("Done with barrier\n");
    fflush(stdout);
#endif
  }
}
