/*  strassen.c   by C.W.Kessler 03/95 */

#include <fork.h>
#include <syscall.h>
#include <io.h>
#include <assert.h>

extern sync void output_array( sh int *, sh int );
extern sync void strassen_mult( sh int *, sh int, sh int *, sh int,
                                sh int *, sh int, sh int );
async void mult_directly( int *a, int sa, int *b, int sb, int *c, int sc, int n );

sh int *A, *B, *C;
sh int N = 16;      /* must be a power of 2 */
sh int WLOCK = 0;   /* simple lock */


main() {
 pr int i, j, p;
 start if ($<343) {      /* this is 7^3 */
   A = (int *) shalloc(N*N);
   B = (int *) shalloc(N*N);
   C = (int *) shalloc(N*N);
   /* preset the input array: */
   farm if ($==0) prS("STRASSEN-MATRIXMULTIPLIKATION\n\n");
   /* initialize the output array: */
   p = groupsize();
   farm
      for (i=$; i<N; i+=p )
         for (j=0; j<N; j++) 
            { A[i*N+j] = 1; B[i*N+j] = 2; }
   farm if ($==0) prS("\nArray A:\n");
   output_array( A, N );                  /*print the original array*/
   farm if ($==0) prS("\nArray B:\n");
   output_array( B, N );                  /*print the original array*/
   strassen_mult( A, N, B, N, C, N, N ); 
   farm if ($==0) prS("\nMatrix Product:\n");
   output_array( C, N );                  /*print the resulting array*/
 }
}


sync void add(
   sh int *a, sh int sa,      /* matrix and its allocated extent */
   sh int *b, sh int sb,
   sh int *c, sh int sc,
   sh int n )                 /* problem size */
{ /* n x n - add two arrays a, b: */
  pr int i, j;
  sh int p = groupsize();
  farm
   for(i=$; i<n; i+=p)    /* parallel loop */
     for (j=0; j<n; j++) {
        c[i*sc + j] = a[i*sa + j] + b[i*sb + j];
     }
}


sync void inv( sh int *a, sh int sa, sh int *c, sh int sc, sh int n )
{
  pr int i, j;
  sh int p = groupsize();
  farm 
   for(i=$; i<n; i+=p)    /* parallel loop */
     for (j=0; j<n; j++) 
        c[i*sc+j] = - a[i*sa+j];
}


async void mult_directly( int *a, int sa, int *b, int sb, int *c, int sc, int n )
{
  pr int i, j, k;           /*sequential matrix multiplication*/
  pr int val;

  for(i=0; i<n; i++)
     for (j=0; j<n; j++) {
        val = 0;
        for (k=0; k<n; k++)
          val += a[i + k*n] * b[k + j*n];
        c[i*n + j] = val;
     }
}
          

/* comp1(), ..., comp7()  install Strassen's set of formulae: */


sync void comp1( sh int *a11, sh int *a22, sh int sa,
                 sh int *b11, sh int *b22, sh int sb,
                 sh int *q1, sh int ndiv2 )
{
  sh int *t1 = (int *) shalloc(ndiv2 * ndiv2);
  sh int *t2 = (int *) shalloc(ndiv2 * ndiv2);
  add( a11, sa, a22, sa, t1, ndiv2, ndiv2 ); 
  add( b11, sb, b22, sb, t2, ndiv2, ndiv2 ); 
  strassen_mult( t1, ndiv2, t2, ndiv2, q1, ndiv2, ndiv2 );
}
  

sync void comp2( sh int *a21, sh int *a22, sh int sa,
                 sh int *b11, sh int sb,
                 sh int *q2, sh int ndiv2 )
{
  sh int *t1 = (int *) shalloc(ndiv2 * ndiv2);
  add( a21, sa, a22, sa, t1, ndiv2, ndiv2 ); 
  strassen_mult( t1, ndiv2, b11, sb, q2, ndiv2, ndiv2 );
}
  

