001    /*
002     * Copyright (C) 2008-2010 by Holger Arndt, Frode Carlsen
003     *
004     * This file is part of the Universal Java Matrix Package (UJMP).
005     * See the NOTICE file distributed with this work for additional
006     * information regarding copyright ownership and licensing.
007     *
008     * UJMP is free software; you can redistribute it and/or modify
009     * it under the terms of the GNU Lesser General Public License as
010     * published by the Free Software Foundation; either version 2
011     * of the License, or (at your option) any later version.
012     *
013     * UJMP is distributed in the hope that it will be useful,
014     * but WITHOUT ANY WARRANTY; without even the implied warranty of
015     * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
016     * GNU Lesser General Public License for more details.
017     *
018     * You should have received a copy of the GNU Lesser General Public
019     * License along with UJMP; if not, write to the
020     * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
021     * Boston, MA  02110-1301  USA
022     */
023    
024    package org.ujmp.core.calculation;
025    
026    import static org.ujmp.core.util.VerifyUtil.assertTrue;
027    
028    import java.util.Arrays;
029    import java.util.LinkedList;
030    import java.util.List;
031    import java.util.concurrent.Callable;
032    import java.util.concurrent.ExecutionException;
033    import java.util.concurrent.Future;
034    
035    import org.ujmp.core.Matrix;
036    import org.ujmp.core.doublematrix.DenseDoubleMatrix2D;
037    import org.ujmp.core.doublematrix.impl.BlockDenseDoubleMatrix2D;
038    import org.ujmp.core.doublematrix.impl.BlockMatrixLayout;
039    import org.ujmp.core.doublematrix.impl.BlockMultiply;
040    import org.ujmp.core.doublematrix.impl.BlockMatrixLayout.BlockOrder;
041    import org.ujmp.core.interfaces.HasColumnMajorDoubleArray1D;
042    import org.ujmp.core.interfaces.HasRowMajorDoubleArray2D;
043    import org.ujmp.core.matrix.DenseMatrix;
044    import org.ujmp.core.matrix.DenseMatrix2D;
045    import org.ujmp.core.matrix.SparseMatrix;
046    import org.ujmp.core.util.AbstractPlugin;
047    import org.ujmp.core.util.UJMPSettings;
048    import org.ujmp.core.util.VerifyUtil;
049    import org.ujmp.core.util.concurrent.PFor;
050    import org.ujmp.core.util.concurrent.UJMPThreadPoolExecutor;
051    
052    public class Mtimes {
053            public static int THRESHOLD = 100;
054    
055            public static final MtimesCalculation<Matrix, Matrix, Matrix> MATRIX = new MtimesMatrix();
056    
057            public static final MtimesCalculation<DenseMatrix, DenseMatrix, DenseMatrix> DENSEMATRIX = new MtimesDenseMatrix();
058    
059            public static final MtimesCalculation<DenseMatrix2D, DenseMatrix2D, DenseMatrix2D> DENSEMATRIX2D = new MtimesDenseMatrix2D();
060    
061            public static final MtimesCalculation<DenseDoubleMatrix2D, DenseDoubleMatrix2D, DenseDoubleMatrix2D> DENSEDOUBLEMATRIX2D = new MtimesDenseDoubleMatrix2D();
062    
063            public static final MtimesCalculation<SparseMatrix, Matrix, Matrix> SPARSEMATRIX1 = new MtimesSparseMatrix1();
064    
065            public static final MtimesCalculation<Matrix, SparseMatrix, Matrix> SPARSEMATRIX2 = new MtimesSparseMatrix2();
066    
067            public static MtimesCalculation<Matrix, Matrix, Matrix> MTIMES_JBLAS = null;
068    
069            public static final boolean RESET_BLOCK_ORDER = false;
070    
071            static {
072                    init();
073            }
074    
075            @SuppressWarnings("unchecked")
076            public static void init() {
077                    try {
078                            AbstractPlugin p = (AbstractPlugin) Class.forName("org.ujmp.jblas.Plugin")
079                                            .newInstance();
080                            if (p.isAvailable()) {
081                                    MTIMES_JBLAS = (MtimesCalculation<Matrix, Matrix, Matrix>) Class.forName(
082                                                    "org.ujmp.jblas.calculation.Mtimes").newInstance();
083                            }
084                    } catch (Throwable t) {
085                    }
086            }
087    
088    }
089    
090    class MtimesMatrix implements MtimesCalculation<Matrix, Matrix, Matrix> {
091    
092            public final void calc(final Matrix source1, final Matrix source2, final Matrix target) {
093                    if (source1 instanceof DenseDoubleMatrix2D && source2 instanceof DenseDoubleMatrix2D
094                                    && target instanceof DenseDoubleMatrix2D) {
095                            Mtimes.DENSEDOUBLEMATRIX2D.calc((DenseDoubleMatrix2D) source1,
096                                            (DenseDoubleMatrix2D) source2, (DenseDoubleMatrix2D) target);
097                    } else if (source1 instanceof DenseMatrix2D && source2 instanceof DenseMatrix2D
098                                    && target instanceof DenseMatrix2D) {
099                            Mtimes.DENSEMATRIX2D.calc((DenseMatrix2D) source1, (DenseMatrix2D) source2,
100                                            (DenseMatrix2D) target);
101                    } else if (source1 instanceof DenseMatrix && source2 instanceof DenseMatrix
102                                    && target instanceof DenseMatrix) {
103                            Mtimes.DENSEMATRIX.calc((DenseMatrix) source1, (DenseMatrix) source2,
104                                            (DenseMatrix) target);
105                    } else if (source1 instanceof SparseMatrix) {
106                            Mtimes.SPARSEMATRIX1.calc((SparseMatrix) source1, source2, target);
107                    } else if (source2 instanceof SparseMatrix) {
108                            Mtimes.SPARSEMATRIX2.calc(source1, (SparseMatrix) source2, target);
109                    } else {
110                            gemm(source1, source2, target);
111                    }
112            }
113    
114            private final void gemm(final Matrix A, final Matrix B, final Matrix C) {
115                    VerifyUtil.assert2D(A);
116                    VerifyUtil.assert2D(B);
117                    VerifyUtil.assert2D(C);
118                    final int m1RowCount = (int) A.getRowCount();
119                    final int m1ColumnCount = (int) A.getColumnCount();
120                    final int m2RowCount = (int) B.getRowCount();
121                    final int m2ColumnCount = (int) B.getColumnCount();
122                    VerifyUtil.assertEquals(m1ColumnCount, m2RowCount, "matrices have wrong sizes");
123                    VerifyUtil.assertEquals(m1RowCount, C.getRowCount(), "matrices have wrong sizes");
124                    VerifyUtil.assertEquals(m2ColumnCount, C.getColumnCount(), "matrices have wrong sizes");
125    
126                    if (m1RowCount >= Mtimes.THRESHOLD && m1ColumnCount >= Mtimes.THRESHOLD
127                                    && m2ColumnCount >= Mtimes.THRESHOLD) {
128                            new PFor(0, m2ColumnCount - 1) {
129    
130                                    @Override
131                                    public void step(int i) {
132                                            for (int irow = 0; irow < m1RowCount; ++irow) {
133                                                    C.setAsDouble(0.0d, irow, i);
134                                            }
135                                            for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
136                                                    final double temp = B.getAsDouble(lcol, i);
137                                                    if (temp != 0.0d) {
138                                                            for (int irow = 0; irow < m1RowCount; ++irow) {
139                                                                    C.setAsDouble(C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol)
140                                                                                    * temp, irow, i);
141                                                            }
142                                                    }
143                                            }
144                                    }
145                            };
146                    } else {
147                            for (int i = 0; i < m2ColumnCount; i++) {
148                                    for (int irow = 0; irow < m1RowCount; ++irow) {
149                                            C.setAsDouble(0.0d, irow, i);
150                                    }
151                                    for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
152                                            final double temp = B.getAsDouble(lcol, i);
153                                            if (temp != 0.0d) {
154                                                    for (int irow = 0; irow < m1RowCount; ++irow) {
155                                                            C.setAsDouble(
156                                                                            C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol) * temp,
157                                                                            irow, i);
158                                                    }
159                                            }
160                                    }
161                            }
162                    }
163            }
164    };
165    
166    class MtimesDenseMatrix implements MtimesCalculation<DenseMatrix, DenseMatrix, DenseMatrix> {
167    
168            public final void calc(final DenseMatrix source1, final DenseMatrix source2,
169                            final DenseMatrix target) {
170                    if (source1 instanceof DenseMatrix2D && source2 instanceof DenseMatrix2D
171                                    && target instanceof DenseMatrix2D) {
172                            Mtimes.DENSEMATRIX2D.calc((DenseMatrix2D) source1, (DenseMatrix2D) source2,
173                                            (DenseMatrix2D) target);
174                    } else {
175                            gemm(source1, source2, target);
176                    }
177            }
178    
179            private final void gemm(final DenseMatrix A, final DenseMatrix B, final DenseMatrix C) {
180                    VerifyUtil.assert2D(A);
181                    VerifyUtil.assert2D(B);
182                    VerifyUtil.assert2D(C);
183                    final int m1RowCount = (int) A.getRowCount();
184                    final int m1ColumnCount = (int) A.getColumnCount();
185                    final int m2RowCount = (int) B.getRowCount();
186                    final int m2ColumnCount = (int) B.getColumnCount();
187                    VerifyUtil.assertEquals(m1ColumnCount, m2RowCount, "matrices have wrong sizes");
188                    VerifyUtil.assertEquals(m1RowCount, C.getRowCount(), "matrices have wrong sizes");
189                    VerifyUtil.assertEquals(m2ColumnCount, C.getColumnCount(), "matrices have wrong sizes");
190    
191                    if (m1RowCount >= Mtimes.THRESHOLD && m1ColumnCount >= Mtimes.THRESHOLD
192                                    && m2ColumnCount >= Mtimes.THRESHOLD) {
193                            new PFor(0, m2ColumnCount - 1) {
194    
195                                    @Override
196                                    public void step(int i) {
197                                            for (int irow = 0; irow < m1RowCount; ++irow) {
198                                                    C.setAsDouble(0.0d, irow, i);
199                                            }
200                                            for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
201                                                    final double temp = B.getAsDouble(lcol, i);
202                                                    if (temp != 0.0d) {
203                                                            for (int irow = 0; irow < m1RowCount; ++irow) {
204                                                                    C.setAsDouble(C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol)
205                                                                                    * temp, irow, i);
206                                                            }
207                                                    }
208                                            }
209                                    }
210                            };
211                    } else {
212                            for (int i = 0; i < m2ColumnCount; i++) {
213                                    for (int irow = 0; irow < m1RowCount; ++irow) {
214                                            C.setAsDouble(0.0d, irow, i);
215                                    }
216                                    for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
217                                            final double temp = B.getAsDouble(lcol, i);
218                                            if (temp != 0.0d) {
219                                                    for (int irow = 0; irow < m1RowCount; ++irow) {
220                                                            C.setAsDouble(
221                                                                            C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol) * temp,
222                                                                            irow, i);
223                                                    }
224                                            }
225                                    }
226                            }
227                    }
228            }
229    };
230    
231    class MtimesSparseMatrix1 implements MtimesCalculation<SparseMatrix, Matrix, Matrix> {
232    
233            public final void calc(final SparseMatrix source1, final Matrix source2, final Matrix target) {
234                    VerifyUtil.assert2D(source1);
235                    VerifyUtil.assert2D(source2);
236                    VerifyUtil.assert2D(target);
237                    VerifyUtil.assertEquals(source1.getColumnCount(), source2.getRowCount(),
238                                    "matrices have wrong sizes");
239                    VerifyUtil.assertEquals(target.getRowCount(), source1.getRowCount(),
240                                    "matrices have wrong sizes");
241                    VerifyUtil.assertEquals(target.getColumnCount(), source2.getColumnCount(),
242                                    "matrices have wrong sizes");
243                    target.clear();
244                    for (long[] c1 : source1.availableCoordinates()) {
245                            final double v1 = source1.getAsDouble(c1);
246                            if (v1 != 0.0d) {
247                                    for (long col2 = source2.getColumnCount(); --col2 != -1;) {
248                                            final double v2 = source2.getAsDouble(c1[1], col2);
249                                            final double temp = v1 * v2;
250                                            if (temp != 0.0d) {
251                                                    final double v3 = target.getAsDouble(c1[0], col2);
252                                                    target.setAsDouble(v3 + temp, c1[0], col2);
253                                            }
254                                    }
255                            }
256                    }
257            }
258    };
259    
260    class MtimesSparseMatrix2 implements MtimesCalculation<Matrix, SparseMatrix, Matrix> {
261    
262            public final void calc(final Matrix source1, final SparseMatrix source2, final Matrix target) {
263                    VerifyUtil.assert2D(source1);
264                    VerifyUtil.assert2D(source2);
265                    VerifyUtil.assert2D(target);
266                    VerifyUtil.assertEquals(source1.getColumnCount(), source2.getRowCount(),
267                                    "matrices have wrong sizes");
268                    VerifyUtil.assertEquals(target.getRowCount(), source1.getRowCount(),
269                                    "matrices have wrong sizes");
270                    VerifyUtil.assertEquals(target.getColumnCount(), source2.getColumnCount(),
271                                    "matrices have wrong sizes");
272                    target.clear();
273                    for (long[] c2 : source2.availableCoordinates()) {
274                            final double v2 = source2.getAsDouble(c2);
275                            if (v2 != 0.0d) {
276                                    for (long row1 = source1.getRowCount(); --row1 != -1;) {
277                                            final double v1 = source1.getAsDouble(row1, c2[0]);
278                                            final double temp = v1 * v2;
279                                            if (temp != 0.0d) {
280                                                    final double v3 = target.getAsDouble(row1, c2[1]);
281                                                    target.setAsDouble(v3 + temp, row1, c2[1]);
282                                            }
283                                    }
284                            }
285                    }
286            }
287    };
288    
289    class MtimesDenseMatrix2D implements MtimesCalculation<DenseMatrix2D, DenseMatrix2D, DenseMatrix2D> {
290    
291            public final void calc(final DenseMatrix2D source1, final DenseMatrix2D source2,
292                            final DenseMatrix2D target) {
293                    if (source1 instanceof DenseDoubleMatrix2D && source2 instanceof DenseDoubleMatrix2D
294                                    && target instanceof DenseDoubleMatrix2D) {
295                            Mtimes.DENSEDOUBLEMATRIX2D.calc((DenseDoubleMatrix2D) source1,
296                                            (DenseDoubleMatrix2D) source2, (DenseDoubleMatrix2D) target);
297                    } else {
298                            gemm(source1, source2, target);
299                    }
300            }
301    
302            private final void gemm(final DenseMatrix2D A, final DenseMatrix2D B, final DenseMatrix2D C) {
303                    VerifyUtil.assert2D(A);
304                    VerifyUtil.assert2D(B);
305                    VerifyUtil.assert2D(C);
306                    final int m1RowCount = (int) A.getRowCount();
307                    final int m1ColumnCount = (int) A.getColumnCount();
308                    final int m2RowCount = (int) B.getRowCount();
309                    final int m2ColumnCount = (int) B.getColumnCount();
310                    VerifyUtil.assertEquals(m1ColumnCount, m2RowCount, "matrices have wrong size");
311                    VerifyUtil.assertEquals(m1RowCount, C.getRowCount(), "matrices have wrong size");
312                    VerifyUtil.assertEquals(m2ColumnCount, C.getColumnCount(), "matrices have wrong size");
313    
314                    if (m1RowCount >= Mtimes.THRESHOLD && m1ColumnCount >= Mtimes.THRESHOLD
315                                    && m2ColumnCount >= Mtimes.THRESHOLD) {
316                            new PFor(0, m2ColumnCount - 1) {
317    
318                                    @Override
319                                    public void step(int i) {
320                                            for (int irow = 0; irow < m1RowCount; ++irow) {
321                                                    C.setAsDouble(0.0d, irow, i);
322                                            }
323                                            for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
324                                                    final double temp = B.getAsDouble(lcol, i);
325                                                    if (temp != 0.0d) {
326                                                            for (int irow = 0; irow < m1RowCount; ++irow) {
327                                                                    C.setAsDouble(C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol)
328                                                                                    * temp, irow, i);
329                                                            }
330                                                    }
331                                            }
332                                    }
333                            };
334                    } else {
335                            for (int i = 0; i < m2ColumnCount; i++) {
336                                    for (int irow = 0; irow < m1RowCount; ++irow) {
337                                            C.setAsDouble(0.0d, irow, i);
338                                    }
339                                    for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
340                                            final double temp = B.getAsDouble(lcol, i);
341                                            if (temp != 0.0d) {
342                                                    for (int irow = 0; irow < m1RowCount; ++irow) {
343                                                            C.setAsDouble(
344                                                                            C.getAsDouble(irow, i) + A.getAsDouble(irow, lcol) * temp,
345                                                                            irow, i);
346                                                    }
347                                            }
348                                    }
349                            }
350                    }
351            }
352    };
353    
354    /**
355     * Contains matrix multiplication methods for different matrix implementations
356     * 
357     * @author Holger Arndt
358     * @author Frode Carlsen
359     * 
360     */
361    class MtimesDenseDoubleMatrix2D implements
362                    MtimesCalculation<DenseDoubleMatrix2D, DenseDoubleMatrix2D, DenseDoubleMatrix2D> {
363    
364            public final void calc(final DenseDoubleMatrix2D source1, final DenseDoubleMatrix2D source2,
365                            final DenseDoubleMatrix2D target) {
366                    assertTrue(source1 != null, "a == null");
367                    assertTrue(source2 != null, "b == null");
368                    assertTrue(target != null, "c == null");
369                    assertTrue(source1.getColumnCount() == source2.getRowCount(), "a.cols!=b.rows");
370                    assertTrue(source1.getRowCount() == target.getRowCount(), "a.rows!=c.rows");
371                    assertTrue(source2.getColumnCount() == target.getColumnCount(), "a.cols!=c.cols");
372                    if (source1.getRowCount() >= Mtimes.THRESHOLD
373                                    && source1.getColumnCount() >= Mtimes.THRESHOLD
374                                    && source2.getColumnCount() >= Mtimes.THRESHOLD) {
375                            if (Mtimes.MTIMES_JBLAS != null && UJMPSettings.isUseJBlas()) {
376                                    Mtimes.MTIMES_JBLAS.calc((DenseDoubleMatrix2D) source1,
377                                                    (DenseDoubleMatrix2D) source2, (DenseDoubleMatrix2D) target);
378                            } else if (UJMPSettings.isUseBlockMatrixMultiply()) {
379                                    calcBlockMatrixMultiThreaded(source1, source2, target);
380                            } else if (source1 instanceof HasColumnMajorDoubleArray1D
381                                            && source2 instanceof HasColumnMajorDoubleArray1D
382                                            && target instanceof HasColumnMajorDoubleArray1D) {
383                                    calcDoubleArrayMultiThreaded(((HasColumnMajorDoubleArray1D) source1)
384                                                    .getColumnMajorDoubleArray1D(), (int) source1.getRowCount(), (int) source1
385                                                    .getColumnCount(), ((HasColumnMajorDoubleArray1D) source2)
386                                                    .getColumnMajorDoubleArray1D(), (int) source2.getRowCount(), (int) source2
387                                                    .getColumnCount(), ((HasColumnMajorDoubleArray1D) target)
388                                                    .getColumnMajorDoubleArray1D());
389                            } else if (source1 instanceof HasRowMajorDoubleArray2D
390                                            && source2 instanceof HasRowMajorDoubleArray2D
391                                            && target instanceof HasRowMajorDoubleArray2D) {
392                                    calcDoubleArray2DMultiThreaded(((HasRowMajorDoubleArray2D) source1)
393                                                    .getRowMajorDoubleArray2D(), ((HasRowMajorDoubleArray2D) source2)
394                                                    .getRowMajorDoubleArray2D(), ((HasRowMajorDoubleArray2D) target)
395                                                    .getRowMajorDoubleArray2D());
396                            } else {
397                                    calcDenseDoubleMatrix2DMultiThreaded(source1, source2, target);
398                            }
399                    } else {
400                            if (source1 instanceof HasColumnMajorDoubleArray1D
401                                            && source2 instanceof HasColumnMajorDoubleArray1D
402                                            && target instanceof HasColumnMajorDoubleArray1D) {
403                                    gemmDoubleArraySingleThreaded(((HasColumnMajorDoubleArray1D) source1)
404                                                    .getColumnMajorDoubleArray1D(), (int) source1.getRowCount(), (int) source1
405                                                    .getColumnCount(), ((HasColumnMajorDoubleArray1D) source2)
406                                                    .getColumnMajorDoubleArray1D(), (int) source2.getRowCount(), (int) source2
407                                                    .getColumnCount(), ((HasColumnMajorDoubleArray1D) target)
408                                                    .getColumnMajorDoubleArray1D());
409                            } else if (source1 instanceof HasRowMajorDoubleArray2D
410                                            && source2 instanceof HasRowMajorDoubleArray2D
411                                            && target instanceof HasRowMajorDoubleArray2D) {
412                                    calcDoubleArray2DSingleThreaded(((HasRowMajorDoubleArray2D) source1)
413                                                    .getRowMajorDoubleArray2D(), ((HasRowMajorDoubleArray2D) source2)
414                                                    .getRowMajorDoubleArray2D(), ((HasRowMajorDoubleArray2D) target)
415                                                    .getRowMajorDoubleArray2D());
416                            } else {
417                                    calcDenseDoubleMatrix2DSingleThreaded(source1, source2, target);
418                            }
419                    }
420            }
421    
422            private void calcBlockMatrixMultiThreaded(DenseDoubleMatrix2D source1,
423                            DenseDoubleMatrix2D source2, DenseDoubleMatrix2D target) {
424                    BlockDenseDoubleMatrix2D a = null;
425                    BlockDenseDoubleMatrix2D b = null;
426                    BlockDenseDoubleMatrix2D c = null;
427                    if (source1 instanceof BlockDenseDoubleMatrix2D) {
428                            a = (BlockDenseDoubleMatrix2D) source1;
429                    } else {
430                            a = new BlockDenseDoubleMatrix2D(source1);
431                    }
432                    if (source2 instanceof BlockDenseDoubleMatrix2D
433                                    && a.getBlockStripeSize() == ((BlockDenseDoubleMatrix2D) source2)
434                                                    .getBlockStripeSize()) {
435                            b = (BlockDenseDoubleMatrix2D) source2;
436                    } else {
437                            b = new BlockDenseDoubleMatrix2D(source2, a.getBlockStripeSize(),
438                                            BlockOrder.COLUMNMAJOR);
439                    }
440                    final int arows = (int) a.getRowCount();
441                    final int bcols = (int) b.getColumnCount();
442                    if (target instanceof BlockDenseDoubleMatrix2D
443                                    && a.getBlockStripeSize() == ((BlockDenseDoubleMatrix2D) target)
444                                                    .getBlockStripeSize()) {
445                            c = (BlockDenseDoubleMatrix2D) target;
446                    } else {
447                            c = new BlockDenseDoubleMatrix2D(arows, bcols, a.getBlockStripeSize(),
448                                            BlockOrder.ROWMAJOR);
449                    }
450    
451                    // force optimal block order
452                    BlockOrder prevA = a.setBlockOrder(BlockOrder.ROWMAJOR);
453                    BlockOrder prevB = b.setBlockOrder(BlockOrder.COLUMNMAJOR);
454    
455                    blockMultiplyMultiThreaded(a, b, c);
456    
457                    if (c != target) {
458                            for (int j = bcols; --j != -1;) {
459                                    for (int i = arows; --i != -1;) {
460                                            target.setDouble(c.getDouble(i, j), i, j);
461                                    }
462                            }
463                    }
464    
465                    // reset block order
466                    if (Mtimes.RESET_BLOCK_ORDER) {
467                            a.setBlockOrder(prevA);
468                            b.setBlockOrder(prevB);
469                    }
470            }
471    
472            private final void gemmDoubleArraySingleThreaded(final double[] A, final int m1RowCount,
473                            final int m1ColumnCount, final double[] B, final int m2RowCount,
474                            final int m2ColumnCount, final double[] C) {
475    
476                    for (int j = 0; j < m2ColumnCount; j++) {
477                            final int jcolTimesM1RowCount = j * m1RowCount;
478                            final int jcolTimesM1ColumnCount = j * m1ColumnCount;
479                            Arrays.fill(C, jcolTimesM1RowCount, jcolTimesM1RowCount + m1RowCount, 0.0d);
480                            for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
481                                    final double temp = B[lcol + jcolTimesM1ColumnCount];
482                                    if (temp != 0.0d) {
483                                            final int lcolTimesM1RowCount = lcol * m1RowCount;
484                                            for (int irow = 0; irow < m1RowCount; ++irow) {
485                                                    C[irow + jcolTimesM1RowCount] += A[irow + lcolTimesM1RowCount] * temp;
486                                            }
487                                    }
488                            }
489                    }
490            }
491    
492            private final void calcDoubleArrayMultiThreaded(final double[] A, final int m1RowCount,
493                            final int m1ColumnCount, final double[] B, final int m2RowCount,
494                            final int m2ColumnCount, final double[] C) {
495                    new PFor(0, m2ColumnCount - 1) {
496                            @Override
497                            public void step(int i) {
498                                    final int jcolTimesM1RowCount = i * m1RowCount;
499                                    final int jcolTimesM1ColumnCount = i * m1ColumnCount;
500                                    Arrays.fill(C, jcolTimesM1RowCount, jcolTimesM1RowCount + m1RowCount, 0.0d);
501                                    for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
502                                            final double temp = B[lcol + jcolTimesM1ColumnCount];
503                                            if (temp != 0.0d) {
504                                                    final int lcolTimesM1RowCount = lcol * m1RowCount;
505                                                    for (int irow = 0; irow < m1RowCount; ++irow) {
506                                                            C[irow + jcolTimesM1RowCount] += A[irow + lcolTimesM1RowCount] * temp;
507                                                    }
508                                            }
509                                    }
510                            }
511                    };
512            }
513    
514            private final void calcDoubleArray2DSingleThreaded(final double[][] m1, final double[][] m2,
515                            final double[][] ret) {
516                    final int columnCount = m1[0].length;
517                    final double[] columns = new double[columnCount];
518    
519                    for (int c = m2[0].length; --c != -1;) {
520                            for (int k = columnCount; --k != -1;) {
521                                    columns[k] = m2[k][c];
522                            }
523                            for (int r = m1.length; --r != -1;) {
524                                    double sum = 0.0d;
525                                    final double[] row = m1[r];
526                                    for (int k = columnCount; --k != -1;) {
527                                            sum += row[k] * columns[k];
528                                    }
529                                    ret[r][c] = sum;
530                            }
531                    }
532            }
533    
534            private final void calcDoubleArray2DMultiThreaded(final double[][] m1, final double[][] m2,
535                            final double[][] ret) {
536                    final int columnCount = m1[0].length;
537                    final double[] columns = new double[columnCount];
538    
539                    new PFor(0, m2[0].length - 1) {
540                            @Override
541                            public void step(int i) {
542                                    for (int k = columnCount; --k != -1;) {
543                                            columns[k] = m2[k][i];
544                                    }
545                                    for (int r = m1.length; --r != -1;) {
546                                            double sum = 0.0d;
547                                            final double[] row = m1[r];
548                                            for (int k = columnCount; --k != -1;) {
549                                                    sum += row[k] * columns[k];
550                                            }
551                                            ret[r][i] = sum;
552                                    }
553                            }
554                    };
555            }
556    
557            private final void calcDenseDoubleMatrix2DSingleThreaded(final DenseDoubleMatrix2D A,
558                            final DenseDoubleMatrix2D B, final DenseDoubleMatrix2D C) {
559                    final int m1RowCount = (int) A.getRowCount();
560                    final int m1ColumnCount = (int) A.getColumnCount();
561                    final int m2ColumnCount = (int) B.getColumnCount();
562    
563                    for (int i = 0; i < m2ColumnCount; i++) {
564                            for (int irow = 0; irow < m1RowCount; ++irow) {
565                                    C.setDouble(0.0d, irow, i);
566                            }
567                            for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
568                                    final double temp = B.getDouble(lcol, i);
569                                    if (temp != 0.0d) {
570                                            for (int irow = 0; irow < m1RowCount; ++irow) {
571                                                    C.setDouble(C.getDouble(irow, i) + A.getDouble(irow, lcol) * temp, irow, i);
572                                            }
573                                    }
574                            }
575                    }
576            }
577    
578            private final void calcDenseDoubleMatrix2DMultiThreaded(final DenseDoubleMatrix2D A,
579                            final DenseDoubleMatrix2D B, final DenseDoubleMatrix2D C) {
580                    final int m1RowCount = (int) A.getRowCount();
581                    final int m1ColumnCount = (int) A.getColumnCount();
582                    final int m2ColumnCount = (int) B.getColumnCount();
583    
584                    new PFor(0, m2ColumnCount - 1) {
585                            @Override
586                            public void step(int i) {
587                                    for (int irow = 0; irow < m1RowCount; ++irow) {
588                                            C.setDouble(0.0d, irow, i);
589                                    }
590                                    for (int lcol = 0; lcol < m1ColumnCount; ++lcol) {
591                                            final double temp = B.getDouble(lcol, i);
592                                            if (temp != 0.0d) {
593                                                    for (int irow = 0; irow < m1RowCount; ++irow) {
594                                                            C.setDouble(C.getDouble(irow, i) + A.getDouble(irow, lcol) * temp,
595                                                                            irow, i);
596                                                    }
597                                            }
598                                    }
599                            }
600                    };
601            }
602    
603            /**
604             * Multiply two matrices concurrently with the given Executor to handle
605             * parallel tasks.
606             * 
607             * @param b
608             *            - matrix to multiply this with.
609             * @param executorService
610             *            - to handle concurrent multiplication tasks.
611             * @return new matrix C containing result of matrix multiplication C = A x
612             *         B.
613             */
614            /**
615             * @param a
616             * @param b
617             * @param c
618             * @return
619             */
620            private BlockDenseDoubleMatrix2D blockMultiplyMultiThreaded(final BlockDenseDoubleMatrix2D a,
621                            final BlockDenseDoubleMatrix2D b, final BlockDenseDoubleMatrix2D c) {
622                    final BlockMatrixLayout al = a.getBlockLayout();
623                    final BlockMatrixLayout bl = b.getBlockLayout();
624                    assertTrue(al.columns == bl.rows, "b.rows != this.columns");
625                    assertTrue(al.blockStripe == bl.blockStripe, "block sizes differ: %s != %s",
626                                    al.blockStripe, bl.blockStripe);
627    
628                    final List<Callable<Void>> tasks = new LinkedList<Callable<Void>>();
629    
630                    final int kMax = (int) b.getColumnCount();
631                    final int jMax = (int) a.getColumnCount();
632                    final int iMax = (int) a.getRowCount();
633    
634                    final int bColSlice = Math.min(al.blockStripe, kMax);
635                    final int aColSlice = Math.min(al.blockStripe, jMax);
636                    final int aRowSlice = Math.min(al.blockStripe, iMax);
637    
638                    // Number of blocks to take for each concurrent task.
639                    final int blocksPerTask = 1;
640                    final int blocksPerTaskDimJ = selectBlocksPerTaskDimJ(al.blockStripe, iMax, jMax, kMax);
641    
642                    for (int k = 0, kStride; k < kMax; k += kStride) {
643                            kStride = Math.min(blocksPerTask * bColSlice, kMax - k);
644    
645                            for (int j = 0, jStride; j < jMax; j += jStride) {
646                                    jStride = Math.min(blocksPerTaskDimJ * aColSlice, jMax - j);
647    
648                                    for (int i = 0, iStride; i < iMax; i += iStride) {
649                                            iStride = Math.min(blocksPerTask * aRowSlice, iMax - i);
650    
651                                            tasks.add(new BlockMultiply(a, b, c, i, (i + iStride), j, (j + jStride), k,
652                                                            (k + kStride)));
653                                    }
654    
655                            }
656                    }
657    
658                    // wait for all tasks to complete.
659                    try {
660                            for (Future<Void> f : UJMPThreadPoolExecutor.getInstance().invokeAll(tasks)) {
661                                    f.get();
662                            }
663                    } catch (ExecutionException e) {
664                            StringBuilder sb = new StringBuilder(
665                                            "Execution exception - while awaiting completion of matrix multiplication ["
666                                                            + e.getMessage() + "]:");
667                            if (e.getCause() != null) {
668                                    for (StackTraceElement stackTraceElement : e.getCause().getStackTrace()) {
669                                            sb.append(stackTraceElement).append("  *  ");
670                                    }
671                            }
672                            throw new RuntimeException(sb.toString(), e.getCause());
673                    } catch (final InterruptedException e) {
674                            String msg = "Interrupted - while awaiting completion of matrix multiplication.";
675                            throw new RuntimeException(msg + ": cause [" + e.getMessage() + "]", e);
676                    }
677    
678                    return c;
679            }
680    
681            // pick a suitable number of blocks to process per task for dimension J
682            // - if too small , then incurs extra gc and contention for synchronization
683            // - if set too large, then may not fully exploit parallelism
684            private int selectBlocksPerTaskDimJ(int blockStripe, int iMax, int jMax, int kMax) {
685                    int adjust = (jMax % blockStripe > 0) ? 1 : 0;
686                    if (jMax < (5 * blockStripe) || jMax <= iMax) {
687                            // do not break this dimension into parallel tasks
688                            return jMax / blockStripe + adjust;
689                    } else {
690                            // assume 2 parallell tasks
691                            return Math.max(1, (jMax / blockStripe + adjust) / 2);
692                    }
693                    // may need something if jMax >>> iMax
694            }
695    };