001 /* 002 * Copyright (C) 2008-2010 by Holger Arndt 003 * 004 * This file is part of the Universal Java Matrix Package (UJMP). 005 * See the NOTICE file distributed with this work for additional 006 * information regarding copyright ownership and licensing. 007 * 008 * UJMP is free software; you can redistribute it and/or modify 009 * it under the terms of the GNU Lesser General Public License as 010 * published by the Free Software Foundation; either version 2 011 * of the License, or (at your option) any later version. 012 * 013 * UJMP is distributed in the hope that it will be useful, 014 * but WITHOUT ANY WARRANTY; without even the implied warranty of 015 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 016 * GNU Lesser General Public License for more details. 017 * 018 * You should have received a copy of the GNU Lesser General Public 019 * License along with UJMP; if not, write to the 020 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, 021 * Boston, MA 02110-1301 USA 022 */ 023 024 package org.ujmp.core.doublematrix.calculation.general.missingvalues; 025 026 import java.util.ArrayList; 027 import java.util.List; 028 import java.util.concurrent.Callable; 029 import java.util.concurrent.ExecutorService; 030 import java.util.concurrent.Executors; 031 import java.util.concurrent.Future; 032 033 import org.ujmp.core.Matrix; 034 import org.ujmp.core.MatrixFactory; 035 import org.ujmp.core.doublematrix.calculation.AbstractDoubleCalculation; 036 import org.ujmp.core.doublematrix.calculation.general.missingvalues.Impute.ImputationMethod; 037 import org.ujmp.core.exceptions.MatrixException; 038 import org.ujmp.core.util.MathUtil; 039 040 public class ImputeRegression extends AbstractDoubleCalculation { 041 private static final long serialVersionUID = 2147234720707721364L; 042 043 Matrix firstGuess = null; 044 045 Matrix imputed = null; 046 047 public ImputeRegression(Matrix matrix) { 048 super(matrix); 049 } 050 051 public ImputeRegression(Matrix matrix, Matrix firstGuess) { 052 super(matrix); 053 this.firstGuess = firstGuess; 054 } 055 056 public double getDouble(long... coordinates) throws MatrixException { 057 if (imputed == null) { 058 createMatrix(); 059 } 060 double v = getSource().getAsDouble(coordinates); 061 if (MathUtil.isNaNOrInfinite(v)) { 062 return imputed.getAsDouble(coordinates); 063 } else { 064 return v; 065 } 066 } 067 068 private void createMatrix() { 069 try { 070 Matrix x = getSource(); 071 072 if (firstGuess == null) { 073 firstGuess = getSource().impute(Ret.NEW, ImputationMethod.RowMean); 074 } 075 076 imputed = Matrix.factory.zeros(x.getSize()); 077 078 ExecutorService executor = Executors.newFixedThreadPool(1); 079 List<Future<Long>> futures = new ArrayList<Future<Long>>(); 080 081 long t0 = System.currentTimeMillis(); 082 083 for (long c = 0; c < x.getColumnCount(); c++) { 084 futures.add(executor.submit(new PredictColumn(c))); 085 } 086 087 for (Future<Long> f : futures) { 088 Long completedCols = f.get(); 089 long elapsedTime = System.currentTimeMillis() - t0; 090 long remainingCols = x.getColumnCount() - completedCols; 091 double colsPerMillisecond = (double) (completedCols + 1) / (double) elapsedTime; 092 long remainingTime = (long) (remainingCols / colsPerMillisecond / 1000.0); 093 System.out.println((completedCols * 1000 / x.getColumnCount() / 10.0) 094 + "% completed (" + remainingTime + " seconds remaining)"); 095 } 096 097 executor.shutdown(); 098 099 } catch (Exception e) { 100 throw new MatrixException(e); 101 } 102 } 103 104 class PredictColumn implements Callable<Long> { 105 106 long column = 0; 107 108 public PredictColumn(long column) { 109 this.column = column; 110 } 111 112 public Long call() throws Exception { 113 Matrix newColumn = replaceInColumn(getSource(), firstGuess, column); 114 for (int r = 0; r < newColumn.getRowCount(); r++) { 115 imputed.setAsDouble(newColumn.getAsDouble(r, 0), r, column); 116 } 117 return column; 118 } 119 120 } 121 122 private static Matrix replaceInColumn(Matrix original, Matrix firstGuess, long column) 123 throws MatrixException { 124 125 Matrix x = firstGuess.deleteColumns(Ret.NEW, column); 126 Matrix y = original.selectColumns(Ret.NEW, column); 127 128 List<Long> missingRows = new ArrayList<Long>(); 129 for (long i = y.getRowCount(); --i >= 0;) { 130 double v = y.getAsDouble(i, 0); 131 if (MathUtil.isNaNOrInfinite(v)) { 132 missingRows.add(i); 133 } 134 } 135 136 if (missingRows.isEmpty()) { 137 return y; 138 } 139 140 Matrix xdel = x.deleteRows(Ret.NEW, missingRows); 141 Matrix bias1 = Matrix.factory.ones(xdel.getRowCount(), 1); 142 Matrix xtrain = MatrixFactory.horCat(xdel, bias1); 143 Matrix ytrain = y.deleteRows(Ret.NEW, missingRows); 144 145 Matrix xinv = xtrain.pinv(); 146 Matrix b = xinv.mtimes(ytrain); 147 Matrix bias2 = Matrix.factory.ones(x.getRowCount(), 1); 148 Matrix yPredicted = MatrixFactory.horCat(x, bias2).mtimes(b); 149 150 // set non-missing values back to original values 151 for (int row = 0; row < y.getRowCount(); row++) { 152 double v = y.getAsDouble(row, 0); 153 if (!Double.isNaN(v)) { 154 yPredicted.setAsDouble(v, row, 0); 155 } 156 } 157 158 return yPredicted; 159 } 160 161 }