/**************************************************************************
*                                                                         *
*  Author      : Dr. Thomas Brandes, GMD, SCAI.LAB                        *
*                                                                         *
*  Copyright   : GMD St. Augustin, Germany                                *
*  Date        : Mar 98                                                   *
*  Last Update : Mar 98                                                   *
*                                                                         *
*  This Module is part of the DALIB                                       *
*                                                                         *
*  Module      : loc_section.m4                                           *
*                                                                         *
*  Function:  managing descriptors for local parts of array sections      *
*                                                                         *
*  Note:  local_range in section_info is not sufficient for arbitrary     *
*         distributions                                                   *
*                                                                         *
***************************************************************************/

#include "dalib.h"

#undef DEBUG

/**************************************************************************
*                                                                         *
*   LocalDimInfo : describes information about one index                  *
*                                                                         *
*   -1 : undefined                                                        *
*    0 : single element                                                   *
*    1 : slice or range                                                   *
*    2 : arbitray indexes                                                 *
*                                                                         *
**************************************************************************/

typedef struct

  { int kind;              /* 0 stands for single element   */
 
    int range [3];         /* for kind = 1 */

    int no_indexes;        /* for kind = 2 */

    int *indexes;

  } LocalDimInfo;

typedef struct

   { array_info array_id;

     LocalDimInfo dimensions [MAX_DIMENSIONS];

   } LocalRecord;

typedef LocalRecord *local_info;

/**************************************************************************
*                                                                         *
*                                                                         *
*                                                                         *
***************************************************************************/

void dalib_print_local_info (local_id)

local_info local_id;

{ array_info array_id;

  LocalDimInfo  *local_dims;
  DimInfo       *array_dims;

  int i, rank;
  int kind;

  array_id = local_id->array_id;
  rank = array_id->rank;

  printf ("%d: local part (dsp=%d,rank=%d) is ",
           pcb.i, local_id, rank);

  local_dims = local_id->dimensions;
  array_dims = array_id->dimensions;

  for (i=0; i<rank; i++)

    { kind = local_dims->kind;

      if (kind == 0)

         printf ("%d", local_dims->range[0]);

       else if (kind == 1)

         printf ("%d:%d:%d", local_dims->range[0],
                  local_dims->range[1], local_dims->range[2]);

       else

         { int j;

           for (j=0; j<local_dims->no_indexes; j++)
             printf ("%d,", local_dims->indexes[j]);
         }

      printf (" of %d:%d:%d in %d:%d ", array_dims->local_size[0],
                array_dims->local_size[1], array_dims->local_size[2],
                array_dims->global_size[0], array_dims->global_size[1]);

      if (i<rank-1) printf (",");

      local_dims++; array_dims++;
    }

  printf (")\n");

} /* dalib_print_local_info */

/**************************************************************************
*                                                                         *
*                                                                         *
*                                                                         *
***************************************************************************/

static int dalib_local_dsp_size (rank)

int rank;

{ int save;

  save = (MAX_DIMENSIONS - rank) * sizeof (LocalDimInfo);

  return sizeof (LocalRecord) - save;

} /* dalib_section_dsp_size */

/**************************************************************************
*                                                                         *
*  void dalib_local_array (local_info local_id, array_info array_id)      *
*                                                                         *
*  - sets the local descriptor for a full array                           *
*                                                                         *
**************************************************************************/

void dalib_local_array (local_id, array_id)

local_info  local_id;
array_info  array_id;

{ DimInfo      *adims;
  LocalDimInfo *ldims;

  int i, rank;

  rank = array_id->rank;

  adims = array_id->dimensions;
  ldims = local_id->dimensions;

  for (i=0; i<rank; i++)

     { ldims->kind = 1;
       ldims->range[0] = adims->local_size[0];
       ldims->range[1] = adims->local_size[1];
       ldims->range[2] = 1;
       ldims++; adims++;
     }

} /* dalib_local_array */

/****************************************************************************
*                                                                           *
*  void dalib_local_section (local_info local_id, section_info section_id)  *
*                                                                           *
*  - sets the local descriptor for a section                                *
*                                                                           *
****************************************************************************/

void dalib_local_section (local_id, section_id)

local_info   local_id;
section_info section_id;

