java實(shí)現(xiàn)任意矩陣Strassen算法
本例輸入為兩個(gè)任意尺寸的矩陣m * n, n * m,輸出為兩個(gè)矩陣的乘積。計(jì)算任意尺寸矩陣相乘時(shí),使用了Strassen算法。程序?yàn)樽跃?,?jīng)過(guò)測(cè)試,請(qǐng)放心使用?;舅惴ㄊ牵?br />
1.對(duì)于方陣(正方形矩陣),找到最大的l, 使得l = 2 ^ k, k為整數(shù)并且l < m。邊長(zhǎng)為l的方形矩陣則采用Strassen算法,其余部分以及方形矩陣中遺漏的部分用蠻力法。
2.對(duì)于非方陣,依照行列相應(yīng)添加0使其成為方陣。
StrassenMethodTest.java
package matrixalgorithm; import java.util.Scanner; public class StrassenMethodTest { private StrassenMethod strassenMultiply; StrassenMethodTest(){ strassenMultiply = new StrassenMethod(); }//end cons public static void main(String[] args){ Scanner input = new Scanner(System.in); System.out.println("Input row size of the first matrix: "); int arow = input.nextInt(); System.out.println("Input column size of the first matrix: "); int acol = input.nextInt(); System.out.println("Input row size of the second matrix: "); int brow = input.nextInt(); System.out.println("Input column size of the second matrix: "); int bcol = input.nextInt(); double[][] A = new double[arow][acol]; double[][] B = new double[brow][bcol]; double[][] C = new double[arow][bcol]; System.out.println("Input data for matrix A: "); /*In all of the codes later in this project, r means row while c means column. */ for (int r = 0; r < arow; r++) { for (int c = 0; c < acol; c++) { System.out.printf("Data of A[%d][%d]: ", r, c); A[r][c] = input.nextDouble(); }//end inner loop }//end loop System.out.println("Input data for matrix B: "); for (int r = 0; r < brow; r++) { for (int c = 0; c < bcol; c++) { System.out.printf("Data of A[%d][%d]: ", r, c); B[r][c] = input.nextDouble(); }//end inner loop }//end loop StrassenMethodTest algorithm = new StrassenMethodTest(); C = algorithm.multiplyRectMatrix(A, B, arow, acol, brow, bcol); //Display the calculation result: System.out.println("Result from matrix C: "); for (int r = 0; r < arow; r++) { for (int c = 0; c < bcol; c++) { System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]); }//end inner loop }//end outter loop }//end main //Deal with matrices that are not square: public double[][] multiplyRectMatrix(double[][] A, double[][] B, int arow, int acol, int brow, int bcol) { if (arow != bcol) //Invalid multiplicatio return new double[][]{{0}}; double[][] C = new double[arow][bcol]; if (arow < acol) { double[][] newA = new double[acol][acol]; double[][] newB = new double[brow][brow]; int n = acol; for (int r = 0; r < acol; r++) for (int c = 0; c < acol; c++) newA[r][c] = 0.0; for (int r = 0; r < brow; r++) for (int c = 0; c < brow; c++) newB[r][c] = 0.0; for (int r = 0; r < arow; r++) for (int c = 0; c < acol; c++) newA[r][c] = A[r][c]; for (int r = 0; r < brow; r++) for (int c = 0; c < bcol; c++) newB[r][c] = B[r][c]; double[][] C2 = multiplySquareMatrix(newA, newB, n); for(int r = 0; r < arow; r++) for(int c = 0; c < bcol; c++) C[r][c] = C2[r][c]; }//end if else if(arow == acol) C = multiplySquareMatrix(A, B, arow); else { int n = arow; double[][] newA = new double[arow][arow]; double[][] newB = new double[bcol][bcol]; for (int r = 0; r < arow; r++) for (int c = 0; c < arow; c++) newA[r][c] = 0.0; for (int r = 0; r < bcol; r++) for (int c = 0; c < bcol; c++) newB[r][c] = 0.0; for (int r = 0; r < arow; r++) for (int c = 0; c < acol; c++) newA[r][c] = A[r][c]; for (int r = 0; r < brow; r++) for (int c = 0; c < bcol; c++) newB[r][c] = B[r][c]; double[][] C2 = multiplySquareMatrix(newA, newB, n); for(int r = 0; r < arow; r++) for(int c = 0; c < bcol; c++) C[r][c] = C2[r][c]; }//end else return C; }//end method //Deal with matrices that are square matrices. public double[][] multiplySquareMatrix(double[][] A2, double[][] B2, int n){ double[][] C2 = new double[n][n]; for(int r = 0; r < n; r++) for(int c = 0; c < n; c++) C2[r][c] = 0; if(n == 1){ C2[0][0] = A2[0][0] * B2[0][0]; return C2; }//end if int exp2k = 2; while(exp2k <= (n / 2) ){ exp2k *= 2; }//end loop if(exp2k == n){ C2 = strassenMultiply.strassenMultiplyMatrix(A2, B2, n); return C2; }//end else //The "biggest" strassen matrix: double[][][] A = new double[6][exp2k][exp2k]; double[][][] B = new double[6][exp2k][exp2k]; double[][][] C = new double[6][exp2k][exp2k]; for(int r = 0; r < exp2k; r++){ for(int c = 0; c < exp2k; c++){ A[0][r][c] = A2[r][c]; B[0][r][c] = B2[r][c]; }//end inner loop }//end outter loop C[0] = strassenMultiply.strassenMultiplyMatrix(A[0], B[0], exp2k); for(int r = 0; r < exp2k; r++) for(int c = 0; c < exp2k; c++) C2[r][c] = C[0][r][c]; int middle = exp2k / 2; for(int r = 0; r < middle; r++){ for(int c = exp2k; c < n; c++){ A[1][r][c - exp2k] = A2[r][c]; B[3][r][c - exp2k] = B2[r][c]; }//end inner loop }//end outter loop for(int r = exp2k; r < n; r++){ for(int c = 0; c < middle; c++){ A[3][r - exp2k][c] = A2[r][c]; B[1][r - exp2k][c] = B2[r][c]; }//end inner loop }//end outter loop for(int r = middle; r < exp2k; r++){ for(int c = exp2k; c < n; c++){ A[2][r - middle][c - exp2k] = A2[r][c]; B[4][r - middle][c - exp2k] = B2[r][c]; }//end inner loop }//end outter loop for(int r = exp2k; r < n; r++){ for(int c = middle; c < n - exp2k + 1; c++){ A[4][r - exp2k][c - middle] = A2[r][c]; B[2][r - exp2k][c - middle] = B2[r][c]; }//end inner loop }//end outter loop for(int i = 1; i <= 4; i++) C[i] = multiplyRectMatrix(A[i], B[i], middle, A[i].length, A[i].length, middle); /* Calculate the final results of grids in the "biggest 2^k square, according to the rules of matrice multiplication. */ for (int row = 0; row < exp2k; row++) { for (int col = 0; col < exp2k; col++) { for (int k = exp2k; k < n; k++) { C2[row][col] += A2[row][k] * B2[k][col]; }//end loop }//end inner loop }//end outter loop //Use brute force to solve the rest, will be improved later: for(int col = exp2k; col < n; col++){ for(int row = 0; row < n; row++){ for(int k = 0; k < n; k++) C2[row][col] += A2[row][k] * B2[k][row]; }//end inner loop }//end outter loop for(int row = exp2k; row < n; row++){ for(int col = 0; col < exp2k; col++){ for(int k = 0; k < n; k++) C2[row][col] += A2[row][k] * B2[k][row]; }//end inner loop }//end outter loop return C2; }//end method }//end class
StrassenMethod.java
package matrixalgorithm; import java.util.Scanner; public class StrassenMethod { private double[][][][] A = new double[2][2][][]; private double[][][][] B = new double[2][2][][]; private double[][][][] C = new double[2][2][][]; /*//Codes for testing this class: public static void main(String[] args) { Scanner input = new Scanner(System.in); System.out.println("Input size of the matrix: "); int n = input.nextInt(); double[][] A = new double[n][n]; double[][] B = new double[n][n]; double[][] C = new double[n][n]; System.out.println("Input data for matrix A: "); for (int r = 0; r < n; r++) { for (int c = 0; c < n; c++) { System.out.printf("Data of A[%d][%d]: ", r, c); A[r][c] = input.nextDouble(); }//end inner loop }//end loop System.out.println("Input data for matrix B: "); for (int r = 0; r < n; r++) { for (int c = 0; c < n; c++) { System.out.printf("Data of A[%d][%d]: ", r, c); B[r][c] = input.nextDouble(); }//end inner loop }//end loop StrassenMethod algorithm = new StrassenMethod(); C = algorithm.strassenMultiplyMatrix(A, B, n); System.out.println("Result from matrix C: "); for (int r = 0; r < n; r++) { for (int c = 0; c < n; c++) { System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]); }//end inner loop }//end outter loop }//end main*/ public double[][] strassenMultiplyMatrix(double[][] A2, double B2[][], int n){ double[][] C2 = new double[n][n]; //Initialize the matrix: for(int rowIndex = 0; rowIndex < n; rowIndex++) for(int colIndex = 0; colIndex < n; colIndex++) C2[rowIndex][colIndex] = 0.0; if(n == 1) C2[0][0] = A2[0][0] * B2[0][0]; //"Slice matrices into 2 * 2 parts: else{ double[][][][] A = new double[2][2][n / 2][n / 2]; double[][][][] B = new double[2][2][n / 2][n / 2]; double[][][][] C = new double[2][2][n / 2][n / 2]; for(int r = 0; r < n / 2; r++){ for(int c = 0; c < n / 2; c++){ A[0][0][r][c] = A2[r][c]; A[0][1][r][c] = A2[r][n / 2 + c]; A[1][0][r][c] = A2[n / 2 + r][c]; A[1][1][r][c] = A2[n / 2 + r][n / 2 + c]; B[0][0][r][c] = B2[r][c]; B[0][1][r][c] = B2[r][n / 2 + c]; B[1][0][r][c] = B2[n / 2 + r][c]; B[1][1][r][c] = B2[n / 2 + r][n / 2 + c]; }//end loop }//end loop n = n / 2; double[][][] S = new double[10][n][n]; S[0] = minusMatrix(B[0][1], B[1][1], n); S[1] = addMatrix(A[0][0], A[0][1], n); S[2] = addMatrix(A[1][0], A[1][1], n); S[3] = minusMatrix(B[1][0], B[0][0], n); S[4] = addMatrix(A[0][0], A[1][1], n); S[5] = addMatrix(B[0][0], B[1][1], n); S[6] = minusMatrix(A[0][1], A[1][1], n); S[7] = addMatrix(B[1][0], B[1][1], n); S[8] = minusMatrix(A[0][0], A[1][0], n); S[9] = addMatrix(B[0][0], B[0][1], n); double[][][] P = new double[7][n][n]; P[0] = strassenMultiplyMatrix(A[0][0], S[0], n); P[1] = strassenMultiplyMatrix(S[1], B[1][1], n); P[2] = strassenMultiplyMatrix(S[2], B[0][0], n); P[3] = strassenMultiplyMatrix(A[1][1], S[3], n); P[4] = strassenMultiplyMatrix(S[4], S[5], n); P[5] = strassenMultiplyMatrix(S[6], S[7], n); P[6] = strassenMultiplyMatrix(S[8], S[9], n); C[0][0] = addMatrix(minusMatrix(addMatrix(P[4], P[3], n), P[1], n), P[5], n); C[0][1] = addMatrix(P[0], P[1], n); C[1][0] = addMatrix(P[2], P[3], n); C[1][1] = minusMatrix(minusMatrix(addMatrix(P[4], P[0], n), P[2], n), P[6], n); n *= 2; for(int r = 0; r < n / 2; r++){ for(int c = 0; c < n / 2; c++){ C2[r][c] = C[0][0][r][c]; C2[r][n / 2 + c] = C[0][1][r][c]; C2[n / 2 + r][c] = C[1][0][r][c]; C2[n / 2 + r][n / 2 + c] = C[1][1][r][c]; }//end inner loop }//end outter loop }//end else return C2; }//end method //Add two matrices according to matrix addition. private double[][] addMatrix(double[][] A, double[][] B, int n){ double C[][] = new double[n][n]; for(int r = 0; r < n; r++) for(int c = 0; c < n; c++) C[r][c] = A[r][c] + B[r][c]; return C; }//end method //Substract two matrices according to matrix addition. private double[][] minusMatrix(double[][] A, double[][] B, int n){ double C[][] = new double[n][n]; for(int r = 0; r < n; r++) for(int c = 0; c < n; c++) C[r][c] = A[r][c] - B[r][c]; return C; }//end method }//end class
希望本文所述對(duì)大家學(xué)習(xí)java程序設(shè)計(jì)有所幫助。
- java 二維數(shù)組矩陣乘法的實(shí)現(xiàn)方法
- Java矩陣連乘問(wèn)題(動(dòng)態(tài)規(guī)劃)算法實(shí)例分析
- Java實(shí)現(xiàn)的矩陣乘法示例
- Java實(shí)現(xiàn)的求逆矩陣算法示例
- Java實(shí)現(xiàn)輸出回環(huán)數(shù)(螺旋矩陣)的方法示例
- Java實(shí)現(xiàn)矩陣加減乘除及轉(zhuǎn)制等運(yùn)算功能示例
- Java實(shí)現(xiàn)的按照順時(shí)針或逆時(shí)針?lè)较蜉敵鲆粋€(gè)數(shù)字矩陣功能示例
- Java實(shí)現(xiàn)矩陣順時(shí)針旋轉(zhuǎn)90度的示例
- java實(shí)現(xiàn)的n*n矩陣求值及求逆矩陣算法示例
- 使用java寫(xiě)的矩陣乘法實(shí)例(Strassen算法)
相關(guān)文章
一篇文章告訴你如何在Java數(shù)組中插入一個(gè)字符
本篇文章主要介紹了Java數(shù)組中插入一個(gè)字符的相關(guān)方法,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來(lái)幫助2021-10-10JAVA生產(chǎn)者消費(fèi)者(線程同步)代碼學(xué)習(xí)示例
這篇文章主要介紹了JAVA線程同步的代碼學(xué)習(xí)示例,大家參考使用吧2013-11-11JavaWeb倉(cāng)庫(kù)管理系統(tǒng)詳解
這篇文章主要為大家詳細(xì)介紹了JavaWeb倉(cāng)庫(kù)管理系統(tǒng),文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-09-09springboot實(shí)現(xiàn)將自定義日志格式存儲(chǔ)到mongodb中
這篇文章主要介紹了springboot實(shí)現(xiàn)將自定義日志格式存儲(chǔ)到mongodb中的操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2021-07-07JavaWeb頁(yè)面中防止點(diǎn)擊Backspace網(wǎng)頁(yè)后退情況
當(dāng)鍵盤(pán)敲下后退鍵(Backspace)后怎么防止網(wǎng)頁(yè)后退情況呢?今天小編通過(guò)本文給大家詳細(xì)介紹下,感興趣的朋友一起看看吧2016-11-11mybatis中<if>標(biāo)簽bool值類型為false判斷方法
這篇文章主要給大家介紹了關(guān)于mybatis中<if>標(biāo)簽bool值類型為false判斷方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家學(xué)習(xí)或者使用mybatis具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-08-08線程池FutureTask異步執(zhí)行多任務(wù)實(shí)現(xiàn)詳解
這篇文章主要為大家介紹了線程池FutureTask異步執(zhí)行多任務(wù)實(shí)現(xiàn)詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-11-11