001 /* 002 * Copyright (C) 2008-2010 by Holger Arndt 003 * 004 * This file is part of the Universal Java Matrix Package (UJMP). 005 * See the NOTICE file distributed with this work for additional 006 * information regarding copyright ownership and licensing. 007 * 008 * UJMP is free software; you can redistribute it and/or modify 009 * it under the terms of the GNU Lesser General Public License as 010 * published by the Free Software Foundation; either version 2 011 * of the License, or (at your option) any later version. 012 * 013 * UJMP is distributed in the hope that it will be useful, 014 * but WITHOUT ANY WARRANTY; without even the implied warranty of 015 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 016 * GNU Lesser General Public License for more details. 017 * 018 * You should have received a copy of the GNU Lesser General Public 019 * License along with UJMP; if not, write to the 020 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, 021 * Boston, MA 02110-1301 USA 022 */ 023 024 package org.ujmp.core.doublematrix.calculation.general.missingvalues; 025 026 import java.util.ArrayList; 027 import java.util.List; 028 import java.util.concurrent.Callable; 029 import java.util.concurrent.ExecutorService; 030 import java.util.concurrent.Executors; 031 import java.util.concurrent.Future; 032 033 import org.ujmp.core.Matrix; 034 import org.ujmp.core.MatrixFactory; 035 import org.ujmp.core.doublematrix.calculation.AbstractDoubleCalculation; 036 import org.ujmp.core.doublematrix.calculation.general.missingvalues.Impute.ImputationMethod; 037 import org.ujmp.core.exceptions.MatrixException; 038 import org.ujmp.core.util.MathUtil; 039 040 public class ImputeEM extends AbstractDoubleCalculation { 041 private static final long serialVersionUID = -1272010036598212696L; 042 043 private Matrix bestGuess = null; 044 045 private Matrix imputed = null; 046 047 private double delta = 1e-6; 048 049 private final double decay = 0.66; 050 051 public ImputeEM(Matrix matrix) { 052 super(matrix); 053 } 054 055 public ImputeEM(Matrix matrix, Matrix firstGuess) { 056 super(matrix); 057 this.bestGuess = firstGuess; 058 } 059 060 public ImputeEM(Matrix matrix, Matrix firstGuess, double delta) { 061 super(matrix); 062 this.bestGuess = firstGuess; 063 this.delta = delta; 064 } 065 066 public double getDouble(long... coordinates) throws MatrixException { 067 if (imputed == null) { 068 createMatrix(); 069 } 070 double v = getSource().getAsDouble(coordinates); 071 if (MathUtil.isNaNOrInfinite(v)) { 072 return imputed.getAsDouble(coordinates); 073 } else { 074 return v; 075 } 076 } 077 078 private void createMatrix() { 079 try { 080 ExecutorService executor = Executors.newFixedThreadPool(1); 081 082 Matrix x = getSource(); 083 084 double valueCount = x.getValueCount(); 085 long missingCount = (long) x.countMissing(Ret.NEW, Matrix.ALL).getEuklideanValue(); 086 double percent = ((int) Math.round((missingCount * 1000.0 / valueCount))) / 10.0; 087 System.out.println("missing values: " + missingCount + " (" + percent + "%)"); 088 System.out.println("============================================"); 089 090 if (bestGuess == null) { 091 bestGuess = getSource().impute(Ret.NEW, ImputationMethod.RowMean); 092 } 093 094 int run = 0; 095 096 while (true) { 097 098 System.out.println("Iteration " + run++); 099 100 List<Future<Long>> futures = new ArrayList<Future<Long>>(); 101 102 imputed = Matrix.factory.zeros(x.getSize()); 103 104 long t0 = System.currentTimeMillis(); 105 106 for (long c = 0; c < x.getColumnCount(); c++) { 107 futures.add(executor.submit(new PredictColumn(c))); 108 } 109 110 for (Future<Long> f : futures) { 111 Long completedCols = f.get(); 112 long elapsedTime = System.currentTimeMillis() - t0; 113 long remainingCols = x.getColumnCount() - completedCols; 114 double colsPerMillisecond = (double) (completedCols + 1) / (double) elapsedTime; 115 long remainingTime = (long) (remainingCols / colsPerMillisecond / 1000.0); 116 System.out.println((completedCols * 1000 / x.getColumnCount() / 10.0) 117 + "% completed (" + remainingTime + " seconds remaining)"); 118 } 119 120 double d = imputed.euklideanDistanceTo(bestGuess, true) / missingCount; 121 System.out.println("delta: " + d); 122 System.out.println("============================================"); 123 124 bestGuess = bestGuess.times(decay).plus(imputed.times(1 - decay)); 125 126 if (d < delta) { 127 break; 128 } 129 130 } 131 132 executor.shutdown(); 133 134 imputed = bestGuess; 135 136 if (imputed.containsMissingValues()) { 137 throw new MatrixException("Matrix has still missing values after imputation"); 138 } 139 140 } catch (Exception e) { 141 throw new MatrixException(e); 142 } 143 } 144 145 class PredictColumn implements Callable<Long> { 146 147 long column = 0; 148 149 public PredictColumn(long column) { 150 this.column = column; 151 } 152 153 public Long call() throws Exception { 154 Matrix newColumn = replaceInColumn(getSource(), bestGuess, column); 155 for (int r = 0; r < newColumn.getRowCount(); r++) { 156 imputed.setAsDouble(newColumn.getAsDouble(r, 0), r, column); 157 } 158 return column; 159 } 160 161 } 162 163 private static Matrix replaceInColumn(Matrix original, Matrix firstGuess, long column) 164 throws MatrixException { 165 166 Matrix x = firstGuess.deleteColumns(Ret.NEW, column); 167 Matrix y = original.selectColumns(Ret.NEW, column); 168 169 List<Long> missingRows = new ArrayList<Long>(); 170 for (long i = y.getRowCount(); --i >= 0;) { 171 double v = y.getAsDouble(i, 0); 172 if (MathUtil.isNaNOrInfinite(v)) { 173 missingRows.add(i); 174 } 175 } 176 177 if (missingRows.isEmpty()) { 178 return y; 179 } 180 181 Matrix xdel = x.deleteRows(Ret.NEW, missingRows); 182 Matrix bias1 = Matrix.factory.ones(xdel.getRowCount(), 1); 183 Matrix xtrain = MatrixFactory.horCat(xdel, bias1); 184 Matrix ytrain = y.deleteRows(Ret.NEW, missingRows); 185 186 Matrix xinv = xtrain.pinv(); 187 Matrix b = xinv.mtimes(ytrain); 188 Matrix bias2 = Matrix.factory.ones(x.getRowCount(), 1); 189 Matrix yPredicted = MatrixFactory.horCat(x, bias2).mtimes(b); 190 191 // set non-missing values back to original values 192 for (int row = 0; row < y.getRowCount(); row++) { 193 double v = y.getAsDouble(row, 0); 194 if (!Double.isNaN(v)) { 195 yPredicted.setAsDouble(v, row, 0); 196 } 197 } 198 199 return yPredicted; 200 } 201 202 }