001    package org.ujmp.core.doublematrix.calculation.general.decomposition;
002    
003    import org.ujmp.core.Matrix;
004    import org.ujmp.core.MatrixFactory;
005    import org.ujmp.core.calculation.Calculation.Ret;
006    import org.ujmp.core.exceptions.MatrixException;
007    import org.ujmp.core.util.DecompositionOps;
008    import org.ujmp.core.util.MathUtil;
009    import org.ujmp.core.util.UJMPSettings;
010    
011    /**
012     * QR Decomposition.
013     * <P>
014     * For an m-by-n matrix A with m >= n, the QR decomposition is an m-by-n
015     * orthogonal matrix Q and an n-by-n upper triangular matrix R so that A = Q*R.
016     * <P>
017     * The QR decompostion always exists, even if the matrix does not have full
018     * rank, so the constructor will never fail. The primary use of the QR
019     * decomposition is in the least squares solution of nonsquare systems of
020     * simultaneous linear equations. This will fail if isFullRank() returns false.
021     */
022    
023    public interface QR<T> {
024    
025            public static int THRESHOLD = 100;
026    
027            public T[] calc(T source);
028    
029            public T solve(T source, T b);
030    
031            public static final QR<Matrix> MATRIX = new QR<Matrix>() {
032    
033                    public final Matrix[] calc(Matrix source) {
034                            if (UJMPSettings.getNumberOfThreads() == 1) {
035                                    if (source.getRowCount() >= THRESHOLD && source.getColumnCount() >= THRESHOLD) {
036                                            return MATRIXLARGESINGLETHREADED.calc(source);
037                                    } else {
038                                            return MATRIXSMALLSINGLETHREADED.calc(source);
039                                    }
040                            } else {
041                                    if (source.getRowCount() >= THRESHOLD && source.getColumnCount() >= THRESHOLD) {
042                                            return MATRIXLARGEMULTITHREADED.calc(source);
043                                    } else {
044                                            return MATRIXSMALLMULTITHREADED.calc(source);
045                                    }
046                            }
047                    }
048    
049                    public final Matrix solve(Matrix source, Matrix b) {
050                            if (UJMPSettings.getNumberOfThreads() == 1) {
051                                    if (source.getRowCount() >= THRESHOLD && source.getColumnCount() >= THRESHOLD) {
052                                            return MATRIXLARGESINGLETHREADED.solve(source, b);
053                                    } else {
054                                            return MATRIXSMALLSINGLETHREADED.solve(source, b);
055                                    }
056                            } else {
057                                    if (source.getRowCount() >= THRESHOLD && source.getColumnCount() >= THRESHOLD) {
058                                            return MATRIXLARGEMULTITHREADED.solve(source, b);
059                                    } else {
060                                            return MATRIXSMALLMULTITHREADED.solve(source, b);
061                                    }
062                            }
063                    }
064            };
065    
066            public static final QR<Matrix> MATRIXLARGESINGLETHREADED = new QR<Matrix>() {
067                    public final Matrix[] calc(Matrix source) {
068                            QR<Matrix> qr = null;
069                            if (UJMPSettings.isUseOjalgo()) {
070                                    qr = DecompositionOps.QR_OJALGO;
071                            }
072                            if (qr == null && UJMPSettings.isUseEJML()) {
073                                    qr = DecompositionOps.QR_EJML;
074                            }
075                            if (qr == null && UJMPSettings.isUseMTJ()) {
076                                    qr = DecompositionOps.QR_MTJ;
077                            }
078                            if (qr == null) {
079                                    qr = UJMP;
080                            }
081                            return qr.calc(source);
082                    }
083    
084                    public final Matrix solve(Matrix source, Matrix b) {
085                            QR<Matrix> qr = null;
086                            if (UJMPSettings.isUseOjalgo()) {
087                                    qr = DecompositionOps.QR_OJALGO;
088                            }
089                            if (qr == null && UJMPSettings.isUseEJML()) {
090                                    qr = DecompositionOps.QR_EJML;
091                            }
092                            if (qr == null && UJMPSettings.isUseMTJ()) {
093                                    qr = DecompositionOps.QR_MTJ;
094                            }
095                            if (qr == null) {
096                                    qr = UJMP;
097                            }
098                            return qr.solve(source, b);
099                    }
100            };
101    
102            public static final QR<Matrix> MATRIXLARGEMULTITHREADED = new QR<Matrix>() {
103                    public final Matrix[] calc(Matrix source) {
104                            QR<Matrix> qr = null;
105                            if (UJMPSettings.isUseOjalgo()) {
106                                    qr = DecompositionOps.QR_OJALGO;
107                            }
108                            if (qr == null && UJMPSettings.isUseEJML()) {
109                                    qr = DecompositionOps.QR_EJML;
110                            }
111                            if (qr == null && UJMPSettings.isUseMTJ()) {
112                                    qr = DecompositionOps.QR_MTJ;
113                            }
114                            if (qr == null) {
115                                    qr = UJMP;
116                            }
117                            return qr.calc(source);
118                    }
119    
120                    public final Matrix solve(Matrix source, Matrix b) {
121                            QR<Matrix> qr = null;
122                            if (UJMPSettings.isUseOjalgo()) {
123                                    qr = DecompositionOps.QR_OJALGO;
124                            }
125                            if (qr == null && UJMPSettings.isUseEJML()) {
126                                    qr = DecompositionOps.QR_EJML;
127                            }
128                            if (qr == null && UJMPSettings.isUseMTJ()) {
129                                    qr = DecompositionOps.QR_MTJ;
130                            }
131                            if (qr == null) {
132                                    qr = UJMP;
133                            }
134                            return qr.solve(source, b);
135                    }
136            };
137    
138            public static final QR<Matrix> INSTANCE = MATRIX;
139    
140            public static final QR<Matrix> UJMP = new QR<Matrix>() {
141    
142                    public final Matrix[] calc(Matrix source) {
143                            if (source.getRowCount() >= source.getColumnCount()) {
144                                    QRMatrix qr = new QRMatrix(source);
145                                    return new Matrix[] { qr.getQ(), qr.getR() };
146                            } else {
147                                    throw new MatrixException("only matrices m>=n are allowed");
148                            }
149                    }
150    
151                    public final Matrix solve(Matrix source, Matrix b) {
152                            if (source.getRowCount() >= source.getColumnCount()) {
153                                    QRMatrix qr = new QRMatrix(source);
154                                    return qr.solve(b);
155                            } else {
156                                    throw new MatrixException("only matrices m>=n are allowed");
157                            }
158                    }
159            };
160    
161            public static final QR<Matrix> MATRIXSMALLMULTITHREADED = UJMP;
162    
163            public static final QR<Matrix> MATRIXSMALLSINGLETHREADED = UJMP;
164    
165            public class QRMatrix {
166                    private static final long serialVersionUID = 2137461328307048867L;
167    
168                    /**
169                     * Array for internal storage of decomposition.
170                     * 
171                     * @serial internal array storage.
172                     */
173                    private final double[][] QR;
174    
175                    /**
176                     * Row and column dimensions.
177                     * 
178                     * @serial column dimension.
179                     * @serial row dimension.
180                     */
181                    private final int m, n;
182    
183                    /**
184                     * Array for internal storage of diagonal of R.
185                     * 
186                     * @serial diagonal of R.
187                     */
188                    private final double[] Rdiag;
189    
190                    /*
191                     * ------------------------ Constructor ------------------------
192                     */
193    
194                    /**
195                     * QR Decomposition, computed by Householder reflections.
196                     * 
197                     * @param A
198                     *            Rectangular matrix
199                     * @return Structure to access R and the Householder vectors and compute
200                     *         Q.
201                     */
202    
203                    public QRMatrix(Matrix A) {
204                            QR = A.toDoubleArray();
205                            m = (int) A.getRowCount();
206                            n = (int) A.getColumnCount();
207                            Rdiag = new double[n];
208    
209                            // Main loop.
210                            for (int k = 0; k < n; k++) {
211                                    // Compute 2-norm of k-th column without under/overflow.
212                                    double nrm = 0;
213                                    for (int i = k; i < m; i++) {
214                                            nrm = MathUtil.hypot(nrm, QR[i][k]);
215                                    }
216    
217                                    if (nrm != 0.0) {
218                                            // Form k-th Householder vector.
219                                            if (QR[k][k] < 0) {
220                                                    nrm = -nrm;
221                                            }
222                                            for (int i = k; i < m; i++) {
223                                                    QR[i][k] /= nrm;
224                                            }
225                                            QR[k][k] += 1.0;
226    
227                                            // Apply transformation to remaining columns.
228                                            for (int j = k + 1; j < n; j++) {
229                                                    double s = 0.0;
230                                                    for (int i = k; i < m; i++) {
231                                                            s += QR[i][k] * QR[i][j];
232                                                    }
233                                                    s = -s / QR[k][k];
234                                                    for (int i = k; i < m; i++) {
235                                                            QR[i][j] += s * QR[i][k];
236                                                    }
237                                            }
238                                    }
239                                    Rdiag[k] = -nrm;
240                            }
241                    }
242    
243                    /*
244                     * ------------------------ Public Methods ------------------------
245                     */
246    
247                    /**
248                     * Is the matrix full rank?
249                     * 
250                     * @return true if R, and hence A, has full rank.
251                     */
252    
253                    public final boolean isFullRank() {
254                            for (int j = 0; j < n; j++) {
255                                    if (Rdiag[j] == 0)
256                                            return false;
257                            }
258                            return true;
259                    }
260    
261                    /**
262                     * Return the Householder vectors
263                     * 
264                     * @return Lower trapezoidal matrix whose columns define the reflections
265                     */
266    
267                    public final Matrix getH() {
268                            final double[][] H = new double[m][n];
269                            for (int i = 0; i < m; i++) {
270                                    for (int j = 0; j < n; j++) {
271                                            if (i >= j) {
272                                                    H[i][j] = QR[i][j];
273                                            }
274                                    }
275                            }
276                            return MatrixFactory.linkToArray(H);
277                    }
278    
279                    /**
280                     * Return the upper triangular factor
281                     * 
282                     * @return R
283                     */
284    
285                    public final Matrix getR() {
286                            final double[][] R = new double[n][n];
287                            for (int i = 0; i < n; i++) {
288                                    for (int j = 0; j < n; j++) {
289                                            if (i < j) {
290                                                    R[i][j] = QR[i][j];
291                                            } else if (i == j) {
292                                                    R[i][j] = Rdiag[i];
293                                            } else {
294                                                    R[i][j] = 0.0;
295                                            }
296                                    }
297                            }
298                            return MatrixFactory.linkToArray(R);
299                    }
300    
301                    /**
302                     * Generate and return the (economy-sized) orthogonal factor
303                     * 
304                     * @return Q
305                     */
306    
307                    public final Matrix getQ() {
308                            final double[][] Q = new double[m][n];
309                            for (int k = n - 1; k >= 0; k--) {
310                                    for (int i = 0; i < m; i++) {
311                                            Q[i][k] = 0.0;
312                                    }
313                                    Q[k][k] = 1.0;
314                                    for (int j = k; j < n; j++) {
315                                            if (QR[k][k] != 0) {
316                                                    double s = 0.0;
317                                                    for (int i = k; i < m; i++) {
318                                                            s += QR[i][k] * Q[i][j];
319                                                    }
320                                                    s = -s / QR[k][k];
321                                                    for (int i = k; i < m; i++) {
322                                                            Q[i][j] += s * QR[i][k];
323                                                    }
324                                            }
325                                    }
326                            }
327                            return MatrixFactory.linkToArray(Q);
328                    }
329    
330                    /**
331                     * Least squares solution of A*X = B
332                     * 
333                     * @param B
334                     *            A Matrix with as many rows as A and any number of columns.
335                     * @return X that minimizes the two norm of Q*R*X-B.
336                     * @exception IllegalArgumentException
337                     *                Matrix row dimensions must agree.
338                     * @exception RuntimeException
339                     *                Matrix is rank deficient.
340                     */
341    
342                    public final Matrix solve(Matrix B) {
343                            if (B.getRowCount() != m) {
344                                    throw new IllegalArgumentException("Matrix row dimensions must agree.");
345                            }
346                            if (!this.isFullRank()) {
347                                    throw new RuntimeException("Matrix is rank deficient.");
348                            }
349    
350                            // Copy right hand side
351                            final int nx = (int) B.getColumnCount();
352                            final double[][] X = B.toDoubleArray();
353    
354                            // Compute Y = transpose(Q)*B
355                            for (int k = 0; k < n; k++) {
356                                    for (int j = 0; j < nx; j++) {
357                                            double s = 0.0;
358                                            for (int i = k; i < m; i++) {
359                                                    s += QR[i][k] * X[i][j];
360                                            }
361                                            s = -s / QR[k][k];
362                                            for (int i = k; i < m; i++) {
363                                                    X[i][j] += s * QR[i][k];
364                                            }
365                                    }
366                            }
367                            // Solve R*X = Y;
368                            for (int k = n - 1; k >= 0; k--) {
369                                    for (int j = 0; j < nx; j++) {
370                                            X[k][j] /= Rdiag[k];
371                                    }
372                                    for (int i = 0; i < k; i++) {
373                                            for (int j = 0; j < nx; j++) {
374                                                    X[i][j] -= X[k][j] * QR[i][k];
375                                            }
376                                    }
377                            }
378                            return MatrixFactory.linkToArray(X).subMatrix(Ret.NEW, 0, 0, n - 1, nx - 1);
379                    }
380            }
381    }