/***************************************************************************
 * RCS INFORMATION:
 *
 *	$RCSfile: generic_redn.p,v $
 *	$Author: narain $	$Locker:  $		$State: Exp $
 *	$Revision: 1.2 $	$Date: 1995/04/13 04:01:36 $
 *
 ***************************************************************************
 * DESCRIPTION:
 *
 ***************************************************************************
 * REVISION HISTORY:
 *
 * $Log: generic_redn.p,v $
 * Revision 1.2  1995/04/13  04:01:36  narain
 * Mc -> Cmi
 *
 * Revision 1.1  1994/11/07  16:11:49  brunner
 * Initial revision
 *
 ***************************************************************************/
#include "generic_redn.h"

#include intFile
#include pgIntFile

#define ReductionOperation(a,b,c,d) PrivateCall(RedOpName(a,b,c,d))

module ModuleName {

message {
    int gid ;
    ChareNumType pgBoc ;
} REDUCE_INIT ;
    
message {
    varSize DataType data[] ;
} REDN_MSG ;

message {
    int size ;
    varSize DataType data[] ;
} REDN_MSG_INTERNAL ;


typedef struct ReductionRefInstance {
    int refnum ;
    int r_type ;
    int send_result_flag ;
    int leftToCollect ;
    int is_first ;
    DataType *y ;
    DataType *r_z ;
    int numBuffered ;
    DataType *buffered[4] ; /* This should be the maximum number of children */
    int numEls ;            /*      that a node can have */
    void (* r_function)() ;
    EntryPointType r_ep ;
    ChareIDType r_cid ;
    ChareNumType r_bocnum ;
    struct ReductionRefInstance *next ;
} ReductionRefInstance ;

#define R_BY_MESSAGE  0
#define R_BY_FUNCTION 1

BranchOffice reduce {
    int me, gid, myParent, numChildren, numToCollect ;
    ChareNumType pgBoc ;
    ReductionRefInstance *refList ;

    entry init : (message REDUCE_INIT *msg)
    {
	gid = msg->gid ;
	pgBoc = msg->pgBoc ;
	
	me = CmiMyPe() ;
	if(gid == 0) {
	    myParent = CmiSpanTreeParent(me) ;
	    numChildren = CmiNumSpanTreeChildren(me) ;
	}
	else {
	    myParent = PG::PgMySpanTreeParent(pgBoc, gid) ;
	    numChildren = PG::PgMyNumSpanTreeChildren(pgBoc, gid) ;
	}
	numToCollect = numChildren+1 ; /* I send a message to myself */
	refList = NULL ;
	CkFreeMsg(msg);
    }





    entry collect : (message REDN_MSG_INTERNAL *cMsg)
    {
	int i, refnum, gotAll ;
	DataType *x ;
	private void RedOpName() ;
	private ReductionRefInstance *FindRef() ;
	private void AddRef() ;
	ReductionRefInstance *ref ;

	refnum = GetRefNumber(cMsg) ;

	ref = (ReductionRefInstance*)PrivateCall(FindRef(refnum)) ;

	/* ref may be null if I get messages from my children before
	 * I join up.  So create a base structure and go on */
	if(ref==NULL) {
	    ref=(ReductionRefInstance*)CkAlloc(sizeof(ReductionRefInstance)) ;
	    ref->numEls = cMsg->size ;
	    ref->refnum = refnum ;
	    ref->leftToCollect = numToCollect ;
	    ref->is_first = 1 ;
	    ref->y = (DataType *)CkAlloc(sizeof(DataType)*ref->numEls) ;
	    ref->next = NULL ;
	    PrivateCall(AddRef(ref,refnum)) ;
	}
	
	x = (DataType *)CkAlloc(sizeof(DataType)*ref->numEls) ;
	for(i=0 ; i<ref->numEls ; i++) 
	    x[i] = (DataType)cMsg->data[i] ;

	ReductionOperation(x,ref->y,ref->numEls,ref->is_first) ;
	ref->is_first = 0 ;
	ref->leftToCollect-- ;

	if (ref->leftToCollect == 0) {
	    /* If I'm the root, then reduction is done - else keep sending
	     * it up the tree */
	    for(i=0;i<ref->numEls;i++) 
		cMsg->data[i] = (DataType) ref->y[i] ;
	    SetRefNumber(cMsg, refnum) ;
	    if (myParent == -1) {
		if(gid == 0) 
		    BroadcastMsgBranch(distribute, cMsg, MyBocNum()) ;
		else 
  	            PG::Multicast(pgBoc, gid, cMsg, distribute, MyBocNum()) ;
	    }
	    else {
		SetRefNumber(cMsg, refnum) ;
		for(i=0;i<ref->numEls;i++)
		    cMsg->data[i] = (DataType) ref->y[i] ;
		SendMsgBranch(collect, cMsg, myParent) ;
	    }
	}
	else { 
	    CkFreeMsg(cMsg);  /* Only free it if we're not reusing it */
	}
    }




	entry distribute : (message REDN_MSG_INTERNAL *dMsg)
	{
	int refnum, i, size ;
	ReductionRefInstance *ref, *temp, *skip ;
	private ReductionRefInstance *FindRef() ;
	private void DeleteRef() ;
	void (*fptr)() ;
	REDN_MSG *msg ;

	refnum = GetRefNumber(dMsg) ;
	/* Look up the reply mechanism for this branch */
	ref = (ReductionRefInstance*)PrivateCall(FindRef(refnum)) ;

	/* Check for error */
	if (!ref) {
	    CkPrintf("Error in %s.  Lost reference number %d\n",
		     "ModuleName", refnum) ;
	    ChareExit() ;
	    return ;
	}
		    
	/* Figure out what/where to send it */
	if (ref->send_result_flag) {
	    if (ref->r_type == R_BY_FUNCTION) {
		for(i=0;i<ref->numEls;i++) 
		    ref->r_z[i] = (DataType)dMsg->data[i] ;
		fptr = ref->r_function ;
		BranchCall(ref->r_bocnum, fptr(MyBocNum())) ;
		}
	    else {
		size = ref->numEls ;
		msg = (REDN_MSG*)CkAllocMsg(REDN_MSG,&size) ;
		for(i=0;i<ref->numEls;i++) 
		    msg->data[i] = dMsg->data[i] ;
		SetRefNumber(msg,refnum) ;
		SendMsg(ref->r_ep, msg, &(ref->r_cid)) ;
	    }
	}
        CkFreeMsg(dMsg) ;

	/* Delete this refnum info from the list */
	PrivateCall(DeleteRef(refnum)) ;
    }
	

    /***********************************************************************/

    public f(x,z,size,refnum,fptr,id)
    DataType x[],*z ;
    int size, refnum ;
    void (*fptr)() ;
    void *id ;
    {
	ReductionRefInstance *ref ;
	private void AddRef() ;
	private ReductionRefInstance *FindRef() ;
	REDN_MSG_INTERNAL *cMsg ;
	int i, varSizes[1] ;

	ref = (ReductionRefInstance*)PrivateCall(FindRef(refnum)) ;

	if(ref==NULL) {
	    ref=(ReductionRefInstance*)CkAlloc(sizeof(ReductionRefInstance)) ;
	    ref->numEls = size ;
	    ref->leftToCollect = numToCollect ;
	    ref->refnum = refnum ;
	    ref->is_first = 1 ;
	    ref->y = (DataType *)CkAlloc(sizeof(DataType)*ref->numEls) ;
	    ref->next = NULL ;
	    PrivateCall(AddRef(ref,refnum)) ;
	}

	/* Fill in the return information */
	ref->r_type = R_BY_FUNCTION ;
	if (id == NULL) 
	    ref->send_result_flag = 0 ;
	else {
	    ref->r_z = z ;
	    ref->send_result_flag = 1 ;
	    ref->r_function = fptr ;
	    ref->r_bocnum = *((ChareNumType *) id) ;
	}

	
	varSizes[0]=ref->numEls ;
	cMsg = (REDN_MSG_INTERNAL *)CkAllocMsg(REDN_MSG_INTERNAL, varSizes) ;
	cMsg->size = ref->numEls ;
	for(i=0;i<ref->numEls;i++){
	    ref->y[i] = x[i] ;
	    cMsg->data[i] = (DataType) x[i] ;
	}
	SetRefNumber(cMsg, refnum) ;

	/* If I'm a leaf, then start by sending up the tree.  Otherwise,
	 * send the portion to myself */
	if(numToCollect == 0) 
	    SendMsgBranch(collect,cMsg,myParent) ;
	else {
	    SendMsgBranch(collect,cMsg,me) ;
	}
    }

    /***********************************************************************/
	
    /* return by message */
    public f_msg(x,size,refnum,ep,id)
    DataType      x[] ;
    int           size ;
    int           refnum ;
    EntryNumType  ep;
    ChareIDType   *id;
    {
	private void AddRef() ;
	private ReductionRefInstance *FindRef() ;
	DataType y ;
	int i, varSizes[1] ;
	ReductionRefInstance *ref, *temp ;
	REDN_MSG_INTERNAL *cMsg ;

	ref = (ReductionRefInstance*)PrivateCall(FindRef(refnum)) ;

	if(ref==NULL) {
	    ref=(ReductionRefInstance*)CkAlloc(sizeof(ReductionRefInstance)) ;
	    ref->numEls = size ;
	    ref->leftToCollect = numToCollect ;
	    ref->refnum = refnum ;
	    ref->is_first = 1 ;
	    ref->y = (DataType *)CkAlloc(sizeof(DataType)*ref->numEls) ;
	    ref->next = NULL ;
	    PrivateCall(AddRef(ref,refnum)) ;
	}

	/* Fill in the return information */

	ref->r_type = R_BY_MESSAGE ;
	ref->is_first = 1 ;
	if (id == NULL) 
	    ref->send_result_flag = 0 ;
	else {
	    ref->send_result_flag = 1 ;
	    ref->r_ep = ep ;
	    ref->r_cid = *id ;
	}

	varSizes[0]=ref->numEls ;
	cMsg = (REDN_MSG_INTERNAL *)CkAllocMsg(REDN_MSG_INTERNAL, varSizes) ;
	cMsg->size = ref->numEls ;
	for(i=0;i<ref->numEls;i++){
	    ref->y[i] = x[i] ;
	    cMsg->data[i] = (DataType) x[i] ;
	}
	SetRefNumber(cMsg, refnum) ;

	/* If I'm a leaf, then start by sending up the tree.  Otherwise,
	 * send the local portion to myself */
	if(numToCollect == 0) 
	    SendMsgBranch(collect,cMsg,myParent) ;
	else 
	    SendMsgBranch(collect,cMsg,me) ;
    }

	/* Reduction operations */
	
	private void Rmax(x,y,n,first)
	DataType x[], y[] ;
	int n,first ;
        {
	    int i ;
	    if(first) 
		for(i=0 ; i<n ; i++)
		    y[i] = x[i] ;
	    else 
		for(i=0 ; i<n ; i++)
		    if (x[i] > y[i]) y[i] = x[i] ;
	}

	private void Rmin(x,y,n,first)
	DataType x[], y[] ;
	int n,first ;
        {
	    int i ;
	    if(first) 
		for(i=0 ; i<n ; i++)
		    y[i] = x[i] ;
	    else 
		for(i=0 ; i<n ; i++) 
		    if (x[i] < y[i]) y[i] = x[i] ;
	}
	
	private void Rsum(x,y,n,first)
        DataType x[], y[] ;
	int n,first ;
        {
	    int i ;
	    if(first) 
		for(i=0 ; i<n ; i++) {
		    y[i] = x[i] ;
		}
	    else 
		for(i=0 ; i<n ; i++) {
		    y[i] += x[i] ;
		}
	}
	
	private void Rprod(x,y,n,first)
	DataType x[], y[] ;
	int n,first ;
        {
	    int i ;
	    if(first) 
		for(i=0 ; i<n ; i++)
		    y[i] = x[i] ;
	    else 
		for(i=0 ; i<n ; i++)
		    y[i] *= x[i] ;
	}
	
	private void Rcount(x,y,n,first)
	int x[], y[], n, first ;
        {  
	   int i ;
	   if(first) 
	       for(i=0 ; i<n ; i++)
		   y[i] = (x[i]==0) ? 0 : 1 ;
	   else 
	       for(i=0 ; i<n ; i++)
		   y[i] += (x[i]==0) ? 0 : 1 ;
        }
	

	/********************************************************/
	/* Data structure access functions **********************/
	/********************************************************/
	
	private void AddRef(newRef, refnum)
	ReductionRefInstance *newRef ;
        int refnum ;
        {
	    ReductionRefInstance *temp ;
	    temp = refList ;
	    if (temp == NULL)
		refList = newRef ;
	    else {
		while ( (temp->next) && (temp->refnum != refnum) )
		    temp=temp->next ;
		if (temp->refnum == refnum) 
		    CkPrintf("ERROR in ModuleName.  Reference number (%d) used in overlapping reductions\n",refnum) ;
		else 
		    temp->next = newRef ;
	    }
	}
	
	private ReductionRefInstance *FindRef(refNum)
        int refNum ;
	{
	    ReductionRefInstance *ref ;
	    
	    ref = refList ;
	    while( (ref != NULL) && (ref->refnum != refNum) ) ref = ref->next ;
	    return ref ;
	}

	private void DeleteRef(refNum)
	int refNum ;
	{
	    ReductionRefInstance *temp, *skip ;
	    
	    temp = refList ;
	    if (refList->refnum == refNum) {
		refList = refList->next ;
		CkFree(temp) ;
	    }
	    else {
		while (temp->next->refnum != refNum) temp=temp->next ;
		if(temp->next == NULL)
		    skip = NULL ;
		else
		    skip = temp->next->next ;
		CkFree(temp->next) ;
		temp->next = skip ;
	    }
	}
    } /* end BOC */

