001    /*
002     * Copyright (C) 2008-2010 by Holger Arndt
003     *  
004     * This file is part of the Universal Java Matrix Package (UJMP).
005     * See the NOTICE file distributed with this work for additional
006     * information regarding copyright ownership and licensing.
007     *
008     * UJMP is free software; you can redistribute it and/or modify
009     * it under the terms of the GNU Lesser General Public License as
010     * published by the Free Software Foundation; either version 2
011     * of the License, or (at your option) any later version.
012     *
013     * UJMP is distributed in the hope that it will be useful,
014     * but WITHOUT ANY WARRANTY; without even the implied warranty of
015     * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
016     * GNU Lesser General Public License for more details.
017     *
018     * You should have received a copy of the GNU Lesser General Public
019     * License along with UJMP; if not, write to the
020     * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
021     * Boston, MA  02110-1301  USA
022     */
023    
024    package org.ujmp.core.benchmark;
025    
026    import java.io.File;
027    import java.util.List;
028    
029    import org.ujmp.core.Coordinates;
030    import org.ujmp.core.Matrix;
031    import org.ujmp.core.MatrixFactory;
032    import org.ujmp.core.doublematrix.DenseDoubleMatrix2D;
033    import org.ujmp.core.enums.FileFormat;
034    import org.ujmp.core.enums.ValueType;
035    import org.ujmp.core.matrix.Matrix2D;
036    import org.ujmp.core.util.MathUtil;
037    
038    public abstract class AbstractBenchmarkTask {
039    
040            private BenchmarkConfig config = null;
041    
042            private Class<? extends Matrix> matrixClass = null;
043    
044            private List<long[]> sizes = null;
045    
046            private long benchmarkSeed = 0;
047    
048            public AbstractBenchmarkTask(long benchmarkSeed, Class<? extends Matrix> matrixClass,
049                            List<long[]> sizes, BenchmarkConfig config) {
050                    this.matrixClass = matrixClass;
051                    this.config = config;
052                    this.sizes = sizes;
053            }
054    
055            public BenchmarkConfig getConfig() {
056                    return config;
057            }
058    
059            public void run() {
060                    File timeFile = new File(BenchmarkUtil.getResultDir(getConfig()) + getMatrixLabel() + "/"
061                                    + getTaskName() + ".csv");
062                    File diffFile = new File(BenchmarkUtil.getResultDir(getConfig()) + getMatrixLabel() + "/"
063                                    + getTaskName() + "-diff.csv");
064                    File memFile = new File(BenchmarkUtil.getResultDir(getConfig()) + getMatrixLabel() + "/"
065                                    + getTaskName() + "-mem.csv");
066                    if (timeFile.exists()) {
067                            System.out.println("old results available, skipping " + getTaskName() + " for "
068                                            + getMatrixLabel());
069                            return;
070                    }
071                    Matrix2D resultTime = (Matrix2D) MatrixFactory.zeros(ValueType.STRING, config.getRuns(),
072                                    sizes.size());
073                    Matrix2D resultDiff = (Matrix2D) MatrixFactory.zeros(ValueType.STRING, config.getRuns(),
074                                    sizes.size());
075                    Matrix2D resultMem = (Matrix2D) MatrixFactory.zeros(ValueType.STRING, config.getRuns(),
076                                    sizes.size());
077    
078                    resultTime.setLabel(getMatrixLabel() + "-" + getTaskName());
079                    resultDiff.setLabel(getMatrixLabel() + "-" + getTaskName() + "-diff");
080                    resultMem.setLabel(getMatrixLabel() + "-" + getTaskName() + "-mem");
081    
082                    // create column labels for all sizes
083                    for (int s = 0; s < sizes.size(); s++) {
084                            long[] size = sizes.get(s);
085                            resultTime.setColumnLabel(s, String.valueOf(size[Matrix.ROW]));
086                            resultDiff.setColumnLabel(s, String.valueOf(size[Matrix.ROW]));
087                            resultMem.setColumnLabel(s, String.valueOf(size[Matrix.ROW]));
088                    }
089    
090                    boolean stopped = false;
091                    for (int s = 0; !stopped && s < sizes.size(); s++) {
092                            long[] size = sizes.get(s);
093                            double bestStd = Double.MAX_VALUE;
094                            int tmpTrialCount = config.getMinTrialCount();
095                            DenseDoubleMatrix2D curTime = DenseDoubleMatrix2D.factory.zeros(config.getRuns(), 1);
096                            DenseDoubleMatrix2D bestTime = DenseDoubleMatrix2D.factory.zeros(config.getRuns(), 1);
097                            DenseDoubleMatrix2D curDiff = DenseDoubleMatrix2D.factory.zeros(config.getRuns(), 1);
098                            DenseDoubleMatrix2D bestDiff = DenseDoubleMatrix2D.factory.zeros(config.getRuns(), 1);
099                            DenseDoubleMatrix2D curMem = DenseDoubleMatrix2D.factory.zeros(config.getRuns(), 1);
100                            DenseDoubleMatrix2D bestMem = DenseDoubleMatrix2D.factory.zeros(config.getRuns(), 1);
101                            for (int c = 0; !stopped && c < tmpTrialCount; c++) {
102                                    System.out.print(getTaskName() + " [" + Coordinates.toString('x', size) + "] ");
103                                    System.out.print((c + 1) + "/" + tmpTrialCount + ": ");
104                                    System.out.flush();
105    
106                                    for (int i = 0; !stopped && i < config.getBurnInRuns(); i++) {
107                                            long t0 = System.currentTimeMillis();
108                                            BenchmarkResult r = task(matrixClass, benchmarkSeed + c, i, size);
109                                            double t = r.getTime();
110                                            long t1 = System.currentTimeMillis();
111                                            if (t == 0.0 || Double.isNaN(t) || t1 - t0 > config.getMaxTime()) {
112                                                    stopped = true;
113                                            }
114                                            System.out.print("#");
115                                            System.out.flush();
116                                    }
117                                    for (int i = 0; !stopped && i < config.getRuns(); i++) {
118                                            long t0 = System.currentTimeMillis();
119                                            BenchmarkResult r = task(matrixClass, benchmarkSeed + c, i, size);
120                                            double t = r.getTime();
121                                            double diff = r.getDifference();
122                                            long mem = r.getMem();
123                                            long t1 = System.currentTimeMillis();
124                                            if (t == 0.0 || Double.isNaN(t) || t1 - t0 > config.getMaxTime()) {
125                                                    stopped = true;
126                                            }
127                                            curTime.setAsDouble(t, i, 0);
128                                            curDiff.setAsDouble(diff, i, 0);
129                                            curMem.setAsLong(mem, i, 0);
130                                            System.out.print(".");
131                                            System.out.flush();
132                                    }
133    
134                                    double meanTime = curTime.getMeanValue();
135                                    double meanDiff = curDiff.getMeanValue();
136                                    double meanMem = curMem.getMeanValue();
137                                    double stdTime = curTime.getStdValue();
138                                    double percentStd = stdTime / meanTime * 100.0;
139                                    System.out.print(" " + MathUtil.round(meanTime, 3) + "+-"
140                                                    + MathUtil.round(stdTime, 3) + "ms (+-" + MathUtil.round(percentStd, 1)
141                                                    + "%)");
142                                    if (!MathUtil.isNaNOrInfinite(meanDiff)) {
143                                            System.out.print(" diff:" + meanDiff + " ");
144                                    }
145                                    System.out.print(" mem:" + (int) meanMem + " Bytes ");
146                                    if (percentStd > config.getMaxStd()) {
147                                            System.out.print(" standard deviation too large, result discarded");
148                                            if (tmpTrialCount < config.getMaxTrialCount()) {
149                                                    tmpTrialCount++;
150                                            }
151                                    }
152                                    if (percentStd < bestStd) {
153                                            bestStd = percentStd;
154                                            for (int i = 0; i < config.getRuns(); i++) {
155                                                    bestTime.setDouble(curTime.getDouble(i, 0), i, 0);
156                                                    bestDiff.setDouble(curDiff.getDouble(i, 0), i, 0);
157                                                    bestMem.setDouble(curMem.getDouble(i, 0), i, 0);
158                                            }
159                                    }
160                                    System.out.println();
161                            }
162    
163                            for (int i = 0; !stopped && i < config.getRuns(); i++) {
164                                    resultTime.setAsDouble(bestTime.getDouble(i, 0), i, s);
165                                    resultDiff.setAsDouble(bestDiff.getDouble(i, 0), i, s);
166                                    resultMem.setAsDouble(bestMem.getDouble(i, 0), i, s);
167                            }
168                    }
169    
170                    Matrix temp = MatrixFactory.vertCat(resultTime.getAnnotation().getDimensionMatrix(
171                                    Matrix.ROW), resultTime);
172                    Matrix diff = MatrixFactory.vertCat(resultDiff.getAnnotation().getDimensionMatrix(
173                                    Matrix.ROW), resultDiff);
174                    Matrix mem = MatrixFactory.vertCat(
175                                    resultMem.getAnnotation().getDimensionMatrix(Matrix.ROW), resultMem);
176                    try {
177                            temp.exportToFile(FileFormat.CSV, timeFile);
178                            mem.exportToFile(FileFormat.CSV, memFile);
179                            if (!diff.containsMissingValues()) {
180                                    diff.exportToFile(FileFormat.CSV, diffFile);
181                            }
182                    } catch (Exception e) {
183                            e.printStackTrace();
184                    }
185            }
186    
187            public abstract BenchmarkResult task(Class<? extends Matrix> matrixClass, long benchmarkSeed,
188                            int run, long[] size);
189    
190            public abstract String getTaskName();
191    
192            public String getMatrixLabel() {
193                    return matrixClass.getSimpleName();
194            }
195    
196    }