/*  strassen.c   by C.W.Kessler 06/99 */

#include <fork.h>
#include <syscall.h>
#include <io.h>
#include <assert.h>

typedef struct { 
  int nrows;       // number of rows
  int ncols;       // number of columns (or max. #cols if sparse)
  float **row;     // array of pointers to first element of each row
  float *data;     // points to linearized array holding all elements
#ifdef SPARSE
  int *n;          // array of sizes of #nonzeros per row (if sparse) 
  int nz;          // number of stored nonzero elements (sparse)
  int *col;        // array of column indices for the nonzeros (sparse)
#endif
} matrix, *Matrix;

// constructors:
extern sync Matrix newMatrix( sh int nr, sh int nc );
//extern sync Matrix copyMatrix( sh Matrix M );
extern sync Matrix *splitMatrix( sh Matrix M, int x, int y );
// destructor:
extern void freeMatrix( Matrix M );
// output:
extern void printMatrix( Matrix M );
// accessing single matrix elements:
//macro for void Set( Matrix M, int i, int j, float value ):
#define Set( M, i, j, value ) \
           (M)->row[i][j] = (value)
// macro for float Get( Matrix M, int i, int j ):
#define Get( M, i, j ) \
           ((M)->row[i][j])
// arithmetics:
extern sync void Add( sh Matrix A, sh Matrix B, sh Matrix C );
extern sync void Sub( sh Matrix A, sh Matrix B, sh Matrix C );
extern sync void Mul( sh Matrix A, sh Matrix B, sh Matrix C );  // standard method
extern sync void StrassenMul( sh Matrix A, sh Matrix B, sh Matrix C );

// All sync routines in this program expect that
// the group's processors are consecutively numbered 0,1,...,p-1.


sync Matrix newMatrix( sh int nr, sh int nc )
{
 sh Matrix M;
 sh int p = groupsize();
 sh float *buf;
 int i;
 
 seq {
   assert( nr > 0 );
   assert( nc > 0 );
   M = (Matrix)shmalloc( sizeof(matrix) );
   M->row = (float **)shmalloc( nr * sizeof(float *) );
   buf = (float *)shmalloc( nr * nc * sizeof(float) );
 }
 M->nrows = nr;
 M->ncols = nc;
 farm
   forall( i, 0, nr, p )
     M->row[i] = buf + i * nc;  
 M->data = buf;
 return M;
} 


void printMatrix( Matrix M )
{
 int m = M->nrows;
 int n = M->ncols;
 int i,j;
 printf("Matrix %dX%d:\n", m, n );
 if (m > 8)  return;
 for (i=0; i<m; i++) {
    for (j=0; j<n; j++)
       printf(" %2.2f", Get( M,i,j));
    printf("\n");
 }
 printf("\n");
}


void freeMatrix( Matrix M )
{
 assert( M );
 shfree(M->data);
 shfree(M->row);
 shfree(M);
}


sync Matrix *splitMatrix( sh Matrix M, int x, int y ) 
{
 // create 4 quarter-submatrix views of M (no copy of data buffer)
 sh Matrix *sM;
 sh int p = groupsize();
 sh float *buf;
 sh int m, n;
 int i;
 
 seq {
   assert( M );
   assert( x > 0 );
   assert( y > 0 );
   m = M->nrows;
   n = M->ncols;
   assert( x < m );
   assert( y < n );
   //pprintf("splitMatrix( %dx%d -> %dx%d etc.\n", m, n, x, y );
   sM = (Matrix *)shmalloc( 4 * sizeof(Matrix) );
 }
 farm
   forall( i, 0, 4, p )
     sM[i] = (Matrix)shmalloc( sizeof(matrix) );
 seq {
   sM[0]->nrows = x;
   sM[0]->ncols = y;
   sM[0]->row = (float **)shmalloc( x * sizeof(float *) );
   // sM[0]->data = M->data;
   sM[1]->nrows = x;
   sM[1]->ncols = n-y;
   sM[1]->row = (float **)shmalloc( x * sizeof(float *) );
   // sM[1]->data = M->data;
   sM[2]->nrows = m-x;
   sM[2]->ncols = y;
   sM[2]->row = (float **)shmalloc( (m-x) * sizeof(float *) );
   // sM[2]->data = M->data;
   sM[3]->nrows = m-x;
   sM[3]->ncols = n-y;
   sM[3]->row = (float **)shmalloc( (m-x) * sizeof(float *) );
   // sM[3]->data = M->data;
 }
 farm {
   forall( i, 0, x, p ) {
      sM[0]->row[i] = M->row[i];
      sM[1]->row[i] = M->row[i] + y;
   }
   forall( i, 0, m-x, p ) {
      sM[2]->row[i] = M->row[i+x];
      sM[3]->row[i] = M->row[i+x] + y;
   }
 }
 return sM;
}



