001    /*
002     * Copyright (C) 2010 by Frode Carlsen
003     *
004     * This file is part of the Universal Java Matrix Package (UJMP).
005     * See the NOTICE file distributed with this work for additional
006     * information regarding copyright ownership and licensing.
007     *
008     * UJMP is free software; you can redistribute it and/or modify
009     * it under the terms of the GNU Lesser General Public License as
010     * published by the Free Software Foundation; either version 2
011     * of the License, or (at your option) any later version.
012     *
013     * UJMP is distributed in the hope that it will be useful,
014     * but WITHOUT ANY WARRANTY; without even the implied warranty of
015     * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
016     * GNU Lesser General Public License for more details.
017     *
018     * You should have received a copy of the GNU Lesser General Public
019     * License along with UJMP; if not, write to the
020     * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
021     * Boston, MA  02110-1301  USA
022     */
023    package org.ujmp.core.doublematrix.impl;
024    
025    import static org.ujmp.core.util.VerifyUtil.assertTrue;
026    
027    import java.util.concurrent.Callable;
028    
029    /**
030     * Multiply blocks of A and B in the specified range(fromM->toM, fromN->toN,
031     * fromK->toK), <br>
032     * and add to matrix C.
033     * <p>
034     * <code>
035     * C(fromM->toM, fromK->toK) += <br>&nbsp&nbsp&nbsp;&nbsp&nbsp&nbsp;
036     * A(fromM->toM, fromN->toN) x B(fromN->toN, fromK->toK)
037     * </code>
038     * <p>
039     * All blocks must be square blocks of the same size, with length of one side =
040     * {@link #blockStripeSize}
041     * 
042     * @author Frode Carlsen
043     */
044    public class BlockMultiply implements Callable<Void> {
045    
046            /** Length of one side of a block of data. */
047            private final int blockStripeSize;
048    
049            /** range of data in matrix to be processed by this instance. */
050            private final int fromM, toM, fromN, toN, fromK, toK;
051    
052            /** Source matrices to be processed. */
053            private final BlockDenseDoubleMatrix2D matrixA, matrixB, matrixC;
054    
055            /**
056             * Constructor taking the two matrices being multiplied, the target matrix C
057             * and the range of rows and columns to multiply.
058             * 
059             * @param a
060             *            - matrix A, size (M, N)
061             * @param b
062             *            - matrix B, size (N, K)
063             * @param c
064             *            - result matrix C, size (M, K)
065             * @param fromM
066             *            - start row M in matrix A
067             * @param toM
068             *            - end row M in A
069             * @param fromN
070             *            - start column N in A (or start row N in B)
071             * @param toN
072             *            - end row N
073             * @param fromK
074             *            - start column K in B
075             * @param toK
076             *            - end column K in B
077             */
078            public BlockMultiply(final BlockDenseDoubleMatrix2D a, final BlockDenseDoubleMatrix2D b,
079                            final BlockDenseDoubleMatrix2D c, final int fromM, final int toM, final int fromN,
080                            final int toN, final int fromK, final int toK) {
081                    super();
082    
083                    verifyInput(a, b, c, fromM, toM, fromN, toN, fromK, toK);
084    
085                    this.matrixA = a;
086                    this.matrixB = b;
087                    this.matrixC = c;
088                    this.fromM = fromM;
089                    this.toM = toM;
090                    this.fromN = fromN;
091                    this.toN = toN;
092                    this.fromK = fromK;
093                    this.toK = toK;
094    
095                    this.blockStripeSize = a.layout.blockStripe;
096            }
097    
098            public Void call() {
099                    multiply();
100                    return null;
101            }
102    
103            /**
104             * Multiply blocks of two matrices A,B and add to C.
105             * <p>
106             * Blocks of Matrix B are transformed to column-major layout (if not
107             * already) to facilitate multiplication.<br>
108             * (If matrices have been created optimally, B should already be
109             * column-major)
110             */
111            protected final void multiply() {
112                    final int step = blockStripeSize, blockSize = blockStripeSize * blockStripeSize;
113    
114                    for (int m = fromM; m < toM; m += step) {
115                            final int aRows = matrixA.layout.getRowsInBlock(m);
116    
117                            for (int k = fromK; k < toK; k += step) {
118                                    final int bCols = matrixB.layout.getColumnsInBlock(k);
119    
120                                    final double[] cBlock = new double[aRows * bCols];
121    
122                                    for (int n = fromN; n < toN; n += step) {
123    
124                                            // ensure a and b are in optimal block order before
125                                            // multiplication
126                                            final double[] aBlock = matrixA.layout.toRowMajorBlock(matrixA, m, n);
127                                            final double[] bBlock = matrixB.layout.toColMajorBlock(matrixB, n, k);
128    
129                                            if (aBlock != null && bBlock != null) {
130                                                    if (aBlock.length == blockSize && bBlock.length == blockSize) {
131                                                            multiplyAxB(aBlock, bBlock, cBlock, step);
132                                                    } else {
133                                                            int aCols = aBlock.length / aRows;
134                                                            int bRows = bBlock.length / bCols;
135                                                            assertTrue(aCols == bRows, "aCols!=bRows");
136                                                            multiplyRowMajorTimesColumnMajorBlocks(aBlock, bBlock, cBlock, aRows,
137                                                                            aCols, bCols);
138                                                    }
139                                            }
140                                    }
141    
142                                    matrixC.addBlockData(m, k, cBlock);
143                            }
144                    }
145            }
146    
147            /**
148             * Multiply row-major block (a) x column-major block (b), and add to block
149             * c.
150             * 
151             * @param a
152             *            - block from {@link #matrixA}
153             * @param b
154             *            - block from {@link #matrixB}
155             * @param c
156             *            - block from result matrix {@link #matrixC}
157             */
158            private static void multiplyAxB(final double[] aBlock, final double[] bBlock,
159                            final double[] cBlock, final int step) {
160                    final int blockStripeMini = step % 3;
161                    final int blockStripeMaxi = step / 3;
162                    final int blockArea = step * step;
163    
164                    for (int iL = 0; iL < blockArea; iL += step) {
165                            int rc = iL;
166    
167                            for (int kL = 0; kL < blockArea; kL += step) {
168                                    int ra = iL;
169                                    int rb = kL;
170                                    double sum = 0.0d;
171    
172                                    for (int jL = blockStripeMini; --jL >= 0;) {
173                                            sum += aBlock[ra++] * bBlock[rb++];
174                                    }
175    
176                                    // loop unrolling
177                                    for (int jL = blockStripeMaxi; --jL >= 0;) {
178                                            sum += aBlock[ra++] * bBlock[rb++] //
179                                                            + aBlock[ra++] * bBlock[rb++] //
180                                                            + aBlock[ra++] * bBlock[rb++];
181                                    }
182    
183                                    cBlock[rc++] += sum;
184                            }
185                    }
186            }
187    
188            public void multiplyRowMajorTimesColumnMajorBlocks(double[] aBlock, double[] bBlock,
189                            double[] cBlock, int aRows, int bRows, int bCols) {
190                    final int aCols = bRows;
191    
192                    for (int i = 0; i < aRows; i++) {
193                            for (int k = 0; k < bCols; k++) {
194                                    double sum = 0.0d;
195                                    for (int j = 0; j < bRows; j++) {
196                                            sum += aBlock[i * aCols + j] * bBlock[k * bRows + j];
197                                    }
198                                    cBlock[i * bCols + k] += sum;
199                            }
200                    }
201            }
202    
203            private static void verifyInput(final BlockDenseDoubleMatrix2D a,
204                            final BlockDenseDoubleMatrix2D b, final BlockDenseDoubleMatrix2D c, final int fromM,
205                            final int toM, final int fromN, final int toN, final int fromK, final int toK) {
206                    assertTrue(a != null, "a cannot be null");
207                    assertTrue(b != null, "b cannot be null");
208                    assertTrue(c != null, "c cannot be null");
209                    assertTrue(fromM <= a.getRowCount() && fromM >= 0, "Invalid argument : fromM");
210                    assertTrue(toM <= a.getRowCount() && toM >= fromM, "Invalid argument : fromM/toM");
211                    assertTrue(fromN <= a.getColumnCount() && fromN >= 0, "Invalid argument : fromN");
212                    assertTrue(toN <= a.getColumnCount() && toN >= fromN, "Invalid argument : fromN/toN");
213                    assertTrue(fromK <= b.getColumnCount() && fromK >= 0, "Invalid argument : fromK");
214                    assertTrue(toK <= b.getColumnCount() && toK >= fromK, "Invalid argument : fromK/toK");
215                    assertTrue(a.getColumnCount() == b.getRowCount(), "Invalid argument : a.columns != b.rows");
216                    assertTrue(a.getRowCount() == c.getRowCount(), "Invalid argument : a.rows != c.rows");
217                    assertTrue(b.getColumnCount() == c.getColumnCount(),
218                                    "Invalid argument : b.columns != c.columns");
219            }
220    
221    }