001    /*
002     * Copyright (C) 2008-2010 by Holger Arndt
003     *
004     * This file is part of the Universal Java Matrix Package (UJMP).
005     * See the NOTICE file distributed with this work for additional
006     * information regarding copyright ownership and licensing.
007     *
008     * UJMP is free software; you can redistribute it and/or modify
009     * it under the terms of the GNU Lesser General Public License as
010     * published by the Free Software Foundation; either version 2
011     * of the License, or (at your option) any later version.
012     *
013     * UJMP is distributed in the hope that it will be useful,
014     * but WITHOUT ANY WARRANTY; without even the implied warranty of
015     * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
016     * GNU Lesser General Public License for more details.
017     *
018     * You should have received a copy of the GNU Lesser General Public
019     * License along with UJMP; if not, write to the
020     * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
021     * Boston, MA  02110-1301  USA
022     */
023    
024    package org.ujmp.core.calculation;
025    
026    import java.math.BigDecimal;
027    
028    import org.ujmp.core.Matrix;
029    import org.ujmp.core.doublematrix.DenseDoubleMatrix2D;
030    import org.ujmp.core.interfaces.HasColumnMajorDoubleArray1D;
031    import org.ujmp.core.interfaces.HasRowMajorDoubleArray2D;
032    import org.ujmp.core.matrix.DenseMatrix;
033    import org.ujmp.core.matrix.DenseMatrix2D;
034    import org.ujmp.core.matrix.SparseMatrix;
035    import org.ujmp.core.util.MathUtil;
036    import org.ujmp.core.util.UJMPSettings;
037    import org.ujmp.core.util.VerifyUtil;
038    import org.ujmp.core.util.concurrent.PForEquidistant;
039    
040    public class PlusMatrix {
041            public static final PlusMatrixCalculation<Matrix, Matrix, Matrix> MATRIX = new PlusMatrixMatrix();
042    
043            public static final PlusMatrixCalculation<DenseMatrix, DenseMatrix, DenseMatrix> DENSEMATRIX = new PlusMatrixDenseMatrix();
044    
045            public static final PlusMatrixCalculation<DenseMatrix2D, DenseMatrix2D, DenseMatrix2D> DENSEMATRIX2D = new PlusMatrixDenseMatrix2D();
046    
047            public static final PlusMatrixCalculation<DenseDoubleMatrix2D, DenseDoubleMatrix2D, DenseDoubleMatrix2D> DENSEDOUBLEMATRIX2D = new PlusMatrixDenseDoubleMatrix2D();
048    
049            public static final PlusMatrixCalculation<SparseMatrix, SparseMatrix, SparseMatrix> SPARSEMATRIX = new PlusMatrixSparseMatrix();
050    }
051    
052    class PlusMatrixMatrix implements PlusMatrixCalculation<Matrix, Matrix, Matrix> {
053    
054            public final void calc(final Matrix source1, final Matrix source2, final Matrix target) {
055                    if (source1 instanceof DenseDoubleMatrix2D && source2 instanceof DenseDoubleMatrix2D
056                                    && target instanceof DenseDoubleMatrix2D) {
057                            PlusMatrix.DENSEDOUBLEMATRIX2D.calc((DenseDoubleMatrix2D) source1,
058                                            (DenseDoubleMatrix2D) source2, (DenseDoubleMatrix2D) target);
059                    } else if (source1 instanceof DenseMatrix2D && source2 instanceof DenseMatrix2D
060                                    && target instanceof DenseMatrix2D) {
061                            PlusMatrix.DENSEMATRIX2D.calc((DenseMatrix2D) source1, (DenseMatrix2D) source2,
062                                            (DenseMatrix2D) target);
063                    } else if (source1 instanceof DenseMatrix && source2 instanceof DenseMatrix
064                                    && target instanceof DenseMatrix) {
065                            PlusMatrix.DENSEMATRIX.calc((DenseMatrix) source1, (DenseMatrix) source2,
066                                            (DenseMatrix) target);
067                    } else if (source1 instanceof SparseMatrix && source2 instanceof SparseMatrix
068                                    && target instanceof SparseMatrix) {
069                            PlusMatrix.SPARSEMATRIX.calc((SparseMatrix) source1, (SparseMatrix) source2,
070                                            (SparseMatrix) target);
071                    } else {
072                            VerifyUtil.assertSameSize(source1, source2, target);
073                            for (long[] c : source1.allCoordinates()) {
074                                    BigDecimal v1 = source1.getAsBigDecimal(c);
075                                    BigDecimal v2 = source2.getAsBigDecimal(c);
076                                    BigDecimal result = MathUtil.plus(v1, v2);
077                                    target.setAsBigDecimal(result, c);
078                            }
079                    }
080            }
081    };
082    
083    class PlusMatrixDenseMatrix implements PlusMatrixCalculation<DenseMatrix, DenseMatrix, DenseMatrix> {
084    
085            public final void calc(final DenseMatrix source1, final DenseMatrix source2,
086                            final DenseMatrix target) {
087                    if (source1 instanceof DenseMatrix2D && source2 instanceof DenseMatrix2D
088                                    && target instanceof DenseMatrix2D) {
089                            PlusMatrix.DENSEMATRIX2D.calc((DenseMatrix2D) source1, (DenseMatrix2D) source2,
090                                            (DenseMatrix2D) target);
091                    } else {
092                            VerifyUtil.assertSameSize(source1, source2, target);
093                            for (long[] c : source1.allCoordinates()) {
094                                    BigDecimal v1 = source1.getAsBigDecimal(c);
095                                    BigDecimal v2 = source2.getAsBigDecimal(c);
096                                    BigDecimal result = MathUtil.plus(v1, v2);
097                                    target.setAsBigDecimal(result, c);
098                            }
099                    }
100            }
101    };
102    
103    class PlusMatrixSparseMatrix implements
104                    PlusMatrixCalculation<SparseMatrix, SparseMatrix, SparseMatrix> {
105    
106            public final void calc(final SparseMatrix source1, final SparseMatrix source2,
107                            final SparseMatrix target) {
108                    VerifyUtil.assertSameSize(source1, source2, target);
109                    // copy all elements in source1 to target
110                    for (long[] c : source1.availableCoordinates()) {
111                            BigDecimal svalue = source1.getAsBigDecimal(c);
112                            target.setAsBigDecimal(svalue, c);
113                    }
114                    // calculate sum with source2
115                    for (long[] c : source2.availableCoordinates()) {
116                            BigDecimal v1 = target.getAsBigDecimal(c);
117                            BigDecimal v2 = source2.getAsBigDecimal(c);
118                            BigDecimal result = MathUtil.plus(v1, v2);
119                            target.setAsBigDecimal(result, c);
120                    }
121            }
122    
123    };
124    
125    class PlusMatrixDenseMatrix2D implements
126                    PlusMatrixCalculation<DenseMatrix2D, DenseMatrix2D, DenseMatrix2D> {
127    
128            public final void calc(final DenseMatrix2D source1, final DenseMatrix2D source2,
129                            final DenseMatrix2D target) {
130                    if (source1 instanceof DenseDoubleMatrix2D && source2 instanceof DenseDoubleMatrix2D
131                                    && target instanceof DenseDoubleMatrix2D) {
132                            PlusMatrix.DENSEDOUBLEMATRIX2D.calc((DenseDoubleMatrix2D) source1,
133                                            (DenseDoubleMatrix2D) source2, (DenseDoubleMatrix2D) target);
134                    } else {
135                            VerifyUtil.assertSameSize(source1, source2, target);
136                            for (int r = (int) source1.getRowCount(); --r != -1;) {
137                                    for (int c = (int) source1.getColumnCount(); --c != -1;) {
138                                            BigDecimal v1 = source1.getAsBigDecimal(r, c);
139                                            BigDecimal v2 = source2.getAsBigDecimal(r, c);
140                                            BigDecimal result = MathUtil.plus(v1, v2);
141                                            target.setAsBigDecimal(result, r, c);
142                                    }
143                            }
144                    }
145            }
146    };
147    
148    class PlusMatrixDenseDoubleMatrix2D implements
149                    PlusMatrixCalculation<DenseDoubleMatrix2D, DenseDoubleMatrix2D, DenseDoubleMatrix2D> {
150    
151            public final void calc(final DenseDoubleMatrix2D source1, final DenseDoubleMatrix2D source2,
152                            final DenseDoubleMatrix2D target) {
153                    if (source1 instanceof HasColumnMajorDoubleArray1D
154                                    && source2 instanceof HasColumnMajorDoubleArray1D
155                                    && target instanceof HasColumnMajorDoubleArray1D) {
156                            calc(((HasColumnMajorDoubleArray1D) source1).getColumnMajorDoubleArray1D(),
157                                            ((HasColumnMajorDoubleArray1D) source2).getColumnMajorDoubleArray1D(),
158                                            ((HasColumnMajorDoubleArray1D) target).getColumnMajorDoubleArray1D());
159                    } else if (source1 instanceof HasRowMajorDoubleArray2D
160                                    && source2 instanceof HasRowMajorDoubleArray2D
161                                    && target instanceof HasRowMajorDoubleArray2D) {
162                            calc(((HasRowMajorDoubleArray2D) source1).getRowMajorDoubleArray2D(),
163                                            ((HasRowMajorDoubleArray2D) source2).getRowMajorDoubleArray2D(),
164                                            ((HasRowMajorDoubleArray2D) target).getRowMajorDoubleArray2D());
165                    } else {
166                            VerifyUtil.assertSameSize(source1, source2, target);
167                            for (int r = (int) source1.getRowCount(); --r != -1;) {
168                                    for (int c = (int) source1.getColumnCount(); --c != -1;) {
169                                            target.setDouble(source1.getDouble(r, c) + source2.getDouble(r, c), r, c);
170                                    }
171                            }
172                    }
173            }
174    
175            private final void calc(final double[][] source1, final double[][] source2,
176                            final double[][] target) {
177                    VerifyUtil.assertSameSize(source1, source2, target);
178                    final int rows = source1.length;
179                    final int cols = source1[0].length;
180                    if (UJMPSettings.getNumberOfThreads() > 1 && rows >= 100 && cols >= 100) {
181                            new PForEquidistant(0, rows - 1) {
182                                    public void step(int i) {
183                                            double[] v1 = source1[i];
184                                            double[] v2 = source2[i];
185                                            double[] t = target[i];
186                                            for (int c = 0; c < cols; c++) {
187                                                    t[c] = v1[c] + v2[c];
188                                            }
189                                    }
190                            };
191                    } else {
192                            double[] v1 = null;
193                            double[] v2 = null;
194                            double[] t = null;
195                            for (int r = 0; r < rows; r++) {
196                                    v1 = source1[r];
197                                    v2 = source2[r];
198                                    t = target[r];
199                                    for (int c = 0; c < cols; c++) {
200                                            t[c] = v1[c] + v2[c];
201                                    }
202                            }
203                    }
204            }
205    
206            private final void calc(final double[] source1, final double[] source2, final double[] target) {
207                    VerifyUtil.assertSameSize(source1, source2, target);
208                    final int length = source1.length;
209                    for (int i = 0; i < length; i++) {
210                            target[i] = source1[i] + source2[i];
211                    }
212            }
213    
214    };