sync void comp5( sh int *a11, sh int *a12, sh int sa,
                 sh int *b22, sh int sb,
                 sh int *q5, sh int ndiv2 )
{
  sh int *t1 = (int *) shalloc(ndiv2 * ndiv2);
  add( a11, sa, a12, sa, t1, ndiv2, ndiv2 );
  strassen_mult( t1, ndiv2, b22, sb, q5, ndiv2, ndiv2 );
}
  

sync void comp4( sh int *a22, sh int sa,
                 sh int *b11, sh int *b21, sh int sb,
                 sh int *q4, sh int ndiv2 )
{
  sh int *t1 = (int *) shalloc(ndiv2 * ndiv2);
  sh int *t2 = (int *) shalloc(ndiv2 * ndiv2);
  inv( b11, sb, t1, ndiv2, ndiv2 );
  add( t1, ndiv2, b21, sb, t2, ndiv2, ndiv2 );
  strassen_mult( a22, sa, t2, ndiv2, q4, ndiv2, ndiv2 );
}


sync void comp6( sh int *a11, sh int *a21, sh int sa,
                 sh int *b11, sh int *b12, sh int sb,
                 sh int *q6, sh int ndiv2 )
{
  sh int *t1 = (int *) shalloc(ndiv2 * ndiv2);
  sh int *t2 = (int *) shalloc(ndiv2 * ndiv2);
  inv( a11, sa, t2, ndiv2, ndiv2 );
  add( t2, ndiv2, a21, sa, t1, ndiv2, ndiv2 );
  add( b11, sb, b12, sb, t2, ndiv2, ndiv2 );
  strassen_mult( t1, ndiv2, t2, ndiv2, q6, ndiv2, ndiv2 );
}


sync void comp3( sh int *a11, sh int sa,
                 sh int *b12, sh int *b22, sh int sb,
                 sh int *q3, sh int ndiv2 )
{
  sh int *t1 = (int *) shalloc(ndiv2 * ndiv2);
  sh int *t2 = (int *) shalloc(ndiv2 * ndiv2);
  inv( b22, sb, t1, ndiv2, ndiv2 );
  add( b12, sb, t1, ndiv2, t2, ndiv2, ndiv2 );
  strassen_mult( a11, sa, t2, ndiv2, q3, ndiv2, ndiv2 );
}


sync void comp7( sh int *a12, sh int *a22, sh int sa,
                 sh int *b21, sh int *b22, sh int sb,
                 sh int *q7, sh int ndiv2 )
{
  sh int *t1 = (int *) shalloc(ndiv2 * ndiv2);
  sh int *t2 = (int *) shalloc(ndiv2 * ndiv2);
  inv( a22, sa, t2, ndiv2, ndiv2 );
  add( t2, ndiv2, a12, sa, t1, ndiv2, ndiv2 );
  add( b21, sb, b22, sb, t2, ndiv2, ndiv2 );
  strassen_mult( t1, ndiv2, t2, ndiv2, q7, ndiv2, ndiv2 );
}


