import Utilities.*;
import Synchronization.*;

class Multiply extends MyObject implements Runnable {

   private int n = 0;
   private int id = -1, jd = -1;
   private double a = 0.0;
   private MessagePassing north = null, east = null,
      south = null, west = null;

   public Multiply(int n, int id, int jd, double a, MessagePassing north,
         MessagePassing east, MessagePassing south, MessagePassing west) {
      super("Multiply n=" + n + " id=" + id + " jd=" + jd + " a=" + a);
      this.n = n;
      this.id = id;  this.jd = jd;
      this.a = a;
      this.north = north;  this.east = east;
      this.south = south;  this.west = west;
      new Thread(this).start();
   }

   public void run() {
      double sum, x;
      for (int i = 0; i < n; i++) {
         x = receiveDouble(north);
         send(south, x);
         sum = receiveDouble(east);
         sum += a*x;
         send(west, sum);
      }
   }
}

class SystolicMatrixMultiply extends MyObject {

   public static void main(String[] args) {

      // parse command line options, if any, to override defaults
      GetOpt go = new GetOpt(args, "Ul:m:n:");
      String usage = "Usage: -l L -m M -n N"
         + " a[l,m] l=0..L-1 m=0..M-1, b[m,n] m=0..M-1 n=0..N-1";
      go.optErr = false;
      int ch = -1;
      int L = 0, M = 0, N = 0;
      while ((ch = go.getopt()) != go.optEOF) {
         if      ((char)ch == 'U') {
            System.out.println(usage);  System.exit(0);
         } else if ((char)ch == 'l') {
            L = go.processArg(go.optArgGet(), L);
            if (L < 1) {
               System.err.println("SystolicMatrixMultiply, L < 1");
               System.exit(1);
            }
         } else if ((char)ch == 'm') {
            M = go.processArg(go.optArgGet(), M);
            if (M < 1) {
               System.err.println("SystolicMatrixMultiply, M < 1");
               System.exit(1);
            }
         } else if ((char)ch == 'n') {
            N = go.processArg(go.optArgGet(), N);
            if (N < 1) {
               System.err.println("SystolicMatrixMultiply, N < 1");
               System.exit(1);
            }
         } else {
            System.err.println(usage);  System.exit(1);
         }
      }
      System.out.println("MatrixMultiply: L=" + L + " M=" + M + " N=" + N);
      double[][] a = new double[L][M];
      double[][] b = new double[M][N];
      double[][] c = new double[L][N];
      // get the matrices to multiply from the command line
      int argNum = go.optIndexGet();
      for (int l = 0; l < L; l++) for (int m = 0; m < M; m++)
         a[l][m] = go.tryArg(argNum++, 0.0 /*default*/);
      for (int m = 0; m < M; m++) for (int n = 0; n < N; n++)
         b[m][n] = go.tryArg(argNum++, 0.0 /*default*/);

      // print out the matrices to be multiplied
      System.out.println("a =");
      for (int l = 0; l < L; l++) {
         for (int m = 0; m < M; m++) System.out.print(" " + a[l][m]);
         System.out.println();
      }
      System.out.println("b =");
      for (int m = 0; m < M; m++) {
         for (int n = 0; n < N; n++) System.out.print(" " + b[m][n]);
         System.out.println();
      }

      // create the communication channels
      MessagePassing[][] channelN  = new MessagePassing[L+1][M+1];
      MessagePassing[][] channelW  = new MessagePassing[L+1][M+1];
      for (int l = 0; l <= L; l++) for (int m = 0; m <= M; m++) {
         channelN[l][m] = new AsyncMessagePassing();
         channelW[l][m] = new AsyncMessagePassing();
      }

      // start a thread for each a[l,m]
      for (int l = 0; l < L; l++) for (int m = 0; m < M; m++) { 
         new Multiply(N, l, m, a[l][m],
         channelN[l][m], channelW[l][m+1], channelN[l+1][m], channelW[l][m]);
      }

      // send columns of b[][] into the source channels along the top
      for (int n = 0; n < N; n++) for (int m = 0; m < M; m++)
         send(channelN[0][m], b[m][n]);

      // send zeros into the left side of the systolic array
      for (int n = 0; n < N; n++) for (int l = 0; l < L; l++)
         send(channelW[l][M], 0.0);

      // throw away the stuff coming out the bottom (was put in the top)
      for (int n = 0; n < N; n++) for (int m = 0; m < M; m++)
         receiveDouble(channelN[L][m]);

      // gather the results into c[l,n]
      for (int n = 0; n < N; n++) for (int l = 0; l < L; l++)
         c[l][n] = receiveDouble(channelW[l][0]);

      // print out the result of the matrix multiply
      System.out.println("c =");
      for (int l = 0; l < L; l++) {
         for (int n = 0; n < N; n++) System.out.print(" " + c[l][n]);
         System.out.println();
      }
      System.out.println("age()=" + age() + " SystolicMatrixMultiply done");
      System.exit(0);
   }
}

/* ............... Example compile and run(s)

D:\>javac mats.java

D:\>java SystolicMatrixMultiply \
   -l2 -m3 -n4 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
MatrixMultiply: L=2 M=3 N=4
a =
 1 2 3
 4 5 6
b =
 7 8 9 10
 11 12 13 14
 15 16 17 18
c =
 74 80 86 92
 173 188 203 218
age()=210 SystolicMatrixMultiply done
                                            ... end of example run(s)  */
