001 /* 002 * Copyright (C) 2008-2010 by Holger Arndt, Frode Carlsen 003 * 004 * This file is part of the Universal Java Matrix Package (UJMP). 005 * See the NOTICE file distributed with this work for additional 006 * information regarding copyright ownership and licensing. 007 * 008 * UJMP is free software; you can redistribute it and/or modify 009 * it under the terms of the GNU Lesser General Public License as 010 * published by the Free Software Foundation; either version 2 011 * of the License, or (at your option) any later version. 012 * 013 * UJMP is distributed in the hope that it will be useful, 014 * but WITHOUT ANY WARRANTY; without even the implied warranty of 015 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 016 * GNU Lesser General Public License for more details. 017 * 018 * You should have received a copy of the GNU Lesser General Public 019 * License along with UJMP; if not, write to the 020 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, 021 * Boston, MA 02110-1301 USA 022 */ 023 024 package org.ujmp.core.calculation; 025 026 import static org.ujmp.core.util.VerifyUtil.assertTrue; 027 028 import java.util.Arrays; 029 import java.util.LinkedList; 030 import java.util.List; 031 import java.util.concurrent.Callable; 032 import java.util.concurrent.ExecutionException; 033 import java.util.concurrent.Future; 034 035 import org.ujmp.core.Matrix; 036 import org.ujmp.core.doublematrix.DenseDoubleMatrix2D; 037 import org.ujmp.core.doublematrix.impl.BlockDenseDoubleMatrix2D; 038 import org.ujmp.core.doublematrix.impl.BlockMatrixLayout; 039 import org.ujmp.core.doublematrix.impl.BlockMultiply; 040 import org.ujmp.core.doublematrix.impl.BlockMatrixLayout.BlockOrder; 041 import org.ujmp.core.interfaces.HasColumnMajorDoubleArray1D; 042 import org.ujmp.core.interfaces.HasRowMajorDoubleArray2D; 043 import org.ujmp.core.matrix.DenseMatrix; 044 import org.ujmp.core.matrix.DenseMatrix2D; 045 import org.ujmp.core.matrix.SparseMatrix; 046 import org.ujmp.core.util.AbstractPlugin; 047 import org.ujmp.core.util.UJMPSettings; 048 import org.ujmp.core.util.VerifyUtil; 049 import org.ujmp.core.util.concurrent.PFor; 050 import org.ujmp.core.util.concurrent.UJMPThreadPoolExecutor; 051 052 public class Mtimes { 053 public static int THRESHOLD = 100; 054 055 public static final MtimesCalculation<Matrix, Matrix, Matrix> MATRIX = new MtimesMatrix(); 056 057 public static final MtimesCalculation<DenseMatrix, DenseMatrix, DenseMatrix> DENSEMATRIX = new MtimesDenseMatrix(); 058 059 public static final MtimesCalculation<DenseMatrix2D, DenseMatrix2D, DenseMatrix2D> DENSEMATRIX2D = new MtimesDenseMatrix2D(); 060 061 public static final MtimesCalculation<DenseDoubleMatrix2D, DenseDoubleMatrix2D, DenseDoubleMatrix2D> DENSEDOUBLEMATRIX2D = new MtimesDenseDoubleMatrix2D(); 062 063 public static final MtimesCalculation<SparseMatrix, Matrix, Matrix> SPARSEMATRIX1 = new MtimesSparseMatrix1(); 064 065 public static final MtimesCalculation<Matrix, SparseMatrix, Matrix> SPARSEMATRIX2 = new MtimesSparseMatrix2(); 066 067 public static MtimesCalculation<Matrix, Matrix, Matrix> MTIMES_JBLAS = null; 068 069 public static final boolean RESET_BLOCK_ORDER = false; 070 071 static { 072 init(); 073 } 074 075 @SuppressWarnings("unchecked") 076 public static void init() { 077 try { 078 AbstractPlugin p = (AbstractPlugin) Class.forName("org.ujmp.jblas.Plugin") 079 .newInstance(); 080 if (p.isAvailable()) { 081 MTIMES_JBLAS = (MtimesCalculation<Matrix, Matrix, Matrix>) Class.forName( 082 "org.ujmp.jblas.calculation.Mtimes").newInstance(); 083 } 084 } catch (Throwable t) { 085 } 086 } 087 088 } 089 090 class MtimesMatrix implements MtimesCalculation<Matrix, Matrix, Matrix> { 091 092 public final void calc(final Matrix source1, final Matrix source2, final Matrix target) { 093 if (source1 instanceof DenseDoubleMatrix2D && source2 instanceof DenseDoubleMatrix2D 094 && target instanceof DenseDoubleMatrix2D) { 095 Mtimes.DENSEDOUBLEMATRIX2D.calc((DenseDoubleMatrix2D) source1, 096 (DenseDoubleMatrix2D) source2, (DenseDoubleMatrix2D) target); 097 } else if (source1 instanceof DenseMatrix2D && source2 instanceof DenseMatrix2D 098 && target instanceof DenseMatrix2D) { 099 Mtimes.DENSEMATRIX2D.calc((DenseMatrix2D) source1, (DenseMatrix2D) source2, 100 (DenseMatrix2D) target); 101 } else if (source1 instanceof DenseMatrix && source2 instanceof DenseMatrix 102 && target instanceof DenseMatrix) { 103 Mtimes.DENSEMATRIX.calc((DenseMatrix) source1, (DenseMatrix) source2, 104 (DenseMatrix) target); 105 } else if (source1 instanceof SparseMatrix) { 106 Mtimes.SPARSEMATRIX1.calc((SparseMatrix) source1, source2, target); 107 } else if (source2 instanceof SparseMatrix) { 108 Mtimes.SPARSEMATRIX2.calc(source1, (SparseMatrix) source2, target); 109 } else { 110 gemm(source1, source2, target); 111 } 112 } 113 114 private final void gemm(final Matrix A, final Matrix B, final Matrix C) { 115 VerifyUtil.assert2D(A); 116 VerifyUtil.assert2D(B); 117 VerifyUtil.assert2D(C); 118 final int m1RowCount = (int) A.getRowCount(); 119 final int m1ColumnCount = (int) A.getColumnCount(); 120 final int m2RowCount = (int) B.getRowCount(); 121 final int m2ColumnCount = (int) B.getColumnCount(); 122 VerifyUtil.assertEquals(m1ColumnCount, m2RowCount, "matrices have wrong sizes"); 123 VerifyUtil.assertEquals(m1RowCount, C.getRowCount(), "matrices have wrong sizes"); 124 VerifyUtil.assertEquals(m2ColumnCount, C.getColumnCount(), "matrices have wrong sizes"); 125 126 if (m1RowCount >= Mtimes.THRESHOLD && m1ColumnCount >= Mtimes.THRESHOLD 127 && m2ColumnCount >= Mtimes.THRESHOLD) { 128 new PFor(0, m2ColumnCount - 1) { 129 130 @Override 131 public void step(int i) { 132 for (int irow = 0; irow < m1RowCount; ++irow) { 133 C.setAsDouble(0.0d, irow, i); 134 } 135 for (int lcol = 0; lcol < m1ColumnCount; ++lcol) { 136 final double temp = B.getAsDouble(lcol, i); 137 if (temp != 0.0d) { 138 for (int irow = 0; irow < m1RowCount; ++irow) { 139 C.setAsDouble(C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol) 140 * temp, irow, i); 141 } 142 } 143 } 144 } 145 }; 146 } else { 147 for (int i = 0; i < m2ColumnCount; i++) { 148 for (int irow = 0; irow < m1RowCount; ++irow) { 149 C.setAsDouble(0.0d, irow, i); 150 } 151 for (int lcol = 0; lcol < m1ColumnCount; ++lcol) { 152 final double temp = B.getAsDouble(lcol, i); 153 if (temp != 0.0d) { 154 for (int irow = 0; irow < m1RowCount; ++irow) { 155 C.setAsDouble( 156 C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol) * temp, 157 irow, i); 158 } 159 } 160 } 161 } 162 } 163 } 164 }; 165 166 class MtimesDenseMatrix implements MtimesCalculation<DenseMatrix, DenseMatrix, DenseMatrix> { 167 168 public final void calc(final DenseMatrix source1, final DenseMatrix source2, 169 final DenseMatrix target) { 170 if (source1 instanceof DenseMatrix2D && source2 instanceof DenseMatrix2D 171 && target instanceof DenseMatrix2D) { 172 Mtimes.DENSEMATRIX2D.calc((DenseMatrix2D) source1, (DenseMatrix2D) source2, 173 (DenseMatrix2D) target); 174 } else { 175 gemm(source1, source2, target); 176 } 177 } 178 179 private final void gemm(final DenseMatrix A, final DenseMatrix B, final DenseMatrix C) { 180 VerifyUtil.assert2D(A); 181 VerifyUtil.assert2D(B); 182 VerifyUtil.assert2D(C); 183 final int m1RowCount = (int) A.getRowCount(); 184 final int m1ColumnCount = (int) A.getColumnCount(); 185 final int m2RowCount = (int) B.getRowCount(); 186 final int m2ColumnCount = (int) B.getColumnCount(); 187 VerifyUtil.assertEquals(m1ColumnCount, m2RowCount, "matrices have wrong sizes"); 188 VerifyUtil.assertEquals(m1RowCount, C.getRowCount(), "matrices have wrong sizes"); 189 VerifyUtil.assertEquals(m2ColumnCount, C.getColumnCount(), "matrices have wrong sizes"); 190 191 if (m1RowCount >= Mtimes.THRESHOLD && m1ColumnCount >= Mtimes.THRESHOLD 192 && m2ColumnCount >= Mtimes.THRESHOLD) { 193 new PFor(0, m2ColumnCount - 1) { 194 195 @Override 196 public void step(int i) { 197 for (int irow = 0; irow < m1RowCount; ++irow) { 198 C.setAsDouble(0.0d, irow, i); 199 } 200 for (int lcol = 0; lcol < m1ColumnCount; ++lcol) { 201 final double temp = B.getAsDouble(lcol, i); 202 if (temp != 0.0d) { 203 for (int irow = 0; irow < m1RowCount; ++irow) { 204 C.setAsDouble(C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol) 205 * temp, irow, i); 206 } 207 } 208 } 209 } 210 }; 211 } else { 212 for (int i = 0; i < m2ColumnCount; i++) { 213 for (int irow = 0; irow < m1RowCount; ++irow) { 214 C.setAsDouble(0.0d, irow, i); 215 } 216 for (int lcol = 0; lcol < m1ColumnCount; ++lcol) { 217 final double temp = B.getAsDouble(lcol, i); 218 if (temp != 0.0d) { 219 for (int irow = 0; irow < m1RowCount; ++irow) { 220 C.setAsDouble( 221 C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol) * temp, 222 irow, i); 223 } 224 } 225 } 226 } 227 } 228 } 229 }; 230 231 class MtimesSparseMatrix1 implements MtimesCalculation<SparseMatrix, Matrix, Matrix> { 232 233 public final void calc(final SparseMatrix source1, final Matrix source2, final Matrix target) { 234 VerifyUtil.assert2D(source1); 235 VerifyUtil.assert2D(source2); 236 VerifyUtil.assert2D(target); 237 VerifyUtil.assertEquals(source1.getColumnCount(), source2.getRowCount(), 238 "matrices have wrong sizes"); 239 VerifyUtil.assertEquals(target.getRowCount(), source1.getRowCount(), 240 "matrices have wrong sizes"); 241 VerifyUtil.assertEquals(target.getColumnCount(), source2.getColumnCount(), 242 "matrices have wrong sizes"); 243 target.clear(); 244 for (long[] c1 : source1.availableCoordinates()) { 245 final double v1 = source1.getAsDouble(c1); 246 if (v1 != 0.0d) { 247 for (long col2 = source2.getColumnCount(); --col2 != -1;) { 248 final double v2 = source2.getAsDouble(c1[1], col2); 249 final double temp = v1 * v2; 250 if (temp != 0.0d) { 251 final double v3 = target.getAsDouble(c1[0], col2); 252 target.setAsDouble(v3 + temp, c1[0], col2); 253 } 254 } 255 } 256 } 257 } 258 }; 259 260 class MtimesSparseMatrix2 implements MtimesCalculation<Matrix, SparseMatrix, Matrix> { 261 262 public final void calc(final Matrix source1, final SparseMatrix source2, final Matrix target) { 263 VerifyUtil.assert2D(source1); 264 VerifyUtil.assert2D(source2); 265 VerifyUtil.assert2D(target); 266 VerifyUtil.assertEquals(source1.getColumnCount(), source2.getRowCount(), 267 "matrices have wrong sizes"); 268 VerifyUtil.assertEquals(target.getRowCount(), source1.getRowCount(), 269 "matrices have wrong sizes"); 270 VerifyUtil.assertEquals(target.getColumnCount(), source2.getColumnCount(), 271 "matrices have wrong sizes"); 272 target.clear(); 273 for (long[] c2 : source2.availableCoordinates()) { 274 final double v2 = source2.getAsDouble(c2); 275 if (v2 != 0.0d) { 276 for (long row1 = source1.getRowCount(); --row1 != -1;) { 277 final double v1 = source1.getAsDouble(row1, c2[0]); 278 final double temp = v1 * v2; 279 if (temp != 0.0d) { 280 final double v3 = target.getAsDouble(row1, c2[1]); 281 target.setAsDouble(v3 + temp, row1, c2[1]); 282 } 283 } 284 } 285 } 286 } 287 }; 288 289 class MtimesDenseMatrix2D implements MtimesCalculation<DenseMatrix2D, DenseMatrix2D, DenseMatrix2D> { 290 291 public final void calc(final DenseMatrix2D source1, final DenseMatrix2D source2, 292 final DenseMatrix2D target) { 293 if (source1 instanceof DenseDoubleMatrix2D && source2 instanceof DenseDoubleMatrix2D 294 && target instanceof DenseDoubleMatrix2D) { 295 Mtimes.DENSEDOUBLEMATRIX2D.calc((DenseDoubleMatrix2D) source1, 296 (DenseDoubleMatrix2D) source2, (DenseDoubleMatrix2D) target); 297 } else { 298 gemm(source1, source2, target); 299 } 300 } 301 302 private final void gemm(final DenseMatrix2D A, final DenseMatrix2D B, final DenseMatrix2D C) { 303 VerifyUtil.assert2D(A); 304 VerifyUtil.assert2D(B); 305 VerifyUtil.assert2D(C); 306 final int m1RowCount = (int) A.getRowCount(); 307 final int m1ColumnCount = (int) A.getColumnCount(); 308 final int m2RowCount = (int) B.getRowCount(); 309 final int m2ColumnCount = (int) B.getColumnCount(); 310 VerifyUtil.assertEquals(m1ColumnCount, m2RowCount, "matrices have wrong size"); 311 VerifyUtil.assertEquals(m1RowCount, C.getRowCount(), "matrices have wrong size"); 312 VerifyUtil.assertEquals(m2ColumnCount, C.getColumnCount(), "matrices have wrong size"); 313 314 if (m1RowCount >= Mtimes.THRESHOLD && m1ColumnCount >= Mtimes.THRESHOLD 315 && m2ColumnCount >= Mtimes.THRESHOLD) { 316 new PFor(0, m2ColumnCount - 1) { 317 318 @Override 319 public void step(int i) { 320 for (int irow = 0; irow < m1RowCount; ++irow) { 321 C.setAsDouble(0.0d, irow, i); 322 } 323 for (int lcol = 0; lcol < m1ColumnCount; ++lcol) { 324 final double temp = B.getAsDouble(lcol, i); 325 if (temp != 0.0d) { 326 for (int irow = 0; irow < m1RowCount; ++irow) { 327 C.setAsDouble(C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol) 328 * temp, irow, i); 329 } 330 } 331 } 332 } 333 }; 334 } else { 335 for (int i = 0; i < m2ColumnCount; i++) { 336 for (int irow = 0; irow < m1RowCount; ++irow) { 337 C.setAsDouble(0.0d, irow, i); 338 } 339 for (int lcol = 0; lcol < m1ColumnCount; ++lcol) { 340 final double temp = B.getAsDouble(lcol, i); 341 if (temp != 0.0d) { 342 for (int irow = 0; irow < m1RowCount; ++irow) { 343 C.setAsDouble( 344 C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol) * temp, 345 irow, i); 346 } 347 } 348 } 349 } 350 } 351 } 352 }; 353 354 /** 355 * Contains matrix multiplication methods for different matrix implementations 356 * 357 * @author Holger Arndt 358 * @author Frode Carlsen 359 * 360 */ 361 class MtimesDenseDoubleMatrix2D implements 362 MtimesCalculation<DenseDoubleMatrix2D, DenseDoubleMatrix2D, DenseDoubleMatrix2D> { 363 364 public final void calc(final DenseDoubleMatrix2D source1, final DenseDoubleMatrix2D source2, 365 final DenseDoubleMatrix2D target) { 366 assertTrue(source1 != null, "a == null"); 367 assertTrue(source2 != null, "b == null"); 368 assertTrue(target != null, "c == null"); 369 assertTrue(source1.getColumnCount() == source2.getRowCount(), "a.cols!=b.rows"); 370 assertTrue(source1.getRowCount() == target.getRowCount(), "a.rows!=c.rows"); 371 assertTrue(source2.getColumnCount() == target.getColumnCount(), "a.cols!=c.cols"); 372 if (source1.getRowCount() >= Mtimes.THRESHOLD 373 && source1.getColumnCount() >= Mtimes.THRESHOLD 374 && source2.getColumnCount() >= Mtimes.THRESHOLD) { 375 if (Mtimes.MTIMES_JBLAS != null && UJMPSettings.isUseJBlas()) { 376 Mtimes.MTIMES_JBLAS.calc((DenseDoubleMatrix2D) source1, 377 (DenseDoubleMatrix2D) source2, (DenseDoubleMatrix2D) target); 378 } else if (UJMPSettings.isUseBlockMatrixMultiply()) { 379 calcBlockMatrixMultiThreaded(source1, source2, target); 380 } else if (source1 instanceof HasColumnMajorDoubleArray1D 381 && source2 instanceof HasColumnMajorDoubleArray1D 382 && target instanceof HasColumnMajorDoubleArray1D) { 383 calcDoubleArrayMultiThreaded(((HasColumnMajorDoubleArray1D) source1) 384 .getColumnMajorDoubleArray1D(), (int) source1.getRowCount(), (int) source1 385 .getColumnCount(), ((HasColumnMajorDoubleArray1D) source2) 386 .getColumnMajorDoubleArray1D(), (int) source2.getRowCount(), (int) source2 387 .getColumnCount(), ((HasColumnMajorDoubleArray1D) target) 388 .getColumnMajorDoubleArray1D()); 389 } else if (source1 instanceof HasRowMajorDoubleArray2D 390 && source2 instanceof HasRowMajorDoubleArray2D 391 && target instanceof HasRowMajorDoubleArray2D) { 392 calcDoubleArray2DMultiThreaded(((HasRowMajorDoubleArray2D) source1) 393 .getRowMajorDoubleArray2D(), ((HasRowMajorDoubleArray2D) source2) 394 .getRowMajorDoubleArray2D(), ((HasRowMajorDoubleArray2D) target) 395 .getRowMajorDoubleArray2D()); 396 } else { 397 calcDenseDoubleMatrix2DMultiThreaded(source1, source2, target); 398 } 399 } else { 400 if (source1 instanceof HasColumnMajorDoubleArray1D 401 && source2 instanceof HasColumnMajorDoubleArray1D 402 && target instanceof HasColumnMajorDoubleArray1D) { 403 gemmDoubleArraySingleThreaded(((HasColumnMajorDoubleArray1D) source1) 404 .getColumnMajorDoubleArray1D(), (int) source1.getRowCount(), (int) source1 405 .getColumnCount(), ((HasColumnMajorDoubleArray1D) source2) 406 .getColumnMajorDoubleArray1D(), (int) source2.getRowCount(), (int) source2 407 .getColumnCount(), ((HasColumnMajorDoubleArray1D) target) 408 .getColumnMajorDoubleArray1D()); 409 } else if (source1 instanceof HasRowMajorDoubleArray2D 410 && source2 instanceof HasRowMajorDoubleArray2D 411 && target instanceof HasRowMajorDoubleArray2D) { 412 calcDoubleArray2DSingleThreaded(((HasRowMajorDoubleArray2D) source1) 413 .getRowMajorDoubleArray2D(), ((HasRowMajorDoubleArray2D) source2) 414 .getRowMajorDoubleArray2D(), ((HasRowMajorDoubleArray2D) target) 415 .getRowMajorDoubleArray2D()); 416 } else { 417 calcDenseDoubleMatrix2DSingleThreaded(source1, source2, target); 418 } 419 } 420 } 421 422 private void calcBlockMatrixMultiThreaded(DenseDoubleMatrix2D source1, 423 DenseDoubleMatrix2D source2, DenseDoubleMatrix2D target) { 424 BlockDenseDoubleMatrix2D a = null; 425 BlockDenseDoubleMatrix2D b = null; 426 BlockDenseDoubleMatrix2D c = null; 427 if (source1 instanceof BlockDenseDoubleMatrix2D) { 428 a = (BlockDenseDoubleMatrix2D) source1; 429 } else { 430 a = new BlockDenseDoubleMatrix2D(source1); 431 } 432 if (source2 instanceof BlockDenseDoubleMatrix2D 433 && a.getBlockStripeSize() == ((BlockDenseDoubleMatrix2D) source2) 434 .getBlockStripeSize()) { 435 b = (BlockDenseDoubleMatrix2D) source2; 436 } else { 437 b = new BlockDenseDoubleMatrix2D(source2, a.getBlockStripeSize(), 438 BlockOrder.COLUMNMAJOR); 439 } 440 final int arows = (int) a.getRowCount(); 441 final int bcols = (int) b.getColumnCount(); 442 if (target instanceof BlockDenseDoubleMatrix2D 443 && a.getBlockStripeSize() == ((BlockDenseDoubleMatrix2D) target) 444 .getBlockStripeSize()) { 445 c = (BlockDenseDoubleMatrix2D) target; 446 } else { 447 c = new BlockDenseDoubleMatrix2D(arows, bcols, a.getBlockStripeSize(), 448 BlockOrder.ROWMAJOR); 449 } 450 451 // force optimal block order 452 BlockOrder prevA = a.setBlockOrder(BlockOrder.ROWMAJOR); 453 BlockOrder prevB = b.setBlockOrder(BlockOrder.COLUMNMAJOR); 454 455 blockMultiplyMultiThreaded(a, b, c); 456 457 if (c != target) { 458 for (int j = bcols; --j != -1;) { 459 for (int i = arows; --i != -1;) { 460 target.setDouble(c.getDouble(i, j), i, j); 461 } 462 } 463 } 464 465 // reset block order 466 if (Mtimes.RESET_BLOCK_ORDER) { 467 a.setBlockOrder(prevA); 468 b.setBlockOrder(prevB); 469 } 470 } 471 472 private final void gemmDoubleArraySingleThreaded(final double[] A, final int m1RowCount, 473 final int m1ColumnCount, final double[] B, final int m2RowCount, 474 final int m2ColumnCount, final double[] C) { 475 476 for (int j = 0; j < m2ColumnCount; j++) { 477 final int jcolTimesM1RowCount = j * m1RowCount; 478 final int jcolTimesM1ColumnCount = j * m1ColumnCount; 479 Arrays.fill(C, jcolTimesM1RowCount, jcolTimesM1RowCount + m1RowCount, 0.0d); 480 for (int lcol = 0; lcol < m1ColumnCount; ++lcol) { 481 final double temp = B[lcol + jcolTimesM1ColumnCount]; 482 if (temp != 0.0d) { 483 final int lcolTimesM1RowCount = lcol * m1RowCount; 484 for (int irow = 0; irow < m1RowCount; ++irow) { 485 C[irow + jcolTimesM1RowCount] += A[irow + lcolTimesM1RowCount] * temp; 486 } 487 } 488 } 489 } 490 } 491 492 private final void calcDoubleArrayMultiThreaded(final double[] A, final int m1RowCount, 493 final int m1ColumnCount, final double[] B, final int m2RowCount, 494 final int m2ColumnCount, final double[] C) { 495 new PFor(0, m2ColumnCount - 1) { 496 @Override 497 public void step(int i) { 498 final int jcolTimesM1RowCount = i * m1RowCount; 499 final int jcolTimesM1ColumnCount = i * m1ColumnCount; 500 Arrays.fill(C, jcolTimesM1RowCount, jcolTimesM1RowCount + m1RowCount, 0.0d); 501 for (int lcol = 0; lcol < m1ColumnCount; ++lcol) { 502 final double temp = B[lcol + jcolTimesM1ColumnCount]; 503 if (temp != 0.0d) { 504 final int lcolTimesM1RowCount = lcol * m1RowCount; 505 for (int irow = 0; irow < m1RowCount; ++irow) { 506 C[irow + jcolTimesM1RowCount] += A[irow + lcolTimesM1RowCount] * temp; 507 } 508 } 509 } 510 } 511 }; 512 } 513 514 private final void calcDoubleArray2DSingleThreaded(final double[][] m1, final double[][] m2, 515 final double[][] ret) { 516 final int columnCount = m1[0].length; 517 final double[] columns = new double[columnCount]; 518 519 for (int c = m2[0].length; --c != -1;) { 520 for (int k = columnCount; --k != -1;) { 521 columns[k] = m2[k][c]; 522 } 523 for (int r = m1.length; --r != -1;) { 524 double sum = 0.0d; 525 final double[] row = m1[r]; 526 for (int k = columnCount; --k != -1;) { 527 sum += row[k] * columns[k]; 528 } 529 ret[r][c] = sum; 530 } 531 } 532 } 533 534 private final void calcDoubleArray2DMultiThreaded(final double[][] m1, final double[][] m2, 535 final double[][] ret) { 536 final int columnCount = m1[0].length; 537 final double[] columns = new double[columnCount]; 538 539 new PFor(0, m2[0].length - 1) { 540 @Override 541 public void step(int i) { 542 for (int k = columnCount; --k != -1;) { 543 columns[k] = m2[k][i]; 544 } 545 for (int r = m1.length; --r != -1;) { 546 double sum = 0.0d; 547 final double[] row = m1[r]; 548 for (int k = columnCount; --k != -1;) { 549 sum += row[k] * columns[k]; 550 } 551 ret[r][i] = sum; 552 } 553 } 554 }; 555 } 556 557 private final void calcDenseDoubleMatrix2DSingleThreaded(final DenseDoubleMatrix2D A, 558 final DenseDoubleMatrix2D B, final DenseDoubleMatrix2D C) { 559 final int m1RowCount = (int) A.getRowCount(); 560 final int m1ColumnCount = (int) A.getColumnCount(); 561 final int m2ColumnCount = (int) B.getColumnCount(); 562 563 for (int i = 0; i < m2ColumnCount; i++) { 564 for (int irow = 0; irow < m1RowCount; ++irow) { 565 C.setDouble(0.0d, irow, i); 566 } 567 for (int lcol = 0; lcol < m1ColumnCount; ++lcol) { 568 final double temp = B.getDouble(lcol, i); 569 if (temp != 0.0d) { 570 for (int irow = 0; irow < m1RowCount; ++irow) { 571 C.setDouble(C.getDouble(irow, i) + A.getDouble(irow, lcol) * temp, irow, i); 572 } 573 } 574 } 575 } 576 } 577 578 private final void calcDenseDoubleMatrix2DMultiThreaded(final DenseDoubleMatrix2D A, 579 final DenseDoubleMatrix2D B, final DenseDoubleMatrix2D C) { 580 final int m1RowCount = (int) A.getRowCount(); 581 final int m1ColumnCount = (int) A.getColumnCount(); 582 final int m2ColumnCount = (int) B.getColumnCount(); 583 584 new PFor(0, m2ColumnCount - 1) { 585 @Override 586 public void step(int i) { 587 for (int irow = 0; irow < m1RowCount; ++irow) { 588 C.setDouble(0.0d, irow, i); 589 } 590 for (int lcol = 0; lcol < m1ColumnCount; ++lcol) { 591 final double temp = B.getDouble(lcol, i); 592 if (temp != 0.0d) { 593 for (int irow = 0; irow < m1RowCount; ++irow) { 594 C.setDouble(C.getDouble(irow, i) + A.getDouble(irow, lcol) * temp, 595 irow, i); 596 } 597 } 598 } 599 } 600 }; 601 } 602 603 /** 604 * Multiply two matrices concurrently with the given Executor to handle 605 * parallel tasks. 606 * 607 * @param b 608 * - matrix to multiply this with. 609 * @param executorService 610 * - to handle concurrent multiplication tasks. 611 * @return new matrix C containing result of matrix multiplication C = A x 612 * B. 613 */ 614 /** 615 * @param a 616 * @param b 617 * @param c 618 * @return 619 */ 620 private BlockDenseDoubleMatrix2D blockMultiplyMultiThreaded(final BlockDenseDoubleMatrix2D a, 621 final BlockDenseDoubleMatrix2D b, final BlockDenseDoubleMatrix2D c) { 622 final BlockMatrixLayout al = a.getBlockLayout(); 623 final BlockMatrixLayout bl = b.getBlockLayout(); 624 assertTrue(al.columns == bl.rows, "b.rows != this.columns"); 625 assertTrue(al.blockStripe == bl.blockStripe, "block sizes differ: %s != %s", 626 al.blockStripe, bl.blockStripe); 627 628 final List<Callable<Void>> tasks = new LinkedList<Callable<Void>>(); 629 630 final int kMax = (int) b.getColumnCount(); 631 final int jMax = (int) a.getColumnCount(); 632 final int iMax = (int) a.getRowCount(); 633 634 final int bColSlice = Math.min(al.blockStripe, kMax); 635 final int aColSlice = Math.min(al.blockStripe, jMax); 636 final int aRowSlice = Math.min(al.blockStripe, iMax); 637 638 // Number of blocks to take for each concurrent task. 639 final int blocksPerTask = 1; 640 final int blocksPerTaskDimJ = selectBlocksPerTaskDimJ(al.blockStripe, iMax, jMax, kMax); 641 642 for (int k = 0, kStride; k < kMax; k += kStride) { 643 kStride = Math.min(blocksPerTask * bColSlice, kMax - k); 644 645 for (int j = 0, jStride; j < jMax; j += jStride) { 646 jStride = Math.min(blocksPerTaskDimJ * aColSlice, jMax - j); 647 648 for (int i = 0, iStride; i < iMax; i += iStride) { 649 iStride = Math.min(blocksPerTask * aRowSlice, iMax - i); 650 651 tasks.add(new BlockMultiply(a, b, c, i, (i + iStride), j, (j + jStride), k, 652 (k + kStride))); 653 } 654 655 } 656 } 657 658 // wait for all tasks to complete. 659 try { 660 for (Future<Void> f : UJMPThreadPoolExecutor.getInstance().invokeAll(tasks)) { 661 f.get(); 662 } 663 } catch (ExecutionException e) { 664 StringBuilder sb = new StringBuilder( 665 "Execution exception - while awaiting completion of matrix multiplication [" 666 + e.getMessage() + "]:"); 667 if (e.getCause() != null) { 668 for (StackTraceElement stackTraceElement : e.getCause().getStackTrace()) { 669 sb.append(stackTraceElement).append(" * "); 670 } 671 } 672 throw new RuntimeException(sb.toString(), e.getCause()); 673 } catch (final InterruptedException e) { 674 String msg = "Interrupted - while awaiting completion of matrix multiplication."; 675 throw new RuntimeException(msg + ": cause [" + e.getMessage() + "]", e); 676 } 677 678 return c; 679 } 680 681 // pick a suitable number of blocks to process per task for dimension J 682 // - if too small , then incurs extra gc and contention for synchronization 683 // - if set too large, then may not fully exploit parallelism 684 private int selectBlocksPerTaskDimJ(int blockStripe, int iMax, int jMax, int kMax) { 685 int adjust = (jMax % blockStripe > 0) ? 1 : 0; 686 if (jMax < (5 * blockStripe) || jMax <= iMax) { 687 // do not break this dimension into parallel tasks 688 return jMax / blockStripe + adjust; 689 } else { 690 // assume 2 parallell tasks 691 return Math.max(1, (jMax / blockStripe + adjust) / 2); 692 } 693 // may need something if jMax >>> iMax 694 } 695 };