java實(shí)現(xiàn)任意矩陣Strassen算法
本例輸入為兩個(gè)任意尺寸的矩陣m * n, n * m,輸出為兩個(gè)矩陣的乘積。計(jì)算任意尺寸矩陣相乘時(shí),使用了Strassen算法。程序?yàn)樽跃帲?jīng)過(guò)測(cè)試,請(qǐng)放心使用?;舅惴ㄊ牵?br />
1.對(duì)于方陣(正方形矩陣),找到最大的l, 使得l = 2 ^ k, k為整數(shù)并且l < m。邊長(zhǎng)為l的方形矩陣則采用Strassen算法,其余部分以及方形矩陣中遺漏的部分用蠻力法。
2.對(duì)于非方陣,依照行列相應(yīng)添加0使其成為方陣。
StrassenMethodTest.java
package matrixalgorithm;
import java.util.Scanner;
public class StrassenMethodTest {
private StrassenMethod strassenMultiply;
StrassenMethodTest(){
strassenMultiply = new StrassenMethod();
}//end cons
public static void main(String[] args){
Scanner input = new Scanner(System.in);
System.out.println("Input row size of the first matrix: ");
int arow = input.nextInt();
System.out.println("Input column size of the first matrix: ");
int acol = input.nextInt();
System.out.println("Input row size of the second matrix: ");
int brow = input.nextInt();
System.out.println("Input column size of the second matrix: ");
int bcol = input.nextInt();
double[][] A = new double[arow][acol];
double[][] B = new double[brow][bcol];
double[][] C = new double[arow][bcol];
System.out.println("Input data for matrix A: ");
/*In all of the codes later in this project,
r means row while c means column.
*/
for (int r = 0; r < arow; r++) {
for (int c = 0; c < acol; c++) {
System.out.printf("Data of A[%d][%d]: ", r, c);
A[r][c] = input.nextDouble();
}//end inner loop
}//end loop
System.out.println("Input data for matrix B: ");
for (int r = 0; r < brow; r++) {
for (int c = 0; c < bcol; c++) {
System.out.printf("Data of A[%d][%d]: ", r, c);
B[r][c] = input.nextDouble();
}//end inner loop
}//end loop
StrassenMethodTest algorithm = new StrassenMethodTest();
C = algorithm.multiplyRectMatrix(A, B, arow, acol, brow, bcol);
//Display the calculation result:
System.out.println("Result from matrix C: ");
for (int r = 0; r < arow; r++) {
for (int c = 0; c < bcol; c++) {
System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]);
}//end inner loop
}//end outter loop
}//end main
//Deal with matrices that are not square:
public double[][] multiplyRectMatrix(double[][] A, double[][] B,
int arow, int acol, int brow, int bcol) {
if (arow != bcol) //Invalid multiplicatio
return new double[][]{{0}};
double[][] C = new double[arow][bcol];
if (arow < acol) {
double[][] newA = new double[acol][acol];
double[][] newB = new double[brow][brow];
int n = acol;
for (int r = 0; r < acol; r++)
for (int c = 0; c < acol; c++)
newA[r][c] = 0.0;
for (int r = 0; r < brow; r++)
for (int c = 0; c < brow; c++)
newB[r][c] = 0.0;
for (int r = 0; r < arow; r++)
for (int c = 0; c < acol; c++)
newA[r][c] = A[r][c];
for (int r = 0; r < brow; r++)
for (int c = 0; c < bcol; c++)
newB[r][c] = B[r][c];
double[][] C2 = multiplySquareMatrix(newA, newB, n);
for(int r = 0; r < arow; r++)
for(int c = 0; c < bcol; c++)
C[r][c] = C2[r][c];
}//end if
else if(arow == acol)
C = multiplySquareMatrix(A, B, arow);
else {
int n = arow;
double[][] newA = new double[arow][arow];
double[][] newB = new double[bcol][bcol];
for (int r = 0; r < arow; r++)
for (int c = 0; c < arow; c++)
newA[r][c] = 0.0;
for (int r = 0; r < bcol; r++)
for (int c = 0; c < bcol; c++)
newB[r][c] = 0.0;
for (int r = 0; r < arow; r++)
for (int c = 0; c < acol; c++)
newA[r][c] = A[r][c];
for (int r = 0; r < brow; r++)
for (int c = 0; c < bcol; c++)
newB[r][c] = B[r][c];
double[][] C2 = multiplySquareMatrix(newA, newB, n);
for(int r = 0; r < arow; r++)
for(int c = 0; c < bcol; c++)
C[r][c] = C2[r][c];
}//end else
return C;
}//end method
//Deal with matrices that are square matrices.
public double[][] multiplySquareMatrix(double[][] A2, double[][] B2, int n){
double[][] C2 = new double[n][n];
for(int r = 0; r < n; r++)
for(int c = 0; c < n; c++)
C2[r][c] = 0;
if(n == 1){
C2[0][0] = A2[0][0] * B2[0][0];
return C2;
}//end if
int exp2k = 2;
while(exp2k <= (n / 2) ){
exp2k *= 2;
}//end loop
if(exp2k == n){
C2 = strassenMultiply.strassenMultiplyMatrix(A2, B2, n);
return C2;
}//end else
//The "biggest" strassen matrix:
double[][][] A = new double[6][exp2k][exp2k];
double[][][] B = new double[6][exp2k][exp2k];
double[][][] C = new double[6][exp2k][exp2k];
for(int r = 0; r < exp2k; r++){
for(int c = 0; c < exp2k; c++){
A[0][r][c] = A2[r][c];
B[0][r][c] = B2[r][c];
}//end inner loop
}//end outter loop
C[0] = strassenMultiply.strassenMultiplyMatrix(A[0], B[0], exp2k);
for(int r = 0; r < exp2k; r++)
for(int c = 0; c < exp2k; c++)
C2[r][c] = C[0][r][c];
int middle = exp2k / 2;
for(int r = 0; r < middle; r++){
for(int c = exp2k; c < n; c++){
A[1][r][c - exp2k] = A2[r][c];
B[3][r][c - exp2k] = B2[r][c];
}//end inner loop
}//end outter loop
for(int r = exp2k; r < n; r++){
for(int c = 0; c < middle; c++){
A[3][r - exp2k][c] = A2[r][c];
B[1][r - exp2k][c] = B2[r][c];
}//end inner loop
}//end outter loop
for(int r = middle; r < exp2k; r++){
for(int c = exp2k; c < n; c++){
A[2][r - middle][c - exp2k] = A2[r][c];
B[4][r - middle][c - exp2k] = B2[r][c];
}//end inner loop
}//end outter loop
for(int r = exp2k; r < n; r++){
for(int c = middle; c < n - exp2k + 1; c++){
A[4][r - exp2k][c - middle] = A2[r][c];
B[2][r - exp2k][c - middle] = B2[r][c];
}//end inner loop
}//end outter loop
for(int i = 1; i <= 4; i++)
C[i] = multiplyRectMatrix(A[i], B[i], middle, A[i].length, A[i].length, middle);
/*
Calculate the final results of grids in the "biggest 2^k square,
according to the rules of matrice multiplication.
*/
for (int row = 0; row < exp2k; row++) {
for (int col = 0; col < exp2k; col++) {
for (int k = exp2k; k < n; k++) {
C2[row][col] += A2[row][k] * B2[k][col];
}//end loop
}//end inner loop
}//end outter loop
//Use brute force to solve the rest, will be improved later:
for(int col = exp2k; col < n; col++){
for(int row = 0; row < n; row++){
for(int k = 0; k < n; k++)
C2[row][col] += A2[row][k] * B2[k][row];
}//end inner loop
}//end outter loop
for(int row = exp2k; row < n; row++){
for(int col = 0; col < exp2k; col++){
for(int k = 0; k < n; k++)
C2[row][col] += A2[row][k] * B2[k][row];
}//end inner loop
}//end outter loop
return C2;
}//end method
}//end class
StrassenMethod.java
package matrixalgorithm;
import java.util.Scanner;
public class StrassenMethod {
private double[][][][] A = new double[2][2][][];
private double[][][][] B = new double[2][2][][];
private double[][][][] C = new double[2][2][][];
/*//Codes for testing this class:
public static void main(String[] args) {
Scanner input = new Scanner(System.in);
System.out.println("Input size of the matrix: ");
int n = input.nextInt();
double[][] A = new double[n][n];
double[][] B = new double[n][n];
double[][] C = new double[n][n];
System.out.println("Input data for matrix A: ");
for (int r = 0; r < n; r++) {
for (int c = 0; c < n; c++) {
System.out.printf("Data of A[%d][%d]: ", r, c);
A[r][c] = input.nextDouble();
}//end inner loop
}//end loop
System.out.println("Input data for matrix B: ");
for (int r = 0; r < n; r++) {
for (int c = 0; c < n; c++) {
System.out.printf("Data of A[%d][%d]: ", r, c);
B[r][c] = input.nextDouble();
}//end inner loop
}//end loop
StrassenMethod algorithm = new StrassenMethod();
C = algorithm.strassenMultiplyMatrix(A, B, n);
System.out.println("Result from matrix C: ");
for (int r = 0; r < n; r++) {
for (int c = 0; c < n; c++) {
System.out.printf("Data of C[%d][%d]: %f\n", r, c, C[r][c]);
}//end inner loop
}//end outter loop
}//end main*/
public double[][] strassenMultiplyMatrix(double[][] A2, double B2[][], int n){
double[][] C2 = new double[n][n];
//Initialize the matrix:
for(int rowIndex = 0; rowIndex < n; rowIndex++)
for(int colIndex = 0; colIndex < n; colIndex++)
C2[rowIndex][colIndex] = 0.0;
if(n == 1)
C2[0][0] = A2[0][0] * B2[0][0];
//"Slice matrices into 2 * 2 parts:
else{
double[][][][] A = new double[2][2][n / 2][n / 2];
double[][][][] B = new double[2][2][n / 2][n / 2];
double[][][][] C = new double[2][2][n / 2][n / 2];
for(int r = 0; r < n / 2; r++){
for(int c = 0; c < n / 2; c++){
A[0][0][r][c] = A2[r][c];
A[0][1][r][c] = A2[r][n / 2 + c];
A[1][0][r][c] = A2[n / 2 + r][c];
A[1][1][r][c] = A2[n / 2 + r][n / 2 + c];
B[0][0][r][c] = B2[r][c];
B[0][1][r][c] = B2[r][n / 2 + c];
B[1][0][r][c] = B2[n / 2 + r][c];
B[1][1][r][c] = B2[n / 2 + r][n / 2 + c];
}//end loop
}//end loop
n = n / 2;
double[][][] S = new double[10][n][n];
S[0] = minusMatrix(B[0][1], B[1][1], n);
S[1] = addMatrix(A[0][0], A[0][1], n);
S[2] = addMatrix(A[1][0], A[1][1], n);
S[3] = minusMatrix(B[1][0], B[0][0], n);
S[4] = addMatrix(A[0][0], A[1][1], n);
S[5] = addMatrix(B[0][0], B[1][1], n);
S[6] = minusMatrix(A[0][1], A[1][1], n);
S[7] = addMatrix(B[1][0], B[1][1], n);
S[8] = minusMatrix(A[0][0], A[1][0], n);
S[9] = addMatrix(B[0][0], B[0][1], n);
double[][][] P = new double[7][n][n];
P[0] = strassenMultiplyMatrix(A[0][0], S[0], n);
P[1] = strassenMultiplyMatrix(S[1], B[1][1], n);
P[2] = strassenMultiplyMatrix(S[2], B[0][0], n);
P[3] = strassenMultiplyMatrix(A[1][1], S[3], n);
P[4] = strassenMultiplyMatrix(S[4], S[5], n);
P[5] = strassenMultiplyMatrix(S[6], S[7], n);
P[6] = strassenMultiplyMatrix(S[8], S[9], n);
C[0][0] = addMatrix(minusMatrix(addMatrix(P[4], P[3], n), P[1], n), P[5], n);
C[0][1] = addMatrix(P[0], P[1], n);
C[1][0] = addMatrix(P[2], P[3], n);
C[1][1] = minusMatrix(minusMatrix(addMatrix(P[4], P[0], n), P[2], n), P[6], n);
n *= 2;
for(int r = 0; r < n / 2; r++){
for(int c = 0; c < n / 2; c++){
C2[r][c] = C[0][0][r][c];
C2[r][n / 2 + c] = C[0][1][r][c];
C2[n / 2 + r][c] = C[1][0][r][c];
C2[n / 2 + r][n / 2 + c] = C[1][1][r][c];
}//end inner loop
}//end outter loop
}//end else
return C2;
}//end method
//Add two matrices according to matrix addition.
private double[][] addMatrix(double[][] A, double[][] B, int n){
double C[][] = new double[n][n];
for(int r = 0; r < n; r++)
for(int c = 0; c < n; c++)
C[r][c] = A[r][c] + B[r][c];
return C;
}//end method
//Substract two matrices according to matrix addition.
private double[][] minusMatrix(double[][] A, double[][] B, int n){
double C[][] = new double[n][n];
for(int r = 0; r < n; r++)
for(int c = 0; c < n; c++)
C[r][c] = A[r][c] - B[r][c];
return C;
}//end method
}//end class
希望本文所述對(duì)大家學(xué)習(xí)java程序設(shè)計(jì)有所幫助。
- java 二維數(shù)組矩陣乘法的實(shí)現(xiàn)方法
- Java矩陣連乘問(wèn)題(動(dòng)態(tài)規(guī)劃)算法實(shí)例分析
- Java實(shí)現(xiàn)的矩陣乘法示例
- Java實(shí)現(xiàn)的求逆矩陣算法示例
- Java實(shí)現(xiàn)輸出回環(huán)數(shù)(螺旋矩陣)的方法示例
- Java實(shí)現(xiàn)矩陣加減乘除及轉(zhuǎn)制等運(yùn)算功能示例
- Java實(shí)現(xiàn)的按照順時(shí)針或逆時(shí)針?lè)较蜉敵鲆粋€(gè)數(shù)字矩陣功能示例
- Java實(shí)現(xiàn)矩陣順時(shí)針旋轉(zhuǎn)90度的示例
- java實(shí)現(xiàn)的n*n矩陣求值及求逆矩陣算法示例
- 使用java寫(xiě)的矩陣乘法實(shí)例(Strassen算法)
相關(guān)文章
一篇文章告訴你如何在Java數(shù)組中插入一個(gè)字符
本篇文章主要介紹了Java數(shù)組中插入一個(gè)字符的相關(guān)方法,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來(lái)幫助2021-10-10
JAVA生產(chǎn)者消費(fèi)者(線(xiàn)程同步)代碼學(xué)習(xí)示例
這篇文章主要介紹了JAVA線(xiàn)程同步的代碼學(xué)習(xí)示例,大家參考使用吧2013-11-11
JavaWeb倉(cāng)庫(kù)管理系統(tǒng)詳解
這篇文章主要為大家詳細(xì)介紹了JavaWeb倉(cāng)庫(kù)管理系統(tǒng),文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2021-09-09
springboot實(shí)現(xiàn)將自定義日志格式存儲(chǔ)到mongodb中
這篇文章主要介紹了springboot實(shí)現(xiàn)將自定義日志格式存儲(chǔ)到mongodb中的操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2021-07-07
JavaWeb頁(yè)面中防止點(diǎn)擊Backspace網(wǎng)頁(yè)后退情況
當(dāng)鍵盤(pán)敲下后退鍵(Backspace)后怎么防止網(wǎng)頁(yè)后退情況呢?今天小編通過(guò)本文給大家詳細(xì)介紹下,感興趣的朋友一起看看吧2016-11-11
mybatis中<if>標(biāo)簽bool值類(lèi)型為false判斷方法
這篇文章主要給大家介紹了關(guān)于mybatis中<if>標(biāo)簽bool值類(lèi)型為false判斷方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家學(xué)習(xí)或者使用mybatis具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-08-08
線(xiàn)程池FutureTask異步執(zhí)行多任務(wù)實(shí)現(xiàn)詳解
這篇文章主要為大家介紹了線(xiàn)程池FutureTask異步執(zhí)行多任務(wù)實(shí)現(xiàn)詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-11-11

