001 /* 002 * Copyright (C) 2008-2010 by Holger Arndt 003 * 004 * This file is part of the Universal Java Matrix Package (UJMP). 005 * See the NOTICE file distributed with this work for additional 006 * information regarding copyright ownership and licensing. 007 * 008 * UJMP is free software; you can redistribute it and/or modify 009 * it under the terms of the GNU Lesser General Public License as 010 * published by the Free Software Foundation; either version 2 011 * of the License, or (at your option) any later version. 012 * 013 * UJMP is distributed in the hope that it will be useful, 014 * but WITHOUT ANY WARRANTY; without even the implied warranty of 015 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 016 * GNU Lesser General Public License for more details. 017 * 018 * You should have received a copy of the GNU Lesser General Public 019 * License along with UJMP; if not, write to the 020 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, 021 * Boston, MA 02110-1301 USA 022 */ 023 024 package org.ujmp.core.doublematrix.calculation.general.statistical; 025 026 import java.util.Arrays; 027 import java.util.HashMap; 028 import java.util.Map; 029 030 import org.ujmp.core.Matrix; 031 import org.ujmp.core.MatrixFactory; 032 import org.ujmp.core.doublematrix.DoubleMatrix2D; 033 import org.ujmp.core.doublematrix.calculation.AbstractDoubleCalculation; 034 import org.ujmp.core.enums.ValueType; 035 import org.ujmp.core.exceptions.MatrixException; 036 import org.ujmp.core.intmatrix.IntMatrix2D; 037 import org.ujmp.core.intmatrix.impl.DefaultDenseIntMatrix2D; 038 import org.ujmp.core.util.MathUtil; 039 040 public class MutualInformation extends AbstractDoubleCalculation { 041 private static final long serialVersionUID = -4891250637894943873L; 042 043 public MutualInformation(Matrix matrix) { 044 super(matrix); 045 } 046 047 048 public double getDouble(long... coordinates) throws MatrixException { 049 return calculate(coordinates[ROW], coordinates[COLUMN], getSource()); 050 } 051 052 053 public long[] getSize() { 054 return new long[] { getSource().getColumnCount(), getSource().getColumnCount() }; 055 } 056 057 public static final double calculate(long var1, long var2, Matrix matrix) { 058 double count = matrix.getRowCount(); 059 060 Map<Double, Double> count1 = new HashMap<Double, Double>(); 061 Map<Double, Double> count2 = new HashMap<Double, Double>(); 062 Map<String, Double> count12 = new HashMap<String, Double>(); 063 064 // count absolute frequency 065 for (int r = 0; r < matrix.getRowCount(); r++) { 066 double value1 = matrix.getAsDouble(r, var1); 067 double value2 = matrix.getAsDouble(r, var2); 068 069 Double c1 = count1.get(value1); 070 c1 = (c1 == null) ? 0.0 : c1; 071 count1.put(value1, c1 + 1.0); 072 073 Double c2 = count2.get(value2); 074 c2 = (c2 == null) ? 0.0 : c2; 075 count2.put(value2, c2 + 1); 076 077 Double c12 = count12.get(value1 + "," + value2); 078 c12 = (c12 == null) ? 0.0 : c12; 079 count12.put(value1 + "," + value2, c12 + 1); 080 } 081 082 // calculate relative frequency 083 for (Double value1 : count1.keySet()) { 084 Double c1 = count1.get(value1); 085 count1.put(value1, c1 / count); 086 } 087 088 for (Double value2 : count2.keySet()) { 089 Double c2 = count2.get(value2); 090 count2.put(value2, c2 / count); 091 } 092 093 for (String value12 : count12.keySet()) { 094 Double c12 = count12.get(value12); 095 count12.put(value12, c12 / count); 096 } 097 098 // calculate mutual information 099 double mutualInformation = 0.0; 100 for (Double value1 : count1.keySet()) { 101 double p1 = count1.get(value1); 102 for (Double value2 : count2.keySet()) { 103 double p2 = count2.get(value2); 104 Double p12 = count12.get(value1 + "," + value2); 105 if (p12 != null) { 106 mutualInformation += p12 * MathUtil.log2(p12 / (p1 * p2)); 107 } 108 } 109 } 110 111 // System.out.println(count1); 112 // System.out.println(count2); 113 // System.out.println(count12); 114 // System.out.println(mutualInformation); 115 116 return mutualInformation; 117 } 118 119 public static DoubleMatrix2D calcNew(Matrix matrix) { 120 return calcNew(matrix.convert(ValueType.INT)); 121 } 122 123 public static DoubleMatrix2D calcNew(IntMatrix2D matrix) { 124 DefaultDenseIntMatrix2D matrix2 = (DefaultDenseIntMatrix2D) matrix; 125 long count = matrix.getColumnCount(); 126 int samples = (int) matrix.getRowCount(); 127 DoubleMatrix2D result = (DoubleMatrix2D) MatrixFactory 128 .zeros(ValueType.DOUBLE, count, count); 129 int[] d_dc = new int[(int) count]; 130 // int[][] matrixInt = matrix.toIntArray(); 131 Arrays.fill(d_dc, (int) matrix.getMaxValue() + 1); 132 int aVal, bVal; 133 for (int a = 0; a < count; a++) { 134 for (int b = 0; b <= a; b++) { 135 double mutual = 0; 136 137 double[][] Nab = new double[d_dc[a]][d_dc[b]]; 138 double[] Na = new double[d_dc[a]]; 139 double[] Nb = new double[d_dc[b]]; 140 for (int k = (int) matrix.getRowCount() - 1; k >= 0; k--) { 141 aVal = matrix2.getInt(k, a);// dataset[aIndex][k]; 142 bVal = matrix2.getInt(k, b);// dataset[bIndex][k]; 143 // aVal = matrixInt[k][a]; 144 // bVal = matrixInt[k][b]; 145 Na[aVal]++; 146 Nb[bVal]++; 147 Nab[aVal][bVal]++; 148 } 149 double[] NaLog = new double[d_dc[a]]; 150 double[] NbLog = new double[d_dc[b]]; 151 double log2 = Math.log(2); 152 for (int j = d_dc[b] - 1; j >= 0; j--) { 153 Nb[j] /= samples; 154 if (Nb[j] != 0) 155 NbLog[j] = Math.log(Nb[j]); 156 } 157 for (int i = d_dc[a] - 1; i >= 0; i--) { 158 Na[i] /= samples; 159 if (Na[i] != 0) 160 NaLog[i] = Math.log(Na[i]); 161 for (int j = d_dc[b] - 1; j >= 0; j--) { 162 Nab[i][j] /= samples; 163 164 if (Na[i] != 0 && Nb[j] != 0 && Nab[i][j] != 0) { 165 mutual += Nab[i][j] * (Math.log(Nab[i][j]) - NaLog[i] - NbLog[j]) 166 / log2; 167 } 168 } 169 } 170 mutual = (mutual < 0) ? 0 : mutual; 171 result.setDouble(mutual, a, b); 172 result.setDouble(mutual, b, a); 173 } 174 175 } 176 177 return result; 178 } 179 }