sync void Add( sh Matrix A, sh Matrix B, sh Matrix C )
{
 sh int p = groupsize();
 sh int n, m;
 int ij, i, j;

 farm {
   assert(A);
   assert(B);
   assert(C);
   n = A->nrows;
   m = A->ncols;
   assert( n == B->nrows );
   assert( n == C->nrows );
   assert( m == B->ncols );
   assert( m == C->ncols );
   forall( ij, 0, n*m, p ) {
     i = ij / n;
     j = ij % n;
     Set( C, i,j, Get(A,i,j) + Get(B,i,j) );
   }
 }
}


sync void Sub( sh Matrix A, sh Matrix B, sh Matrix C )
{
 sh int p = groupsize();
 sh int n = A->nrows;
 sh int m = A->ncols;
 int ij, i, j;

 farm {
   assert( n == B->nrows );
   assert( m == B->ncols );
   forall( ij, 0, n*m, p ) {
     i = ij / n;
     j = ij % n;
     Set( C, i,j, Get(A,i,j) - Get(B,i,j) );
   }
 }
}


sync void Mul( sh Matrix A, sh Matrix B, sh Matrix C )
{  // standard method
 sh int p = groupsize();
 sh int n, m, r;
 int ij, i, j, k;
 float s;
 seq {
   assert(A);
   assert(B);
   assert(C);
   n = A->nrows;
   m = B->ncols;
   r = A->ncols;
   assert( r == B->nrows );
   assert( n == C->nrows );
   assert( m == C->ncols );
 }
 farm {
   forall( ij, 0, n*m, p ) {
     i = ij / m;
     j = ij % m;
     s = 0.0;
     for (k=0; k<r; k++)
        s += Get(A,i,k) * Get(B,k,j);
     Set( C, i,j, s );
   }
 }
}



/* comp1(), ..., comp7()  implement Strassen's set of formulae: */

sync Matrix comp1( sh Matrix a11, sh Matrix a22,
                   sh Matrix b11, sh Matrix b22 )
{
  sh int ndiv2 = a11->nrows;
  sh Matrix t1 = newMatrix( ndiv2, ndiv2 );
  sh Matrix t2 = newMatrix( ndiv2, ndiv2 );
  sh Matrix q1 = newMatrix( ndiv2, ndiv2 );
  Add( a11, a22, t1 );
  Add( b11, b22, t2 );
  StrassenMul( t1, t2, q1 );
  seq { freeMatrix( t1 ); freeMatrix( t2 ); }
  return q1;
}


sync Matrix comp2( sh Matrix a21, sh Matrix a22, sh Matrix b11 )
{
  sh int ndiv2 = a21->nrows;
  sh Matrix t1 = newMatrix( ndiv2, ndiv2 );
  sh Matrix q2 = newMatrix( ndiv2, ndiv2 );
  Add( a21, a22, t1 );
  StrassenMul( t1, b11, q2 );
  seq freeMatrix( t1 );
  return q2;
}
  

