001    /*
002     * Copyright (C) 2008-2010 by Holger Arndt
003     *
004     * This file is part of the Universal Java Matrix Package (UJMP).
005     * See the NOTICE file distributed with this work for additional
006     * information regarding copyright ownership and licensing.
007     *
008     * UJMP is free software; you can redistribute it and/or modify
009     * it under the terms of the GNU Lesser General Public License as
010     * published by the Free Software Foundation; either version 2
011     * of the License, or (at your option) any later version.
012     *
013     * UJMP is distributed in the hope that it will be useful,
014     * but WITHOUT ANY WARRANTY; without even the implied warranty of
015     * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
016     * GNU Lesser General Public License for more details.
017     *
018     * You should have received a copy of the GNU Lesser General Public
019     * License along with UJMP; if not, write to the
020     * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
021     * Boston, MA  02110-1301  USA
022     */
023    
024    package org.ujmp.core.doublematrix.calculation.general.missingvalues;
025    
026    import java.util.ArrayList;
027    import java.util.List;
028    import java.util.concurrent.Callable;
029    import java.util.concurrent.ExecutorService;
030    import java.util.concurrent.Executors;
031    import java.util.concurrent.Future;
032    
033    import org.ujmp.core.Matrix;
034    import org.ujmp.core.MatrixFactory;
035    import org.ujmp.core.doublematrix.calculation.AbstractDoubleCalculation;
036    import org.ujmp.core.doublematrix.calculation.general.missingvalues.Impute.ImputationMethod;
037    import org.ujmp.core.exceptions.MatrixException;
038    import org.ujmp.core.util.MathUtil;
039    
040    public class ImputeRegression extends AbstractDoubleCalculation {
041            private static final long serialVersionUID = 2147234720707721364L;
042    
043            Matrix firstGuess = null;
044    
045            Matrix imputed = null;
046    
047            public ImputeRegression(Matrix matrix) {
048                    super(matrix);
049            }
050    
051            public ImputeRegression(Matrix matrix, Matrix firstGuess) {
052                    super(matrix);
053                    this.firstGuess = firstGuess;
054            }
055    
056            public double getDouble(long... coordinates) throws MatrixException {
057                    if (imputed == null) {
058                            createMatrix();
059                    }
060                    double v = getSource().getAsDouble(coordinates);
061                    if (MathUtil.isNaNOrInfinite(v)) {
062                            return imputed.getAsDouble(coordinates);
063                    } else {
064                            return v;
065                    }
066            }
067    
068            private void createMatrix() {
069                    try {
070                            Matrix x = getSource();
071    
072                            if (firstGuess == null) {
073                                    firstGuess = getSource().impute(Ret.NEW, ImputationMethod.RowMean);
074                            }
075    
076                            imputed = Matrix.factory.zeros(x.getSize());
077    
078                            ExecutorService executor = Executors.newFixedThreadPool(1);
079                            List<Future<Long>> futures = new ArrayList<Future<Long>>();
080    
081                            long t0 = System.currentTimeMillis();
082    
083                            for (long c = 0; c < x.getColumnCount(); c++) {
084                                    futures.add(executor.submit(new PredictColumn(c)));
085                            }
086    
087                            for (Future<Long> f : futures) {
088                                    Long completedCols = f.get();
089                                    long elapsedTime = System.currentTimeMillis() - t0;
090                                    long remainingCols = x.getColumnCount() - completedCols;
091                                    double colsPerMillisecond = (double) (completedCols + 1) / (double) elapsedTime;
092                                    long remainingTime = (long) (remainingCols / colsPerMillisecond / 1000.0);
093                                    System.out.println((completedCols * 1000 / x.getColumnCount() / 10.0)
094                                                    + "% completed (" + remainingTime + " seconds remaining)");
095                            }
096    
097                            executor.shutdown();
098    
099                    } catch (Exception e) {
100                            throw new MatrixException(e);
101                    }
102            }
103    
104            class PredictColumn implements Callable<Long> {
105    
106                    long column = 0;
107    
108                    public PredictColumn(long column) {
109                            this.column = column;
110                    }
111    
112                    public Long call() throws Exception {
113                            Matrix newColumn = replaceInColumn(getSource(), firstGuess, column);
114                            for (int r = 0; r < newColumn.getRowCount(); r++) {
115                                    imputed.setAsDouble(newColumn.getAsDouble(r, 0), r, column);
116                            }
117                            return column;
118                    }
119    
120            }
121    
122            private static Matrix replaceInColumn(Matrix original, Matrix firstGuess, long column)
123                            throws MatrixException {
124    
125                    Matrix x = firstGuess.deleteColumns(Ret.NEW, column);
126                    Matrix y = original.selectColumns(Ret.NEW, column);
127    
128                    List<Long> missingRows = new ArrayList<Long>();
129                    for (long i = y.getRowCount(); --i >= 0;) {
130                            double v = y.getAsDouble(i, 0);
131                            if (MathUtil.isNaNOrInfinite(v)) {
132                                    missingRows.add(i);
133                            }
134                    }
135    
136                    if (missingRows.isEmpty()) {
137                            return y;
138                    }
139    
140                    Matrix xdel = x.deleteRows(Ret.NEW, missingRows);
141                    Matrix bias1 = Matrix.factory.ones(xdel.getRowCount(), 1);
142                    Matrix xtrain = MatrixFactory.horCat(xdel, bias1);
143                    Matrix ytrain = y.deleteRows(Ret.NEW, missingRows);
144    
145                    Matrix xinv = xtrain.pinv();
146                    Matrix b = xinv.mtimes(ytrain);
147                    Matrix bias2 = Matrix.factory.ones(x.getRowCount(), 1);
148                    Matrix yPredicted = MatrixFactory.horCat(x, bias2).mtimes(b);
149    
150                    // set non-missing values back to original values
151                    for (int row = 0; row < y.getRowCount(); row++) {
152                            double v = y.getAsDouble(row, 0);
153                            if (!Double.isNaN(v)) {
154                                    yPredicted.setAsDouble(v, row, 0);
155                            }
156                    }
157    
158                    return yPredicted;
159            }
160    
161    }