001    /*
002     * Copyright (C) 2008-2010 by Holger Arndt
003     *
004     * This file is part of the Universal Java Matrix Package (UJMP).
005     * See the NOTICE file distributed with this work for additional
006     * information regarding copyright ownership and licensing.
007     *
008     * UJMP is free software; you can redistribute it and/or modify
009     * it under the terms of the GNU Lesser General Public License as
010     * published by the Free Software Foundation; either version 2
011     * of the License, or (at your option) any later version.
012     *
013     * UJMP is distributed in the hope that it will be useful,
014     * but WITHOUT ANY WARRANTY; without even the implied warranty of
015     * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
016     * GNU Lesser General Public License for more details.
017     *
018     * You should have received a copy of the GNU Lesser General Public
019     * License along with UJMP; if not, write to the
020     * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
021     * Boston, MA  02110-1301  USA
022     */
023    
024    package org.ujmp.core.doublematrix.calculation.general.missingvalues;
025    
026    import java.util.ArrayList;
027    import java.util.List;
028    import java.util.concurrent.Callable;
029    import java.util.concurrent.ExecutorService;
030    import java.util.concurrent.Executors;
031    import java.util.concurrent.Future;
032    
033    import org.ujmp.core.Matrix;
034    import org.ujmp.core.MatrixFactory;
035    import org.ujmp.core.doublematrix.calculation.AbstractDoubleCalculation;
036    import org.ujmp.core.doublematrix.calculation.general.missingvalues.Impute.ImputationMethod;
037    import org.ujmp.core.exceptions.MatrixException;
038    import org.ujmp.core.util.MathUtil;
039    
040    public class ImputeEM extends AbstractDoubleCalculation {
041            private static final long serialVersionUID = -1272010036598212696L;
042    
043            private Matrix bestGuess = null;
044    
045            private Matrix imputed = null;
046    
047            private double delta = 1e-6;
048    
049            private final double decay = 0.66;
050    
051            public ImputeEM(Matrix matrix) {
052                    super(matrix);
053            }
054    
055            public ImputeEM(Matrix matrix, Matrix firstGuess) {
056                    super(matrix);
057                    this.bestGuess = firstGuess;
058            }
059    
060            public ImputeEM(Matrix matrix, Matrix firstGuess, double delta) {
061                    super(matrix);
062                    this.bestGuess = firstGuess;
063                    this.delta = delta;
064            }
065    
066            public double getDouble(long... coordinates) throws MatrixException {
067                    if (imputed == null) {
068                            createMatrix();
069                    }
070                    double v = getSource().getAsDouble(coordinates);
071                    if (MathUtil.isNaNOrInfinite(v)) {
072                            return imputed.getAsDouble(coordinates);
073                    } else {
074                            return v;
075                    }
076            }
077    
078            private void createMatrix() {
079                    try {
080                            ExecutorService executor = Executors.newFixedThreadPool(1);
081    
082                            Matrix x = getSource();
083    
084                            double valueCount = x.getValueCount();
085                            long missingCount = (long) x.countMissing(Ret.NEW, Matrix.ALL).getEuklideanValue();
086                            double percent = ((int) Math.round((missingCount * 1000.0 / valueCount))) / 10.0;
087                            System.out.println("missing values: " + missingCount + " (" + percent + "%)");
088                            System.out.println("============================================");
089    
090                            if (bestGuess == null) {
091                                    bestGuess = getSource().impute(Ret.NEW, ImputationMethod.RowMean);
092                            }
093    
094                            int run = 0;
095    
096                            while (true) {
097    
098                                    System.out.println("Iteration " + run++);
099    
100                                    List<Future<Long>> futures = new ArrayList<Future<Long>>();
101    
102                                    imputed = Matrix.factory.zeros(x.getSize());
103    
104                                    long t0 = System.currentTimeMillis();
105    
106                                    for (long c = 0; c < x.getColumnCount(); c++) {
107                                            futures.add(executor.submit(new PredictColumn(c)));
108                                    }
109    
110                                    for (Future<Long> f : futures) {
111                                            Long completedCols = f.get();
112                                            long elapsedTime = System.currentTimeMillis() - t0;
113                                            long remainingCols = x.getColumnCount() - completedCols;
114                                            double colsPerMillisecond = (double) (completedCols + 1) / (double) elapsedTime;
115                                            long remainingTime = (long) (remainingCols / colsPerMillisecond / 1000.0);
116                                            System.out.println((completedCols * 1000 / x.getColumnCount() / 10.0)
117                                                            + "% completed (" + remainingTime + " seconds remaining)");
118                                    }
119    
120                                    double d = imputed.euklideanDistanceTo(bestGuess, true) / missingCount;
121                                    System.out.println("delta: " + d);
122                                    System.out.println("============================================");
123    
124                                    bestGuess = bestGuess.times(decay).plus(imputed.times(1 - decay));
125    
126                                    if (d < delta) {
127                                            break;
128                                    }
129    
130                            }
131    
132                            executor.shutdown();
133    
134                            imputed = bestGuess;
135    
136                            if (imputed.containsMissingValues()) {
137                                    throw new MatrixException("Matrix has still missing values after imputation");
138                            }
139    
140                    } catch (Exception e) {
141                            throw new MatrixException(e);
142                    }
143            }
144    
145            class PredictColumn implements Callable<Long> {
146    
147                    long column = 0;
148    
149                    public PredictColumn(long column) {
150                            this.column = column;
151                    }
152    
153                    public Long call() throws Exception {
154                            Matrix newColumn = replaceInColumn(getSource(), bestGuess, column);
155                            for (int r = 0; r < newColumn.getRowCount(); r++) {
156                                    imputed.setAsDouble(newColumn.getAsDouble(r, 0), r, column);
157                            }
158                            return column;
159                    }
160    
161            }
162    
163            private static Matrix replaceInColumn(Matrix original, Matrix firstGuess, long column)
164                            throws MatrixException {
165    
166                    Matrix x = firstGuess.deleteColumns(Ret.NEW, column);
167                    Matrix y = original.selectColumns(Ret.NEW, column);
168    
169                    List<Long> missingRows = new ArrayList<Long>();
170                    for (long i = y.getRowCount(); --i >= 0;) {
171                            double v = y.getAsDouble(i, 0);
172                            if (MathUtil.isNaNOrInfinite(v)) {
173                                    missingRows.add(i);
174                            }
175                    }
176    
177                    if (missingRows.isEmpty()) {
178                            return y;
179                    }
180    
181                    Matrix xdel = x.deleteRows(Ret.NEW, missingRows);
182                    Matrix bias1 = Matrix.factory.ones(xdel.getRowCount(), 1);
183                    Matrix xtrain = MatrixFactory.horCat(xdel, bias1);
184                    Matrix ytrain = y.deleteRows(Ret.NEW, missingRows);
185    
186                    Matrix xinv = xtrain.pinv();
187                    Matrix b = xinv.mtimes(ytrain);
188                    Matrix bias2 = Matrix.factory.ones(x.getRowCount(), 1);
189                    Matrix yPredicted = MatrixFactory.horCat(x, bias2).mtimes(b);
190    
191                    // set non-missing values back to original values
192                    for (int row = 0; row < y.getRowCount(); row++) {
193                            double v = y.getAsDouble(row, 0);
194                            if (!Double.isNaN(v)) {
195                                    yPredicted.setAsDouble(v, row, 0);
196                            }
197                    }
198    
199                    return yPredicted;
200            }
201    
202    }