sync Matrix comp5( sh Matrix a11, sh Matrix a12, sh Matrix b22 )
{
  sh int ndiv2 = a11->nrows;
  sh Matrix t1 = newMatrix( ndiv2, ndiv2 );
  sh Matrix q5 = newMatrix( ndiv2, ndiv2 );
  Add( a11, a12, t1 );
  StrassenMul( t1, b22, q5 );
  seq freeMatrix( t1 );
  return q5;
}
  

sync Matrix comp4( sh Matrix a22, sh Matrix b11, sh Matrix b21 )
{
  sh int ndiv2 = a22->nrows;
  sh Matrix t1 = newMatrix( ndiv2, ndiv2 );
  sh Matrix q4 = newMatrix( ndiv2, ndiv2 );
  Sub( b21, b11, t1 );
  StrassenMul( a22, t1, q4 );
  seq freeMatrix( t1 );
  return q4;
}


sync Matrix comp6( sh Matrix a11, sh Matrix a21,
                   sh Matrix b11, sh Matrix b12 )
{
  sh int ndiv2 = a11->nrows;
  sh Matrix t1 = newMatrix( ndiv2, ndiv2 );
  sh Matrix t2 = newMatrix( ndiv2, ndiv2 );
  sh Matrix q6 = newMatrix( ndiv2, ndiv2 );
  Sub( a21, a11, t1 );
  Add( b11, b12, t2 );
  StrassenMul( t1, t2, q6 );
  seq { freeMatrix( t1 ); freeMatrix( t2 ); }
  return q6;
}


sync Matrix comp3( sh Matrix a11, sh Matrix b12, sh Matrix b22 )
{
  sh int ndiv2 = a11->nrows;
  sh Matrix t1 = newMatrix( ndiv2, ndiv2 );
  sh Matrix q3 = newMatrix( ndiv2, ndiv2 );
  Sub( b12, b22, t1 );
  StrassenMul( a11, t1, q3 );
  seq freeMatrix( t1 );
  return q3;
}


sync Matrix comp7( sh Matrix a12, sh Matrix a22,
                   sh Matrix b21, sh Matrix b22 )
{
  sh int ndiv2 = a12->nrows;
  sh Matrix t1 = newMatrix( ndiv2, ndiv2 );
  sh Matrix t2 = newMatrix( ndiv2, ndiv2 );
  sh Matrix q7 = newMatrix( ndiv2, ndiv2 );
  Sub( a12, a22, t1 );
  Add( b21, b22, t2 );
  StrassenMul( t1, t2, q7 );
  seq { freeMatrix( t1 ); freeMatrix( t2 ); }
  return q7;
}


sync int issmall( sh int n )
{
  if (n <= 2) return 1;
  else        return 0;
}