    /********************************************************/
    /* Functions accessable from the outside are below here */
    /********************************************************/
    
    Create()
    {
        REDUCE_INIT *msg;
        msg = (REDUCE_INIT *) CkAllocMsg(REDUCE_INIT);
	msg->gid = 0 ;
        return CreateBoc(ModuleName::reduce,ModuleName::reduce@init,msg);
    }
    
    CreateOverGroup(gid, pgBoc)
	int gid ;
    ChareNumType pgBoc ;
    {
	int        boc;
	REDUCE_INIT *msg;
	
        msg = (REDUCE_INIT *) CkAllocMsg(REDUCE_INIT);
	msg->gid = gid ;
	msg->pgBoc = pgBoc ;
        boc=CreateBoc(ModuleName::reduce,ModuleName::reduce@init,msg);
        return boc;
    }


    DepositData(boc,x,z,nelements,ref,fptr,id)
    ChareNumType      boc;
    DataType          x[],z[];
    int               nelements ;
    int               ref ;
    void              (*fptr)() ;
    void              *id;
    {
	BranchCall(boc,ModuleName::reduce@f(x,z,nelements,ref,fptr,id));
    }


    
    DepositDataMsg(boc,x,nelements,ref,ep,id)
    ChareNumType      boc;
    DataType          x[];
    int               nelements ;
    int               ref ;
    EntryNumType      ep;
    ChareIDType       *id;
    {
	BranchCall(boc,ModuleName::reduce@f_msg(x,nelements,ref,ep,id));
    }


    /* One of the fuctions below will be preprocessed into a "LocalReduction"
     * call which will be visible from outside. */
    
    void Lmax(x,y,n)
    DataType x[],*y ;
    int n ;
    {
	int i ;
	*y = x[0] ;
	for(i=1;i<n;i++) 
	    if(x[i] > *y) *y = x[i] ;
    }

    void Lmin(x,y,n)
    DataType x[],*y ;
    int n ;
    {
	int i ;
	*y = x[0] ;
	for(i=1;i<n;i++) 
	    if(x[i] < *y) *y = x[i] ;
    }

    void Lsum(x,y,n)
    DataType x[],*y ;
    int n ;
    {
	int i ;
	*y = x[0] ;
	for(i=1;i<n;i++) 
	    *y += x[i] ;
    }

    void Lprod(x,y,n)
    DataType x[],*y ;
    int n ;
    {
	int i ;
	*y = x[0] ;
	for(i=1;i<n;i++) 
	    *y *= x[i] ;
    }

    void Lcount(x,y,n)
    DataType x[],*y ;
    int n ;
    {
	int i ;
	*y = x[0] ;
	for(i=1;i<n;i++) 
	    (*y) += (x[i] != 0) ? 1 : 0 ;
    }

} /* end module ModuleName */