{ DimInfo      *adims;
  SecDimInfo   *sdims;
  LocalDimInfo *ldims;

  array_info   array_id;

  int i, rank;

  /* we need also the array descriptor */

  array_id = section_id->array_id;

  rank = array_id->rank;

  adims = array_id->dimensions;
  sdims = section_id->dimensions;
  ldims = local_id->dimensions;

  for (i=0; i<rank; i++)

     { if (adims->map_flag)

          { /* the dimension is mapped so we need a mapping from the
               global range to the local range                       */

            ldims->kind = -1;   /* make local dimension undefined */

          }

         else

          { ldims->kind = 0;

            if (sdims->is_range) ldims->kind = 1;

            dalib_intersect_sections (sdims->global_range, 
                                      adims->local_size,
                                      ldims->range);
          }

       ldims++; sdims++; adims++;
     }

} /* dalib_local_section */

/**************************************************************************
*                                                                         *
*  void dalib_local_dsp_create (local_info *local_,                       *
*                               section_info section_id)                  *
*                                                                         *
*  - creating a new descriptor for a local part of a section              *
*  - takes over the information of local range of the section             *
*                                                                         *
**************************************************************************/

void dalib_local_dsp_create (local_id, section_id)

local_info   *local_id;
section_info section_id;

{ int i, rank;
  array_info array_id;
  local_info dsp;

  SecDimInfo   *sdims;
  LocalDimInfo *ldims;

  if (dalib_is_array_info (section_id))

     { dalib_internal_error ("local_dsp_create: only for sections");
       dalib_stop ();
     }

  if (!dalib_is_section_info (section_id))

     { dalib_internal_error ("local_dsp_create: not a section");
       dalib_stop ();
     }

  array_id = section_id->array_id;
  rank = array_id->rank;

  dsp = (local_info)
    dalib_malloc (dalib_local_dsp_size (rank), "local_dsp_create");

  dsp->array_id = array_id;

  sdims = section_id->dimensions;
  ldims = dsp->dimensions;

  for (i=0; i<rank; i++)

     { ldims->kind = 0;
       if (sdims->is_range) ldims->kind = 1;
       ldims->range[0] = sdims->local_range[0];
       ldims->range[1] = sdims->local_range[1];
       ldims->range[2] = sdims->local_range[2];
       ldims++; sdims++;
     }

  *local_id = dsp;

#ifdef DEBUG
  printf ("%d: local descriptor (addr = %d) of rank %d created\n", 
           pcb.i, dsp, rank);
#endif

} /* dalib_local_dsp_create */

/**************************************************************************
*                                                                         *
*  void dalib_local_create (local_info *local_,d section_info section_id) *
*                                                                         *
*  - creating a new descriptor for a local part of a section              *
*                                                                         *
***************************************************************************/

void dalib_local_dsp_free (local_id)

local_info local_id;

{
  dalib_free (local_id, dalib_local_dsp_size (local_id->array_id->rank));

} /* dalib_local_dsp_free */

/**************************************************************************
*                                                                         *
*                                                                         *
*                                                                         *
***************************************************************************/

void dalib_local_set_range (local_id, dim, lb, ub, str)

local_info local_id;

int dim, lb, ub, str;

{ LocalDimInfo *local_dim;

  local_dim = local_id->dimensions + (dim-1);

  local_dim->kind = 1;
  local_dim->range[0] = lb;
  local_dim->range[1] = ub;
  local_dim->range[2] = str;

} /* dalib_local_set_range */

/**************************************************************************
*                                                                         *
*                                                                         *
*                                                                         *
***************************************************************************/

void dalib_local_set_dim (local_id, dim, val)

local_info local_id;
int dim;
int val;

{ LocalDimInfo *local_dim;

  local_dim = local_id->dimensions + (dim-1);

  local_dim->kind = 0;
  local_dim->range[0] = val;
  local_dim->range[1] = val;

} /* dalib_local_set_dim */

/**************************************************************************
*                                                                         *
*                                                                         *
*                                                                         *
***************************************************************************/

void dalib_local_set_indirect (local_id, dim, N, indexes)

local_info local_id;
int        dim;
int        N;
int        indexes[];

{ LocalDimInfo *local_dim;

  local_dim = local_id->dimensions + (dim-1);

  local_dim->kind = 2;
  local_dim->no_indexes = N;
  local_dim->indexes    = indexes;

} /* dalib_local_set_indirect */

