/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta.ensembleSelection;

import java.util.Random;
import weka.classifiers.Evaluation;
import weka.classifiers.meta.ensembleSelection.EnsembleMetricHelper;
import weka.core.Instances;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;

public class ModelBag
implements RevisionHandler {
    private double[][][] m_models;
    private int[] m_modelIndex;
    private int m_bagSize;
    private int m_numChosen;
    private int[] m_timesChosen;
    private boolean m_debug;
    private double m_bestPerformance;
    private int[] m_bestTimesChosen;

    public ModelBag(double[][][] models, double bag_percent, boolean debug) {
        this.m_debug = debug;
        if (models.length == 0) {
            throw new IllegalArgumentException("ModelBag needs at least 1 model.");
        }
        this.m_bagSize = (int)((double)models.length * bag_percent);
        this.m_models = models;
        this.m_modelIndex = new int[this.m_models.length];
        this.m_timesChosen = new int[this.m_models.length];
        this.m_bestTimesChosen = this.m_timesChosen;
        this.m_bestPerformance = 0.0;
        this.m_numChosen = 0;
        for (int i = 0; i < this.m_models.length; ++i) {
            this.m_modelIndex[i] = i;
            this.m_timesChosen[i] = 0;
        }
    }

    private void swap(int i, int j) {
        if (i != j) {
            int temp_index = this.m_modelIndex[i];
            this.m_modelIndex[i] = this.m_modelIndex[j];
            this.m_modelIndex[j] = temp_index;
            int tempWeight = this.m_timesChosen[i];
            this.m_timesChosen[i] = this.m_timesChosen[j];
            this.m_timesChosen[j] = tempWeight;
        }
    }

    public void shuffle(Random rand) {
        if (this.m_models.length < 2) {
            return;
        }
        for (int i = 0; i < this.m_models.length; ++i) {
            int swap_index = rand.nextInt(this.m_models.length - 1);
            if (swap_index >= i) {
                ++swap_index;
            }
            this.swap(i, swap_index);
        }
    }

    private int[] virtualToRealWeights(int[] virtual_weights) {
        int[] real_weights = new int[virtual_weights.length];
        for (int i = 0; i < real_weights.length; ++i) {
            real_weights[this.m_modelIndex[i]] = virtual_weights[i];
        }
        return real_weights;
    }

    private void updateBestTimesChosen() {
        this.m_bestTimesChosen = this.virtualToRealWeights(this.m_timesChosen);
    }

    public int[] sortInitialize(int num, boolean greedy, Instances instances, int metric) throws Exception {
        double[] performance = new double[this.m_bagSize];
        for (int i = 0; i < this.m_bagSize; ++i) {
            performance[i] = this.evaluatePredictions(instances, this.model(i), metric);
        }
        int[] bestModels = new int[num];
        for (int i = 0; i < num; ++i) {
            int max_index = i;
            double max_value = performance[i];
            for (int j = i + 1; j < this.m_bagSize; ++j) {
                if (!(performance[j] > max_value)) continue;
                max_value = performance[j];
                max_index = j;
            }
            this.swap(i, max_index);
            double temp_perf = performance[i];
            performance[i] = performance[max_index];
            performance[max_index] = temp_perf;
            bestModels[i] = this.m_modelIndex[i];
            if (greedy) continue;
            int n = i;
            this.m_timesChosen[n] = this.m_timesChosen[n] + 1;
            ++this.m_numChosen;
        }
        if (greedy) {
            double metric_value;
            double[][] tempPredictions = null;
            double bestPerformance = 0.0;
            if (num > 0) {
                this.m_timesChosen[0] = this.m_timesChosen[0] + 1;
                ++this.m_numChosen;
                this.updateBestTimesChosen();
            }
            int i = 1;
            while (i < num && (metric_value = this.evaluatePredictions(instances, tempPredictions = this.computePredictions(i, true), metric)) > bestPerformance) {
                bestPerformance = metric_value;
                int n = i++;
                this.m_timesChosen[n] = this.m_timesChosen[n] + 1;
                ++this.m_numChosen;
                this.updateBestTimesChosen();
            }
        }
        this.updateBestTimesChosen();
        if (this.m_debug) {
            System.out.println("Sort Initialization added best " + this.m_numChosen + " models to the bag.");
        }
        return bestModels;
    }

    public void weightAll(int weight) {
        int i = 0;
        while (i < this.m_bagSize) {
            int n = i++;
            this.m_timesChosen[n] = this.m_timesChosen[n] + weight;
            this.m_numChosen += weight;
        }
        this.updateBestTimesChosen();
    }

    public void forwardSelect(boolean withReplacement, Instances instances, int metric) throws Exception {
        double bestPerformance = -1.0;
        int bestIndex = -1;
        for (int i = 0; i < this.m_bagSize; ++i) {
            double[][] tempPredictions;
            double metric_value;
            if (this.m_timesChosen[i] != 0 && !withReplacement || !((metric_value = this.evaluatePredictions(instances, tempPredictions = this.computePredictions(i, true), metric)) > bestPerformance)) continue;
            bestIndex = i;
            bestPerformance = metric_value;
        }
        if (bestIndex == -1) {
            if (this.m_debug) {
                System.out.println("Couldn't add model.  No action performed.");
            }
            return;
        }
        int n = bestIndex;
        this.m_timesChosen[n] = this.m_timesChosen[n] + 1;
        ++this.m_numChosen;
        if (bestPerformance > this.m_bestPerformance) {
            this.updateBestTimesChosen();
            this.m_bestPerformance = bestPerformance;
        }
    }

    public void backwardEliminate(Instances instances, int metric) throws Exception {
        if (this.m_numChosen <= 1) {
            return;
        }
        double bestPerformance = -1.0;
        int bestIndex = -1;
        for (int i = 0; i < this.m_bagSize; ++i) {
            double[][] tempPredictions;
            double metric_value;
            if (this.m_timesChosen[i] <= 0 || !((metric_value = this.evaluatePredictions(instances, tempPredictions = this.computePredictions(i, false), metric)) > bestPerformance)) continue;
            bestIndex = i;
            bestPerformance = metric_value;
        }
        if (bestIndex == -1) {
            if (this.m_debug) {
                System.out.println("Couldn't remove model.  No action performed.");
            }
            return;
        }
        int n = bestIndex;
        this.m_timesChosen[n] = this.m_timesChosen[n] - 1;
        --this.m_numChosen;
        if (this.m_debug) {
            System.out.println("Removing model " + this.m_modelIndex[bestIndex] + " (" + bestIndex + ") " + bestPerformance);
        }
        if (bestPerformance > this.m_bestPerformance) {
            this.updateBestTimesChosen();
            this.m_bestPerformance = bestPerformance;
        }
    }

    public void forwardSelectOrBackwardEliminate(boolean with_replacement, Instances instances, int metric) throws Exception {
        double bestPerformance = -1.0;
        int bestIndex = -1;
        boolean added = true;
        for (int i = 0; i < this.m_bagSize; ++i) {
            double[][] tempPredictions;
            double metric_value;
            if (this.m_timesChosen[i] > 0 && (metric_value = this.evaluatePredictions(instances, tempPredictions = this.computePredictions(i, false), metric)) > bestPerformance) {
                bestIndex = i;
                bestPerformance = metric_value;
                added = false;
            }
            if (this.m_timesChosen[i] != 0 && !with_replacement || !((metric_value = this.evaluatePredictions(instances, tempPredictions = this.computePredictions(i, true), metric)) > bestPerformance)) continue;
            bestIndex = i;
            bestPerformance = metric_value;
            added = true;
        }
        if (bestIndex == -1) {
            if (this.m_debug) {
                System.out.println("Couldn't add or remove model.  No action performed.");
            }
            return;
        }
        int changeInWeight = added ? 1 : -1;
        int n = bestIndex;
        this.m_timesChosen[n] = this.m_timesChosen[n] + changeInWeight;
        this.m_numChosen += changeInWeight;
        if (bestPerformance > this.m_bestPerformance) {
            this.updateBestTimesChosen();
            this.m_bestPerformance = bestPerformance;
        }
    }

    public int[] getModelWeights() {
        return this.m_bestTimesChosen;
    }

    private double[][] model(int index) {
        return this.m_models[this.m_modelIndex[index]];
    }

    private double[][] computePredictions(int index_to_change, boolean add) {
        int k;
        double[][] predictions = new double[this.m_models[0].length][this.m_models[0][0].length];
        for (int i = 0; i < this.m_bagSize; ++i) {
            if (this.m_timesChosen[i] <= 0) continue;
            for (int j = 0; j < this.m_models[0].length; ++j) {
                for (k = 0; k < this.m_models[0][j].length; ++k) {
                    double[] dArray = predictions[j];
                    int n = k;
                    dArray[n] = dArray[n] + this.model(i)[j][k] * (double)this.m_timesChosen[i];
                }
            }
        }
        for (int j = 0; j < this.m_models[0].length; ++j) {
            int change = add ? 1 : -1;
            k = 0;
            while (k < this.m_models[0][j].length) {
                double[] dArray = predictions[j];
                int n = k;
                dArray[n] = dArray[n] + (double)change * this.model(index_to_change)[j][k];
                double[] dArray2 = predictions[j];
                int n2 = k++;
                dArray2[n2] = dArray2[n2] / (double)(this.m_numChosen + change);
            }
        }
        return predictions;
    }

    private double evaluatePredictions(Instances instances, double[][] temp_predictions, int metric) throws Exception {
        Evaluation eval = new Evaluation(instances);
        for (int i = 0; i < instances.numInstances(); ++i) {
            eval.evaluateModelOnceAndRecordPrediction(temp_predictions[i], instances.instance(i));
        }
        return EnsembleMetricHelper.getMetric(eval, metric);
    }

    public double[] getIndividualPerformance(Instances instances, int metric) throws Exception {
        double[] performance = new double[this.m_bagSize];
        for (int i = 0; i < this.m_bagSize; ++i) {
            performance[i] = this.evaluatePredictions(instances, this.model(i), metric);
        }
        return performance;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.2 $");
    }
}