sync void strassen_mult(
  sh int *a,     /* operand array, allocated size sa x sa, extent n x n */
  sh int sa,
  sh int *b,     /* operand array, length n x n */
  sh int sb,
  sh int *c,     /* result array, length n x n */
  sh int sc,
  sh int n )     /* problem size */
{
  sh int ndiv2 = n>>1;
  sh int *a11, *a12, *a21, *a22;
  sh int *b11, *b12, *b21, *b22;
  sh int *c11, *c12, *c21, *c22;
  sh int *q1, *q2, *q3, *q4, *q5, *q6, *q7;
  sh int *t11, *t12, *t21, *t22;

  sh int p = groupsize();

  farm assert( p );

  if (p == 1) {    /* no more processors available: */
       farm mult_directly( a, sa, b, sb, c, sc, n );
       return;
  }

  if (n==1) {   /* the trivial case */
     *c = *a * *b;
     farm prI( *c, 0 );
     return;
  }

  a11 = a;               a12 = a + ndiv2;
  a21 = a + sa*ndiv2;    a22 = a + sa*ndiv2 + ndiv2; 
  b11 = b;               b12 = b + ndiv2;
  b21 = b + sb*ndiv2;    b22 = b + sb*ndiv2 + ndiv2; 
  c11 = c;               c12 = c + ndiv2;
  c21 = c + sc*ndiv2;    c22 = c + sc*ndiv2 + ndiv2; 

  q1 = (int *) shalloc(ndiv2 * ndiv2);
  q2 = (int *) shalloc(ndiv2 * ndiv2);
  q3 = (int *) shalloc(ndiv2 * ndiv2);
  q4 = (int *) shalloc(ndiv2 * ndiv2);
  q5 = (int *) shalloc(ndiv2 * ndiv2);
  q6 = (int *) shalloc(ndiv2 * ndiv2);
  q7 = (int *) shalloc(ndiv2 * ndiv2);

  farm if ($==0) prS("\nFORK\n");

  /* explicitly open 7 subgroups: */

  fork ( 7; @ = $%7; $=$/7 ) {           /* the @+1-th group computes q@: */

     if (@==0)      /* shared if condition -> no group frame construction */
        comp1( a11, a22, sa, b11, b22, sb, q1, ndiv2 );
     else
     if (@==1) 
        comp2( a21, a22, sa, b11, sb, q2, ndiv2 );
     else
     if (@==2) 
        comp3( a11, sa, b12, b22, sb, q3, ndiv2 );
     else
     if (@==3) 
        comp4( a22, sa, b11, b21, sb, q4, ndiv2 );
     else
     if (@==4) 
        comp5( a21, a22, sa, b11, sb, q5, ndiv2 );
     else
     if (@==5) 
        comp6( a11, a21, sa, b11, b12, sb, q6, ndiv2 );
     else
        comp7( a12, a22, sa, b21, b22, sb, q7, ndiv2 );
  } /* end of fork */
  
  /* explicitly reopen 2 subgroups: */

  fork( 2; @ = $%2; $ = $/2 )
      if (@==0) {    /* shared if condition */
                /* the first subgroup computes c11 and c21 */
                t11 = (int *) shalloc(ndiv2 * ndiv2);
                t12 = (int *) shalloc(ndiv2 * ndiv2);
                inv( q5, ndiv2, t11, ndiv2, ndiv2 );
                add( t11, ndiv2, q7, ndiv2, t12, ndiv2, ndiv2 );
                add( t12, ndiv2, q1, ndiv2, t11, ndiv2, ndiv2 );
                add( t11, ndiv2, q4, ndiv2, c11, sc, ndiv2 );
                add( q2, ndiv2, q4, ndiv2, c21, sc, ndiv2 );
      }
      else {             /* the second subgroup computes c12 and c22 */
                         t21 = (int *) shalloc(ndiv2 * ndiv2);
                         t22 = (int *) shalloc(ndiv2 * ndiv2);
                         inv( q2, ndiv2, t21, ndiv2, ndiv2 );
                         add( t21, ndiv2, q6, ndiv2, t22, ndiv2, ndiv2 );
                         add( t22, ndiv2, q1, ndiv2, t21, ndiv2, ndiv2 );
                         add( t21, ndiv2, q3, ndiv2, c22, sc, ndiv2 );
                         add( q3, ndiv2, q5, ndiv2, c12, sc, ndiv2 );
      }
}
 

sync void output_array (
  sh int *arr,    /* the array to print out */
  sh int n )      /* length of arr */
{
  pr int i, j;
  farm if ($==0) {
    for (i=0; i<n; i++) {
      for (j=0; j<n; j++)
       { prI( arr[i*n+j],0 ); write(1," ",1); }
      write(1,"\n",1);
    }
  }
}