/**************************************************************************
*                                                                         *
*                                                                         *
*                                                                         *
***************************************************************************/

static void dalib_apply_int_perm (n, data, perm)

int n;
int data[];
int perm[];

{ int help [MAX_DIMENSIONS];
  int i;

  for (i=0; i<n; i++) help [i] = data[i];
  for (i=0; i<n; i++) data [i] = help[perm[i]];

} 

/**************************************************************************
*                                                                         *
*                                                                         *
*                                                                         *
***************************************************************************/

static void dalib_apply_addr_perm (n, data, perm)

int n;
int *data[];
int perm[];

{ int *help [MAX_DIMENSIONS];
  int i;

  for (i=0; i<n; i++) help [i] = data[i];
  for (i=0; i<n; i++) data [i] = help[perm[i]];

} 

/**************************************************************************
*                                                                         *
*                                                                         *
*                                                                         *
***************************************************************************/

static void get_dim (gdim, sdim, size, srank, 
                     kind, inc, n, vals, first, zero_section)

DimInfo *gdim;
LocalDimInfo *sdim;
int size;             /* diff for two succesive elements in this dim */
int *srank;           /* increments for rank of section */
int inc[], n[];
int kind[];
int *vals[];
int *first, *zero_section;

{  int lb, ub, str, sn;

   if (sdim->kind == 0)

      { int val;

        /* A (....,val,....), only first has to be updated */

        val = sdim->range[0];
        *first += val * size;

        /* QUESTION: have we to make sure that val is really local */

        return;
      }


   if (sdim->kind == 1)

     { int lb, ub, str;

       /* first  will be address of A (lb1,lb2,...,lbn) */

       lb   = sdim->range[0];
       ub   = sdim->range[1];
       str  = sdim->range[2];

       *first += lb * size;

       sn   = dalib_range_size (lb, ub, str);

       if (sn == 0)

         { *zero_section = 0;
           n [*srank] = 0;
           (*srank)++;
         }

        else

        { inc[*srank] = size * str;
          n  [*srank] = sn;              /* local section size */
        }

       kind[*srank] = 1;
       (*srank)++;

      }

   if (sdim->kind == 2)

      { /* indirect addressing */

        n   [*srank] = sdim->no_indexes;
        inc [*srank] = size;
        vals[*srank] = sdim->indexes;
        kind[*srank] = 2;

        (*srank)++;

      }

}  /* get_dim */

/**************************************************************************
*                                                                         *
*                                                                         *
*                                                                         *
**************************************************************************/

void dalib_split_indirect (section_rank, local_rank, 
                           local_kind, local_N, local_val, local_inc)

int *section_rank;
int local_rank;
int local_kind[];
int local_N[];
int local_inc[];
int *local_val[];

{ int i;

  int val;
  int *ind;

  *section_rank = 0;

  for (i=0; i<local_rank; i++)

    { if (local_kind[i] == 1)

        { /* is a range dimension */

          val = local_inc[*section_rank];
          local_inc[*section_rank] = local_inc[i];
          local_inc[i] = val;

          val = local_N[*section_rank];
          local_N[*section_rank] = local_N[i];
          local_N[i] = val;
       
          ind = local_val[*section_rank];
          local_val[*section_rank] = local_val[i];
          local_val[i] = ind;

          (*section_rank)++;
        }
    }
}

/**************************************************************************
*                                                                         *
*                                                                         *
*                                                                         *
**************************************************************************/

void dalib_make_perm_local_ddt (new_ddt, local_id, perm)

dd_type    *new_ddt;
local_info local_id;
int        perm[];

