001    /*
002     * Copyright (C) 2008-2010 by Holger Arndt
003     *
004     * This file is part of the Universal Java Matrix Package (UJMP).
005     * See the NOTICE file distributed with this work for additional
006     * information regarding copyright ownership and licensing.
007     *
008     * UJMP is free software; you can redistribute it and/or modify
009     * it under the terms of the GNU Lesser General Public License as
010     * published by the Free Software Foundation; either version 2
011     * of the License, or (at your option) any later version.
012     *
013     * UJMP is distributed in the hope that it will be useful,
014     * but WITHOUT ANY WARRANTY; without even the implied warranty of
015     * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
016     * GNU Lesser General Public License for more details.
017     *
018     * You should have received a copy of the GNU Lesser General Public
019     * License along with UJMP; if not, write to the
020     * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
021     * Boston, MA  02110-1301  USA
022     */
023    
024    package org.ujmp.parallelcolt;
025    
026    import org.ujmp.core.Matrix;
027    import org.ujmp.core.doublematrix.DenseDoubleMatrix2D;
028    import org.ujmp.core.doublematrix.stub.AbstractDenseDoubleMatrix2D;
029    import org.ujmp.core.exceptions.MatrixException;
030    import org.ujmp.core.interfaces.HasColumnMajorDoubleArray1D;
031    import org.ujmp.core.interfaces.HasRowMajorDoubleArray2D;
032    import org.ujmp.core.interfaces.Wrapper;
033    import org.ujmp.parallelcolt.calculation.Solve;
034    
035    import cern.colt.matrix.tdouble.DoubleFactory2D;
036    import cern.colt.matrix.tdouble.DoubleMatrix2D;
037    import cern.colt.matrix.tdouble.algo.DenseDoubleAlgebra;
038    import cern.colt.matrix.tdouble.algo.decomposition.DenseDoubleCholeskyDecomposition;
039    import cern.colt.matrix.tdouble.algo.decomposition.DenseDoubleEigenvalueDecomposition;
040    import cern.colt.matrix.tdouble.algo.decomposition.DenseDoubleLUDecomposition;
041    import cern.colt.matrix.tdouble.algo.decomposition.DenseDoubleQRDecomposition;
042    import cern.colt.matrix.tdouble.algo.decomposition.DenseDoubleSingularValueDecomposition;
043    import cern.jet.math.tdouble.DoubleFunctions;
044    
045    public class ParallelColtDenseDoubleMatrix2D extends
046                    AbstractDenseDoubleMatrix2D implements
047                    Wrapper<cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D> {
048            private static final long serialVersionUID = -1941030601886654699L;
049    
050            public static final DenseDoubleAlgebra ALG = new DenseDoubleAlgebra();
051    
052            private cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D matrix;
053    
054            public ParallelColtDenseDoubleMatrix2D(long... size) {
055                    this.matrix = new cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D(
056                                    (int) size[ROW], (int) size[COLUMN]);
057            }
058    
059            public ParallelColtDenseDoubleMatrix2D(DoubleMatrix2D m) {
060                    if (m instanceof DenseDoubleMatrix2D) {
061                            this.matrix = (cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D) m;
062                    } else {
063                            this.matrix = new cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D(
064                                            m.toArray());
065                    }
066            }
067    
068            public ParallelColtDenseDoubleMatrix2D(
069                            cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D m) {
070                    this.matrix = m;
071            }
072    
073            public ParallelColtDenseDoubleMatrix2D(Matrix source)
074                            throws MatrixException {
075                    if (source instanceof HasColumnMajorDoubleArray1D) {
076                            final double[] data = ((HasColumnMajorDoubleArray1D) source)
077                                            .getColumnMajorDoubleArray1D();
078                            this.matrix = new cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D(
079                                            (int) source.getRowCount(), (int) source.getColumnCount(),
080                                            data, 0, 0, 1, (int) source.getRowCount(), false);
081                    } else if (source instanceof HasRowMajorDoubleArray2D) {
082                            final double[][] data = ((HasRowMajorDoubleArray2D) source)
083                                            .getRowMajorDoubleArray2D();
084                            this.matrix = new cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D(
085                                            data);
086                    } else if (source instanceof DenseDoubleMatrix2D) {
087                            this.matrix = new cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D(
088                                            (int) source.getRowCount(), (int) source.getColumnCount());
089                            final DenseDoubleMatrix2D m2 = (DenseDoubleMatrix2D) source;
090                            for (int r = (int) source.getRowCount(); --r >= 0;) {
091                                    for (int c = (int) source.getColumnCount(); --c >= 0;) {
092                                            matrix.setQuick(r, c, m2.getDouble(r, c));
093                                    }
094                            }
095                    } else {
096                            this.matrix = new cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D(
097                                            (int) source.getRowCount(), (int) source.getColumnCount());
098                            for (long[] c : source.availableCoordinates()) {
099                                    setDouble(source.getAsDouble(c), c);
100                            }
101                    }
102            }
103    
104            public double getDouble(long row, long column) {
105                    return matrix.getQuick((int) row, (int) column);
106            }
107    
108            public double getDouble(int row, int column) {
109                    return matrix.getQuick(row, column);
110            }
111    
112            public long[] getSize() {
113                    return new long[] { matrix.rows(), matrix.columns() };
114            }
115    
116            public void setDouble(double value, long row, long column) {
117                    matrix.setQuick((int) row, (int) column, value);
118            }
119    
120            public void setDouble(double value, int row, int column) {
121                    matrix.setQuick(row, column, value);
122            }
123    
124            public cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D getWrappedObject() {
125                    return matrix;
126            }
127    
128            public void setWrappedObject(
129                            cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D object) {
130                    this.matrix = object;
131            }
132    
133            public Matrix plus(double value) {
134                    return new ParallelColtDenseDoubleMatrix2D(
135                                    (cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D) matrix
136                                                    .copy().assign(DoubleFunctions.plus(value)));
137            }
138    
139            public Matrix inv() {
140                    return new ParallelColtDenseDoubleMatrix2D(
141                                    (cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D) ALG
142                                                    .inverse(matrix));
143            }
144    
145            public Matrix times(double value) {
146                    return new ParallelColtDenseDoubleMatrix2D(
147                                    (cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D) matrix
148                                                    .copy().assign(DoubleFunctions.mult(value)));
149            }
150    
151            public Matrix transpose() {
152                    return new ParallelColtDenseDoubleMatrix2D(
153                                    (cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D) matrix
154                                                    .viewDice().copy());
155            }
156    
157            public Matrix plus(Matrix m) {
158                    if (m instanceof ParallelColtDenseDoubleMatrix2D) {
159                            DoubleMatrix2D result = matrix.copy();
160                            result.assign(((ParallelColtDenseDoubleMatrix2D) m)
161                                            .getWrappedObject(), DoubleFunctions.plus);
162                            return new ParallelColtDenseDoubleMatrix2D(result);
163                    } else {
164                            return super.plus(m);
165                    }
166            }
167    
168            public Matrix minus(Matrix m) {
169                    if (m instanceof ParallelColtDenseDoubleMatrix2D) {
170                            DoubleMatrix2D result = matrix.copy();
171                            result.assign(((ParallelColtDenseDoubleMatrix2D) m)
172                                            .getWrappedObject(), DoubleFunctions.minus);
173                            return new ParallelColtDenseDoubleMatrix2D(result);
174                    } else {
175                            return super.minus(m);
176                    }
177            }
178    
179            public Matrix mtimes(Matrix m) throws MatrixException {
180                    if (m instanceof ParallelColtDenseDoubleMatrix2D) {
181                            cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D ret = new cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D(
182                                            (int) getRowCount(), (int) m.getColumnCount());
183                            matrix.zMult(((ParallelColtDenseDoubleMatrix2D) m).matrix, ret);
184                            return new ParallelColtDenseDoubleMatrix2D(ret);
185                    } else {
186                            return super.mtimes(m);
187                    }
188            }
189    
190            public Matrix[] svd() {
191                    DenseDoubleSingularValueDecomposition svd = new DenseDoubleSingularValueDecomposition(
192                                    matrix, true, false);
193                    Matrix u = new ParallelColtDenseDoubleMatrix2D(svd.getU());
194                    Matrix s = new ParallelColtDenseDoubleMatrix2D(svd.getS());
195                    Matrix v = new ParallelColtDenseDoubleMatrix2D(svd.getV());
196                    return new Matrix[] { u, s, v };
197            }
198    
199            public Matrix[] eig() {
200                    DenseDoubleEigenvalueDecomposition eig = new DenseDoubleEigenvalueDecomposition(
201                                    matrix);
202                    Matrix v = new ParallelColtDenseDoubleMatrix2D(eig.getV());
203                    Matrix d = new ParallelColtDenseDoubleMatrix2D(eig.getD());
204                    return new Matrix[] { v, d };
205            }
206    
207            public Matrix[] qr() {
208                    DenseDoubleQRDecomposition qr = new DenseDoubleQRDecomposition(matrix);
209                    Matrix q = new ParallelColtDenseDoubleMatrix2D(qr.getQ(false));
210                    Matrix r = new ParallelColtDenseDoubleMatrix2D(qr.getR(false));
211                    return new Matrix[] { q, r };
212            }
213    
214            public Matrix[] lu() {
215                    if (getRowCount() >= getColumnCount()) {
216                            DenseDoubleLUDecomposition lu = new DenseDoubleLUDecomposition(
217                                            matrix);
218                            Matrix l = new ParallelColtDenseDoubleMatrix2D(lu.getL());
219                            Matrix u = new ParallelColtDenseDoubleMatrix2D(lu.getU().viewPart(
220                                            0, 0, (int) getColumnCount(), (int) getColumnCount()));
221                            int m = (int) getRowCount();
222                            int[] piv = lu.getPivot();
223                            Matrix p = new ParallelColtDenseDoubleMatrix2D(m, m);
224                            for (int i = 0; i < m; i++) {
225                                    p.setAsDouble(1, i, piv[i]);
226                            }
227                            return new Matrix[] { l, u, p };
228                    } else {
229                            throw new MatrixException("only supported for matrices m>=n");
230                    }
231            }
232    
233            public Matrix chol() {
234                    DenseDoubleCholeskyDecomposition chol = new DenseDoubleCholeskyDecomposition(
235                                    matrix);
236                    Matrix r = new ParallelColtDenseDoubleMatrix2D(chol.getL());
237                    return r;
238            }
239    
240            public Matrix copy() {
241                    Matrix m = new ParallelColtDenseDoubleMatrix2D(
242                                    (cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D) matrix
243                                                    .copy());
244                    if (getAnnotation() != null) {
245                            m.setAnnotation(getAnnotation().clone());
246                    }
247                    return m;
248            }
249    
250            public Matrix solve(Matrix b) {
251                    return Solve.INSTANCE.calc(this, b);
252            }
253    
254            public Matrix solveSPD(Matrix b) {
255                    if (b instanceof ParallelColtDenseDoubleMatrix2D) {
256                            ParallelColtDenseDoubleMatrix2D b2 = new ParallelColtDenseDoubleMatrix2D(
257                                            b);
258                            DenseDoubleCholeskyDecomposition chol = new DenseDoubleCholeskyDecomposition(
259                                            matrix);
260                            chol.solve(b2.matrix);
261                            return b2;
262                    } else {
263                            return super.solve(b);
264                    }
265            }
266    
267            public Matrix invSPD() {
268                    DenseDoubleCholeskyDecomposition chol = new DenseDoubleCholeskyDecomposition(
269                                    matrix);
270                    DoubleMatrix2D ret = DoubleFactory2D.dense.identity(matrix.rows());
271                    chol.solve(ret);
272                    return new ParallelColtDenseDoubleMatrix2D(ret);
273            }
274    }