sync void StrassenMul( sh Matrix a, sh Matrix b, sh Matrix c )
{
  // Matrix c is pre-allocated
  // works only for square matrices a, b where n=m contains a power of 2
  sh int n = a->nrows;
  sh int ndiv2 = n / 2;
  sh Matrix *sa, *sb, *sc;
  sh Matrix a11, a12, a21, a22;
  sh Matrix b11, b12, b21, b22;
  sh Matrix c11, c12, c21, c22;
  sh Matrix q1, q2, q3, q4, q5, q6, q7;
  sh Matrix t11, t12, t21, t22;
  sh int p = groupsize();
  sh int issm;

  //seq pprintf("Strassen %d with %d procs\n", n, p );
  seq assert(a);
  seq assert(b);
  seq assert(c);
  if (p < 7) {    /* not enough processors available: */
     Mul( a, b, c );
     return;
  }
  issm = issmall(n);
  if (issm) {       /* the trivial case */
     Mul( a, b, c );
     return;
  }

  // create 4 quarter-submatrix views of each matrix:
  sa = splitMatrix( a, ndiv2, ndiv2 ); 
  a11 = sa[0]; a12 = sa[1]; a21 = sa[2]; a22 = sa[3];
  sb = splitMatrix( b, ndiv2, ndiv2 ); 
  b11 = sb[0]; b12 = sb[1]; b21 = sb[2]; b22 = sb[3];
  sc = splitMatrix( c, ndiv2, ndiv2 ); 
  c11 = sc[0]; c12 = sc[1]; c21 = sc[2]; c22 = sc[3];
  
  //seq printf("FORK 7 for %d\n", p);

  fork ( 7; @ = $%7; $=$/7 )
  { // the @+1-th subgroup computes q@:
    if (@==0)       q1 = comp1( a11, a22, b11, b22 );
    else if (@==1)  q2 = comp2( a21, a22, b11 );
    else if (@==2)  q3 = comp3( a11, b12, b22 );
    else if (@==3)  q4 = comp4( a22, b11, b21 );
    else if (@==4)  q5 = comp5( a21, a22, b11 );
    else if (@==5)  q6 = comp6( a11, a21, b11, b12 );
    else            q7 = comp7( a12, a22, b21, b22 );
  } /* end of fork */
  
  /* explicitly reopen 2 subgroups: */
  // seq printf("FORK 2 for %d\n", p);

  fork( 2; @ = $%2; $ = $/2 ) {
    if (@==0) {
      // the first subgroup computes c11 and c21
      t11 = newMatrix( ndiv2, ndiv2 );
      t12 = newMatrix( ndiv2, ndiv2 );
      Add( q1, q4, t11 );
      Sub( q7, q5, t12 );
      Add( t11, t12, c11 );
      Add( q2, q4, c21 );
      seq { freeMatrix( t11 ); freeMatrix( t12 ); 
            freeMatrix( q4 );  freeMatrix( q7 );  }
    }
    else {             /* the second subgroup computes c12 and c22 */
                      t21 = newMatrix( ndiv2, ndiv2 );
                      t22 = newMatrix( ndiv2, ndiv2 );
                      Add( q3, q5, c12 );
                      Add( q1, q3, t21 );
                      Sub( q6, q2, t22 );
                      Add( t21, t22, c22 );
                      seq { freeMatrix( t21 ); freeMatrix( t22 );
                            freeMatrix( q3 ); freeMatrix( q6 ); }
    }
  }
  seq { freeMatrix( q1 ); freeMatrix( q2 ); freeMatrix( q5 ); }
}
 

void main( void ) 
{
 start {  // if ($<343)       /* this starts 7^3 processors */
   sh int N = 8;  // must be a power of 2 or at least contain 2^(log_7 p)
   sh Matrix A, B, C;
   sh int p = groupsize();
   sh int t;
   int i, j;
  
   seq {
      prS("STRASSEN MATRIX-MATRIX MULTIPLICATION\n\n");
      printf("Enter matrix size N: ");
      scanf("%d", &N);
   }
   A = newMatrix( N, N );
   B = newMatrix( N, N );
   C = newMatrix( N, N );
   /* preset the input array: */
   farm
      forall( i, 0, N, p )
         for (j=0; j<N; j++) { 
            Set(A, i,j, 1.0);
            Set(B, i,j, 2.0);
         }
   seq {
      prS("\nMatrix A:\n");
      printMatrix( A );
      prS("\nMatrix B:\n");
      printMatrix( B );
      t = getct();
   }
   StrassenMul( A, B, C );
   seq {
      t = getct() - t;
      prS("\nMatrix Product:\n");
      printMatrix( C );
      printf("\nTime: %d msecs\n", t>>8 );
      t = getct();
   } 
   Mul( A, B, C );
   seq {
      t = getct() - t;
      prS("\nStandard Matrix Product:\n");
      printMatrix( C );
      printf("\nO(n^3) Time: %d msecs\n<press key to continue>", t>>8 );
      scanf("%c", &t);
   } 
 }
}