{ int local_kind [MAX_DIMENSIONS];
  int local_inc  [MAX_DIMENSIONS];
  int local_N    [MAX_DIMENSIONS];
  int *local_val [MAX_DIMENSIONS];

  array_info array_id;
  int        array_rank;
  char       *array_data;   /* not needed, dummy */
  DimInfo    *array_dims;

  int        array_first;
  int        array_total[MAX_DIMENSIONS+1];

  int        i;
  int        zero_section;

  LocalDimInfo *local_dims;
  int local_rank; 
  int section_rank;
  int array_size;

  array_id  = local_id->array_id; 

  dalib_array_addressing (array_id, pcb.i,
                          &array_data, &array_first, array_total);

  array_dims = array_id->dimensions;
  array_rank = array_id->rank;
  array_size = array_id->size;

#ifdef DEBUG
  printf ("%d: make_perm_local_ddt\n", pcb.i);
  dalib_print_local_info (local_id);
  printf ("%d: array size = %d,  first = %d, total = ", 
          pcb.i, array_size, array_first);
  for (i=0; i<=array_rank; i++) printf ("%d ", array_total[i]);
  printf ("\n");
#endif 

  array_first  = -array_first;   /* subtract it from start point  */
  zero_section = 1;              /* not a zero section by default */

  local_dims = local_id->dimensions;
  local_rank = 0;   /* increase only if not fixed element */

  for (i=0; i<array_rank; i++)

    { get_dim (array_dims++, local_dims++, array_total[i], &local_rank, 
               local_kind, local_inc, local_N, local_val,
               &array_first, &zero_section);
    }

  array_data += array_first * array_size;

  if (perm != (int *) 0)

     { /* apply permutation to all the arrays */

       dalib_apply_int_perm  (local_rank, local_kind, perm); 
       dalib_apply_int_perm  (local_rank, local_inc,  perm);
       dalib_apply_int_perm  (local_rank, local_N,    perm);
       dalib_apply_addr_perm (local_rank, local_val,  perm); 

     } /* permutations have been applied */

  if (zero_section == 0)

     { dalib_ddt_def_section (new_ddt, array_data, 0, 0); 
       return;
     }

#ifdef DEBUG
  printf ("make local ddt, first = %d\n", array_first);

  for (i=0; i<local_rank; i++)

    { if (local_kind[i] == 2)
         printf ("dim = %d, indirect, N = %d, inc = %d\n", 
                  i+1, local_N[i], local_inc[i]);
       else
         printf ("dim = %d, range, N = %d, inc = %d\n", 
                  i+1, local_N[i], local_inc[i]);
    }
#endif

   dalib_split_indirect (&section_rank, local_rank, local_kind,
                         local_N, local_inc, local_val); 

#ifdef DEBUG
   printf ("split indirect, section rank = %d\n", section_rank);
#endif

   dalib_ddt_def_section (new_ddt, array_data, array_size, section_rank,
                          local_inc, local_N);

   for (i=section_rank; i<local_rank; i++)

      dalib_ddt_def_tensor (new_ddt, *new_ddt, array_size, local_inc[i],
                            local_N[i], local_val[i]);

   dalib_ddt_set_data (*new_ddt, array_data, array_size);

} /* make_perm_local_ddt */

/**************************************************************************
*                                                                         *
*                                                                         *
*                                                                         *
**************************************************************************/

void dalib_local_perm_send (pid, local_id, permutation)

local_info local_id;
int pid;
int permutation [];

{ dd_type send_ddt;

  dalib_make_perm_local_ddt (&send_ddt, local_id, permutation);
  dalib_send_ddt (pid, send_ddt);
  dalib_ddt_free (send_ddt);

} /* dalib_local_perm_send */

/**************************************************************************
*                                                                         *
*                                                                         *
*                                                                         *
**************************************************************************/

void dalib_local_send (pid, local_id)

local_info local_id;
int pid;

{ dd_type send_ddt;

  dalib_make_perm_local_ddt (&send_ddt, local_id, (int *) 0);

#ifdef DEBUG
  printf ("%d: send local part to %d\n", pcb.i, pid);
  dalib_ddt_print (send_ddt);
#endif

  dalib_send_ddt (pid, send_ddt);
  dalib_ddt_free (send_ddt);

} /* dalib_local_send */

/**************************************************************************
*                                                                         *
*                                                                         *
*                                                                         *
**************************************************************************/

void dalib_local_recv (pid, local_id)

local_info local_id;
int pid;

{ dd_type recv_ddt;

  dalib_make_perm_local_ddt (&recv_ddt, local_id, (int *) 0);

#ifdef DEBUG
  printf ("%d: receive local part from %d\n", pcb.i, pid);
  dalib_ddt_print (recv_ddt);
#endif

  dalib_recv_ddt_op (pid, recv_ddt, 0);
  dalib_ddt_free (recv_ddt);

} /* dalib_local_recv */


