/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.bayes;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class AODEsr
extends AbstractClassifier
implements OptionHandler,
WeightedInstancesHandler,
UpdateableClassifier,
TechnicalInformationHandler {
    static final long serialVersionUID = 5602143019183068848L;
    private double[][][] m_CondiCounts;
    private double[][] m_CondiCountsNoClass;
    private double[] m_ClassCounts;
    private double[][] m_SumForCounts;
    private int m_NumClasses;
    private int m_NumAttributes;
    private int m_NumInstances;
    private int m_ClassIndex;
    private Instances m_Instances;
    private int m_TotalAttValues;
    private int[] m_StartAttIndex;
    private int[] m_NumAttValues;
    private double[] m_Frequencies;
    private double m_SumInstances;
    private int m_Limit = 1;
    private boolean m_Debug = false;
    protected double m_MWeight = 1.0;
    private boolean m_Laplace = false;
    private int m_Critical = 50;

    public String globalInfo() {
        return "AODEsr augments AODE with Subsumption Resolution.AODEsr detects specializations between two attribute values at classification time and deletes the generalization attribute value.\nFor more information, see:\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Fei Zheng and Geoffrey I. Webb");
        result.setValue(TechnicalInformation.Field.YEAR, "2006");
        result.setValue(TechnicalInformation.Field.TITLE, "Efficient Lazy Elimination for Averaged-One Dependence Estimators");
        result.setValue(TechnicalInformation.Field.PAGES, "1113-1120");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Proceedings of the Twenty-third International Conference on Machine  Learning (ICML 2006)");
        result.setValue(TechnicalInformation.Field.PUBLISHER, "ACM Press");
        result.setValue(TechnicalInformation.Field.ISBN, "1-59593-383-2");
        return result;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setMinimumNumberInstances(0);
        return result;
    }

    @Override
    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().testWithFail(instances);
        this.m_Instances = new Instances(instances);
        this.m_Instances.deleteWithMissingClass();
        this.m_SumInstances = 0.0;
        this.m_ClassIndex = instances.classIndex();
        this.m_NumInstances = this.m_Instances.numInstances();
        this.m_NumAttributes = instances.numAttributes();
        this.m_NumClasses = instances.numClasses();
        this.m_StartAttIndex = new int[this.m_NumAttributes];
        this.m_NumAttValues = new int[this.m_NumAttributes];
        this.m_TotalAttValues = 0;
        for (int i = 0; i < this.m_NumAttributes; ++i) {
            if (i != this.m_ClassIndex) {
                this.m_StartAttIndex[i] = this.m_TotalAttValues;
                this.m_NumAttValues[i] = this.m_Instances.attribute(i).numValues();
                this.m_TotalAttValues += this.m_NumAttValues[i] + 1;
                continue;
            }
            this.m_NumAttValues[i] = this.m_NumClasses;
        }
        this.m_CondiCounts = new double[this.m_NumClasses][this.m_TotalAttValues][this.m_TotalAttValues];
        this.m_ClassCounts = new double[this.m_NumClasses];
        this.m_SumForCounts = new double[this.m_NumClasses][this.m_NumAttributes];
        this.m_Frequencies = new double[this.m_TotalAttValues];
        this.m_CondiCountsNoClass = new double[this.m_TotalAttValues][this.m_TotalAttValues];
        for (int k = 0; k < this.m_NumInstances; ++k) {
            this.addToCounts(this.m_Instances.instance(k));
        }
        this.m_Instances = new Instances(this.m_Instances, 0);
    }

    @Override
    public void updateClassifier(Instance instance) {
        this.addToCounts(instance);
    }

    private void addToCounts(Instance instance) {
        if (instance.classIsMissing()) {
            return;
        }
        int classVal = (int)instance.classValue();
        double weight = instance.weight();
        int n = classVal;
        this.m_ClassCounts[n] = this.m_ClassCounts[n] + weight;
        this.m_SumInstances += weight;
        int[] attIndex = new int[this.m_NumAttributes];
        for (int i = 0; i < this.m_NumAttributes; ++i) {
            attIndex[i] = i == this.m_ClassIndex ? -1 : (instance.isMissing(i) ? this.m_StartAttIndex[i] + this.m_NumAttValues[i] : this.m_StartAttIndex[i] + (int)instance.value(i));
        }
        for (int Att1 = 0; Att1 < this.m_NumAttributes; ++Att1) {
            if (attIndex[Att1] == -1) continue;
            int n2 = attIndex[Att1];
            this.m_Frequencies[n2] = this.m_Frequencies[n2] + weight;
            if (!instance.isMissing(Att1)) {
                double[] dArray = this.m_SumForCounts[classVal];
                int n3 = Att1;
                dArray[n3] = dArray[n3] + weight;
            }
            double[] countsPointer = this.m_CondiCounts[classVal][attIndex[Att1]];
            double[] countsNoClassPointer = this.m_CondiCountsNoClass[attIndex[Att1]];
            for (int Att2 = 0; Att2 < this.m_NumAttributes; ++Att2) {
                if (attIndex[Att2] == -1) continue;
                int n4 = attIndex[Att2];
                countsPointer[n4] = countsPointer[n4] + weight;
                int n5 = attIndex[Att2];
                countsNoClassPointer[n5] = countsNoClassPointer[n5] + weight;
            }
        }
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        int i;
        double[] probs = new double[this.m_NumClasses];
        int[] SpecialGeneralArray = new int[this.m_NumAttributes];
        int[] attIndex = new int[this.m_NumAttributes];
        for (int att = 0; att < this.m_NumAttributes; ++att) {
            attIndex[att] = instance.isMissing(att) || att == this.m_ClassIndex ? -1 : this.m_StartAttIndex[att] + (int)instance.value(att);
        }
        for (i = 0; i < this.m_NumAttributes; ++i) {
            SpecialGeneralArray[i] = -1;
        }
        block2: for (i = 0; i < this.m_NumAttributes; ++i) {
            if (attIndex[i] == -1) continue;
            double[] countsForAtti = this.m_CondiCountsNoClass[attIndex[i]];
            for (int j = 0; j < this.m_NumAttributes; ++j) {
                double[] countsForAttj;
                if (attIndex[j] == -1 || i == j || SpecialGeneralArray[j] == i || !((countsForAttj = this.m_CondiCountsNoClass[attIndex[j]])[attIndex[j]] > (double)this.m_Critical) || countsForAttj[attIndex[j]] != countsForAtti[attIndex[j]] || countsForAttj[attIndex[j]] == countsForAtti[attIndex[i]] && i < j) continue;
                SpecialGeneralArray[i] = j;
                continue block2;
            }
        }
        for (int classVal = 0; classVal < this.m_NumClasses; ++classVal) {
            probs[classVal] = 0.0;
            double x = 0.0;
            int parentCount = 0;
            double[][] countsForClass = this.m_CondiCounts[classVal];
            for (int parent = 0; parent < this.m_NumAttributes; ++parent) {
                int pIndex;
                if (attIndex[parent] == -1 || this.m_Frequencies[pIndex = attIndex[parent]] < (double)this.m_Limit || SpecialGeneralArray[parent] != -1) continue;
                double[] countsForClassParent = countsForClass[pIndex];
                attIndex[parent] = -1;
                ++parentCount;
                double classparentfreq = countsForClassParent[pIndex];
                double missing4ParentAtt = this.m_Frequencies[this.m_StartAttIndex[parent] + this.m_NumAttValues[parent]];
                x = this.m_Laplace ? this.LaplaceEstimate(classparentfreq, this.m_SumInstances - missing4ParentAtt, this.m_NumClasses * this.m_NumAttValues[parent]) : this.MEstimate(classparentfreq, this.m_SumInstances - missing4ParentAtt, this.m_NumClasses * this.m_NumAttValues[parent]);
                for (int att = 0; att < this.m_NumAttributes; ++att) {
                    if (attIndex[att] == -1 || SpecialGeneralArray[att] != -1) continue;
                    double missingForParentandChildAtt = countsForClassParent[this.m_StartAttIndex[att] + this.m_NumAttValues[att]];
                    if (this.m_Laplace) {
                        x *= this.LaplaceEstimate(countsForClassParent[attIndex[att]], classparentfreq - missingForParentandChildAtt, this.m_NumAttValues[att]);
                        continue;
                    }
                    x *= this.MEstimate(countsForClassParent[attIndex[att]], classparentfreq - missingForParentandChildAtt, this.m_NumAttValues[att]);
                }
                int n = classVal;
                probs[n] = probs[n] + x;
                attIndex[parent] = pIndex;
            }
            if (parentCount < 1) {
                probs[classVal] = this.NBconditionalProb(instance, classVal);
                continue;
            }
            int n = classVal;
            probs[n] = probs[n] / (double)parentCount;
        }
        Utils.normalize(probs);
        return probs;
    }

    public double NBconditionalProb(Instance instance, int classVal) throws Exception {
        double prob = this.m_Laplace ? this.LaplaceEstimate(this.m_ClassCounts[classVal], this.m_SumInstances, this.m_NumClasses) : this.MEstimate(this.m_ClassCounts[classVal], this.m_SumInstances, this.m_NumClasses);
        double[][] pointer = this.m_CondiCounts[classVal];
        for (int att = 0; att < this.m_NumAttributes; ++att) {
            if (att == this.m_ClassIndex || instance.isMissing(att)) continue;
            int attIndex = this.m_StartAttIndex[att] + (int)instance.value(att);
            if (this.m_Laplace) {
                prob *= this.LaplaceEstimate(pointer[attIndex][attIndex], this.m_SumForCounts[classVal][att], this.m_NumAttValues[att]);
                continue;
            }
            prob *= this.MEstimate(pointer[attIndex][attIndex], this.m_SumForCounts[classVal][att], this.m_NumAttValues[att]);
        }
        return prob;
    }

    public double MEstimate(double frequency, double total, double numValues) {
        return (frequency + this.m_MWeight / numValues) / (total + this.m_MWeight);
    }

    public double LaplaceEstimate(double frequency, double total, double numValues) {
        return (frequency + 1.0) / (total + numValues);
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(5);
        newVector.addElement(new Option("\tOutput debugging information\n", "D", 0, "-D"));
        newVector.addElement(new Option("\tImpose a critcal value for specialization-generalization relationship\n\t(default is 50)", "C", 1, "-C"));
        newVector.addElement(new Option("\tImpose a frequency limit for superParents\n\t(default is 1)", "F", 2, "-F"));
        newVector.addElement(new Option("\tUsing Laplace estimation\n\t(default is m-esimation (m=1))", "L", 3, "-L"));
        newVector.addElement(new Option("\tWeight value for m-estimation\n\t(default is 1.0)", "M", 4, "-M"));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.m_Debug = Utils.getFlag('D', options);
        String Critical = Utils.getOption('C', options);
        this.m_Critical = Critical.length() != 0 ? Integer.parseInt(Critical) : 50;
        String Freq = Utils.getOption('F', options);
        this.m_Limit = Freq.length() != 0 ? Integer.parseInt(Freq) : 1;
        this.m_Laplace = Utils.getFlag('L', options);
        String MWeight = Utils.getOption('M', options);
        if (MWeight.length() != 0) {
            if (this.m_Laplace) {
                throw new Exception("weight for m-estimate is pointless if using laplace estimation!");
            }
            this.m_MWeight = Double.parseDouble(MWeight);
        } else {
            this.m_MWeight = 1.0;
        }
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        if (this.m_Debug) {
            result.add("-D");
        }
        result.add("-F");
        result.add("" + this.m_Limit);
        if (this.m_Laplace) {
            result.add("-L");
        } else {
            result.add("-M");
            result.add("" + this.m_MWeight);
        }
        result.add("-C");
        result.add("" + this.m_Critical);
        return result.toArray(new String[result.size()]);
    }

    public String mestWeightTipText() {
        return "Set the weight for m-estimate.";
    }

    public void setMestWeight(double w) {
        if (this.getUseLaplace()) {
            System.out.println("Weight is only used in conjunction with m-estimate - ignored!");
        } else if (w > 0.0) {
            this.m_MWeight = w;
        } else {
            System.out.println("M-Estimate Weight must be greater than 0!");
        }
    }

    public double getMestWeight() {
        return this.m_MWeight;
    }

    public String useLaplaceTipText() {
        return "Use Laplace correction instead of m-estimation.";
    }

    public boolean getUseLaplace() {
        return this.m_Laplace;
    }

    public void setUseLaplace(boolean value) {
        this.m_Laplace = value;
    }

    public String frequencyLimitTipText() {
        return "Attributes with a frequency in the train set below this value aren't used as parents.";
    }

    public void setFrequencyLimit(int f) {
        this.m_Limit = f;
    }

    public int getFrequencyLimit() {
        return this.m_Limit;
    }

    public String criticalValueTipText() {
        return "Specify critical value for specialization-generalization relationship (default 50).";
    }

    public void setCriticalValue(int c) {
        this.m_Critical = c;
    }

    public int getCriticalValue() {
        return this.m_Critical;
    }

    public String toString() {
        StringBuffer text = new StringBuffer();
        text.append("The AODEsr Classifier");
        if (this.m_Instances == null) {
            text.append(": No model built yet.");
        } else {
            try {
                for (int i = 0; i < this.m_NumClasses; ++i) {
                    text.append("\nClass " + this.m_Instances.classAttribute().value(i) + ": Prior probability = " + Utils.doubleToString((this.m_ClassCounts[i] + 1.0) / (this.m_SumInstances + (double)this.m_NumClasses), 4, 2) + "\n\n");
                }
                text.append("Dataset: " + this.m_Instances.relationName() + "\n" + "Instances: " + this.m_NumInstances + "\n" + "Attributes: " + this.m_NumAttributes + "\n" + "Frequency limit for superParents: " + this.m_Limit + "\n" + "Critical value for the specializtion-generalization " + "relationship: " + this.m_Critical + "\n");
                if (this.m_Laplace) {
                    text.append("Using LapLace estimation.");
                } else {
                    text.append("Using m-estimation, m = " + this.m_MWeight);
                }
            }
            catch (Exception ex) {
                text.append(ex.getMessage());
            }
        }
        return text.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5928 $");
    }

    public static void main(String[] argv) {
        AODEsr.runClassifier(new AODEsr(), argv);
    }
}

