001    /*
002     * Copyright (C) 2008-2010 by Holger Arndt
003     *
004     * This file is part of the Universal Java Matrix Package (UJMP).
005     * See the NOTICE file distributed with this work for additional
006     * information regarding copyright ownership and licensing.
007     *
008     * UJMP is free software; you can redistribute it and/or modify
009     * it under the terms of the GNU Lesser General Public License as
010     * published by the Free Software Foundation; either version 2
011     * of the License, or (at your option) any later version.
012     *
013     * UJMP is distributed in the hope that it will be useful,
014     * but WITHOUT ANY WARRANTY; without even the implied warranty of
015     * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
016     * GNU Lesser General Public License for more details.
017     *
018     * You should have received a copy of the GNU Lesser General Public
019     * License along with UJMP; if not, write to the
020     * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
021     * Boston, MA  02110-1301  USA
022     */
023    
024    package org.ujmp.core.doublematrix.calculation.general.statistical;
025    
026    import java.util.Arrays;
027    import java.util.HashMap;
028    import java.util.Map;
029    
030    import org.ujmp.core.Matrix;
031    import org.ujmp.core.MatrixFactory;
032    import org.ujmp.core.doublematrix.DoubleMatrix2D;
033    import org.ujmp.core.doublematrix.calculation.AbstractDoubleCalculation;
034    import org.ujmp.core.enums.ValueType;
035    import org.ujmp.core.exceptions.MatrixException;
036    import org.ujmp.core.intmatrix.IntMatrix2D;
037    import org.ujmp.core.intmatrix.impl.DefaultDenseIntMatrix2D;
038    import org.ujmp.core.util.MathUtil;
039    
040    public class MutualInformation extends AbstractDoubleCalculation {
041            private static final long serialVersionUID = -4891250637894943873L;
042    
043            public MutualInformation(Matrix matrix) {
044                    super(matrix);
045            }
046    
047            
048            public double getDouble(long... coordinates) throws MatrixException {
049                    return calculate(coordinates[ROW], coordinates[COLUMN], getSource());
050            }
051    
052            
053            public long[] getSize() {
054                    return new long[] { getSource().getColumnCount(), getSource().getColumnCount() };
055            }
056    
057            public static final double calculate(long var1, long var2, Matrix matrix) {
058                    double count = matrix.getRowCount();
059    
060                    Map<Double, Double> count1 = new HashMap<Double, Double>();
061                    Map<Double, Double> count2 = new HashMap<Double, Double>();
062                    Map<String, Double> count12 = new HashMap<String, Double>();
063    
064                    // count absolute frequency
065                    for (int r = 0; r < matrix.getRowCount(); r++) {
066                            double value1 = matrix.getAsDouble(r, var1);
067                            double value2 = matrix.getAsDouble(r, var2);
068    
069                            Double c1 = count1.get(value1);
070                            c1 = (c1 == null) ? 0.0 : c1;
071                            count1.put(value1, c1 + 1.0);
072    
073                            Double c2 = count2.get(value2);
074                            c2 = (c2 == null) ? 0.0 : c2;
075                            count2.put(value2, c2 + 1);
076    
077                            Double c12 = count12.get(value1 + "," + value2);
078                            c12 = (c12 == null) ? 0.0 : c12;
079                            count12.put(value1 + "," + value2, c12 + 1);
080                    }
081    
082                    // calculate relative frequency
083                    for (Double value1 : count1.keySet()) {
084                            Double c1 = count1.get(value1);
085                            count1.put(value1, c1 / count);
086                    }
087    
088                    for (Double value2 : count2.keySet()) {
089                            Double c2 = count2.get(value2);
090                            count2.put(value2, c2 / count);
091                    }
092    
093                    for (String value12 : count12.keySet()) {
094                            Double c12 = count12.get(value12);
095                            count12.put(value12, c12 / count);
096                    }
097    
098                    // calculate mutual information
099                    double mutualInformation = 0.0;
100                    for (Double value1 : count1.keySet()) {
101                            double p1 = count1.get(value1);
102                            for (Double value2 : count2.keySet()) {
103                                    double p2 = count2.get(value2);
104                                    Double p12 = count12.get(value1 + "," + value2);
105                                    if (p12 != null) {
106                                            mutualInformation += p12 * MathUtil.log2(p12 / (p1 * p2));
107                                    }
108                            }
109                    }
110    
111                    // System.out.println(count1);
112                    // System.out.println(count2);
113                    // System.out.println(count12);
114                    // System.out.println(mutualInformation);
115    
116                    return mutualInformation;
117            }
118    
119            public static DoubleMatrix2D calcNew(Matrix matrix) {
120                    return calcNew(matrix.convert(ValueType.INT));
121            }
122    
123            public static DoubleMatrix2D calcNew(IntMatrix2D matrix) {
124                    DefaultDenseIntMatrix2D matrix2 = (DefaultDenseIntMatrix2D) matrix;
125                    long count = matrix.getColumnCount();
126                    int samples = (int) matrix.getRowCount();
127                    DoubleMatrix2D result = (DoubleMatrix2D) MatrixFactory
128                                    .zeros(ValueType.DOUBLE, count, count);
129                    int[] d_dc = new int[(int) count];
130                    // int[][] matrixInt = matrix.toIntArray();
131                    Arrays.fill(d_dc, (int) matrix.getMaxValue() + 1);
132                    int aVal, bVal;
133                    for (int a = 0; a < count; a++) {
134                            for (int b = 0; b <= a; b++) {
135                                    double mutual = 0;
136    
137                                    double[][] Nab = new double[d_dc[a]][d_dc[b]];
138                                    double[] Na = new double[d_dc[a]];
139                                    double[] Nb = new double[d_dc[b]];
140                                    for (int k = (int) matrix.getRowCount() - 1; k >= 0; k--) {
141                                            aVal = matrix2.getInt(k, a);// dataset[aIndex][k];
142                                            bVal = matrix2.getInt(k, b);// dataset[bIndex][k];
143                                            // aVal = matrixInt[k][a];
144                                            // bVal = matrixInt[k][b];
145                                            Na[aVal]++;
146                                            Nb[bVal]++;
147                                            Nab[aVal][bVal]++;
148                                    }
149                                    double[] NaLog = new double[d_dc[a]];
150                                    double[] NbLog = new double[d_dc[b]];
151                                    double log2 = Math.log(2);
152                                    for (int j = d_dc[b] - 1; j >= 0; j--) {
153                                            Nb[j] /= samples;
154                                            if (Nb[j] != 0)
155                                                    NbLog[j] = Math.log(Nb[j]);
156                                    }
157                                    for (int i = d_dc[a] - 1; i >= 0; i--) {
158                                            Na[i] /= samples;
159                                            if (Na[i] != 0)
160                                                    NaLog[i] = Math.log(Na[i]);
161                                            for (int j = d_dc[b] - 1; j >= 0; j--) {
162                                                    Nab[i][j] /= samples;
163    
164                                                    if (Na[i] != 0 && Nb[j] != 0 && Nab[i][j] != 0) {
165                                                            mutual += Nab[i][j] * (Math.log(Nab[i][j]) - NaLog[i] - NbLog[j])
166                                                                            / log2;
167                                                    }
168                                            }
169                                    }
170                                    mutual = (mutual < 0) ? 0 : mutual;
171                                    result.setDouble(mutual, a, b);
172                                    result.setDouble(mutual, b, a);
173                            }
174    
175                    }
176    
177                    return result;
178            }
179    }