package weighted_hierarchical;

import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.ml.clustering.Cluster;
import org.apache.commons.math3.ml.clustering.Clusterer;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import weka.clusterers.SimpleKMeans;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import weka.clusterers.HierarchicalClusterer;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.MathIllegalArgumentException;
import org.apache.commons.math3.ml.clustering.Cluster;
import org.apache.commons.math3.ml.clustering.Clusterer;
import org.apache.commons.math3.ml.distance.EuclideanDistance;

/* Weighted version of Hierarchical algorithm - single link version */
public class HierarchicalWeighted extends Clusterer<WeightedPoint> {

    private int k;
    private int numAttrs;

    public HierarchicalWeighted(int k, int numAttrs) {
        super(new EuclideanDistance());
        this.k = k;
        this.numAttrs = numAttrs;
    }

    @Override
    public List<? extends Cluster<WeightedPoint>> cluster(Collection<WeightedPoint> points)
            throws MathIllegalArgumentException, ConvergenceException {

        int numAttrs = this.numAttrs;

        ArrayList<Attribute> atts = new ArrayList<Attribute>();

        for (int i = 0; i < numAttrs; i++) {
            atts.add(new Attribute("attr" + i, i));
        }

        Instances data = new Instances("Table", atts, numAttrs);

        Instance record;

        for (WeightedPoint p : points) {
            record = new DenseInstance(this.numAttrs);

            for (int i = 0; i < numAttrs; i++) {
                record.setValue((Attribute) atts.get(i), p.getPoint()[i]);
            }
            record.setWeight(p.getWeight());

            data.add(record);
        }

        HierarchicalClusterer clusterer;

        List<Cluster<WeightedPoint>> clusters = new ArrayList<Cluster<WeightedPoint>>();

        try {
            clusterer = new HierarchicalClusterer();
            clusterer.setOptions(new String[]{"-L", "SINGLE"});
            clusterer.setNumClusters(this.k);
            clusterer.setDistanceFunction(new weka.core.EuclideanDistance());
            clusterer.setDistanceIsBranchLength(true);

            clusterer.buildClusterer(data);

            for (int c = 0; c < this.k; c++) {
                clusters.add(c, new Cluster<WeightedPoint>());
            }

            for (Instance rec : data) {
                int idCluster = clusterer.clusterInstance(rec);

                Cluster<WeightedPoint> clusterCurrentPoint = clusters.get(idCluster);

                clusterCurrentPoint.addPoint(new WeightedPoint((long) rec.weight(), rec.toDoubleArray()));
            }

        } catch (Exception e) {
            e.printStackTrace();
        }

        return clusters;

    }

}