首页 > 编程 > Java > 正文

java实现任意矩阵Strassen算法

2019-11-26 14:36:02
字体:
来源:转载
供稿:网友

本例输入为两个任意尺寸的矩阵m * n, n * m,输出为两个矩阵的乘积。计算任意尺寸矩阵相乘时,使用了Strassen算法。程序为自编,经过测试,请放心使用。基本算法是:
1.对于方阵(正方形矩阵),找到最大的l, 使得l = 2 ^ k, k为整数并且l < m。边长为l的方形矩阵则采用Strassen算法,其余部分以及方形矩阵中遗漏的部分用蛮力法。
2.对于非方阵,依照行列相应添加0使其成为方阵。
StrassenMethodTest.java

package matrixalgorithm; import java.util.Scanner; public class StrassenMethodTest {   private StrassenMethod strassenMultiply;      StrassenMethodTest(){    strassenMultiply = new StrassenMethod();  }//end cons    public static void main(String[] args){    Scanner input = new Scanner(System.in);    System.out.println("Input row size of the first matrix: ");    int arow = input.nextInt();    System.out.println("Input column size of the first matrix: ");    int acol = input.nextInt();    System.out.println("Input row size of the second matrix: ");    int brow = input.nextInt();    System.out.println("Input column size of the second matrix: ");    int bcol = input.nextInt();     double[][] A = new double[arow][acol];    double[][] B = new double[brow][bcol];    double[][] C = new double[arow][bcol];    System.out.println("Input data for matrix A: ");         /*In all of the codes later in this project,    r means row while c means column.    */    for (int r = 0; r < arow; r++) {      for (int c = 0; c < acol; c++) {        System.out.printf("Data of A[%d][%d]: ", r, c);        A[r][c] = input.nextDouble();      }//end inner loop    }//end loop     System.out.println("Input data for matrix B: ");    for (int r = 0; r < brow; r++) {      for (int c = 0; c < bcol; c++) {        System.out.printf("Data of A[%d][%d]: ", r, c);        B[r][c] = input.nextDouble();      }//end inner loop    }//end loop     StrassenMethodTest algorithm = new StrassenMethodTest();    C = algorithm.multiplyRectMatrix(A, B, arow, acol, brow, bcol);     //Display the calculation result:    System.out.println("Result from matrix C: ");    for (int r = 0; r < arow; r++) {      for (int c = 0; c < bcol; c++) {        System.out.printf("Data of C[%d][%d]: %f/n", r, c, C[r][c]);      }//end inner loop    }//end outter loop   }//end main     //Deal with matrices that are not square:  public double[][] multiplyRectMatrix(double[][] A, double[][] B,      int arow, int acol, int brow, int bcol) {    if (arow != bcol) //Invalid multiplicatio      return new double[][]{{0}};        double[][] C = new double[arow][bcol];     if (arow < acol) {             double[][] newA = new double[acol][acol];      double[][] newB = new double[brow][brow];       int n = acol;             for (int r = 0; r < acol; r++)         for (int c = 0; c < acol; c++)           newA[r][c] = 0.0;       for (int r = 0; r < brow; r++)         for (int c = 0; c < brow; c++)           newB[r][c] = 0.0;       for (int r = 0; r < arow; r++)         for (int c = 0; c < acol; c++)           newA[r][c] = A[r][c];       for (int r = 0; r < brow; r++)         for (int c = 0; c < bcol; c++)           newB[r][c] = B[r][c];             double[][] C2 = multiplySquareMatrix(newA, newB, n);      for(int r = 0; r < arow; r++)        for(int c = 0; c < bcol; c++)          C[r][c] = C2[r][c];    }//end if         else if(arow == acol)      C = multiplySquareMatrix(A, B, arow);           else {      int n = arow;      double[][] newA = new double[arow][arow];      double[][] newB = new double[bcol][bcol];       for (int r = 0; r < arow; r++)         for (int c = 0; c < arow; c++)           newA[r][c] = 0.0;       for (int r = 0; r < bcol; r++)         for (int c = 0; c < bcol; c++)           newB[r][c] = 0.0;        for (int r = 0; r < arow; r++)         for (int c = 0; c < acol; c++)           newA[r][c] = A[r][c];       for (int r = 0; r < brow; r++)        for (int c = 0; c < bcol; c++)           newB[r][c] = B[r][c];       double[][] C2 = multiplySquareMatrix(newA, newB, n);      for(int r = 0; r < arow; r++)        for(int c = 0; c < bcol; c++)          C[r][c] = C2[r][c];    }//end else            return C;   }//end method     //Deal with matrices that are square matrices.    public double[][] multiplySquareMatrix(double[][] A2, double[][] B2, int n){           double[][] C2 = new double[n][n];         for(int r = 0; r < n; r++)       for(int c = 0; c < n; c++)         C2[r][c] = 0;          if(n == 1){      C2[0][0] = A2[0][0] * B2[0][0];      return C2;     }//end if               int exp2k = 2;          while(exp2k <= (n / 2) ){       exp2k *= 2;     }//end loop          if(exp2k == n){       C2 = strassenMultiply.strassenMultiplyMatrix(A2, B2, n);       return C2;     }//end else          //The "biggest" strassen matrix:     double[][][] A = new double[6][exp2k][exp2k];     double[][][] B = new double[6][exp2k][exp2k];     double[][][] C = new double[6][exp2k][exp2k];          for(int r = 0; r < exp2k; r++){       for(int c = 0; c < exp2k; c++){         A[0][r][c] = A2[r][c];         B[0][r][c] = B2[r][c];       }//end inner loop     }//end outter loop         C[0] = strassenMultiply.strassenMultiplyMatrix(A[0], B[0], exp2k);         for(int r = 0; r < exp2k; r++)      for(int c = 0; c < exp2k; c++)        C2[r][c] = C[0][r][c];         int middle = exp2k / 2;         for(int r = 0; r < middle; r++){      for(int c = exp2k; c < n; c++){        A[1][r][c - exp2k] = A2[r][c];        B[3][r][c - exp2k] = B2[r][c];      }//end inner loop         }//end outter loop         for(int r = exp2k; r < n; r++){      for(int c = 0; c < middle; c++){        A[3][r - exp2k][c] = A2[r][c];        B[1][r - exp2k][c] = B2[r][c];      }//end inner loop         }//end outter loop         for(int r = middle; r < exp2k; r++){      for(int c = exp2k; c < n; c++){        A[2][r - middle][c - exp2k] = A2[r][c];        B[4][r - middle][c - exp2k] = B2[r][c];      }//end inner loop         }//end outter loop         for(int r = exp2k; r < n; r++){      for(int c = middle; c < n - exp2k + 1; c++){        A[4][r - exp2k][c - middle] = A2[r][c];        B[2][r - exp2k][c - middle] = B2[r][c];           }//end inner loop         }//end outter loop        for(int i = 1; i <= 4; i++)      C[i] = multiplyRectMatrix(A[i], B[i], middle, A[i].length, A[i].length, middle);         /*    Calculate the final results of grids in the "biggest 2^k square,    according to the rules of matrice multiplication.    */    for (int row = 0; row < exp2k; row++) {       for (int col = 0; col < exp2k; col++) {         for (int k = exp2k; k < n; k++) {           C2[row][col] += A2[row][k] * B2[k][col];         }//end loop       }//end inner loop     }//end outter loop         //Use brute force to solve the rest, will be improved later:    for(int col = exp2k; col < n; col++){      for(int row = 0; row < n; row++){        for(int k = 0; k < n; k++)          C2[row][col] += A2[row][k] * B2[k][row];      }//end inner loop    }//end outter loop         for(int row = exp2k; row < n; row++){      for(int col = 0; col < exp2k; col++){        for(int k = 0; k < n; k++)          C2[row][col] += A2[row][k] * B2[k][row];      }//end inner loop    }//end outter loop            return C2;   }//end method   }//end class

StrassenMethod.java

package matrixalgorithm; import java.util.Scanner; public class StrassenMethod {   private double[][][][] A = new double[2][2][][];  private double[][][][] B = new double[2][2][][];  private double[][][][] C = new double[2][2][][];   /*//Codes for testing this class:    public static void main(String[] args) {    Scanner input = new Scanner(System.in);    System.out.println("Input size of the matrix: ");    int n = input.nextInt();     double[][] A = new double[n][n];    double[][] B = new double[n][n];    double[][] C = new double[n][n];    System.out.println("Input data for matrix A: ");    for (int r = 0; r < n; r++) {      for (int c = 0; c < n; c++) {        System.out.printf("Data of A[%d][%d]: ", r, c);        A[r][c] = input.nextDouble();      }//end inner loop    }//end loop     System.out.println("Input data for matrix B: ");    for (int r = 0; r < n; r++) {      for (int c = 0; c < n; c++) {        System.out.printf("Data of A[%d][%d]: ", r, c);        B[r][c] = input.nextDouble();      }//end inner loop    }//end loop     StrassenMethod algorithm = new StrassenMethod();    C = algorithm.strassenMultiplyMatrix(A, B, n);     System.out.println("Result from matrix C: ");    for (int r = 0; r < n; r++) {      for (int c = 0; c < n; c++) {        System.out.printf("Data of C[%d][%d]: %f/n", r, c, C[r][c]);      }//end inner loop    }//end outter loop   }//end main*/      public double[][] strassenMultiplyMatrix(double[][] A2, double B2[][], int n){    double[][] C2 = new double[n][n];    //Initialize the matrix:    for(int rowIndex = 0; rowIndex < n; rowIndex++)      for(int colIndex = 0; colIndex < n; colIndex++)        C2[rowIndex][colIndex] = 0.0;     if(n == 1)      C2[0][0] = A2[0][0] * B2[0][0];    //"Slice matrices into 2 * 2 parts:     else{      double[][][][] A = new double[2][2][n / 2][n / 2];      double[][][][] B = new double[2][2][n / 2][n / 2];      double[][][][] C = new double[2][2][n / 2][n / 2];             for(int r = 0; r < n / 2; r++){        for(int c = 0; c < n / 2; c++){                    A[0][0][r][c] = A2[r][c];          A[0][1][r][c] = A2[r][n / 2 + c];          A[1][0][r][c] = A2[n / 2 + r][c];          A[1][1][r][c] = A2[n / 2 + r][n / 2 + c];                     B[0][0][r][c] = B2[r][c];          B[0][1][r][c] = B2[r][n / 2 + c];          B[1][0][r][c] = B2[n / 2 + r][c];          B[1][1][r][c] = B2[n / 2 + r][n / 2 + c];        }//end loop      }//end loop             n = n / 2;             double[][][] S = new double[10][n][n];      S[0] = minusMatrix(B[0][1], B[1][1], n);      S[1] = addMatrix(A[0][0], A[0][1], n);      S[2] = addMatrix(A[1][0], A[1][1], n);      S[3] = minusMatrix(B[1][0], B[0][0], n);      S[4] = addMatrix(A[0][0], A[1][1], n);      S[5] = addMatrix(B[0][0], B[1][1], n);      S[6] = minusMatrix(A[0][1], A[1][1], n);      S[7] = addMatrix(B[1][0], B[1][1], n);      S[8] = minusMatrix(A[0][0], A[1][0], n);      S[9] = addMatrix(B[0][0], B[0][1], n);             double[][][] P = new double[7][n][n];      P[0] = strassenMultiplyMatrix(A[0][0], S[0], n);      P[1] = strassenMultiplyMatrix(S[1], B[1][1], n);      P[2] = strassenMultiplyMatrix(S[2], B[0][0], n);      P[3] = strassenMultiplyMatrix(A[1][1], S[3], n);      P[4] = strassenMultiplyMatrix(S[4], S[5], n);      P[5] = strassenMultiplyMatrix(S[6], S[7], n);      P[6] = strassenMultiplyMatrix(S[8], S[9], n);             C[0][0] = addMatrix(minusMatrix(addMatrix(P[4], P[3], n), P[1], n), P[5], n);      C[0][1] = addMatrix(P[0], P[1], n);      C[1][0] = addMatrix(P[2], P[3], n);      C[1][1] = minusMatrix(minusMatrix(addMatrix(P[4], P[0], n), P[2], n), P[6], n);             n *= 2;              for(int r = 0; r < n / 2; r++){        for(int c = 0; c < n / 2; c++){          C2[r][c] = C[0][0][r][c];          C2[r][n / 2 + c] = C[0][1][r][c];          C2[n / 2 + r][c] = C[1][0][r][c];          C2[n / 2 + r][n / 2 + c] = C[1][1][r][c];        }//end inner loop      }//end outter loop    }//end else          return C2;  }//end method      //Add two matrices according to matrix addition.   private double[][] addMatrix(double[][] A, double[][] B, int n){    double C[][] = new double[n][n];         for(int r = 0; r < n; r++)      for(int c = 0; c < n; c++)        C[r][c] = A[r][c] + B[r][c];         return C;  }//end method       //Substract two matrices according to matrix addition.   private double[][] minusMatrix(double[][] A, double[][] B, int n){    double C[][] = new double[n][n];         for(int r = 0; r < n; r++)      for(int c = 0; c < n; c++)        C[r][c] = A[r][c] - B[r][c];         return C;  }//end method   }//end class

希望本文所述对大家学习java程序设计有所帮助。

发表评论 共有条评论
用户名: 密码:
验证码: 匿名发表