استراسن الگوریتم

الگوریتم استراسن جهت ضرب ماتریس ها

یکی از راه های طراحی الگوریتم شکستن مسیله به زیرمسیله های کوچکتر می باشد، در کتاب های الگوریتم  این روش divide and conquer نامیده می شود. در این پست ابتدا به شرح الگوریتم استراسن و سپس پیاده سازی آن در جاوا خواهیم پرداخت.

در این لینک کد جاوا و سی پلاس پلاس الگوریتم استراسن آورده شده است.

  1. public static int[][] ikjAlgorithm(int[][] A, int[][] B) {
  2. int n = A.length;
  3.  
  4. // initialise C
  5. int[][] C = new int[n][n];
  6.  
  7. for (int i = 0; i < n; i++) {
  8. for (int k = 0; k < n; k++) {
  9. for (int j = 0; j < n; j++) {
  10. C[i][j] += A[i][k] * B[k][j];
  11. }
  12. }
  13. }
  14. return C;
  15. }
  16.  
  17. private static int[][] add(int[][] A, int[][] B) {
  18. int n = A.length;
  19. int[][] C = new int[n][n];
  20. for (int i = 0; i < n; i++) {
  21. for (int j = 0; j < n; j++) {
  22. C[i][j] = A[i][j] + B[i][j];
  23. }
  24. }
  25. return C;
  26. }
  27.  
  28. private static int[][] subtract(int[][] A, int[][] B) {
  29. int n = A.length;
  30. int[][] C = new int[n][n];
  31. for (int i = 0; i < n; i++) {
  32. for (int j = 0; j < n; j++) {
  33. C[i][j] = A[i][j] - B[i][j];
  34. }
  35. }
  36. return C;
  37. }
  38.  
  39. private static int nextPowerOfTwo(int n) {
  40. int log2 = (int) Math.ceil(Math.log(n) / Math.log(2));
  41. return (int) Math.pow(2, log2);
  42. }
  43.  
  44. public static int[][] strassen(ArrayList<arraylist> A,
  45. ArrayList<arraylist> B) {
  46. // Make the matrices bigger so that you can apply the strassen
  47. // algorithm recursively without having to deal with odd
  48. // matrix sizes
  49. int n = A.size();
  50. int m = nextPowerOfTwo(n);
  51. int[][] APrep = new int[m][m];
  52. int[][] BPrep = new int[m][m];
  53. for (int i = 0; i < n; i++) {
  54. for (int j = 0; j < n; j++) {
  55. APrep[i][j] = A.get(i).get(j);
  56. BPrep[i][j] = B.get(i).get(j);
  57. }
  58. }
  59.  
  60. int[][] CPrep = strassenR(APrep, BPrep);
  61. int[][] C = new int[n][n];
  62. for (int i = 0; i < n; i++) {
  63. for (int j = 0; j < n; j++) {
  64. C[i][j] = CPrep[i][j];
  65. }
  66. }
  67. return C;
  68. }
  69.  
  70. private static int[][] strassenR(int[][] A, int[][] B) {
  71. int n = A.length;
  72.  
  73. if (n <= LEAF_SIZE) {
  74. return ikjAlgorithm(A, B);
  75. } else {
  76. // initializing the new sub-matrices
  77. int newSize = n / 2;
  78. int[][] a11 = new int[newSize][newSize];
  79. int[][] a12 = new int[newSize][newSize];
  80. int[][] a21 = new int[newSize][newSize];
  81. int[][] a22 = new int[newSize][newSize];
  82.  
  83. int[][] b11 = new int[newSize][newSize];
  84. int[][] b12 = new int[newSize][newSize];
  85. int[][] b21 = new int[newSize][newSize];
  86. int[][] b22 = new int[newSize][newSize];
  87.  
  88. int[][] aResult = new int[newSize][newSize];
  89. int[][] bResult = new int[newSize][newSize];
  90.  
  91. // dividing the matrices in 4 sub-matrices:
  92. for (int i = 0; i < newSize; i++) {
  93. for (int j = 0; j < newSize; j++) {
  94. a11[i][j] = A[i][j]; // top left
  95. a12[i][j] = A[i][j + newSize]; // top right
  96. a21[i][j] = A[i + newSize][j]; // bottom left
  97. a22[i][j] = A[i + newSize][j + newSize]; // bottom right
  98.  
  99. b11[i][j] = B[i][j]; // top left
  100. b12[i][j] = B[i][j + newSize]; // top right
  101. b21[i][j] = B[i + newSize][j]; // bottom left
  102. b22[i][j] = B[i + newSize][j + newSize]; // bottom right
  103. }
  104. }
  105.  
  106. // Calculating p1 to p7:
  107. aResult = add(a11, a22);
  108. bResult = add(b11, b22);
  109. int[][] p1 = strassenR(aResult, bResult);
  110. // p1 = (a11+a22) * (b11+b22)
  111.  
  112. aResult = add(a21, a22); // a21 + a22
  113. int[][] p2 = strassenR(aResult, b11); // p2 = (a21+a22) * (b11)
  114.  
  115. bResult = subtract(b12, b22); // b12 - b22
  116. int[][] p3 = strassenR(a11, bResult);
  117. // p3 = (a11) * (b12 - b22)
  118.  
  119. bResult = subtract(b21, b11); // b21 - b11
  120. int[][] p4 = strassenR(a22, bResult);
  121. // p4 = (a22) * (b21 - b11)
  122.  
  123. aResult = add(a11, a12); // a11 + a12
  124. int[][] p5 = strassenR(aResult, b22);
  125. // p5 = (a11+a12) * (b22)
  126.  
  127. aResult = subtract(a21, a11); // a21 - a11
  128. bResult = add(b11, b12); // b11 + b12
  129. int[][] p6 = strassenR(aResult, bResult);
  130. // p6 = (a21-a11) * (b11+b12)
  131.  
  132. aResult = subtract(a12, a22); // a12 - a22
  133. bResult = add(b21, b22); // b21 + b22
  134. int[][] p7 = strassenR(aResult, bResult);
  135. // p7 = (a12-a22) * (b21+b22)
  136.  
  137. // calculating c21, c21, c11 e c22:
  138. int[][] c12 = add(p3, p5); // c12 = p3 + p5
  139. int[][] c21 = add(p2, p4); // c21 = p2 + p4
  140.  
  141. aResult = add(p1, p4); // p1 + p4
  142. bResult = add(aResult, p7); // p1 + p4 + p7
  143. int[][] c11 = subtract(bResult, p5);
  144. // c11 = p1 + p4 - p5 + p7
  145.  
  146. aResult = add(p1, p3); // p1 + p3
  147. bResult = add(aResult, p6); // p1 + p3 + p6
  148. int[][] c22 = subtract(bResult, p2);
  149. // c22 = p1 + p3 - p2 + p6
  150.  
  151. // Grouping the results obtained in a single matrix:
  152. int[][] C = new int[n][n];
  153. for (int i = 0; i < newSize; i++) {
  154. for (int j = 0; j < newSize; j++) {
  155. C[i][j] = c11[i][j];
  156. C[i][j + newSize] = c12[i][j];
  157. C[i + newSize][j] = c21[i][j];
  158. C[i + newSize][j + newSize] = c22[i][j];
  159. }
  160. }
  161. return C;
  162. }

Sharing is caring!

پاسخ دهید

نشانی ایمیل شما منتشر نخواهد شد. بخش‌های موردنیاز علامت‌گذاری شده‌اند *