前言

文章标题的两个概念也许对于许多同学们来说都相对比较陌生,都比较偏向于于理论方面的知识,但是这个算法非常的强大,在很多方面都会存在他的影子。2个概念,1个维特比算法,1个隐马尔可夫模型。你很难想象,输入法的设计也会用到其中的一些知识。
HMM-隐马尔可夫模型

隐马尔可夫模型如果真的要展开来讲,那短短的一篇文章当然无法阐述的清,所以我会以最简单的方式解释。隐马尔可夫模型简称HMM,根据百度百科中的描述,隐马尔可夫模型描述的是一个含有隐含未知参数的马尔可夫模型。模型的本质是从观察的参数中获取隐含的参数信息。一般的在马尔可夫模型中,前后之间的特征会存在部分的依赖影响。示例图如下:


隐马尔可夫模型在语音识别中有广泛的应用。其实在输入法中也有用到。举个例子,假设我输入wszs,分别代表4个字,所以观察特征就是w, s, z, s,那么我想挖掘出他所想表达的信息,也就是我想打出的字是什么,可能是"我是张三",又可能是“晚上再说”,这个就是可能的信息,最可能的信息,就会被放在输入选择越靠前的位置。在这里我们就会用到一个叫维特比的算法了。
Viterbi-维特比算法

维特比算法这个名字的产生是以著名科学家维特比命名的,维特比算法是数字通信中非常常用的一种算法。那么维特比算法到底是做什么的呢。简单的说,他是一种特殊的动态规划算法,也就是DP问题。但是这里可不是单纯的寻找最短路径这些问题,可能他是需要寻找各个条件因素下最大概率的一条路径,假设针对观察特征,会有多个隐含特征值的情况。比如下面这个是多种隐含变量的组合情况,形成了一个密集的篱笆网络。


于是问题就转变成了,如何在这么多的路径中找到最佳路径。如果这是输入法的例子,上面的每一列的值就是某个拼音下对应的可能的字。于是我们就很容易联想到可以用dp的思想去做,每次求得相邻变量之间求得后最佳的值,存在下一列的节点上,而不是组合这么多种情况去算。时间复杂度能降低不少。但是在马尔可夫模型中,你还要考虑一些别的因素。所以总的来说,维特比算法就是一种利用动态规划算法寻找最有可能产生观察序列的隐含信息序列,尤其是在类似于隐马尔可夫模型的应用中。
算法实例

下面给出一个实际例子,来说明一下维特比算法到底怎么用,如果你用过动态规划算法,相信一定能很迅速的理解我想表达的意思。下面这个例子讲的是海藻的观察特征与天气的关系,通过观测海藻的特征状态,退出当天天气的状况。当然当天天气的预测还可能受昨天的天气影响,所以这是一个很棒的隐马尔可夫模型问题。问题的描述是下面这段话:

假设连续观察3天的海藻湿度为(Dry,Damp,Soggy),求这三天最可能的天气情况。天气只有三类(Sunny,Cloudy,Rainy),而且海藻湿度和天气有一定的关系。问题具体描述,链接在此

ok,状态转移概率矩阵和混淆矩阵都已给出,详细代码以及相关数据,请点击链接:https://github.com/linyiqun/DataMiningAlgorithm/tree/master/Others/DataMining_Viterbi

直接给出代码解答,主算法类ViterbiTool.java:

package DataMining_Viterbi;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * 维特比算法工具类
 *
 * @author lyq
 *
 */
public class ViterbiTool {
    // 状态转移概率矩阵文件地址
    private String stmFilePath;
    // 混淆矩阵文件地址
    private String confusionFilePath;
    // 初始状态概率
    private double[] initStatePro;
    // 观察到的状态序列
    public String[] observeStates;
    // 状态转移矩阵值
    private double[][] stMatrix;
    // 混淆矩阵值
    private double[][] confusionMatrix;
    // 各个条件下的潜在特征概率值
    private double[][] potentialValues;
    // 潜在特征
    private ArrayList<String> potentialAttrs;
    // 属性值列坐标映射图
    private HashMap<String, Integer> name2Index;
    // 列坐标属性值映射图
    private HashMap<Integer, String> index2name;

    public ViterbiTool(String stmFilePath, String confusionFilePath,
            double[] initStatePro, String[] observeStates) {
        this.stmFilePath = stmFilePath;
        this.confusionFilePath = confusionFilePath;
        this.initStatePro = initStatePro;
        this.observeStates = observeStates;

        initOperation();
    }

    /**
     * 初始化数据操作
     */
    private void initOperation() {
        double[] temp;
        int index;
        ArrayList<String[]> smtDatas;
        ArrayList<String[]> cfDatas;

        smtDatas = readDataFile(stmFilePath);
        cfDatas = readDataFile(confusionFilePath);

        index = 0;
        this.stMatrix = new double[smtDatas.size()][];
        for (String[] array : smtDatas) {
            temp = new double[array.length];
            for (int i = 0; i < array.length; i++) {
                try {
                    temp[i] = Double.parseDouble(array[i]);
                } catch (NumberFormatException e) {
                    temp[i] = -1;
                }
            }

            // 将转换后的值赋给数组中
            this.stMatrix[index] = temp;
            index++;
        }

        index = 0;
        this.confusionMatrix = new double[cfDatas.size()][];
        for (String[] array : cfDatas) {
            temp = new double[array.length];
            for (int i = 0; i < array.length; i++) {
                try {
                    temp[i] = Double.parseDouble(array[i]);
                } catch (NumberFormatException e) {
                    temp[i] = -1;
                }
            }

            // 将转换后的值赋给数组中
            this.confusionMatrix[index] = temp;
            index++;
        }

        this.potentialAttrs = new ArrayList<>();
        // 添加潜在特征属性
        for (String s : smtDatas.get(0)) {
            this.potentialAttrs.add(s);
        }
        // 去除首列无效列
        potentialAttrs.remove(0);

        this.name2Index = new HashMap<>();
        this.index2name = new HashMap<>();

        // 添加名称下标映射关系
        for (int i = 1; i < smtDatas.get(0).length; i++) {
            this.name2Index.put(smtDatas.get(0)[i], i);
            // 添加下标到名称的映射
            this.index2name.put(i, smtDatas.get(0)[i]);
        }

        for (int i = 1; i < cfDatas.get(0).length; i++) {
            this.name2Index.put(cfDatas.get(0)[i], i);
        }
    }

    /**
     * 从文件中读取数据
     */
    private ArrayList<String[]> readDataFile(String filePath) {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        return dataArray;
    }

    /**
     * 根据观察特征计算隐藏的特征概率矩阵
     */
    private void calPotencialProMatrix() {
        String curObserveState;
        // 观察特征和潜在特征的下标
        int osIndex;
        int psIndex;
        double temp;
        double maxPro;
        // 混淆矩阵概率值,就是相关影响的因素概率
        double confusionPro;

        this.potentialValues = new double[observeStates.length][potentialAttrs
                .size() + 1];
        for (int i = 0; i < this.observeStates.length; i++) {
            curObserveState = this.observeStates[i];
            osIndex = this.name2Index.get(curObserveState);
            maxPro = -1;

            // 因为是第一个观察特征,没有前面的影响,根据初始状态计算
            if (i == 0) {
                for (String attr : this.potentialAttrs) {
                    psIndex = this.name2Index.get(attr);
                    confusionPro = this.confusionMatrix[psIndex][osIndex];

                    temp = this.initStatePro[psIndex - 1] * confusionPro;
                    this.potentialValues[BaseNames.DAY1][psIndex] = temp;
                }
            } else {
                // 后面的潜在特征受前一个特征的影响,以及当前的混淆因素影响
                for (String toDayAttr : this.potentialAttrs) {
                    psIndex = this.name2Index.get(toDayAttr);
                    confusionPro = this.confusionMatrix[psIndex][osIndex];

                    int index;
                    maxPro = -1;
                    // 通过昨天的概率计算今天此特征的最大概率
                    for (String yAttr : this.potentialAttrs) {
                        index = this.name2Index.get(yAttr);
                        temp = this.potentialValues[i - 1][index]
                                * this.stMatrix[index][psIndex];

                        // 计算得到今天此潜在特征的最大概率
                        if (temp > maxPro) {
                            maxPro = temp;
                        }
                    }

                    this.potentialValues[i][psIndex] = maxPro * confusionPro;
                }
            }
        }
    }

    /**
     * 根据同时期最大概率值输出潜在特征值
     */
    private void outputResultAttr() {
        double maxPro;
        int maxIndex;
        ArrayList<String> psValues;

        psValues = new ArrayList<>();
        for (int i = 0; i < this.potentialValues.length; i++) {
            maxPro = -1;
            maxIndex = 0;

            for (int j = 0; j < potentialValues[i].length; j++) {
                if (this.potentialValues[i][j] > maxPro) {
                    maxPro = potentialValues[i][j];
                    maxIndex = j;
                }
            }

            // 取出最大概率下标对应的潜在特征
            psValues.add(this.index2name.get(maxIndex));
        }

        System.out.println("观察特征为:");
        for (String s : this.observeStates) {
            System.out.print(s + ", ");
        }
        System.out.println();

        System.out.println("潜在特征为:");
        for (String s : psValues) {
            System.out.print(s + ", ");
        }
        System.out.println();
    }

    /**
     * 根据观察属性,得到潜在属性信息
     */
    public void calHMMObserve() {
        calPotencialProMatrix();
        outputResultAttr();
    }
}

测试结果输出:

观察特征为:
Dry, Damp, Soggy,
潜在特征为:
Sunny, Cloudy, Rainy,


参考文献

百度百科-隐马尔可夫模型

百度百科-维特比

<<数学之美>>第二版-吴军

http://blog.csdn.net/jeiwt/article/details/8076739
作者:Androidlushangderen 发表于2015/8/3 23:09:39 原文链接
阅读:475 评论:0 查看评论
再学贝叶斯网络--TAN树型朴素贝叶斯算法
2015年7月5日 15:18
前言

在前面的时间里已经学习过了NB朴素贝叶斯算法, 又刚刚初步的学习了贝叶斯网络的一些基本概念和常用的计算方法。于是就有了上篇初识贝叶斯网络的文章,由于本人最近一直在研究学习<<贝叶斯网引论>>,也接触到了许多与贝叶斯网络相关的知识,可以说朴素贝叶斯算法这些只是我们所了解贝叶斯知识的很小的一部分。今天我要总结的学习成果就是基于NB算法的,叫做Tree Augmented Naive Bays,中文意思就是树型朴素贝叶斯算法,简单理解就是树增强型NB算法,那么问题来了,他是如何增强的呢,请继续往下正文的描述。
朴素贝叶斯算法

又得要从朴素贝叶斯算法开始讲起了,因为在前言中已经说了,TAN算法是对NB算法的增强,了解过NB算法的,一定知道NB算法在使用的时候是假设属性事件是相互独立的,而决策属性的分类结果是依赖于各个条件属性的情况的,最后选择分类属性中拥有最大后验概率的值为决策属性。比如下面这个模型可以描述一个简单的模型,


上面账号是否真实的依赖属性条件有3个,好友密度,是否使用真实头像,日志密度,假设这3个属性是相互独立的,但是事实上,在这里的头像是否真实和好友密度其实是有关联的,所以更加真实的情况是下面这张情况;


OK,TAN的出现就解决了条件间的部分属性依赖的问题。在上面的例子中我们是根据自己的主观意识判断出头像和好友密度的关系,但是在真实算法中,我们当然希望机器能够自己根据所给数据集帮我们得出这样的关系,令人高兴的事,TAN帮我们做到了这点。
TAN算法
互信息值

互信息值,在百度百科中的解释如下:

互信息值是信息论中一个有用的信息度量。它可以看出是一个信息量里包含另一个随机变量的信息量。

用图线来表示就是下面这样。


中间的I(x;y)就是互信息值,X,Y代表的2种属性。于是下面这个属性就很好理解了,互信息值越大,就代表2个属性关联性越大。互信息值的标准公式如下:


但是在TAN中会有少许的不一样,会有类变量属性的加入,因为属性之间的关联性的前提是要在某一分类属性确定下进行重新计算,不同的类属性值会有不同的属性关联性。下面是TAN中的I(x;Y)计算公式:


现在看不懂不要紧,后面在给出的程序代码中可自行调试。
算法实现过程

TAN的算法过程其实并不简单,在计算完各个属性对的互信息值之后,要进行贝叶斯网络的构建,这个是TAN中最难的部分,这个部分有下面几个阶段。

1、根据各个属性对的互信息值降序排序,依次取出其中的节点对,遵循不产生环路的原则,构造最大权重跨度树,直到选择完n-1条边为止(因为总共n个属性节点,n-1条边即可确定)。按照互信息值从高到低选择的原因就是要保留关联性更高的关联依赖性的边。

2、上述过程构成的是一个无向图,接下来为整个无向图确定边的方向。选择任意一个属性节点作为根节点,由根节点向外的方向为属性节点之间的方向。

3、为每一个属性节点添加父节点,父节点就是分类属性节点,至此贝叶斯网络结构构造完毕。

为了方便大家理解,我在网上截了几张图,下面这张是在5个属性节点中优先选择了互信息值最大的4条作为无向图:


上述带了箭头是因为,我选择的A作为树的根节点,然后方向就全部确定了,因为A直接连着4个属性节点,然后再此基础上添加父节点,就是下面这个样子了。


OK,这样应该就比较好理解了吧,如果还不理解,请仔细分析我写的程序,从代码中去理解这个过程也可以。
分类结果概率的计算

分类结果概率的计算其实非常简单,只要把查询的条件属性传入分类模型中,然后计算不同类属性下的概率值,拥有最大概率值的分类属性值为最终的分类结果。下面是计算公式,就是联合概率分布公式:


代码实现

测试数据集input.txt:

OutLook Temperature Humidity Wind PlayTennis
Sunny Hot High Weak No
Sunny Hot High Strong No
Overcast Hot High Weak Yes
Rainy Mild High Weak Yes
Rainy Cool Normal Weak Yes
Rainy Cool Normal Strong No
Overcast Cool Normal Strong Yes
Sunny Mild High Weak No
Sunny Cool Normal Weak Yes
Rainy Mild Normal Weak Yes
Sunny Mild Normal Strong Yes
Overcast Mild High Strong Yes
Overcast Hot Normal Weak Yes
Rainy Mild High Strong No

节点类Node.java:

package DataMining_TAN;

import java.util.ArrayList;

/**
 * 贝叶斯网络节点类
 *
 * @author lyq
 *
 */
public class Node {
    //节点唯一id,方便后面节点连接方向的确定
    int id;
    // 节点的属性名称
    String name;
    // 该节点所连续的节点
    ArrayList<Node> connectedNodes;

    public Node(int id, String name) {
        this.id = id;
        this.name = name;

        // 初始化变量
        this.connectedNodes = new ArrayList<>();
    }

    /**
     * 将自身节点连接到目标给定的节点
     *
     * @param node
     *            下游节点
     */
    public void connectNode(Node node) {
        //避免连接自身
        if(this.id == node.id){
            return;
        }
        
        // 将节点加入自身节点的节点列表中
        this.connectedNodes.add(node);
        // 将自身节点加入到目标节点的列表中
        node.connectedNodes.add(this);
    }

    /**
     * 判断与目标节点是否相同,主要比较名称是否相同即可
     *
     * @param node
     *            目标结点
     * @return
     */
    public boolean isEqual(Node node) {
        boolean isEqual;

        isEqual = false;
        // 节点名称相同则视为相等
        if (this.id == node.id) {
            isEqual = true;
        }

        return isEqual;
    }
}

互信息值类.java:

package DataMining_TAN;

/**
 * 属性之间的互信息值,表示属性之间的关联性大小
 * @author lyq
 *
 */
public class AttrMutualInfo implements Comparable<AttrMutualInfo>{
    //互信息值
    Double value;
    //关联属性值对
    Node[] nodeArray;
    
    public AttrMutualInfo(double value, Node node1, Node node2){
        this.value = value;
        
        this.nodeArray = new Node[2];
        this.nodeArray[0] = node1;
        this.nodeArray[1] = node2;
    }

    @Override
    public int compareTo(AttrMutualInfo o) {
        // TODO Auto-generated method stub
        return o.value.compareTo(this.value);
    }
    
}



算法主程序类TANTool.java:

package DataMining_TAN;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;

/**
 * TAN树型朴素贝叶斯算法工具类
 *
 * @author lyq
 *
 */
public class TANTool {
    // 测试数据集地址
    private String filePath;
    // 数据集属性总数,其中一个个分类属性
    private int attrNum;
    // 分类属性名
    private String classAttrName;
    // 属性列名称行
    private String[] attrNames;
    // 贝叶斯网络边的方向,数组内的数值为节点id,从i->j
    private int[][] edges;
    // 属性名到列下标的映射
    private HashMap<String, Integer> attr2Column;
    // 属性,属性对取值集合映射对
    private HashMap<String, ArrayList<String>> attr2Values;
    // 贝叶斯网络总节点列表
    private ArrayList<Node> totalNodes;
    // 总的测试数据
    private ArrayList<String[]> totalDatas;

    public TANTool(String filePath) {
        this.filePath = filePath;

        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] array;

            while ((str = in.readLine()) != null) {
                array = str.split(" ");
                dataArray.add(array);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        this.totalDatas = dataArray;
        this.attrNames = this.totalDatas.get(0);
        this.attrNum = this.attrNames.length;
        this.classAttrName = this.attrNames[attrNum - 1];

        Node node;
        this.edges = new int[attrNum][attrNum];
        this.totalNodes = new ArrayList<>();
        this.attr2Column = new HashMap<>();
        this.attr2Values = new HashMap<>();

        // 分类属性节点id最小设为0
        node = new Node(0, attrNames[attrNum - 1]);
        this.totalNodes.add(node);
        for (int i = 0; i < attrNames.length; i++) {
            if (i < attrNum - 1) {
                // 创建贝叶斯网络节点,每个属性一个节点
                node = new Node(i + 1, attrNames[i]);
                this.totalNodes.add(node);
            }

            // 添加属性到列下标的映射
            this.attr2Column.put(attrNames[i], i);
        }

        String[] temp;
        ArrayList<String> values;
        // 进行属性名,属性值对的映射匹配
        for (int i = 1; i < this.totalDatas.size(); i++) {
            temp = this.totalDatas.get(i);

            for (int j = 0; j < temp.length; j++) {
                // 判断map中是否包含此属性名
                if (this.attr2Values.containsKey(attrNames[j])) {
                    values = this.attr2Values.get(attrNames[j]);
                } else {
                    values = new ArrayList<>();
                }

                if (!values.contains(temp[j])) {
                    // 加入新的属性值
                    values.add(temp[j]);
                }

                this.attr2Values.put(attrNames[j], values);
            }
        }
    }

    /**
     * 根据条件互信息度对构建最大权重跨度树,返回第一个节点为根节点
     *
     * @param iArray
     */
    private Node constructWeightTree(ArrayList<Node[]> iArray) {
        Node node1;
        Node node2;
        Node root;
        ArrayList<Node> existNodes;

        existNodes = new ArrayList<>();

        for (Node[] i : iArray) {
            node1 = i[0];
            node2 = i[1];

            // 将2个节点进行连接
            node1.connectNode(node2);
            // 避免出现环路现象
            addIfNotExist(node1, existNodes);
            addIfNotExist(node2, existNodes);

            if (existNodes.size() == attrNum - 1) {
                break;
            }
        }

        // 返回第一个作为根节点
        root = existNodes.get(0);
        return root;
    }

    /**
     * 为树型结构确定边的方向,方向为属性根节点方向指向其他属性节点方向
     *
     * @param root
     *            当前遍历到的节点
     */
    private void confirmGraphDirection(Node currentNode) {
        int i;
        int j;
        ArrayList<Node> connectedNodes;

        connectedNodes = currentNode.connectedNodes;

        i = currentNode.id;
        for (Node n : connectedNodes) {
            j = n.id;

            // 判断连接此2节点的方向是否被确定
            if (edges[i][j] == 0 && edges[j][i] == 0) {
                // 如果没有确定,则制定方向为i->j
                edges[i][j] = 1;

                // 递归继续搜索
                confirmGraphDirection(n);
            }
        }
    }

    /**
     * 为属性节点添加分类属性节点为父节点
     *
     * @param parentNode
     *            父节点
     * @param nodeList
     *            子节点列表
     */
    private void addParentNode() {
        // 分类属性节点
        Node parentNode;

        parentNode = null;
        for (Node n : this.totalNodes) {
            if (n.id == 0) {
                parentNode = n;
                break;
            }
        }

        for (Node child : this.totalNodes) {
            parentNode.connectNode(child);

            if (child.id != 0) {
                // 确定连接方向
                this.edges[0][child.id] = 1;
            }
        }
    }

    /**
     * 在节点集合中添加节点
     *
     * @param node
     *            待添加节点
     * @param existNodes
     *            已存在的节点列表
     * @return
     */
    public boolean addIfNotExist(Node node, ArrayList<Node> existNodes) {
        boolean canAdd;

        canAdd = true;
        for (Node n : existNodes) {
            // 如果节点列表中已经含有节点,则算添加失败
            if (n.isEqual(node)) {
                canAdd = false;
                break;
            }
        }

        if (canAdd) {
            existNodes.add(node);
        }

        return canAdd;
    }

    /**
     * 计算节点条件概率
     *
     * @param node
     *            关于node的后验概率
     * @param queryParam
     *            查询的属性参数
     * @return
     */
    private double calConditionPro(Node node, HashMap<String, String> queryParam) {
        int id;
        double pro;
        String value;
        String[] attrValue;

        ArrayList<String[]> priorAttrInfos;
        ArrayList<String[]> backAttrInfos;
        ArrayList<Node> parentNodes;

        pro = 1;
        id = node.id;
        parentNodes = new ArrayList<>();
        priorAttrInfos = new ArrayList<>();
        backAttrInfos = new ArrayList<>();

        for (int i = 0; i < this.edges.length; i++) {
            // 寻找父节点id
            if (this.edges[i][id] == 1) {
                for (Node temp : this.totalNodes) {
                    // 寻找目标节点id
                    if (temp.id == i) {
                        parentNodes.add(temp);
                        break;
                    }
                }
            }
        }

        // 获取先验属性的属性值,首先添加先验属性
        value = queryParam.get(node.name);
        attrValue = new String[2];
        attrValue[0] = node.name;
        attrValue[1] = value;
        priorAttrInfos.add(attrValue);

        // 逐一添加后验属性
        for (Node p : parentNodes) {
            value = queryParam.get(p.name);
            attrValue = new String[2];
            attrValue[0] = p.name;
            attrValue[1] = value;

            backAttrInfos.add(attrValue);
        }

        pro = queryConditionPro(priorAttrInfos, backAttrInfos);

        return pro;
    }

    /**
     * 查询条件概率
     *
     * @param attrValues
     *            条件属性值
     * @return
     */
    private double queryConditionPro(ArrayList<String[]> priorValues,
            ArrayList<String[]> backValues) {
        // 判断是否满足先验属性值条件
        boolean hasPrior;
        // 判断是否满足后验属性值条件
        boolean hasBack;
        int attrIndex;
        double backPro;
        double totalPro;
        double pro;
        String[] tempData;

        pro = 0;
        totalPro = 0;
        backPro = 0;

        // 跳过第一行的属性名称行
        for (int i = 1; i < this.totalDatas.size(); i++) {
            tempData = this.totalDatas.get(i);

            hasPrior = true;
            hasBack = true;

            // 判断是否满足先验条件
            for (String[] array : priorValues) {
                attrIndex = this.attr2Column.get(array[0]);

                // 判断值是否满足条件
                if (!tempData[attrIndex].equals(array[1])) {
                    hasPrior = false;
                    break;
                }
            }

            // 判断是否满足后验条件
            for (String[] array : backValues) {
                attrIndex = this.attr2Column.get(array[0]);

                // 判断值是否满足条件
                if (!tempData[attrIndex].equals(array[1])) {
                    hasBack = false;
                    break;
                }
            }

            // 进行计数统计,分别计算满足后验属性的值和同时满足条件的个数
            if (hasBack) {
                backPro++;
                if (hasPrior) {
                    totalPro++;
                }
            } else if (hasPrior && backValues.size() == 0) {
                // 如果只有先验概率则为纯概率的计算
                totalPro++;
                backPro = 1.0;
            }
        }

        if (backPro == 0) {
            pro = 0;
        } else {
            // 计算总的概率=都发生概率/只发生后验条件的时间概率
            pro = totalPro / backPro;
        }

        return pro;
    }

    /**
     * 输入查询条件参数,计算发生概率
     *
     * @param queryParam
     *            条件参数
     * @return
     */
    public double calHappenedPro(String queryParam) {
        double result;
        double temp;
        // 分类属性值
        String classAttrValue;
        String[] array;
        String[] array2;
        HashMap<String, String> params;

        result = 1;
        params = new HashMap<>();

        // 进行查询字符的参数分解
        array = queryParam.split(",");
        for (String s : array) {
            array2 = s.split("=");
            params.put(array2[0], array2[1]);
        }

        classAttrValue = params.get(classAttrName);
        // 构建贝叶斯网络结构
        constructBayesNetWork(classAttrValue);

        for (Node n : this.totalNodes) {
            temp = calConditionPro(n, params);

            // 为了避免出现条件概率为0的现象,进行轻微矫正
            if (temp == 0) {
                temp = 0.001;
            }

            // 按照联合概率公式,进行乘积运算
            result *= temp;
        }

        return result;
    }

    /**
     * 构建树型贝叶斯网络结构
     *
     * @param value
     *            类别量值
     */
    private void constructBayesNetWork(String value) {
        Node rootNode;
        ArrayList<AttrMutualInfo> mInfoArray;
        // 互信息度对
        ArrayList<Node[]> iArray;

        iArray = null;
        rootNode = null;

        // 在每次重新构建贝叶斯网络结构的时候,清空原有的连接结构
        for (Node n : this.totalNodes) {
            n.connectedNodes.clear();
        }
        this.edges = new int[attrNum][attrNum];

        // 从互信息对象中取出属性值对
        iArray = new ArrayList<>();
        mInfoArray = calAttrMutualInfoArray(value);
        for (AttrMutualInfo v : mInfoArray) {
            iArray.add(v.nodeArray);
        }

        // 构建最大权重跨度树
        rootNode = constructWeightTree(iArray);
        // 为无向图确定边的方向
        confirmGraphDirection(rootNode);
        // 为每个属性节点添加分类属性父节点
        addParentNode();
    }

    /**
     * 给定分类变量值,计算属性之间的互信息值
     *
     * @param value
     *            分类变量值
     * @return
     */
    private ArrayList<AttrMutualInfo> calAttrMutualInfoArray(String value) {
        double iValue;
        Node node1;
        Node node2;
        AttrMutualInfo mInfo;
        ArrayList<AttrMutualInfo> mInfoArray;

        mInfoArray = new ArrayList<>();

        for (int i = 0; i < this.totalNodes.size() - 1; i++) {
            node1 = this.totalNodes.get(i);
            // 跳过分类属性节点
            if (node1.id == 0) {
                continue;
            }

            for (int j = i + 1; j < this.totalNodes.size(); j++) {
                node2 = this.totalNodes.get(j);
                // 跳过分类属性节点
                if (node2.id == 0) {
                    continue;
                }

                // 计算2个属性节点之间的互信息值
                iValue = calMutualInfoValue(node1, node2, value);
                mInfo = new AttrMutualInfo(iValue, node1, node2);
                mInfoArray.add(mInfo);
            }
        }

        // 将结果进行降序排列,让互信息值高的优先用于构建树
        Collections.sort(mInfoArray);

        return mInfoArray;
    }

    /**
     * 计算2个属性节点的互信息值
     *
     * @param node1
     *            节点1
     * @param node2
     *            节点2
     * @param vlaue
     *            分类变量值
     */
    private double calMutualInfoValue(Node node1, Node node2, String value) {
        double iValue;
        double temp;
        // 三种不同条件的后验概率
        double pXiXj;
        double pXi;
        double pXj;
        String[] array1;
        String[] array2;
        ArrayList<String> attrValues1;
        ArrayList<String> attrValues2;
        ArrayList<String[]> priorValues;
        // 后验概率,在这里就是类变量值
        ArrayList<String[]> backValues;

        array1 = new String[2];
        array2 = new String[2];
        priorValues = new ArrayList<>();
        backValues = new ArrayList<>();

        iValue = 0;
        array1[0] = classAttrName;
        array1[1] = value;
        // 后验属性都是类属性
        backValues.add(array1);

        // 获取节点属性的属性值集合
        attrValues1 = this.attr2Values.get(node1.name);
        attrValues2 = this.attr2Values.get(node2.name);

        for (String v1 : attrValues1) {
            for (String v2 : attrValues2) {
                priorValues.clear();

                array1 = new String[2];
                array1[0] = node1.name;
                array1[1] = v1;
                priorValues.add(array1);

                array2 = new String[2];
                array2[0] = node2.name;
                array2[1] = v2;
                priorValues.add(array2);

                // 计算3种条件下的概率
                pXiXj = queryConditionPro(priorValues, backValues);

                priorValues.clear();
                priorValues.add(array1);
                pXi = queryConditionPro(priorValues, backValues);

                priorValues.clear();
                priorValues.add(array2);
                pXj = queryConditionPro(priorValues, backValues);

                // 如果出现其中一个计数概率为0,则直接赋值为0处理
                if (pXiXj == 0 || pXi == 0 || pXj == 0) {
                    temp = 0;
                } else {
                    // 利用公式计算针对此属性值对组合的概率
                    temp = pXiXj * Math.log(pXiXj / (pXi * pXj)) / Math.log(2);
                }

                // 进行和属性值对组合的累加即为整个属性的互信息值
                iValue += temp;
            }
        }

        return iValue;
    }
}

场景测试类client.java:

package DataMining_TAN;

/**
 * TAN树型朴素贝叶斯算法
 *
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args) {
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
        // 条件查询语句
        String queryStr;
        // 分类结果概率1
        double classResult1;
        // 分类结果概率2
        double classResult2;

        TANTool tool = new TANTool(filePath);
        queryStr = "OutLook=Sunny,Temperature=Hot,Humidity=High,Wind=Weak,PlayTennis=No";
        classResult1 = tool.calHappenedPro(queryStr);

        queryStr = "OutLook=Sunny,Temperature=Hot,Humidity=High,Wind=Weak,PlayTennis=Yes";
        classResult2 = tool.calHappenedPro(queryStr);

        System.out.println(String.format("类别为%s所求得的概率为%s", "PlayTennis=No",
                classResult1));
        System.out.println(String.format("类别为%s所求得的概率为%s", "PlayTennis=Yes",
                classResult2));
        if (classResult1 > classResult2) {
            System.out.println("分类类别为PlayTennis=No");
        } else {
            System.out.println("分类类别为PlayTennis=Yes");
        }
    }
}

结果输出:

类别为PlayTennis=No所求得的概率为0.09523809523809525
类别为PlayTennis=Yes所求得的概率为3.571428571428571E-5
分类类别为PlayTennis=No


参考文献

百度百科

贝叶斯网络分类器与应用,作者:余民杰

用于数据挖掘的TAN分类器的研究和应用,作者:孙笑徽等4人


更多数据挖掘算法

https://github.com/linyiqun/DataMiningAlgorithm


作者:Androidlushangderen 发表于2015/7/5 15:18:09 原文链接
阅读:638 评论:0 查看评论
初识贝叶斯网络
2015年6月29日 16:38
前言

一看到贝叶斯网络,马上让人联想到的是5个字,朴素贝叶斯,在所难免,NaiveByes的知名度确实会被贝叶斯网络算法更高一点。其实不管是朴素贝叶斯算法,还是今天我打算讲述的贝叶斯网络算法也罢,归根结底来说都是贝叶斯系列分类算法,他的核心思想就是基于概率学的知识进行分类判断,至于分类得到底准不准,大家尽可以自己用数据集去测试测试。OK,下面进入正题--贝叶斯网络算法。
朴素贝叶斯

一般我在介绍某种算法之前,都事先会学习一下相关的算法,以便于新算法的学习,而与贝叶斯网络算法相关性比较大的在我看来就是朴素贝叶斯算法,而且前段时间也恰好学习过,简单的来说,朴素贝叶斯算法的假设条件是各个事件相互独立,然后利用贝叶斯定理,做概率的计算,于是这个算法的核心就是就是这个贝叶斯定理的运用了喽,不错,贝叶斯定理的确很有用,他是基于条件概率的先验概率和后验概率的转换公式,这么说有点抽象,下面是公式的表达式:

大学里概率学的课本上都有介绍过的,这个公式的好处在于对于一些比较难直接得出的概率通过转换后的概率计算可得,一般是把决策属性值放在先验属性中,当做目标值,然后通过决策属性值的后验概率计算所得。具体请查看我的朴素贝叶斯算法介绍。
贝叶斯网络

下面这个部分就是文章的主题了,贝叶斯网络,里面有2个字非常关键,就是网络,网络代表的潜在意思有2点,第一是有结构的,第二存在关联,我们可以马上联想到DAG有向无环图。不错,存在关联的这个特点就是与朴素贝叶斯算法最大的一个不同点,因为朴素贝叶斯算法在计算概率值上是假设各个事务属性是相互独立的,但是理性的思考一下,其实这个很难做到,任何事务,如果你仔细去想想,其实都还是有点联系的。比如这里有个例子:

在SNS社区中检验账号的真实性

如果用朴素贝叶斯来做的话,就会是这样的假设:

i、真实账号比非真实账号平均具有更大的日志密度、各大的好友密度以及更多的使用真实头像。
ii、日志密度、好友密度和是否使用真实头像在账号真实性给定的条件下是独立的。

但是其实往深入一想,使用真实的头像其实是会提高人家添加你为好友的概率的,所以在这个条件的独立其实是有问题的,所以在贝叶斯网络中是允许关联的存在的,假设就变为如下:

i、真实账号比非真实账号平均具有更大的日志密度、各大的好友密度以及更多的使用真实头像。
ii、日志密度与好友密度、日志密度与是否使用真实头像在账号真实性给定的条件下是独立的。
iii、使用真实头像的用户比使用非真实头像的用户平均有更大的好友密度。

在贝叶斯网络中,会用一张DAG来表示,每个节点代表某个属性事件,每条边代表其中的条件概率,如下:


贝叶斯网络概率的计算

贝叶斯网络概率的计算很简单,是从联合概率分布公式中变换所得,下面是联合概率分布公式:


而在贝叶斯网络中,由于存在前述的关系存在,该公式就被简化为了如下:


其中Parent(xi),表示的是xi的前驱结点,如果还不理解,可以对照我后面的代码,自行调试分析。
代码实现

需要输入2部分的数据,依赖关系,用于构建贝叶斯网络图,第二个是测试数据集,算法总代码地址:

https://github.com/linyiqun/DataMiningAlgorithm/tree/master/Others/DataMining_BayesNetwork

依赖关系数据如下:

B A
E A
A M
A J

测试数据集:

B E A M J P
y y y y y 0.00012
y y y y n 0.000051
y y y n y 0.000013
y y y n n 0.0000057
y y n y y 0.000000005
y y n y n 0.00000049
y y n n y 0.000000095
y y n n n 0.0000094
y n y y y 0.0058
y n y y n 0.0025
y n y n y 0.00065
y n y n n 0.00028
y n n y y 0.00000029
y n n y n 0.000029
y n n n y 0.0000056
y n n n n 0.00055
n y y y y 0.0036
n y y y n 0.0016
n y y n y 0.0004
n y y n n 0.00017
n y n y y 0.000007
n y n y n 0.00069
n y n n y 0.00013
n y n n n 0.013
n n y y y 0.00061
n n y y n 0.00026
n n y n y 0.000068
n n y n n 0.000029
n n n y y 0.00048
n n n y n 0.048
n n n n y 0.0092
n n n n n 0.91

节点类Node.java:

package DataMining_BayesNetwork;

import java.util.ArrayList;

/**
 * 贝叶斯网络节点类
 *
 * @author lyq
 *
 */
public class Node {
    // 节点的属性名称
    String name;
    // 节点的父亲节点,也就是上游节点,可能多个
    ArrayList<Node> parentNodes;
    // 节点的子节点,也就是下游节点,可能多个
    ArrayList<Node> childNodes;

    public Node(String name) {
        this.name = name;

        // 初始化变量
        this.parentNodes = new ArrayList<>();
        this.childNodes = new ArrayList<>();
    }

    /**
     * 将自身节点连接到目标给定的节点
     *
     * @param node
     *            下游节点
     */
    public void connectNode(Node node) {
        // 将下游节点加入自身节点的孩子节点中
        this.childNodes.add(node);
        // 将自身节点加入到下游节点的父节点中
        node.parentNodes.add(this);
    }

    /**
     * 判断与目标节点是否相同,主要比较名称是否相同即可
     *
     * @param node
     *            目标结点
     * @return
     */
    public boolean isEqual(Node node) {
        boolean isEqual;

        isEqual = false;
        // 节点名称相同则视为相等
        if (this.name.equals(node.name)) {
            isEqual = true;
        }

        return isEqual;
    }
}


算法类BayesNetworkTool.java:

package DataMining_BayesNetwork;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;

/**
 * 贝叶斯网络算法工具类
 *
 * @author lyq
 *
 */
public class BayesNetWorkTool {
    // 联合概率分布数据文件地址
    private String dataFilePath;
    // 事件关联数据文件地址
    private String attachFilePath;
    // 属性列列数
    private int columns;
    // 概率分布数据
    private String[][] totalData;
    // 关联数据对
    private ArrayList<String[]> attachData;
    // 节点存放列表
    private ArrayList<Node> nodes;
    // 属性名与列数之间的对应关系
    private HashMap<String, Integer> attr2Column;

    public BayesNetWorkTool(String dataFilePath, String attachFilePath) {
        this.dataFilePath = dataFilePath;
        this.attachFilePath = attachFilePath;

        initDatas();
    }

    /**
     * 初始化关联数据和概率分布数据
     */
    private void initDatas() {
        String[] columnValues;
        String[] array;
        ArrayList<String> datas;
        ArrayList<String> adatas;

        // 从文件中读取数据
        datas = readDataFile(dataFilePath);
        adatas = readDataFile(attachFilePath);

        columnValues = datas.get(0).split(" ");
        // 属性割名称代表事件B(盗窃),E(地震),A(警铃响).M(接到M的电话),J同M的意思,
        // 属性值都是y,n代表yes发生和no不发生
        this.attr2Column = new HashMap<>();
        for (int i = 0; i < columnValues.length; i++) {
            // 从数据中取出属性名称行,列数值存入图中
            this.attr2Column.put(columnValues[i], i);
        }

        this.columns = columnValues.length;
        this.totalData = new String[datas.size()][columns];
        for (int i = 0; i < datas.size(); i++) {
            this.totalData[i] = datas.get(i).split(" ");
        }

        this.attachData = new ArrayList<>();
        // 解析关联数据对
        for (String str : adatas) {
            array = str.split(" ");
            this.attachData.add(array);
        }

        // 构造贝叶斯网络结构图
        constructDAG();
    }

    /**
     * 从文件中读取数据
     */
    private ArrayList<String> readDataFile(String filePath) {
        File file = new File(filePath);
        ArrayList<String> dataArray = new ArrayList<String>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            while ((str = in.readLine()) != null) {
                dataArray.add(str);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        return dataArray;
    }

    /**
     * 根据关联数据构造贝叶斯网络无环有向图
     */
    private void constructDAG() {
        // 节点存在标识
        boolean srcExist;
        boolean desExist;
        String name1;
        String name2;
        Node srcNode;
        Node desNode;

        this.nodes = new ArrayList<>();
        for (String[] array : this.attachData) {
            srcExist = false;
            desExist = false;

            name1 = array[0];
            name2 = array[1];

            // 新建节点
            srcNode = new Node(name1);
            desNode = new Node(name2);

            for (Node temp : this.nodes) {
                // 如果找到相同节点,则取出
                if (srcNode.isEqual(temp)) {
                    srcExist = true;
                    srcNode = temp;
                } else if (desNode.isEqual(temp)) {
                    desExist = true;
                    desNode = temp;
                }

                // 如果2个节点都已找到,则跳出循环
                if (srcExist && desExist) {
                    break;
                }
            }

            // 将2个节点进行连接
            srcNode.connectNode(desNode);

            // 根据标识判断是否需要加入列表容器中
            if (!srcExist) {
                this.nodes.add(srcNode);
            }

            if (!desExist) {
                this.nodes.add(desNode);
            }
        }
    }

    /**
     * 查询条件概率
     *
     * @param attrValues
     *            条件属性值
     * @return
     */
    private double queryConditionPro(ArrayList<String[]> attrValues) {
        // 判断是否满足先验属性值条件
        boolean hasPrior;
        // 判断是否满足后验属性值条件
        boolean hasBack;
        int priorIndex;
        int attrIndex;
        double backPro;
        double totalPro;
        double pro;
        double currentPro;
        // 先验属性
        String[] priorValue;
        String[] tempData;

        pro = 0;
        totalPro = 0;
        backPro = 0;
        attrValues.get(0);
        priorValue = attrValues.get(0);
        // 得到后验概率
        attrValues.remove(0);

        // 取出先验属性的列数
        priorIndex = this.attr2Column.get(priorValue[0]);
        // 跳过第一行的属性名称行
        for (int i = 1; i < this.totalData.length; i++) {
            tempData = this.totalData[i];

            hasPrior = false;
            hasBack = true;

            // 当前行的概率
            currentPro = Double.parseDouble(tempData[this.columns - 1]);
            // 判断是否满足先验条件
            if (tempData[priorIndex].equals(priorValue[1])) {
                hasPrior = true;
            }

            for (String[] array : attrValues) {
                attrIndex = this.attr2Column.get(array[0]);

                // 判断值是否满足条件
                if (!tempData[attrIndex].equals(array[1])) {
                    hasBack = false;
                    break;
                }
            }

            // 进行计数统计,分别计算满足后验属性的值和同时满足条件的个数
            if (hasBack) {
                backPro += currentPro;
                if (hasPrior) {
                    totalPro += currentPro;
                }
            } else if (hasPrior && attrValues.size() == 0) {
                // 如果只有先验概率则为纯概率的计算
                totalPro += currentPro;
                backPro = 1.0;
            }
        }

        // 计算总的概率=都发生概率/只发生后验条件的时间概率
        pro = totalPro / backPro;

        return pro;
    }

    /**
     * 根据贝叶斯网络计算概率
     *
     * @param queryStr
     *            查询条件串
     * @return
     */
    public double calProByNetWork(String queryStr) {
        double temp;
        double pro;
        String[] array;
        // 先验条件值
        String[] preValue;
        // 后验条件值
        String[] backValue;
        // 所有先验条件和后验条件值的属性值的汇总
        ArrayList<String[]> attrValues;

        // 判断是否满足网络结构
        if (!satisfiedNewWork(queryStr)) {
            return -1;
        }

        pro = 1;
        // 首先做查询条件的分解
        array = queryStr.split(",");

        // 概率的初值等于第一个事件发生的随机概率
        attrValues = new ArrayList<>();
        attrValues.add(array[0].split("="));
        pro = queryConditionPro(attrValues);

        for (int i = 0; i < array.length - 1; i++) {
            attrValues.clear();

            // 下标小的在前面的属于后验属性
            backValue = array[i].split("=");
            preValue = array[i + 1].split("=");
            attrValues.add(preValue);
            attrValues.add(backValue);

            // 算出此种情况的概率值
            temp = queryConditionPro(attrValues);
            // 进行积的相乘
            pro *= temp;
        }

        return pro;
    }

    /**
     * 验证事件的查询因果关系是否满足贝叶斯网络
     *
     * @param queryStr
     *            查询字符串
     * @return
     */
    private boolean satisfiedNewWork(String queryStr) {
        String attrName;
        String[] array;
        boolean isExist;
        boolean isSatisfied;
        // 当前节点
        Node currentNode;
        // 候选节点列表
        ArrayList<Node> nodeList;

        isSatisfied = true;
        currentNode = null;
        // 做查询字符串的分解
        array = queryStr.split(",");
        nodeList = this.nodes;

        for (String s : array) {
            // 开始时默认属性对应的节点不存在
            isExist = false;
            // 得到属性事件名
            attrName = s.split("=")[0];

            for (Node n : nodeList) {
                if (n.name.equals(attrName)) {
                    isExist = true;

                    currentNode = n;
                    // 下一轮的候选节点为当前节点的孩子节点
                    nodeList = currentNode.childNodes;

                    break;
                }
            }

            // 如果存在未找到的节点,则说明不满足依赖结构跳出循环
            if (!isExist) {
                isSatisfied = false;
                break;
            }
        }

        return isSatisfied;
    }
}

场景测试类Client.java:

package DataMining_BayesNetwork;

import java.text.MessageFormat;

/**
 * 贝叶斯网络场景测试类
 *
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args) {
        String dataFilePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
        String attachFilePath = "C:\\Users\\lyq\\Desktop\\icon\\attach.txt";
        // 查询串语句
        String queryStr;
        // 结果概率
        double result;

        // 查询语句的描述的事件是地震发生了,导致响铃响了,导致接到Mary的电话
        queryStr = "E=y,A=y,M=y";
        BayesNetWorkTool tool = new BayesNetWorkTool(dataFilePath,
                attachFilePath);
        result = tool.calProByNetWork(queryStr);

        if (result == -1) {
            System.out.println("所描述的事件不满足贝叶斯网络的结构,无法求其概率");
        } else {
            System.out.println(String.format("事件%s发生的概率为%s", queryStr, result));
        }
    }
}

输出结果:

事件E=y,A=y,M=y发生的概率为0.005373075715453122


参考文献

百度百科

http://www.cnblogs.com/leoo2sk/archive/2010/09/18/bayes-network.html


更多数据挖掘算法

https://github.com/linyiqun/DataMiningAlgorithm
作者:Androidlushangderen 发表于2015/6/29 16:38:45 原文链接
阅读:626 评论:0 查看评论
ACO蚁群算法解决TSP旅行商问题
2015年4月30日 15:31
前言

蚁群算法也是一种利用了大自然规律的启发式算法,与之前学习过的GA遗传算法类似,遗传算法是用了生物进行理论,把更具适应性的基因传给下一代,最后就能得到一个最优解,常常用来寻找问题的最优解。当然,本篇文章不会主讲GA算法的,想要了解的同学可以查看,我的遗传算法学习和遗传算法在走迷宫中的应用。话题重新回到蚁群算法,蚁群算法是一个利用了蚂蚁寻找食物的原理。不知道小时候有没有发现,当一个蚂蚁发现了地上的食物,然后非常迅速的,就有其他的蚂蚁聚拢过来,最后把食物抬回家,这里面其实有着非常多的道理的,在ACO中就用到了这个机理用于解决实际生活中的一些问题。
蚂蚁找食物

首先我们要具体说说一个有意思的事情,就是蚂蚁找食物的问题,理解了这个原理之后,对于理解ACO算法就非常容易了。蚂蚁作为那么小的动物,在地上漫无目的的寻找食物,起初都是没有目标的,他从蚂蚁洞中走出,随机的爬向各个方向,在这期间他会向外界播撒一种化学物质,姑且就叫做信息素,所以这里就可以得到的一个前提,越多蚂蚁走过的路径,信息素浓度就会越高,那么某条路径信息素浓度高了,自然就会有越多的蚂蚁感觉到了,就会聚集过来了。所以当众多蚂蚁中的一个找到食物之后,他就会在走过的路径中放出信息素浓度,因此就会有很多的蚂蚁赶来了。类似下面的场景:


至于蚂蚁是如何感知这个信息素,这个就得问生物学家了,我也没做过研究。
算法介绍

OK,有了上面这个自然生活中的生物场景之后,我们再来切入文章主题来学习一下蚁群算法,百度百科中对应蚁群算法是这么介绍的:蚁群算法是一种在图中寻找优化路径的机率型算法。他的灵感就是来自于蚂蚁发现食物的行为。蚁群算法是一种新的模拟进化优化的算法,与遗传算法有很多相似的地方。蚁群算法在比较早的时候成功解决了TSP旅行商的问题(在后面的例子中也会以这个例子)。要用算法去模拟蚂蚁的这种行为,关键在于信息素的在算法中的设计,以及路径中信息素浓度越大的路径,将会有更高的概率被蚂蚁所选择到。

算法原理

要想实现上面的几个模拟行为,需要借助几个公式,当然公式不是我自己定义的,主要有3个,如下图:


上图中所出现的alpha,beita,p等数字都是控制因子,所以可不必理会,Tij(n)的意思是在时间为n的时候,从城市i到城市j的路径的信息素浓度。类似于nij的字母是城市i到城市j距离的倒数。就是下面这个公式。


所以所有的公式都是为第一个公式服务的,第一个公式的意思是指第k只蚂蚁选择从城市i到城市j的概率,可以见得,这个受距离和信息素浓度的双重影响,距离越远,去此城市的概率自然也低,所以nij会等于距离的倒数,而且在算信息素浓度的时候,也考虑到了信息素浓度衰减的问题,所以会在上次的浓度值上乘以一个衰减因子P。另外还要加上本轮搜索增加的信息素浓度(假如有蚂蚁经过此路径的话),所以这几个公式的整体设计思想还是非常棒的。
算法的代码实现

由于本身我这里没有什么真实的测试数据,就随便自己构造了一个简单的数据,输入如下,分为城市名称和城市之间的距离,用#符号做区分标识,大家应该可以看得懂吧

# CityName
1
2
3
4
# Distance
1 2 1
1 3 1.4
1 4 1
2 3 1
2 4 1
3 4 1

蚂蚁类Ant.java:

package DataMining_ACO;

import java.util.ArrayList;

/**
 * 蚂蚁类,进行路径搜索的载体
 *
 * @author lyq
 *
 */
public class Ant implements Comparable<Ant> {
    // 蚂蚁当前所在城市
    String currentPos;
    // 蚂蚁遍历完回到原点所用的总距离
    Double sumDistance;
    // 城市间的信息素浓度矩阵,随着时间的增多而减少
    double[][] pheromoneMatrix;
    // 蚂蚁已经走过的城市集合
    ArrayList<String> visitedCitys;
    // 还未走过的城市集合
    ArrayList<String> nonVisitedCitys;
    // 蚂蚁当前走过的路径
    ArrayList<String> currentPath;

    public Ant(double[][] pheromoneMatrix, ArrayList<String> nonVisitedCitys) {
        this.pheromoneMatrix = pheromoneMatrix;
        this.nonVisitedCitys = nonVisitedCitys;

        this.visitedCitys = new ArrayList<>();
        this.currentPath = new ArrayList<>();
    }

    /**
     * 计算路径的总成本(距离)
     *
     * @return
     */
    public double calSumDistance() {
        sumDistance = 0.0;
        String lastCity;
        String currentCity;

        for (int i = 0; i < currentPath.size() - 1; i++) {
            lastCity = currentPath.get(i);
            currentCity = currentPath.get(i + 1);

            // 通过距离矩阵进行计算
            sumDistance += ACOTool.disMatrix[Integer.parseInt(lastCity)][Integer
                    .parseInt(currentCity)];
        }

        return sumDistance;
    }

    /**
     * 蚂蚁选择前往下一个城市
     *
     * @param city
     *            所选的城市
     */
    public void goToNextCity(String city) {
        this.currentPath.add(city);
        this.currentPos = city;
        this.nonVisitedCitys.remove(city);
        this.visitedCitys.add(city);
    }

    /**
     * 判断蚂蚁是否已经又重新回到起点
     *
     * @return
     */
    public boolean isBack() {
        boolean isBack = false;
        String startPos;
        String endPos;

        if (currentPath.size() == 0) {
            return isBack;
        }

        startPos = currentPath.get(0);
        endPos = currentPath.get(currentPath.size() - 1);
        if (currentPath.size() > 1 && startPos.equals(endPos)) {
            isBack = true;
        }

        return isBack;
    }

    /**
     * 判断蚂蚁在本次的走过的路径中是否包含从城市i到城市j
     *
     * @param cityI
     *            城市I
     * @param cityJ
     *            城市J
     * @return
     */
    public boolean pathContained(String cityI, String cityJ) {
        String lastCity;
        String currentCity;
        boolean isContained = false;

        for (int i = 0; i < currentPath.size() - 1; i++) {
            lastCity = currentPath.get(i);
            currentCity = currentPath.get(i + 1);

            // 如果某一段路径的始末位置一致,则认为有经过此城市
            if ((lastCity.equals(cityI) && currentCity.equals(cityJ))
                    || (lastCity.equals(cityJ) && currentCity.equals(cityI))) {
                isContained = true;
                break;
            }
        }

        return isContained;
    }

    @Override
    public int compareTo(Ant o) {
        // TODO Auto-generated method stub
        return this.sumDistance.compareTo(o.sumDistance);
    }
}

蚁群算法工具类ACOTool.java:

package DataMining_ACO;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

/**
 * 蚁群算法工具类
 *
 * @author lyq
 *
 */
public class ACOTool {
    // 输入数据类型
    public static final int INPUT_CITY_NAME = 1;
    public static final int INPUT_CITY_DIS = 2;

    // 城市间距离邻接矩阵
    public static double[][] disMatrix;
    // 当前时间
    public static int currentTime;

    // 测试数据地址
    private String filePath;
    // 蚂蚁数量
    private int antNum;
    // 控制参数
    private double alpha;
    private double beita;
    private double p;
    private double Q;
    // 随机数产生器
    private Random random;
    // 城市名称集合,这里为了方便,将城市用数字表示
    private ArrayList<String> totalCitys;
    // 所有的蚂蚁集合
    private ArrayList<Ant> totalAnts;
    // 城市间的信息素浓度矩阵,随着时间的增多而减少
    private double[][] pheromoneMatrix;
    // 目标的最短路径,顺序为从集合的前部往后挪动
    private ArrayList<String> bestPath;
    // 信息素矩阵存储图,key采用的格式(i,j,t)->value
    private Map<String, Double> pheromoneTimeMap;

    public ACOTool(String filePath, int antNum, double alpha, double beita,
            double p, double Q) {
        this.filePath = filePath;
        this.antNum = antNum;
        this.alpha = alpha;
        this.beita = beita;
        this.p = p;
        this.Q = Q;
        this.currentTime = 0;

        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        int flag = -1;
        int src = 0;
        int des = 0;
        int size = 0;
        // 进行城市名称种数的统计
        this.totalCitys = new ArrayList<>();
        for (String[] array : dataArray) {
            if (array[0].equals("#") && totalCitys.size() == 0) {
                flag = INPUT_CITY_NAME;

                continue;
            } else if (array[0].equals("#") && totalCitys.size() > 0) {
                size = totalCitys.size();
                // 初始化距离矩阵
                this.disMatrix = new double[size + 1][size + 1];
                this.pheromoneMatrix = new double[size + 1][size + 1];

                // 初始值-1代表此对应位置无值
                for (int i = 0; i < size; i++) {
                    for (int j = 0; j < size; j++) {
                        this.disMatrix[i][j] = -1;
                        this.pheromoneMatrix[i][j] = -1;
                    }
                }

                flag = INPUT_CITY_DIS;
                continue;
            }

            if (flag == INPUT_CITY_NAME) {
                this.totalCitys.add(array[0]);
            } else {
                src = Integer.parseInt(array[0]);
                des = Integer.parseInt(array[1]);

                this.disMatrix[src][des] = Double.parseDouble(array[2]);
                this.disMatrix[des][src] = Double.parseDouble(array[2]);
            }
        }
    }

    /**
     * 计算从蚂蚁城市i到j的概率
     *
     * @param cityI
     *            城市I
     * @param cityJ
     *            城市J
     * @param currentTime
     *            当前时间
     * @return
     */
    private double calIToJProbably(String cityI, String cityJ, int currentTime) {
        double pro = 0;
        double n = 0;
        double pheromone;
        int i;
        int j;

        i = Integer.parseInt(cityI);
        j = Integer.parseInt(cityJ);

        pheromone = getPheromone(currentTime, cityI, cityJ);
        n = 1.0 / disMatrix[i][j];

        if (pheromone == 0) {
            pheromone = 1;
        }

        pro = Math.pow(n, alpha) * Math.pow(pheromone, beita);

        return pro;
    }

    /**
     * 计算综合概率蚂蚁从I城市走到J城市的概率
     *
     * @return
     */
    public String selectAntNextCity(Ant ant, int currentTime) {
        double randomNum;
        double tempPro;
        // 总概率指数
        double proTotal;
        String nextCity = null;
        ArrayList<String> allowedCitys;
        // 各城市概率集
        double[] proArray;

        // 如果是刚刚开始的时候,没有路过任何城市,则随机返回一个城市
        if (ant.currentPath.size() == 0) {
            nextCity = String.valueOf(random.nextInt(totalCitys.size()) + 1);

            return nextCity;
        } else if (ant.nonVisitedCitys.isEmpty()) {
            // 如果全部遍历完毕,则再次回到起点
            nextCity = ant.currentPath.get(0);

            return nextCity;
        }

        proTotal = 0;
        allowedCitys = ant.nonVisitedCitys;
        proArray = new double[allowedCitys.size()];

        for (int i = 0; i < allowedCitys.size(); i++) {
            nextCity = allowedCitys.get(i);
            proArray[i] = calIToJProbably(ant.currentPos, nextCity, currentTime);
            proTotal += proArray[i];
        }

        for (int i = 0; i < allowedCitys.size(); i++) {
            // 归一化处理
            proArray[i] /= proTotal;
        }

        // 用随机数选择下一个城市
        randomNum = random.nextInt(100) + 1;
        randomNum = randomNum / 100;
        // 因为1.0是无法判断到的,,总和会无限接近1.0取为0.99做判断
        if (randomNum == 1) {
            randomNum = randomNum - 0.01;
        }

        tempPro = 0;
        // 确定区间
        for (int j = 0; j < allowedCitys.size(); j++) {
            if (randomNum > tempPro && randomNum <= tempPro + proArray[j]) {
                // 采用拷贝的方式避免引用重复
                nextCity = allowedCitys.get(j);
                break;
            } else {
                tempPro += proArray[j];
            }
        }

        return nextCity;
    }

    /**
     * 获取给定时间点上从城市i到城市j的信息素浓度
     *
     * @param t
     * @param cityI
     * @param cityJ
     * @return
     */
    private double getPheromone(int t, String cityI, String cityJ) {
        double pheromone = 0;
        String key;

        // 上一周期需将时间倒回一周期
        key = MessageFormat.format("{0},{1},{2}", cityI, cityJ, t);

        if (pheromoneTimeMap.containsKey(key)) {
            pheromone = pheromoneTimeMap.get(key);
        }

        return pheromone;
    }

    /**
     * 每轮结束,刷新信息素浓度矩阵
     *
     * @param t
     */
    private void refreshPheromone(int t) {
        double pheromone = 0;
        // 上一轮周期结束后的信息素浓度,丛信息素浓度图中查找
        double lastTimeP = 0;
        // 本轮信息素浓度增加量
        double addPheromone;
        String key;

        for (String i : totalCitys) {
            for (String j : totalCitys) {
                if (!i.equals(j)) {
                    // 上一周期需将时间倒回一周期
                    key = MessageFormat.format("{0},{1},{2}", i, j, t - 1);

                    if (pheromoneTimeMap.containsKey(key)) {
                        lastTimeP = pheromoneTimeMap.get(key);
                    } else {
                        lastTimeP = 0;
                    }

                    addPheromone = 0;
                    for (Ant ant : totalAnts) {
                        if(ant.pathContained(i, j)){
                            // 每只蚂蚁传播的信息素为控制因子除以距离总成本
                            addPheromone += Q / ant.calSumDistance();
                        }
                    }

                    // 将上次的结果值加上递增的量,并存入图中
                    pheromone = p * lastTimeP + addPheromone;
                    key = MessageFormat.format("{0},{1},{2}", i, j, t);
                    pheromoneTimeMap.put(key, pheromone);
                }
            }
        }

    }

    /**
     * 蚁群算法迭代次数
     * @param loopCount
     * 具体遍历次数
     */
    public void antStartSearching(int loopCount) {
        // 蚁群寻找的总次数
        int count = 0;
        // 选中的下一个城市
        String selectedCity = "";

        pheromoneTimeMap = new HashMap<String, Double>();
        totalAnts = new ArrayList<>();
        random = new Random();

        while (count < loopCount) {
            initAnts();

            while (true) {
                for (Ant ant : totalAnts) {
                    selectedCity = selectAntNextCity(ant, currentTime);
                    ant.goToNextCity(selectedCity);
                }

                // 如果已经遍历完所有城市,则跳出此轮循环
                if (totalAnts.get(0).isBack()) {
                    break;
                }
            }

            // 周期时间叠加
            currentTime++;
            refreshPheromone(currentTime);
            count++;
        }

        // 根据距离成本,选出所花距离最短的一个路径
        Collections.sort(totalAnts);
        bestPath = totalAnts.get(0).currentPath;
        System.out.println(MessageFormat.format("经过{0}次循环遍历,最终得出的最佳路径:", count));
        System.out.print("entrance");
        for (String cityName : bestPath) {
            System.out.print(MessageFormat.format("-->{0}", cityName));
        }
    }

    /**
     * 初始化蚁群操作
     */
    private void initAnts() {
        Ant tempAnt;
        ArrayList<String> nonVisitedCitys;
        totalAnts.clear();

        // 初始化蚁群
        for (int i = 0; i < antNum; i++) {
            nonVisitedCitys = (ArrayList<String>) totalCitys.clone();
            tempAnt = new Ant(pheromoneMatrix, nonVisitedCitys);

            totalAnts.add(tempAnt);
        }
    }
}

场景测试类Client.java:

package DataMining_ACO;

/**
 * 蚁群算法测试类
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args){
        //测试数据
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
        //蚂蚁数量
        int antNum;
        //蚁群算法迭代次数
        int loopCount;
        //控制参数
        double alpha;
        double beita;
        double p;
        double Q;
        
        antNum = 3;
        alpha = 0.5;
        beita = 1;
        p = 0.5;
        Q = 5;
        loopCount = 5;
        
        ACOTool tool = new ACOTool(filePath, antNum, alpha, beita, p, Q);
        tool.antStartSearching(loopCount);
    }
}

算法的输出,就是在多次搜索之后,找到的路径中最短的一个路径:

经过5次循环遍历,最终得出的最佳路径:
entrance-->4-->1-->2-->3-->4

因为数据量比较小,并不能看出蚁群算法在这方面的优势,博友们可以再次基础上自行改造,并用大一点的数据做测试,其中的4个控制因子也可以调控。蚁群算法作为一种启发式算法,还可以和遗传算法结合,创造出更优的算法。蚁群算法可以解决许多这样的连通图路径优化问题。但是有的时候也会出现搜索时间过长的问题。


参考文献:百度百科.蚁群算法

我的数据挖掘算法库:https://github.com/linyiqun/DataMiningAlgorithm

我的算法库:https://github.com/linyiqun/lyq-algorithms-lib

作者:Androidlushangderen 发表于2015/4/30 15:31:45 原文链接
阅读:925 评论:0 查看评论
从Apriori到MS-Apriori算法
2015年4月16日 22:42
前言

最近的几个月一直在研究和学习各种经典的DM,机器学习的相关算法,收获还是挺多的,另外还整了一个DM算法库,集成了很多数据挖掘算法,放在了我的github上,博友的热度超出我的想象,有很多人给我点了star,在此感谢各大博友们,我将会继续更新我的DM算法库。也许这些算法还不能直接拿来用,但是可以给你提供思路,或变变数据的输入格式就能用了。好,扯得有点远了,现在说正题,本篇文章重新回到讲述Apriori算法,当然我这次不会讲之前说过的Apriori,(比较老套的那些东西网上也很多,我分析的也不一定是最好),本文的主题是Apriori算法的升级版本算法--MS-Apriori。在前面加了Ms,是什么意思呢,他可不是升级的意思,Ms是Mis的缩写,MIS的全称是Min Item Support,最小项目支持度。这有何Apriori算法有什么关系呢,在后面的正文中,我会主要解释这是什么意思,其实这只是其中的一个小的点,Ms-Apriori还是有很多对于Apriori算法的改进的。
Apriori

在了解Ms-Apriori算法之前,还是有必要重新回顾一下Apriori算法,Apriori算法是一种演绎算法,后一次的结果是依赖于上一次的计算结果的,算法的目的就是通过给定的数据挖掘出其中的频繁项,进而推导出关联规则,属于模式挖掘的范畴。Apriori算法的核心步骤可以概括为2个过程,1个是连接运算,1个剪枝运算,这具体的过程就不详细说了,如果想要了解的话,请点击我的Apriori算法分析。尽管Apriori算法在一定的程度上看起来非常的好用,但是他并不是十全十美的,首先在选择的类型上就存在限制,他无法照顾到不同类型的频繁项的挖掘。比如说一些稀有项目的挖掘,比如部分奢侈品。换句话说,如果最小支持度设置的过大,就会导致这些稀有的项集就很难被发现,于是我们就想把这个最小支持度值调得足够小不久OK了吗,事实并非这么简单,支持度调小的话,会造成巨大量的频繁项集合候选项的产生,同时会有大量的一些无关的关联规则被推导出来,当然这个问题就是ms-apriori所要解决的主要问题。下面看看ms-apropri给出了怎么样的解决办法。
Ms-Apriori

Ms-Apriori算法采用另外一种办法,既然统一的支持度值不能兼顾所有的情况,那我可以设置多个支持度值啊,每个种类项都有一个最小支持度阈值,然后一个频繁项的最小支持度阈值取其中项集元素中的最小支持度值作为该项集的最小支持度值。这样的话,如果一个频繁项中出现了稀有项集,则这个项集的最小支持度值就会被拉低,如果又有某个项都是出现频率很高的项构成的话,则支持度阈值又会被拉高。当然,如果出现了一个比较难变态的情况就是,频繁项中同时出现了稀有项和普通项,我们可以通过设置SDC支持度差别限制来限制这种情况的发生,使之挖掘的频繁项更加的合理。通过这里的描述,你就可以发现,当mis最小支持度阈值数组的个数只有1个的时候,ms-apriori算法就退化成了Apriori算法了。

其实ms-apriori算法在某些细节的处理上也有对原先的算法做过一定的优化,这里提及2点。
1、每个候选项的支持度值的统计

原先Apriori算法的操作是扫描整个数据集,进行计数的统计,说白了就是线性扫描一遍,效率自不必说,但是如果你自己思考,其实每次的统计的结果一定不会超过他的上一次子集的结果值,因为他是从上一次的计算过程演绎而来的,当前项集的结果是包含了子项集的结果的,所以改进的算法是每次从子集的计数量中再次计算支持度值,具体操作详见后面我的代码实现,效率还是提高了不少。
2、第二是关联规则的推导

找到了所有的频繁项,找出其中的关联规则最笨的办法就是一个个去算置信度,然后输出满足要求条件的规则,但是其实这里面也包含有上条规则中类似的特点,举个例子,如果已经有一条规则,{1}-->{2, 3, 4},代表在存在1的情况下能退出2,3,4,的存在,那么我们就一定能退出{1, 2}--->{3, 4},因为这是后者的情况其实是被包含于前者的情况的,如果你还不能理解,代入置信度计算的公式,分子相同的情况下,{1,2}发生的情况数一定小于或等于{1}的情况,于是整个置信度必定{1,2}的大于{1}的情况。
关联规则挖掘的数据格式

这里再随便说说关联规则的数据格式,也许在很多书中,用于进行Apriori这类算法的测试的数据都是事务型的数据,其实不是的关系表型的数据同样可以做关联规则的挖掘,不过这需要经过一步预处理的方式,让机器能够更好的识别,推荐一种常见的做法,就是采用属性名+属性值的方式,单单用属性值是不够的,因为属性值是在不同的属性中可能会有重,这点在CBA(基于关联规则分类算法)中也提到过一些,具体的可以查阅CBA基于关联规则分类。

MS-Apriori算法的代码实现

算法的测试我采用了2种类型数据做测试一种是事务型数据,一种是非事务型的数据,输入分别如下:

input.txt:

T1 1 2 5
T2 2 4
T3 2 3
T4 1 2 4
T5 1 3
T6 2 3
T7 1 3
T8 1 2 3 5
T9 1 2 3

input2.txt

Rid Age Income Student CreditRating BuysComputer
1 Youth High No Fair No
2 Youth High No Excellent No
3 MiddleAged High No Fair Yes
4 Senior Medium No Fair Yes
5 Senior Low Yes Fair Yes
6 Senior Low Yes Excellent No
7 MiddleAged Low Yes Excellent Yes
8 Youth Medium No Fair No
9 Youth Low Yes Fair Yes
10 Senior Medium Yes Fair Yes
11 Youth Medium Yes Excellent Yes
12 MiddleAged Medium No Excellent Yes
13 MiddleAged High Yes Fair Yes
14 Senior Medium No Excellent No

算法工具类MSAprioriTool.java:

package DataMining_MSApriori;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import DataMining_Apriori.FrequentItem;

/**
 * 基于多支持度的Apriori算法工具类
 *
 * @author lyq
 *
 */
public class MSAprioriTool {
    // 前件判断的结果值,用于关联规则的推导
    public static final int PREFIX_NOT_SUB = -1;
    public static final int PREFIX_EQUAL = 1;
    public static final int PREFIX_IS_SUB = 2;

    // 是否读取的是事务型数据
    private boolean isTransaction;
    // 最大频繁k项集的k值
    private int initFItemNum;
    // 事务数据文件地址
    private String filePath;
    // 最小支持度阈值
    private double minSup;
    // 最小置信度率
    private double minConf;
    // 最大支持度差别阈值
    private double delta;
    // 多项目的最小支持度数,括号中的下标代表的是商品的ID
    private double[] mis;
    // 每个事务中的商品ID
    private ArrayList<String[]> totalGoodsIDs;
    // 关系表数据所转化的事务数据
    private ArrayList<String[]> transactionDatas;
    // 过程中计算出来的所有频繁项集列表
    private ArrayList<FrequentItem> resultItem;
    // 过程中计算出来频繁项集的ID集合
    private ArrayList<String[]> resultItemID;
    // 属性到数字的映射图
    private HashMap<String, Integer> attr2Num;
    // 数字id对应属性的映射图
    private HashMap<Integer, String> num2Attr;
    // 频繁项集所覆盖的id数值
    private Map<String, int[]> fItem2Id;

    /**
     * 事务型数据关联挖掘算法
     *
     * @param filePath
     * @param minConf
     * @param delta
     * @param mis
     * @param isTransaction
     */
    public MSAprioriTool(String filePath, double minConf, double delta,
            double[] mis, boolean isTransaction) {
        this.filePath = filePath;
        this.minConf = minConf;
        this.delta = delta;
        this.mis = mis;
        this.isTransaction = isTransaction;
        this.fItem2Id = new HashMap<>();

        readDataFile();
    }

    /**
     * 非事务型关联挖掘
     *
     * @param filePath
     * @param minConf
     * @param minSup
     * @param isTransaction
     */
    public MSAprioriTool(String filePath, double minConf, double minSup,
            boolean isTransaction) {
        this.filePath = filePath;
        this.minConf = minConf;
        this.minSup = minSup;
        this.isTransaction = isTransaction;
        this.delta = 1.0;
        this.fItem2Id = new HashMap<>();

        readRDBMSData(filePath);
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        String[] temp = null;
        ArrayList<String[]> dataArray;

        dataArray = readLine(filePath);
        totalGoodsIDs = new ArrayList<>();

        for (String[] array : dataArray) {
            temp = new String[array.length - 1];
            System.arraycopy(array, 1, temp, 0, array.length - 1);

            // 将事务ID加入列表吧中
            totalGoodsIDs.add(temp);
        }
    }

    /**
     * 从文件中逐行读数据
     *
     * @param filePath
     *            数据文件地址
     * @return
     */
    private ArrayList<String[]> readLine(String filePath) {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        return dataArray;
    }

    /**
     * 计算频繁项集
     */
    public void calFItems() {
        FrequentItem fItem;

        computeLink();
        printFItems();

        if (isTransaction) {
            fItem = resultItem.get(resultItem.size() - 1);
            // 取出最后一个频繁项集做关联规则的推导
            System.out.println("最后一个频繁项集做关联规则的推导结果:");
            printAttachRuls(fItem.getIdArray());
        }
    }

    /**
     * 输出频繁项集
     */
    private void printFItems() {
        if (isTransaction) {
            System.out.println("事务型数据频繁项集输出结果:");
        } else {
            System.out.println("非事务(关系)型数据频繁项集输出结果:");
        }

        // 输出频繁项集
        for (int k = 1; k <= initFItemNum; k++) {
            System.out.println("频繁" + k + "项集:");
            for (FrequentItem i : resultItem) {
                if (i.getLength() == k) {
                    System.out.print("{");
                    for (String t : i.getIdArray()) {
                        if (!isTransaction) {
                            // 如果原本是非事务型数据,需要重新做替换
                            t = num2Attr.get(Integer.parseInt(t));
                        }

                        System.out.print(t + ",");
                    }
                    System.out.print("},");
                }
            }
            System.out.println();
        }
    }

    /**
     * 项集进行连接运算
     */
    private void computeLink() {
        // 连接计算的终止数,k项集必须算到k-1子项集为止
        int endNum = 0;
        // 当前已经进行连接运算到几项集,开始时就是1项集
        int currentNum = 1;
        // 商品,1频繁项集映射图
        HashMap<String, FrequentItem> itemMap = new HashMap<>();
        FrequentItem tempItem;
        // 初始列表
        ArrayList<FrequentItem> list = new ArrayList<>();
        // 经过连接运算后产生的结果项集
        resultItem = new ArrayList<>();
        resultItemID = new ArrayList<>();
        // 商品ID的种类
        ArrayList<String> idType = new ArrayList<>();
        for (String[] a : totalGoodsIDs) {
            for (String s : a) {
                if (!idType.contains(s)) {
                    tempItem = new FrequentItem(new String[] { s }, 1);
                    idType.add(s);
                    resultItemID.add(new String[] { s });
                } else {
                    // 支持度计数加1
                    tempItem = itemMap.get(s);
                    tempItem.setCount(tempItem.getCount() + 1);
                }
                itemMap.put(s, tempItem);
            }
        }
        // 将初始频繁项集转入到列表中,以便继续做连接运算
        for (Map.Entry<String, FrequentItem> entry : itemMap.entrySet()) {
            tempItem = entry.getValue();

            // 判断1频繁项集是否满足支持度阈值的条件
            if (judgeFItem(tempItem.getIdArray())) {
                list.add(tempItem);
            }
        }

        // 按照商品ID进行排序,否则连接计算结果将会不一致,将会减少
        Collections.sort(list);
        resultItem.addAll(list);

        String[] array1;
        String[] array2;
        String[] resultArray;
        ArrayList<String> tempIds;
        ArrayList<String[]> resultContainer;
        // 总共要算到endNum项集
        endNum = list.size() - 1;
        initFItemNum = list.size() - 1;

        while (currentNum < endNum) {
            resultContainer = new ArrayList<>();
            for (int i = 0; i < list.size() - 1; i++) {
                tempItem = list.get(i);
                array1 = tempItem.getIdArray();

                for (int j = i + 1; j < list.size(); j++) {
                    tempIds = new ArrayList<>();
                    array2 = list.get(j).getIdArray();

                    for (int k = 0; k < array1.length; k++) {
                        // 如果对应位置上的值相等的时候,只取其中一个值,做了一个连接删除操作
                        if (array1[k].equals(array2[k])) {
                            tempIds.add(array1[k]);
                        } else {
                            tempIds.add(array1[k]);
                            tempIds.add(array2[k]);
                        }
                    }

                    resultArray = new String[tempIds.size()];
                    tempIds.toArray(resultArray);

                    boolean isContain = false;
                    // 过滤不符合条件的的ID数组,包括重复的和长度不符合要求的
                    if (resultArray.length == (array1.length + 1)) {
                        isContain = isIDArrayContains(resultContainer,
                                resultArray);
                        if (!isContain) {
                            resultContainer.add(resultArray);
                        }
                    }
                }
            }

            // 做频繁项集的剪枝处理,必须保证新的频繁项集的子项集也必须是频繁项集
            list = cutItem(resultContainer);
            currentNum++;
        }
    }

    /**
     * 对频繁项集做剪枝步骤,必须保证新的频繁项集的子项集也必须是频繁项集
     */
    private ArrayList<FrequentItem> cutItem(ArrayList<String[]> resultIds) {
        String[] temp;
        // 忽略的索引位置,以此构建子集
        int igNoreIndex = 0;
        FrequentItem tempItem;
        // 剪枝生成新的频繁项集
        ArrayList<FrequentItem> newItem = new ArrayList<>();
        // 不符合要求的id
        ArrayList<String[]> deleteIdArray = new ArrayList<>();
        // 子项集是否也为频繁子项集
        boolean isContain = true;

        for (String[] array : resultIds) {
            // 列举出其中的一个个的子项集,判断存在于频繁项集列表中
            temp = new String[array.length - 1];
            for (igNoreIndex = 0; igNoreIndex < array.length; igNoreIndex++) {
                isContain = true;
                for (int j = 0, k = 0; j < array.length; j++) {
                    if (j != igNoreIndex) {
                        temp[k] = array[j];
                        k++;
                    }
                }

                if (!isIDArrayContains(resultItemID, temp)) {
                    isContain = false;
                    break;
                }
            }

            if (!isContain) {
                deleteIdArray.add(array);
            }
        }

        // 移除不符合条件的ID组合
        resultIds.removeAll(deleteIdArray);

        // 移除支持度计数不够的id集合
        int tempCount = 0;
        boolean isSatisfied = false;
        for (String[] array : resultIds) {
            isSatisfied = judgeFItem(array);

            // 如果此频繁项集满足多支持度阈值限制条件和支持度差别限制条件,则添加入结果集中
            if (isSatisfied) {
                tempItem = new FrequentItem(array, tempCount);
                newItem.add(tempItem);
                resultItemID.add(array);
                resultItem.add(tempItem);
            }
        }

        return newItem;
    }

    /**
     * 判断列表结果中是否已经包含此数组
     *
     * @param container
     *            ID数组容器
     * @param array
     *            待比较数组
     * @return
     */
    private boolean isIDArrayContains(ArrayList<String[]> container,
            String[] array) {
        boolean isContain = true;
        if (container.size() == 0) {
            isContain = false;
            return isContain;
        }

        for (String[] s : container) {
            // 比较的视乎必须保证长度一样
            if (s.length != array.length) {
                continue;
            }

            isContain = true;
            for (int i = 0; i < s.length; i++) {
                // 只要有一个id不等,就算不相等
                if (s[i] != array[i]) {
                    isContain = false;
                    break;
                }
            }

            // 如果已经判断是包含在容器中时,直接退出
            if (isContain) {
                break;
            }
        }

        return isContain;
    }

    /**
     * 判断一个频繁项集是否满足条件
     *
     * @param frequentItem
     *            待判断频繁项集
     * @return
     */
    private boolean judgeFItem(String[] frequentItem) {
        boolean isSatisfied = true;
        int id;
        int count;
        double tempMinSup;
        // 最小的支持度阈值
        double minMis = Integer.MAX_VALUE;
        // 最大的支持度阈值
        double maxMis = -Integer.MAX_VALUE;

        // 如果是事务型数据,用mis数组判断,如果不是统一用同样的最小支持度阈值判断
        if (isTransaction) {
            // 寻找频繁项集中的最小支持度阈值
            for (int i = 0; i < frequentItem.length; i++) {
                id = i + 1;

                if (mis[id] < minMis) {
                    minMis = mis[id];
                }

                if (mis[id] > maxMis) {
                    maxMis = mis[id];
                }
            }
        } else {
            minMis = minSup;
            maxMis = minSup;
        }

        count = calSupportCount(frequentItem);
        tempMinSup = 1.0 * count / totalGoodsIDs.size();
        // 判断频繁项集的支持度阈值是否超过最小的支持度阈值
        if (tempMinSup < minMis) {
            isSatisfied = false;
        }

        // 如果误差超过了最大支持度差别,也算不满足条件
        if (Math.abs(maxMis - minMis) > delta) {
            isSatisfied = false;
        }

        return isSatisfied;
    }

    /**
     * 统计候选频繁项集的支持度数,利用他的子集进行技术,无须扫描整个数据集
     *
     * @param frequentItem
     *            待计算频繁项集
     * @return
     */
    private int calSupportCount(String[] frequentItem) {
        int count = 0;
        int[] ids;
        String key;
        String[] array;
        ArrayList<Integer> newIds;

        key = "";
        for (int i = 1; i < frequentItem.length; i++) {
            key += frequentItem[i];
        }

        newIds = new ArrayList<>();
        // 找出所属的事务ID
        ids = fItem2Id.get(key);

        // 如果没有找到子项集的事务id,则全盘扫描数据集
        if (ids == null || ids.length == 0) {
            for (int j = 0; j < totalGoodsIDs.size(); j++) {
                array = totalGoodsIDs.get(j);
                if (isStrArrayContain(array, frequentItem)) {
                    count++;
                    newIds.add(j);
                }
            }
        } else {
            for (int index : ids) {
                array = totalGoodsIDs.get(index);
                if (isStrArrayContain(array, frequentItem)) {
                    count++;
                    newIds.add(index);
                }
            }
        }

        ids = new int[count];
        for (int i = 0; i < ids.length; i++) {
            ids[i] = newIds.get(i);
        }

        key = frequentItem[0] + key;
        // 将所求值存入图中,便于下次的计数
        fItem2Id.put(key, ids);

        return count;
    }

    /**
     * 根据给定的频繁项集输出关联规则
     *
     * @param frequentItems
     *            频繁项集
     */
    public void printAttachRuls(String[] frequentItem) {
        // 关联规则前件,后件对
        Map<ArrayList<String>, ArrayList<String>> rules;
        // 前件搜索历史
        Map<ArrayList<String>, ArrayList<String>> searchHistory;
        ArrayList<String> prefix;
        ArrayList<String> suffix;

        rules = new HashMap<ArrayList<String>, ArrayList<String>>();
        searchHistory = new HashMap<>();

        for (int i = 0; i < frequentItem.length; i++) {
            suffix = new ArrayList<>();
            for (int j = 0; j < frequentItem.length; j++) {
                suffix.add(frequentItem[j]);
            }
            prefix = new ArrayList<>();

            recusiveFindRules(rules, searchHistory, prefix, suffix);
        }

        // 依次输出找到的关联规则
        for (Map.Entry<ArrayList<String>, ArrayList<String>> entry : rules
                .entrySet()) {
            prefix = entry.getKey();
            suffix = entry.getValue();

            printRuleDetail(prefix, suffix);
        }
    }

    /**
     * 根据前件后件,输出关联规则
     *
     * @param prefix
     * @param suffix
     */
    private void printRuleDetail(ArrayList<String> prefix,
            ArrayList<String> suffix) {
        // {A}-->{B}的意思为在A的情况下发生B的概率
        System.out.print("{");
        for (String s : prefix) {
            System.out.print(s + ", ");
        }
        System.out.print("}-->");
        System.out.print("{");
        for (String s : suffix) {
            System.out.print(s + ", ");
        }
        System.out.println("}");
    }

    /**
     * 递归扩展关联规则解
     *
     * @param rules
     *            关联规则结果集
     * @param history
     *            前件搜索历史
     * @param prefix
     *            关联规则前件
     * @param suffix
     *            关联规则后件
     */
    private void recusiveFindRules(
            Map<ArrayList<String>, ArrayList<String>> rules,
            Map<ArrayList<String>, ArrayList<String>> history,
            ArrayList<String> prefix, ArrayList<String> suffix) {
        int count1;
        int count2;
        int compareResult;
        // 置信度大小
        double conf;
        String[] temp1;
        String[] temp2;
        ArrayList<String> copyPrefix;
        ArrayList<String> copySuffix;

        // 如果后件只有1个,则函数返回
        if (suffix.size() == 1) {
            return;
        }

        for (String s : suffix) {
            count1 = 0;
            count2 = 0;

            copyPrefix = (ArrayList<String>) prefix.clone();
            copyPrefix.add(s);

            copySuffix = (ArrayList<String>) suffix.clone();
            // 将拷贝的后件移除添加的一项
            copySuffix.remove(s);

            compareResult = isSubSetInRules(history, copyPrefix);
            if (compareResult == PREFIX_EQUAL) {
                // 如果曾经已经被搜索过,则跳过
                continue;
            }

            // 判断是否为子集,如果是子集则无需计算
            compareResult = isSubSetInRules(rules, copyPrefix);
            if (compareResult == PREFIX_IS_SUB) {
                rules.put(copyPrefix, copySuffix);
                // 加入到搜索历史中
                history.put(copyPrefix, copySuffix);
                recusiveFindRules(rules, history, copyPrefix, copySuffix);
                continue;
            }

            // 暂时合并为总的集合
            copySuffix.addAll(copyPrefix);
            temp1 = new String[copyPrefix.size()];
            temp2 = new String[copySuffix.size()];
            copyPrefix.toArray(temp1);
            copySuffix.toArray(temp2);
            // 之后再次移除之前天剑的前件
            copySuffix.removeAll(copyPrefix);

            for (String[] a : totalGoodsIDs) {
                if (isStrArrayContain(a, temp1)) {
                    count1++;

                    // 在group1的条件下,统计group2的事件发生次数
                    if (isStrArrayContain(a, temp2)) {
                        count2++;
                    }
                }
            }

            conf = 1.0 * count2 / count1;
            if (conf > minConf) {
                // 设置此前件条件下,能导出关联规则
                rules.put(copyPrefix, copySuffix);
            }

            // 加入到搜索历史中
            history.put(copyPrefix, copySuffix);
            recusiveFindRules(rules, history, copyPrefix, copySuffix);
        }
    }

    /**
     * 判断当前的前件是否会关联规则的子集
     *
     * @param rules
     *            当前已经判断出的关联规则
     * @param prefix
     *            待判断的前件
     * @return
     */
    private int isSubSetInRules(
            Map<ArrayList<String>, ArrayList<String>> rules,
            ArrayList<String> prefix) {
        int result = PREFIX_NOT_SUB;
        String[] temp1;
        String[] temp2;
        ArrayList<String> tempPrefix;

        for (Map.Entry<ArrayList<String>, ArrayList<String>> entry : rules
                .entrySet()) {
            tempPrefix = entry.getKey();

            temp1 = new String[tempPrefix.size()];
            temp2 = new String[prefix.size()];

            tempPrefix.toArray(temp1);
            prefix.toArray(temp2);

            // 判断当前构造的前件是否已经是存在前件的子集
            if (isStrArrayContain(temp2, temp1)) {
                if (temp2.length == temp1.length) {
                    result = PREFIX_EQUAL;
                } else {
                    result = PREFIX_IS_SUB;
                }
            }

            if (result == PREFIX_EQUAL) {
                break;
            }
        }

        return result;
    }

    /**
     * 数组array2是否包含于array1中,不需要完全一样
     *
     * @param array1
     * @param array2
     * @return
     */
    private boolean isStrArrayContain(String[] array1, String[] array2) {
        boolean isContain = true;
        for (String s2 : array2) {
            isContain = false;
            for (String s1 : array1) {
                // 只要s2字符存在于array1中,这个字符就算包含在array1中
                if (s2.equals(s1)) {
                    isContain = true;
                    break;
                }
            }

            // 一旦发现不包含的字符,则array2数组不包含于array1中
            if (!isContain) {
                break;
            }
        }

        return isContain;
    }

    /**
     * 读关系表中的数据,并转化为事务数据
     *
     * @param filePath
     */
    private void readRDBMSData(String filePath) {
        String str;
        // 属性名称行
        String[] attrNames = null;
        String[] temp;
        String[] newRecord;
        ArrayList<String[]> datas = null;

        datas = readLine(filePath);

        // 获取首行
        attrNames = datas.get(0);
        this.transactionDatas = new ArrayList<>();

        // 去除首行数据
        for (int i = 1; i < datas.size(); i++) {
            temp = datas.get(i);

            // 过滤掉首列id列
            for (int j = 1; j < temp.length; j++) {
                str = "";
                // 采用属性名+属性值的形式避免数据的重复
                str = attrNames[j] + ":" + temp[j];
                temp[j] = str;
            }

            newRecord = new String[attrNames.length - 1];
            System.arraycopy(temp, 1, newRecord, 0, attrNames.length - 1);
            this.transactionDatas.add(newRecord);
        }

        attributeReplace();
        // 将事务数转到totalGoodsID中做统一处理
        this.totalGoodsIDs = transactionDatas;
    }

    /**
     * 属性值的替换,替换成数字的形式,以便进行频繁项的挖掘
     */
    private void attributeReplace() {
        int currentValue = 1;
        String s;
        // 属性名到数字的映射图
        attr2Num = new HashMap<>();
        num2Attr = new HashMap<>();

        // 按照1列列的方式来,从左往右边扫描,跳过列名称行和id列
        for (int j = 0; j < transactionDatas.get(0).length; j++) {
            for (int i = 0; i < transactionDatas.size(); i++) {
                s = transactionDatas.get(i)[j];

                if (!attr2Num.containsKey(s)) {
                    attr2Num.put(s, currentValue);
                    num2Attr.put(currentValue, s);

                    transactionDatas.get(i)[j] = currentValue + "";
                    currentValue++;
                } else {
                    transactionDatas.get(i)[j] = attr2Num.get(s) + "";
                }
            }
        }
    }
}


频繁项集类FrequentItem.java:

package DataMining_MSApriori;

/**
 * 频繁项集
 *
 * @author lyq
 *
 */
public class FrequentItem implements Comparable<FrequentItem>{
    // 频繁项集的集合ID
    private String[] idArray;
    // 频繁项集的支持度计数
    private int count;
    //频繁项集的长度,1项集或是2项集,亦或是3项集
    private int length;
    
    public FrequentItem(String[] idArray, int count){
        this.idArray = idArray;
        this.count = count;
        length = idArray.length;
    }

    public String[] getIdArray() {
        return idArray;
    }

    public void setIdArray(String[] idArray) {
        this.idArray = idArray;
    }

    public int getCount() {
        return count;
    }

    public void setCount(int count) {
        this.count = count;
    }

    public int getLength() {
        return length;
    }

    public void setLength(int length) {
        this.length = length;
    }

    @Override
    public int compareTo(FrequentItem o) {
        // TODO Auto-generated method stub
        Integer int1 = Integer.parseInt(this.getIdArray()[0]);
        Integer int2 = Integer.parseInt(o.getIdArray()[0]);
        
        return int1.compareTo(int2);
    }
    
}

测试类Client.java:

package DataMining_MSApriori;

/**
 * 基于多支持度的Apriori算法测试类
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args){
        //是否是事务型数据
        boolean isTransaction;
        //测试数据文件地址
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
        //关系表型数据文件地址
        String tableFilePath = "C:\\Users\\lyq\\Desktop\\icon\\input2.txt";
        //最小支持度阈值
        double minSup;
        // 最小置信度率
        double minConf;
        //最大支持度差别阈值
        double delta;
        //多项目的最小支持度数,括号中的下标代表的是商品的ID
        double[] mis;
        //msApriori算法工具类
        MSAprioriTool tool;
        
        //为了测试的方便,取一个偏低的置信度值0.3
        minConf = 0.3;
        minSup = 0.1;
        delta = 0.5;
        //每项的支持度率都默认为0.1,第一项不使用
        mis = new double[]{-1, 0.1, 0.1, 0.1, 0.1, 0.1};
        isTransaction = true;
        
        isTransaction = true;
        tool = new MSAprioriTool(filePath, minConf, delta, mis, isTransaction);
        tool.calFItems();
        System.out.println();
        
        isTransaction = false;
        //重新初始化数据
        tool = new MSAprioriTool(tableFilePath, minConf, minSup, isTransaction);
        tool.calFItems();
    }    
}

算法输出(输出的内容有点多):

事务型数据频繁项集输出结果:
频繁1项集:
{1,},{2,},{3,},{4,},{5,},
频繁2项集:
{1,2,},{1,3,},{1,4,},{1,5,},{2,3,},{2,4,},{2,5,},{3,5,},
频繁3项集:
{1,2,3,},{1,2,4,},{1,2,5,},{1,3,5,},{2,3,5,},
频繁4项集:
{1,2,3,5,},
最后一个频繁项集做关联规则的推导结果:
{2, 5, }-->{1, 3, }
{5, }-->{1, 2, 3, }
{3, 5, }-->{1, 2, }
{1, 5, }-->{2, 3, }
{2, 3, 5, }-->{1, }
{1, 2, 5, }-->{3, }
{1, 3, 5, }-->{2, }
{1, 2, 3, }-->{5, }

非事务(关系)型数据频繁项集输出结果:
频繁1项集:
{Age:Youth,},{Age:MiddleAged,},{Age:Senior,},{Income:High,},{Income:Medium,},{Income:Low,},{Student:No,},{Student:Yes,},{CreditRating:Fair,},{CreditRating:Excellent,},{BuysComputer:No,},{BuysComputer:Yes,},
频繁2项集:
{Age:Youth,Income:High,},{Age:Youth,Income:Medium,},{Age:Youth,Student:No,},{Age:Youth,Student:Yes,},{Age:Youth,CreditRating:Fair,},{Age:Youth,CreditRating:Excellent,},{Age:Youth,BuysComputer:No,},{Age:Youth,BuysComputer:Yes,},{Age:MiddleAged,Income:High,},{Age:MiddleAged,Student:No,},{Age:MiddleAged,Student:Yes,},{Age:MiddleAged,CreditRating:Fair,},{Age:MiddleAged,CreditRating:Excellent,},{Age:MiddleAged,BuysComputer:Yes,},{Age:Senior,Income:Medium,},{Age:Senior,Income:Low,},{Age:Senior,Student:No,},{Age:Senior,Student:Yes,},{Age:Senior,CreditRating:Fair,},{Age:Senior,CreditRating:Excellent,},{Age:Senior,BuysComputer:No,},{Age:Senior,BuysComputer:Yes,},{Income:High,Student:No,},{Income:High,CreditRating:Fair,},{Income:High,BuysComputer:No,},{Income:High,BuysComputer:Yes,},{Income:Medium,Student:No,},{Income:Medium,Student:Yes,},{Income:Medium,CreditRating:Fair,},{Income:Medium,CreditRating:Excellent,},{Income:Medium,BuysComputer:No,},{Income:Medium,BuysComputer:Yes,},{Income:Low,Student:Yes,},{Income:Low,CreditRating:Fair,},{Income:Low,CreditRating:Excellent,},{Income:Low,BuysComputer:Yes,},{Student:No,CreditRating:Fair,},{Student:No,CreditRating:Excellent,},{Student:No,BuysComputer:No,},{Student:No,BuysComputer:Yes,},{Student:Yes,CreditRating:Fair,},{Student:Yes,CreditRating:Excellent,},{Student:Yes,BuysComputer:Yes,},{CreditRating:Fair,BuysComputer:No,},{CreditRating:Fair,BuysComputer:Yes,},{CreditRating:Excellent,BuysComputer:No,},{CreditRating:Excellent,BuysComputer:Yes,},
频繁3项集:
{Age:Youth,Income:High,Student:No,},{Age:Youth,Income:High,BuysComputer:No,},{Age:Youth,Student:No,CreditRating:Fair,},{Age:Youth,Student:No,BuysComputer:No,},{Age:Youth,Student:Yes,BuysComputer:Yes,},{Age:Youth,CreditRating:Fair,BuysComputer:No,},{Age:MiddleAged,Income:High,CreditRating:Fair,},{Age:MiddleAged,Income:High,BuysComputer:Yes,},{Age:MiddleAged,Student:No,BuysComputer:Yes,},{Age:MiddleAged,Student:Yes,BuysComputer:Yes,},{Age:MiddleAged,CreditRating:Fair,BuysComputer:Yes,},{Age:MiddleAged,CreditRating:Excellent,BuysComputer:Yes,},{Age:Senior,Income:Medium,Student:No,},{Age:Senior,Income:Medium,CreditRating:Fair,},{Age:Senior,Income:Medium,BuysComputer:Yes,},{Age:Senior,Income:Low,Student:Yes,},{Age:Senior,Student:Yes,CreditRating:Fair,},{Age:Senior,Student:Yes,BuysComputer:Yes,},{Age:Senior,CreditRating:Fair,BuysComputer:Yes,},{Age:Senior,CreditRating:Excellent,BuysComputer:No,},{Income:High,Student:No,CreditRating:Fair,},{Income:High,Student:No,BuysComputer:No,},{Income:High,CreditRating:Fair,BuysComputer:Yes,},{Income:Medium,Student:No,CreditRating:Fair,},{Income:Medium,Student:No,CreditRating:Excellent,},{Income:Medium,Student:No,BuysComputer:No,},{Income:Medium,Student:No,BuysComputer:Yes,},{Income:Medium,Student:Yes,BuysComputer:Yes,},{Income:Medium,CreditRating:Fair,BuysComputer:Yes,},{Income:Medium,CreditRating:Excellent,BuysComputer:Yes,},{Income:Low,Student:Yes,CreditRating:Fair,},{Income:Low,Student:Yes,CreditRating:Excellent,},{Income:Low,Student:Yes,BuysComputer:Yes,},{Income:Low,CreditRating:Fair,BuysComputer:Yes,},{Student:No,CreditRating:Fair,BuysComputer:No,},{Student:No,CreditRating:Fair,BuysComputer:Yes,},{Student:No,CreditRating:Excellent,BuysComputer:No,},{Student:Yes,CreditRating:Fair,BuysComputer:Yes,},{Student:Yes,CreditRating:Excellent,BuysComputer:Yes,},
频繁4项集:
{Age:Youth,Income:High,Student:No,BuysComputer:No,},{Age:Youth,Student:No,CreditRating:Fair,BuysComputer:No,},{Age:MiddleAged,Income:High,CreditRating:Fair,BuysComputer:Yes,},{Age:Senior,Income:Medium,CreditRating:Fair,BuysComputer:Yes,},{Age:Senior,Student:Yes,CreditRating:Fair,BuysComputer:Yes,},{Income:Low,Student:Yes,CreditRating:Fair,BuysComputer:Yes,},
频繁5项集:

频繁6项集:

频繁7项集:

频繁8项集:

频繁9项集:

频繁10项集:

频繁11项集:


参考文献:刘兵.<<Web数据挖掘>> 第一部分.第二章.关联规则和序列模式

我的数据挖掘算法库:https://github.com/linyiqun/DataMiningAlgorithm

我的算法库:https://github.com/linyiqun/lyq-algorithms-lib
作者:Androidlushangderen 发表于2015/4/16 22:42:53 原文链接
阅读:594 评论:0 查看评论
多维空间分割树--KD树
2015年4月10日 21:39
算法介绍

KD树的全称为k-Dimension Tree的简称,是一种分割K维空间的数据结构,主要应用于关键信息的搜索。为什么说是K维的呢,因为这时候的空间不仅仅是2维度的,他可能是3维,4维度的或者是更多。我们举个例子,如果是二维的空间,对于其中的空间进行分割的就是一条条的分割线,比如说下面这个样子。


如果是3维的呢,那么分割的媒介就是一个平面了,下面是3维空间的分割


这就稍稍有点抽象了,如果是3维以上,我们把这样的分割媒介可以统统叫做超平面 。那么KD树算法有什么特别之处呢,还有他与K-NN算法之间又有什么关系呢,这将是下面所将要描述的。
KNN

KNN就是K最近邻算法,他是一个分类算法,因为算法简单,分类效果也还不错,也被许多人使用着,算法的原理就是选出与给定数据最近的k个数据,然后根据k个数据中占比最多的分类作为测试数据的最终分类。图示如下:


算法固然简单,但是其中通过逐个去比较的办法求得最近的k个数据点,效率太低,时间复杂度会随着训练数据数量的增多而线性增长。于是就需要一种更加高效快速的办法来找到所给查询点的最近邻,而KD树就是其中的一种行之有效的办法。但是不管是KNN算法还是KD树算法,他们都属于相似性查询中的K近邻查询的范畴。在相似性查询算法中还有一类查询是范围查询,就是给定距离阈值和查询点,dbscan算法可以说是一种范围查询,基于给定点进行局部密度范围的搜索。想要了解KNN算法或者是Dbscan算法的可以点击我的K-最近邻算法和Dbscan基于密度的聚类算法。
KD-Tree

在KNN算法中,针对查询点数据的查找采用的是线性扫描的方法,说白了就是暴力比较,KD树在这方面用了二分划分的思想,将数据进行逐层空间上的划分,大大的提高了查询的速度,可以理解为一个变形的二分搜索时间,只不过这个适用到了多维空间的层次上。下面是二维空间的情况下,数据的划分结果:


现在看到的图在逻辑上的意思就是一棵完整的二叉树,虚线上的点是叶子节点。
KD树的算法原理

KD树的算法的实现原理并不是那么好理解,主要分为树的构建和基于KD树进行最近邻的查询2个过程,后者比前者更加复杂。当然,要想实现最近点的查询,首先我们得先理解KD树的构建过程。下面是KD树节点的定义,摘自百度百科:

域名
    
数据类型
    
描述
Node-data
    
数据矢量
    
数据集中某个数据点,是n维矢量(这里也就是k维)
Range
    
空间矢量
    
该节点所代表的空间范围
split
    
整数
    
垂直于分割超平面的方向轴序号
Left
    
k-d树
    
由位于该节点分割超平面左子空间内所有数据点所构成的k-d树
Right
    
k-d树
    
由位于该节点分割超平面右子空间内所有数据点所构成的k-d树
parent
    
k-d树
    
父节点

变量还是有点多的,节点中有孩子节点和父亲节点,所以必然会用到递归。KD树的构建算法过程如下(这里假设构建的是2维KD树,简单易懂,后续同上):

1、首先将数据节点坐标中的X坐标和Y坐标进行方差计算,选出其中方差大的,作为分割线的方向,就是接下来将要创建点的split值。

2、将上面的数据点按照分割方向的维度进行排序,选出其中的中位数的点作为数据矢量,就是要分割的分割点。

3、同时进行空间矢量的再次划分,要在父亲节点的空间范围内再进行子分割,就是Range变量,不理解的话,可以阅读我的代码加以理解。

4、对剩余的节点进行左侧空间和右侧空间的分割,进行左孩子和右孩子节点的分割。

5、分割的终点是最终只剩下1个数据点或一侧没有数据点的情况。

在这里举个例子,给定6个数据点:

(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)

对这6个数据点进行最终的KD树的构建效果图如下,左边是实际分割效果,右边是所构成的KD树:

       

x,y代表的是当前节点的分割方向。读者可以进行手动计算并验证,本人不再加以描述。

KD树构建完毕,之后就是对于给定查询点数据,进行此空间数据的最近数据点,大致过程如下:

1、从根节点开始,从上往下,根据分割方向,在对应维度的坐标点上,进行树的顺序查找,比如给定(3,1),首先来到(7,2),因为根节点的划分方向为X,因此只比较X坐标的划分,因为3<7,所以往左边走,后续的节点同样的道理,最终到达叶子节点为止。

2、当然以这种方式找到的点并不一定是最近的,也许在父节点的另外一个空间内存在更近的点呢,或者说另外一种情况,当前的叶子节点的父亲节点比叶子节点离查询点更近呢,这也是有可能的。

3、所以这个过程会有回溯的步骤,回溯到父节点时候,需要做2点,第一要和父节点比,谁里查询点更近,如果父节点更近,则更改当前找到的最近点,第二以查询点为圆心,当前查询点与最近点的距离为半径画个圆,判断是否与父节点的分割线是否相交,如果相交,则说明有存在父节点另外的孩子空间存在于查询距离更短的点,然后进行父节点空间的又一次深度优先遍历。在局部的遍历查找完毕,在于当前的最近点做比较,比较完之后,继续往上回溯。

下面给出基于上面例子的2个测试例子,查询点为(2.1,3.1)和(2,4.5),前者的例子用于理解一般过程,后面的测试点真正诠释了递归,回溯的过程。先看下(2.1,3.1)的情况:


因为没有碰到任何的父节点分割边界,所以就一直回溯到根节点,最近的节点就是叶子节点(2,3).下面(2,4.5)是需要重点理解的例子,中间出现了一次回溯,和一次再搜索:


在第一次回溯的时候,发现与y=4碰撞到了,进行了又一次的搜寻,结果发现存在更近的点,因此结果变化了,具体的过程可以详细查看百度百科-kd树对这个例子的描述。
算法的代码实现

许多资料都是只有理论,没有实践,本人基于上面的测试例子,自己写了一个,效果还行,基本上实现了上述的过程,不过貌似Range这个变量没有表现出用途来,可以我一番设计,例子完全是上面的例子,输入数据就不放出来了,就是给定的6个坐标点。

坐标点类Point.java:

package DataMining_KDTree;

/**
 * 坐标点类
 *
 * @author lyq
 *
 */
public class Point{
    // 坐标点横坐标
    Double x;
    // 坐标点纵坐标
    Double y;

    public Point(double x, double y){
        this.x = x;
        this.y = y;
    }
    
    public Point(String x, String y) {
        this.x = (Double.parseDouble(x));
        this.y = (Double.parseDouble(y));
    }

    /**
     * 计算当前点与制定点之间的欧式距离
     *
     * @param p
     *            待计算聚类的p点
     * @return
     */
    public double ouDistance(Point p) {
        double distance = 0;

        distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y)
                * (this.y - p.y);
        distance = Math.sqrt(distance);

        return distance;
    }

    /**
     * 判断2个坐标点是否为用个坐标点
     *
     * @param p
     *            待比较坐标点
     * @return
     */
    public boolean isTheSame(Point p) {
        boolean isSamed = false;

        if (this.x == p.x && this.y == p.y) {
            isSamed = true;
        }

        return isSamed;
    }
}

空间矢量类Range.java:

package DataMining_KDTree;

/**
 * 空间矢量,表示所代表的空间范围
 *
 * @author lyq
 *
 */
public class Range {
    // 边界左边界
    double left;
    // 边界右边界
    double right;
    // 边界上边界
    double top;
    // 边界下边界
    double bottom;

    public Range() {
        this.left = -Integer.MAX_VALUE;
        this.right = Integer.MAX_VALUE;
        this.top = Integer.MAX_VALUE;
        this.bottom = -Integer.MAX_VALUE;
    }

    public Range(int left, int right, int top, int bottom) {
        this.left = left;
        this.right = right;
        this.top = top;
        this.bottom = bottom;
    }

    /**
     * 空间矢量进行并操作
     *
     * @param range
     * @return
     */
    public Range crossOperation(Range r) {
        Range range = new Range();

        // 取靠近右侧的左边界
        if (r.left > this.left) {
            range.left = r.left;
        } else {
            range.left = this.left;
        }

        // 取靠近左侧的右边界
        if (r.right < this.right) {
            range.right = r.right;
        } else {
            range.right = this.right;
        }

        // 取靠近下侧的上边界
        if (r.top < this.top) {
            range.top = r.top;
        } else {
            range.top = this.top;
        }

        // 取靠近上侧的下边界
        if (r.bottom > this.bottom) {
            range.bottom = r.bottom;
        } else {
            range.bottom = this.bottom;
        }

        return range;
    }

    /**
     * 根据坐标点分割方向确定左侧空间矢量
     *
     * @param p
     *            数据矢量
     * @param dir
     *            分割方向
     * @return
     */
    public static Range initLeftRange(Point p, int dir) {
        Range range = new Range();

        if (dir == KDTreeTool.DIRECTION_X) {
            range.right = p.x;
        } else {
            range.bottom = p.y;
        }

        return range;
    }

    /**
     * 根据坐标点分割方向确定右侧空间矢量
     *
     * @param p
     *            数据矢量
     * @param dir
     *            分割方向
     * @return
     */
    public static Range initRightRange(Point p, int dir) {
        Range range = new Range();

        if (dir == KDTreeTool.DIRECTION_X) {
            range.left = p.x;
        } else {
            range.top = p.y;
        }

        return range;
    }
}

KD树节点类TreeNode.java:

package DataMining_KDTree;

/**
 * KD树节点
 * @author lyq
 *
 */
public class TreeNode {
    //数据矢量
    Point nodeData;
    //分割平面的分割线
    int spilt;
    //空间矢量,该节点所表示的空间范围
    Range range;
    //父节点
    TreeNode parentNode;
    //位于分割超平面左侧的孩子节点
    TreeNode leftNode;
    //位于分割超平面右侧的孩子节点
    TreeNode rightNode;
    //节点是否被访问过,用于回溯时使用
    boolean isVisited;
    
    public TreeNode(){
        this.isVisited = false;
    }
}

算法封装类KDTreeTool.java:

package DataMining_KDTree;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Stack;

/**
 * KD树-k维空间关键数据检索算法工具类
 *
 * @author lyq
 *
 */
public class KDTreeTool {
    // 空间平面的方向
    public static final int DIRECTION_X = 0;
    public static final int DIRECTION_Y = 1;

    // 输入的测试数据坐标点文件
    private String filePath;
    // 原始所有数据点数据
    private ArrayList<Point> totalDatas;
    // KD树根节点
    private TreeNode rootNode;

    public KDTreeTool(String filePath) {
        this.filePath = filePath;

        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        Point p;
        totalDatas = new ArrayList<>();
        for (String[] array : dataArray) {
            p = new Point(array[0], array[1]);
            totalDatas.add(p);
        }
    }

    /**
     * 创建KD树
     *
     * @return
     */
    public TreeNode createKDTree() {
        ArrayList<Point> copyDatas;

        rootNode = new TreeNode();
        // 根据节点开始时所表示的空间时无限大的
        rootNode.range = new Range();
        copyDatas = (ArrayList<Point>) totalDatas.clone();
        recusiveConstructNode(rootNode, copyDatas);

        return rootNode;
    }

    /**
     * 递归进行KD树的构造
     *
     * @param node
     *            当前正在构造的节点
     * @param datas
     *            该节点对应的正在处理的数据
     * @return
     */
    private void recusiveConstructNode(TreeNode node, ArrayList<Point> datas) {
        int direction = 0;
        ArrayList<Point> leftSideDatas;
        ArrayList<Point> rightSideDatas;
        Point p;
        TreeNode leftNode;
        TreeNode rightNode;
        Range range;
        Range range2;

        // 如果划分的数据点集合只有1个数据,则不再划分
        if (datas.size() == 1) {
            node.nodeData = datas.get(0);
            return;
        }

        // 首先在当前的数据点集合中进行分割方向的选择
        direction = selectSplitDrc(datas);
        // 根据方向取出中位数点作为数据矢量
        p = getMiddlePoint(datas, direction);

        node.spilt = direction;
        node.nodeData = p;

        leftSideDatas = getLeftSideDatas(datas, p, direction);
        datas.removeAll(leftSideDatas);
        // 还要去掉自身
        datas.remove(p);
        rightSideDatas = datas;

        if (leftSideDatas.size() > 0) {
            leftNode = new TreeNode();
            leftNode.parentNode = node;
            range2 = Range.initLeftRange(p, direction);
            // 获取父节点的空间矢量,进行交集运算做范围拆分
            range = node.range.crossOperation(range2);
            leftNode.range = range;

            node.leftNode = leftNode;
            recusiveConstructNode(leftNode, leftSideDatas);
        }

        if (rightSideDatas.size() > 0) {
            rightNode = new TreeNode();
            rightNode.parentNode = node;
            range2 = Range.initRightRange(p, direction);
            // 获取父节点的空间矢量,进行交集运算做范围拆分
            range = node.range.crossOperation(range2);
            rightNode.range = range;

            node.rightNode = rightNode;
            recusiveConstructNode(rightNode, rightSideDatas);
        }
    }

    /**
     * 搜索出给定数据点的最近点
     *
     * @param p
     *            待比较坐标点
     */
    public Point searchNearestData(Point p) {
        // 节点距离给定数据点的距离
        TreeNode nearestNode = null;
        // 用栈记录遍历过的节点
        Stack<TreeNode> stackNodes;

        stackNodes = new Stack<>();
        findedNearestLeafNode(p, rootNode, stackNodes);

        // 取出叶子节点,作为当前找到的最近节点
        nearestNode = stackNodes.pop();
        nearestNode = dfsSearchNodes(stackNodes, p, nearestNode);

        return nearestNode.nodeData;
    }

    /**
     * 深度优先的方式进行最近点的查找
     *
     * @param stack
     *            KD树节点栈
     * @param desPoint
     *            给定的数据点
     * @param nearestNode
     *            当前找到的最近节点
     * @return
     */
    private TreeNode dfsSearchNodes(Stack<TreeNode> stack, Point desPoint,
            TreeNode nearestNode) {
        // 是否碰到父节点边界
        boolean isCollision;
        double minDis;
        double dis;
        TreeNode parentNode;

        // 如果栈内节点已经全部弹出,则遍历结束
        if (stack.isEmpty()) {
            return nearestNode;
        }

        // 获取父节点
        parentNode = stack.pop();

        minDis = desPoint.ouDistance(nearestNode.nodeData);
        dis = desPoint.ouDistance(parentNode.nodeData);

        // 如果与当前回溯到的父节点距离更短,则搜索到的节点进行更新
        if (dis < minDis) {
            minDis = dis;
            nearestNode = parentNode;
        }

        // 默认没有碰撞到
        isCollision = false;
        // 判断是否触碰到了父节点的空间分割线
        if (parentNode.spilt == DIRECTION_X) {
            if (parentNode.nodeData.x > desPoint.x - minDis
                    && parentNode.nodeData.x < desPoint.x + minDis) {
                isCollision = true;
            }
        } else {
            if (parentNode.nodeData.y > desPoint.y - minDis
                    && parentNode.nodeData.y < desPoint.y + minDis) {
                isCollision = true;
            }
        }

        // 如果触碰到父边界了,并且此节点的孩子节点还未完全遍历完,则可以继续遍历
        if (isCollision
                && (!parentNode.leftNode.isVisited || !parentNode.rightNode.isVisited)) {
            TreeNode newNode;
            // 新建当前的小局部节点栈
            Stack<TreeNode> otherStack = new Stack<>();
            // 从parentNode的树以下继续寻找
            findedNearestLeafNode(desPoint, parentNode, otherStack);
            newNode = dfsSearchNodes(otherStack, desPoint, otherStack.pop());

            dis = newNode.nodeData.ouDistance(desPoint);
            if (dis < minDis) {
                nearestNode = newNode;
            }
        }

        // 继续往上回溯
        nearestNode = dfsSearchNodes(stack, desPoint, nearestNode);

        return nearestNode;
    }

    /**
     * 找到与所给定节点的最近的叶子节点
     *
     * @param p
     *            待比较节点
     * @param node
     *            当前搜索到的节点
     * @param stack
     *            遍历过的节点栈
     */
    private void findedNearestLeafNode(Point p, TreeNode node,
            Stack<TreeNode> stack) {
        // 分割方向
        int splitDic;

        // 将遍历过的节点加入栈中
        stack.push(node);
        // 标记为访问过
        node.isVisited = true;
        // 如果此节点没有左右孩子节点说明已经是叶子节点了
        if (node.leftNode == null && node.rightNode == null) {
            return;
        }

        splitDic = node.spilt;
        // 选择一个符合分割范围的节点继续递归搜寻
        if ((splitDic == DIRECTION_X && p.x < node.nodeData.x)
                || (splitDic == DIRECTION_Y && p.y < node.nodeData.y)) {
            if (!node.leftNode.isVisited) {
                findedNearestLeafNode(p, node.leftNode, stack);
            } else {
                // 如果左孩子节点已经访问过,则访问另一边
                findedNearestLeafNode(p, node.rightNode, stack);
            }
        } else if ((splitDic == DIRECTION_X && p.x > node.nodeData.x)
                || (splitDic == DIRECTION_Y && p.y > node.nodeData.y)) {
            if (!node.rightNode.isVisited) {
                findedNearestLeafNode(p, node.rightNode, stack);
            } else {
                // 如果右孩子节点已经访问过,则访问另一边
                findedNearestLeafNode(p, node.leftNode, stack);
            }
        }
    }

    /**
     * 根据给定的数据点通过计算反差选择的分割点
     *
     * @param datas
     *            部分的集合点集合
     * @return
     */
    private int selectSplitDrc(ArrayList<Point> datas) {
        int direction = 0;
        double avgX = 0;
        double avgY = 0;
        double varianceX = 0;
        double varianceY = 0;

        for (Point p : datas) {
            avgX += p.x;
            avgY += p.y;
        }

        avgX /= datas.size();
        avgY /= datas.size();

        for (Point p : datas) {
            varianceX += (p.x - avgX) * (p.x - avgX);
            varianceY += (p.y - avgY) * (p.y - avgY);
        }

        // 求最后的方差
        varianceX /= datas.size();
        varianceY /= datas.size();

        // 通过比较方差的大小决定分割方向,选择波动较大的进行划分
        direction = varianceX > varianceY ? DIRECTION_X : DIRECTION_Y;

        return direction;
    }

    /**
     * 根据坐标点方位进行排序,选出中间点的坐标数据
     *
     * @param datas
     *            数据点集合
     * @param dir
     *            排序的坐标方向
     */
    private Point getMiddlePoint(ArrayList<Point> datas, int dir) {
        int index = 0;
        Point middlePoint;

        index = datas.size() / 2;
        if (dir == DIRECTION_X) {
            Collections.sort(datas, new Comparator<Point>() {

                @Override
                public int compare(Point o1, Point o2) {
                    // TODO Auto-generated method stub
                    return o1.x.compareTo(o2.x);
                }
            });
        } else {
            Collections.sort(datas, new Comparator<Point>() {

                @Override
                public int compare(Point o1, Point o2) {
                    // TODO Auto-generated method stub
                    return o1.y.compareTo(o2.y);
                }
            });
        }

        // 取出中位数
        middlePoint = datas.get(index);

        return middlePoint;
    }

    /**
     * 根据方向得到原部分节点集合左侧的数据点
     *
     * @param datas
     *            原始数据点集合
     * @param nodeData
     *            数据矢量
     * @param dir
     *            分割方向
     * @return
     */
    private ArrayList<Point> getLeftSideDatas(ArrayList<Point> datas,
            Point nodeData, int dir) {
        ArrayList<Point> leftSideDatas = new ArrayList<>();

        for (Point p : datas) {
            if (dir == DIRECTION_X && p.x < nodeData.x) {
                leftSideDatas.add(p);
            } else if (dir == DIRECTION_Y && p.y < nodeData.y) {
                leftSideDatas.add(p);
            }
        }

        return leftSideDatas;
    }
}

场景测试类Client.java:

package DataMining_KDTree;

import java.text.MessageFormat;

/**
 * KD树算法测试类
 *
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args) {
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
        Point queryNode;
        Point searchedNode;
        KDTreeTool tool = new KDTreeTool(filePath);

        // 进行KD树的构建
        tool.createKDTree();

        // 通过KD树进行数据点的最近点查询
        queryNode = new Point(2.1, 3.1);
        searchedNode = tool.searchNearestData(queryNode);
        System.out.println(MessageFormat.format(
                "距离查询点({0}, {1})最近的坐标点为({2}, {3})", queryNode.x, queryNode.y,
                searchedNode.x, searchedNode.y));
        
        //重新构造KD树,去除之前的访问记录
        tool.createKDTree();
        queryNode = new Point(2, 4.5);
        searchedNode = tool.searchNearestData(queryNode);
        System.out.println(MessageFormat.format(
                "距离查询点({0}, {1})最近的坐标点为({2}, {3})", queryNode.x, queryNode.y,
                searchedNode.x, searchedNode.y));
    }
}

算法的输出结果:

距离查询点(2.1, 3.1)最近的坐标点为(2, 3)
距离查询点(2, 4.5)最近的坐标点为(2, 3)

算法的输出结果与期望值还是一致的。

目前KD-Tree的使用场景是SIFT算法做特征点匹配的时候使用到了,特征点匹配指的是通过距离函数在高维矢量空间进行相似性检索。


参考文献:百度百科 http://baike.baidu.com

我的数据挖掘算法库:https://github.com/linyiqun/DataMiningAlgorithm

我的算法库:https://github.com/linyiqun/lyq-algorithms-lib
作者:Androidlushangderen 发表于2015/4/10 21:39:58 原文链接
阅读:581 评论:0 查看评论
随机森林和GBDT的学习
2015年3月30日 20:28

参考文献:http://www.zilhua.com/629.html
http://www.tuicool.com/articles/JvMJve
http://blog.sina.com.cn/s/blog_573085f70101ivj5.html
我的数据挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm
我的算法库:https://github.com/linyiqun/lyq-algorithms-lib
前言

提到森林,就不得不联想到树,因为正是一棵棵的树构成了庞大的森林,而在本篇文章中的”树“,指的就是Decision Tree-----决策树。随机森林就是一棵棵决策树的组合,也就是说随机森林=boosting+决策树,这样就好理解多了吧,再来说说GBDT,GBDT全称是Gradient Boosting Decision Tree,就是梯度提升决策树,与随机森林的思想很像,但是比随机森林稍稍的难一点,当然效果相对于前者而言,也会好许多。由于本人才疏学浅,本文只会详细讲述Random Forest算法的部分,至于GBDT我会给出一小段篇幅做介绍引导,读者能够如果有兴趣的话,可以自行学习。
随机森林算法
决策树

要想理解随机森林算法,就不得不提决策树,什么是决策树,如何构造决策树,简单的回答就是数据的分类以树形结构的方式所展现,每个子分支都代表着不同的分类情况,比如下面的这个图所示:


当然决策树的每个节点分支不一定是三元的,可以有2个或者更多。分类的终止条件为,没有可以再拿来分类的属性条件或者说分到的数据的分类已经完全一致的情况。决策树分类的标准和依据是什么呢,下面介绍主要的2种划分标准。

1、信息增益。这是ID3算法系列所用的方法,C4.5算法在这上面做了少许的改进,用信息增益率来作为划分的标准,可以稍稍减小数据过于拟合的缺点。

2、基尼指数。这是CART分类回归树所用的方法。也是类似于信息增益的一个定义,最终都是根据数据划分后的纯度来做比较,这个纯度,你也可以理解为熵的变化,当然我们所希望的情况就是分类后数据的纯度更纯,也就是说,前后划分分类之后的熵的差越大越好。不过CART算法比较好的一点是树构造好后,还有剪枝的操作,剪枝操作的种类就比较多了,我之前在实现CART算法时用的是代价复杂度的剪枝方法。

这2种决策算法在我之前的博文中已经有所提及,不理解的可以点击我的ID3系列算法介绍和我的CART分类回归树算法。
Boosting

原本不打算将Boosting单独拉出来讲的,后来想想还是有很多内容可谈的。Boosting本身不是一种算法,他更应该说是一种思想,首先对数据构造n个弱分类器,最后通过组合n个弱分类器对于某个数据的判断结果作为最终的分类结果,就变成了一个强分类器,效果自然要好过单一分类器的分类效果。他可以理解为是一种提升算法,举一个比较常见的Boosting思想的算法AdaBoost,他在训练每个弱分类器的时候,提高了对于之前分错数据的权重值,最终能够组成一批相互互补的分类器集合。详细可以查看我的AdaBoost算法学习。

OK,2个重要的概念都已经介绍完毕,终于可以介绍主角Random Forest的出现了,正如前言中所说Random Forest=Decision Trees + Boosting,这里的每个弱分类器就是一个决策树了,不过这里的决策树都是二叉树,就是只有2个孩子分支,自然我立刻想到的做法就是用CART算法来构建,因为人家算法就是二元分支的。随机算法,随机算法,当然重在随机2个字上面,下面是2个方面体现了随机性。对于数据样本的采集量,比如我数据由100条,我可以每次随机取出其中的20条,作为我构造决策树的源数据,采取又放回的方式,并不是第一次抽到的数据,第二次不能重复,第二随机性体现在对于数据属性的随机采集,比如一行数据总共有10个特征属性,我每次随机采用其中的4个。正是由于对于数据的行压缩和列压缩,使得数据的随机性得以保证,就很难出现之前的数据过拟合的问题了,也就不需要在决策树最后进行剪枝操作了,这个是与一般的CART算法所不同的,尤其需要注意。

下面是随机森林算法的构造过程:

1、通过给定的原始数据,选出其中部分数据进行决策树的构造,数据选取是”有放回“的过程,我在这里用的是CART分类回归树。

2、随机森林构造完成之后,给定一组测试数据,使得每个分类器对其结果分类进行评估,最后取评估结果的众数最为最终结果。

算法非常的好理解,在Boosting算法和决策树之上做了一个集成,下面给出算法的实现,很多资料上只有大篇幅的理论,我还是希望能带给大家一点实在的东西。
随机算法的实现

输入数据(之前决策树算法时用过的)input.txt:

Rid Age Income Student CreditRating BuysComputer
1 Youth High No Fair No
2 Youth High No Excellent No
3 MiddleAged High No Fair Yes
4 Senior Medium No Fair Yes
5 Senior Low Yes Fair Yes
6 Senior Low Yes Excellent No
7 MiddleAged Low Yes Excellent Yes
8 Youth Medium No Fair No
9 Youth Low Yes Fair Yes
10 Senior Medium Yes Fair Yes
11 Youth Medium Yes Excellent Yes
12 MiddleAged Medium No Excellent Yes
13 MiddleAged High Yes Fair Yes
14 Senior Medium No Excellent No

树节点类TreeNode.java:

package DataMining_RandomForest;

import java.util.ArrayList;

/**
 * 回归分类树节点
 *
 * @author lyq
 *
 */
public class TreeNode {
    // 节点属性名字
    private String attrName;
    // 节点索引标号
    private int nodeIndex;
    //包含的叶子节点数
    private int leafNum;
    // 节点误差率
    private double alpha;
    // 父亲分类属性值
    private String parentAttrValue;
    // 孩子节点
    private TreeNode[] childAttrNode;
    // 数据记录索引
    private ArrayList<String> dataIndex;

    public String getAttrName() {
        return attrName;
    }

    public void setAttrName(String attrName) {
        this.attrName = attrName;
    }

    public int getNodeIndex() {
        return nodeIndex;
    }

    public void setNodeIndex(int nodeIndex) {
        this.nodeIndex = nodeIndex;
    }

    public double getAlpha() {
        return alpha;
    }

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    public String getParentAttrValue() {
        return parentAttrValue;
    }

    public void setParentAttrValue(String parentAttrValue) {
        this.parentAttrValue = parentAttrValue;
    }

    public TreeNode[] getChildAttrNode() {
        return childAttrNode;
    }

    public void setChildAttrNode(TreeNode[] childAttrNode) {
        this.childAttrNode = childAttrNode;
    }

    public ArrayList<String> getDataIndex() {
        return dataIndex;
    }

    public void setDataIndex(ArrayList<String> dataIndex) {
        this.dataIndex = dataIndex;
    }

    public int getLeafNum() {
        return leafNum;
    }

    public void setLeafNum(int leafNum) {
        this.leafNum = leafNum;
    }
    
    
    
}

决策树类DecisionTree.java:

package DataMining_RandomForest;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * 决策树
 *
 * @author lyq
 *
 */
public class DecisionTree {
    // 树的根节点
    TreeNode rootNode;
    // 数据的属性列名称
    String[] featureNames;
    // 这棵树所包含的数据
    ArrayList<String[]> datas;
    // 决策树构造的的工具类
    CARTTool tool;

    public DecisionTree(ArrayList<String[]> datas) {
        this.datas = datas;
        this.featureNames = datas.get(0);

        tool = new CARTTool(datas);
        // 通过CART工具类进行决策树的构建,并返回树的根节点
        rootNode = tool.startBuildingTree();
    }

    /**
     * 根据给定的数据特征描述进行类别的判断
     *
     * @param features
     * @return
     */
    public String decideClassType(String features) {
        String classType = "";
        // 查询属性组
        String[] queryFeatures;
        // 在本决策树中对应的查询的属性值描述
        ArrayList<String[]> featureStrs;

        featureStrs = new ArrayList<>();
        queryFeatures = features.split(",");

        String[] array;
        for (String name : featureNames) {
            for (String featureValue : queryFeatures) {
                array = featureValue.split("=");
                // 将对应的属性值加入到列表中
                if (array[0].equals(name)) {
                    featureStrs.add(array);
                }
            }
        }

        // 开始从根据节点往下递归搜索
        classType = recusiveSearchClassType(rootNode, featureStrs);

        return classType;
    }

    /**
     * 递归搜索树,查询属性的分类类别
     *
     * @param node
     *            当前搜索到的节点
     * @param remainFeatures
     *            剩余未判断的属性
     * @return
     */
    private String recusiveSearchClassType(TreeNode node,
            ArrayList<String[]> remainFeatures) {
        String classType = null;

        // 如果节点包含了数据的id索引,说明已经分类到底了
        if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {
            classType = judgeClassType(node.getDataIndex());

            return classType;
        }

        // 取出剩余属性中的一个匹配属性作为当前的判断属性名称
        String[] currentFeature = null;
        for (String[] featureValue : remainFeatures) {
            if (node.getAttrName().equals(featureValue[0])) {
                currentFeature = featureValue;
                break;
            }
        }

        for (TreeNode childNode : node.getChildAttrNode()) {
            // 寻找子节点中属于此属性值的分支
            if (childNode.getParentAttrValue().equals(currentFeature[1])) {
                remainFeatures.remove(currentFeature);
                classType = recusiveSearchClassType(childNode, remainFeatures);

                // 如果找到了分类结果,则直接挑出循环
                break;
            }else{
                //进行第二种情况的判断加上!符号的情况
                String value = childNode.getParentAttrValue();
                
                if(value.charAt(0) == '!'){
                    //去掉第一个!字符
                    value = value.substring(1, value.length());
                    
                    if(!value.equals(currentFeature[1])){
                        remainFeatures.remove(currentFeature);
                        classType = recusiveSearchClassType(childNode, remainFeatures);

                        break;
                    }
                }
            }
        }

        return classType;
    }

    /**
     * 根据得到的数据行分类进行类别的决策
     *
     * @param dataIndex
     *            根据分类的数据索引号
     * @return
     */
    public String judgeClassType(ArrayList<String> dataIndex) {
        // 结果类型值
        String resultClassType = "";
        String classType = "";
        int count = 0;
        int temp = 0;
        Map<String, Integer> type2Num = new HashMap<String, Integer>();

        for (String index : dataIndex) {
            temp = Integer.parseInt(index);
            // 取最后一列的决策类别数据
            classType = datas.get(temp)[featureNames.length - 1];

            if (type2Num.containsKey(classType)) {
                // 如果类别已经存在,则使其计数加1
                count = type2Num.get(classType);
                count++;
            } else {
                count = 1;
            }

            type2Num.put(classType, count);
        }

        // 选出其中类别支持计数最多的一个类别值
        count = -1;
        for (Map.Entry entry : type2Num.entrySet()) {
            if ((int) entry.getValue() > count) {
                count = (int) entry.getValue();
                resultClassType = (String) entry.getKey();
            }
        }

        return resultClassType;
    }
}

随机森林算法工具类RandomForestTool.java:

package DataMining_RandomForest;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

/**
 * 随机森林算法工具类
 *
 * @author lyq
 *
 */
public class RandomForestTool {
    // 测试数据文件地址
    private String filePath;
    // 决策树的样本占总数的占比率
    private double sampleNumRatio;
    // 样本数据的采集特征数量占总特征的比例
    private double featureNumRatio;
    // 决策树的采样样本数
    private int sampleNum;
    // 样本数据的采集采样特征数
    private int featureNum;
    // 随机森林中的决策树的数目,等于总的数据数/用于构造每棵树的数据的数量
    private int treeNum;
    // 随机数产生器
    private Random random;
    // 样本数据列属性名称行
    private String[] featureNames;
    // 原始的总的数据
    private ArrayList<String[]> totalDatas;
    // 决策树森林
    private ArrayList<DecisionTree> decisionForest;

    public RandomForestTool(String filePath, double sampleNumRatio,
            double featureNumRatio) {
        this.filePath = filePath;
        this.sampleNumRatio = sampleNumRatio;
        this.featureNumRatio = featureNumRatio;

        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        totalDatas = dataArray;
        featureNames = totalDatas.get(0);
        sampleNum = (int) ((totalDatas.size() - 1) * sampleNumRatio);
        //算属性数量的时候需要去掉id属性和决策属性,用条件属性计算
        featureNum = (int) ((featureNames.length -2) * featureNumRatio);
        // 算数量的时候需要去掉首行属性名称行
        treeNum = (totalDatas.size() - 1) / sampleNum;
    }

    /**
     * 产生决策树
     */
    private DecisionTree produceDecisionTree() {
        int temp = 0;
        DecisionTree tree;
        String[] tempData;
        //采样数据的随机行号组
        ArrayList<Integer> sampleRandomNum;
        //采样属性特征的随机列号组
        ArrayList<Integer> featureRandomNum;
        ArrayList<String[]> datas;
        
        sampleRandomNum = new ArrayList<>();
        featureRandomNum = new ArrayList<>();
        datas = new ArrayList<>();
        
        for(int i=0; i<sampleNum;){
            temp = random.nextInt(totalDatas.size());
            
            //如果是行首属性名称行,则跳过
            if(temp == 0){
                continue;
            }
            
            if(!sampleRandomNum.contains(temp)){
                sampleRandomNum.add(temp);
                i++;
            }
        }
        
        for(int i=0; i<featureNum;){
            temp = random.nextInt(featureNames.length);
            
            //如果是第一列的数据id号或者是决策属性列,则跳过
            if(temp == 0 || temp == featureNames.length-1){
                continue;
            }
            
            if(!featureRandomNum.contains(temp)){
                featureRandomNum.add(temp);
                i++;
            }
        }

        String[] singleRecord;
        String[] headCulumn = null;
        // 获取随机数据行
        for(int dataIndex: sampleRandomNum){
            singleRecord = totalDatas.get(dataIndex);
            
            //每行的列数=所选的特征数+id号
            tempData = new String[featureNum+2];
            headCulumn = new String[featureNum+2];
            
            for(int i=0,k=1; i<featureRandomNum.size(); i++,k++){
                temp = featureRandomNum.get(i);
                
                headCulumn[k] = featureNames[temp];
                tempData[k] = singleRecord[temp];
            }
            
            //加上id列的信息
            headCulumn[0] = featureNames[0];
            //加上决策分类列的信息
            headCulumn[featureNum+1] = featureNames[featureNames.length-1];
            tempData[featureNum+1] = singleRecord[featureNames.length-1];
            
            //加入此行数据
            datas.add(tempData);
        }
        
        //加入行首列出现名称
        datas.add(0, headCulumn);
        //对筛选出的数据重新做id分配
        temp = 0;
        for(String[] array: datas){
            //从第2行开始赋值
            if(temp > 0){
                array[0] = temp + "";
            }
            
            temp++;
        }
        
        tree = new DecisionTree(datas);
        
        return tree;
    }

    /**
     * 构造随机森林
     */
    public void constructRandomTree() {
        DecisionTree tree;
        random = new Random();
        decisionForest = new ArrayList<>();

        System.out.println("下面是随机森林中的决策树:");
        // 构造决策树加入森林中
        for (int i = 0; i < treeNum; i++) {
            System.out.println("\n决策树" + (i+1));
            tree = produceDecisionTree();
            decisionForest.add(tree);
        }
    }

    /**
     * 根据给定的属性条件进行类别的决策
     *
     * @param features
     *            给定的已知的属性描述
     * @return
     */
    public String judgeClassType(String features) {
        // 结果类型值
        String resultClassType = "";
        String classType = "";
        int count = 0;
        Map<String, Integer> type2Num = new HashMap<String, Integer>();

        for (DecisionTree tree : decisionForest) {
            classType = tree.decideClassType(features);
            if (type2Num.containsKey(classType)) {
                // 如果类别已经存在,则使其计数加1
                count = type2Num.get(classType);
                count++;
            } else {
                count = 1;
            }

            type2Num.put(classType, count);
        }

        // 选出其中类别支持计数最多的一个类别值
        count = -1;
        for (Map.Entry entry : type2Num.entrySet()) {
            if ((int) entry.getValue() > count) {
                count = (int) entry.getValue();
                resultClassType = (String) entry.getKey();
            }
        }

        return resultClassType;
    }
}

CART算法工具类CARTTool.java:

package DataMining_RandomForest;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Queue;

/**
 * CART分类回归树算法工具类
 *
 * @author lyq
 *
 */
public class CARTTool {
    // 类标号的值类型
    private final String YES = "Yes";
    private final String NO = "No";

    // 所有属性的类型总数,在这里就是data源数据的列数
    private int attrNum;
    private String filePath;
    // 初始源数据,用一个二维字符数组存放模仿表格数据
    private String[][] data;
    // 数据的属性行的名字
    private String[] attrNames;
    // 每个属性的值所有类型
    private HashMap<String, ArrayList<String>> attrValue;

    public CARTTool(ArrayList<String[]> dataArray) {
        attrValue = new HashMap<>();
        readData(dataArray);
    }

    /**
     * 根据随机选取的样本数据进行初始化
     * @param dataArray
     * 已经读入的样本数据
     */
    public void readData(ArrayList<String[]> dataArray) {
        data = new String[dataArray.size()][];
        dataArray.toArray(data);
        attrNum = data[0].length;
        attrNames = data[0];
    }

    /**
     * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
     */
    public void initAttrValue() {
        ArrayList<String> tempValues;

        // 按照列的方式,从左往右找
        for (int j = 1; j < attrNum; j++) {
            // 从一列中的上往下开始寻找值
            tempValues = new ArrayList<>();
            for (int i = 1; i < data.length; i++) {
                if (!tempValues.contains(data[i][j])) {
                    // 如果这个属性的值没有添加过,则添加
                    tempValues.add(data[i][j]);
                }
            }

            // 一列属性的值已经遍历完毕,复制到map属性表中
            attrValue.put(data[0][j], tempValues);
        }
    }

    /**
     * 计算机基尼指数
     *
     * @param remainData
     *            剩余数据
     * @param attrName
     *            属性名称
     * @param value
     *            属性值
     * @param beLongValue
     *            分类是否属于此属性值
     * @return
     */
    public double computeGini(String[][] remainData, String attrName,
            String value, boolean beLongValue) {
        // 实例总数
        int total = 0;
        // 正实例数
        int posNum = 0;
        // 负实例数
        int negNum = 0;
        // 基尼指数
        double gini = 0;

        // 还是按列从左往右遍历属性
        for (int j = 1; j < attrNames.length; j++) {
            // 找到了指定的属性
            if (attrName.equals(attrNames[j])) {
                for (int i = 1; i < remainData.length; i++) {
                    // 统计正负实例按照属于和不属于值类型进行划分
                    if ((beLongValue && remainData[i][j].equals(value))
                            || (!beLongValue && !remainData[i][j].equals(value))) {
                        if (remainData[i][attrNames.length - 1].equals(YES)) {
                            // 判断此行数据是否为正实例
                            posNum++;
                        } else {
                            negNum++;
                        }
                    }
                }
            }
        }

        total = posNum + negNum;
        double posProbobly = (double) posNum / total;
        double negProbobly = (double) negNum / total;
        gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly;

        // 返回计算基尼指数
        return gini;
    }

    /**
     * 计算属性划分的最小基尼指数,返回最小的属性值划分和最小的基尼指数,保存在一个数组中
     *
     * @param remainData
     *            剩余谁
     * @param attrName
     *            属性名称
     * @return
     */
    public String[] computeAttrGini(String[][] remainData, String attrName) {
        String[] str = new String[2];
        // 最终该属性的划分类型值
        String spiltValue = "";
        // 临时变量
        int tempNum = 0;
        // 保存属性的值划分时的最小的基尼指数
        double minGini = Integer.MAX_VALUE;
        ArrayList<String> valueTypes = attrValue.get(attrName);
        // 属于此属性值的实例数
        HashMap<String, Integer> belongNum = new HashMap<>();

        for (String string : valueTypes) {
            // 重新计数的时候,数字归0
            tempNum = 0;
            // 按列从左往右遍历属性
            for (int j = 1; j < attrNames.length; j++) {
                // 找到了指定的属性
                if (attrName.equals(attrNames[j])) {
                    for (int i = 1; i < remainData.length; i++) {
                        // 统计正负实例按照属于和不属于值类型进行划分
                        if (remainData[i][j].equals(string)) {
                            tempNum++;
                        }
                    }
                }
            }

            belongNum.put(string, tempNum);
        }

        double tempGini = 0;
        double posProbably = 1.0;
        double negProbably = 1.0;
        for (String string : valueTypes) {
            tempGini = 0;

            posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1);
            negProbably = 1 - posProbably;

            tempGini += posProbably
                    * computeGini(remainData, attrName, string, true);
            tempGini += negProbably
                    * computeGini(remainData, attrName, string, false);

            if (tempGini < minGini) {
                minGini = tempGini;
                spiltValue = string;
            }
        }

        str[0] = spiltValue;
        str[1] = minGini + "";

        return str;
    }

    public void buildDecisionTree(TreeNode node, String parentAttrValue,
            String[][] remainData, ArrayList<String> remainAttr,
            boolean beLongParentValue) {
        // 属性划分值
        String valueType = "";
        // 划分属性名称
        String spiltAttrName = "";
        double minGini = Integer.MAX_VALUE;
        double tempGini = 0;
        // 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值
        String[] giniArray;

        if (beLongParentValue) {
            node.setParentAttrValue(parentAttrValue);
        } else {
            node.setParentAttrValue("!" + parentAttrValue);
        }

        if (remainAttr.size() == 0) {
            if (remainData.length > 1) {
                ArrayList<String> indexArray = new ArrayList<>();
                for (int i = 1; i < remainData.length; i++) {
                    indexArray.add(remainData[i][0]);
                }
                node.setDataIndex(indexArray);
            }
        //    System.out.println("attr remain null");
            return;
        }

        for (String str : remainAttr) {
            giniArray = computeAttrGini(remainData, str);
            tempGini = Double.parseDouble(giniArray[1]);

            if (tempGini < minGini) {
                spiltAttrName = str;
                minGini = tempGini;
                valueType = giniArray[0];
            }
        }
        // 移除划分属性
        remainAttr.remove(spiltAttrName);
        node.setAttrName(spiltAttrName);

        // 孩子节点,分类回归树中,每次二元划分,分出2个孩子节点
        TreeNode[] childNode = new TreeNode[2];
        String[][] rData;

        boolean[] bArray = new boolean[] { true, false };
        for (int i = 0; i < bArray.length; i++) {
            // 二元划分属于属性值的划分
            rData = removeData(remainData, spiltAttrName, valueType, bArray[i]);

            boolean sameClass = true;
            ArrayList<String> indexArray = new ArrayList<>();
            for (int k = 1; k < rData.length; k++) {
                indexArray.add(rData[k][0]);
                // 判断是否为同一类的
                if (!rData[k][attrNames.length - 1]
                        .equals(rData[1][attrNames.length - 1])) {
                    // 只要有1个不相等,就不是同类型的
                    sameClass = false;
                    break;
                }
            }

            childNode[i] = new TreeNode();
            if (!sameClass) {
                // 创建新的对象属性,对象的同个引用会出错
                ArrayList<String> rAttr = new ArrayList<>();
                for (String str : remainAttr) {
                    rAttr.add(str);
                }
                buildDecisionTree(childNode[i], valueType, rData, rAttr,
                        bArray[i]);
            } else {
                String pAtr = (bArray[i] ? valueType : "!" + valueType);
                childNode[i].setParentAttrValue(pAtr);
                childNode[i].setDataIndex(indexArray);
            }
        }

        node.setChildAttrNode(childNode);
    }

    /**
     * 属性划分完毕,进行数据的移除
     *
     * @param srcData
     *            源数据
     * @param attrName
     *            划分的属性名称
     * @param valueType
     *            属性的值类型
     * @parame beLongValue 分类是否属于此值类型
     */
    private String[][] removeData(String[][] srcData, String attrName,
            String valueType, boolean beLongValue) {
        String[][] desDataArray;
        ArrayList<String[]> desData = new ArrayList<>();
        // 待删除数据
        ArrayList<String[]> selectData = new ArrayList<>();
        selectData.add(attrNames);

        // 数组数据转化到列表中,方便移除
        for (int i = 0; i < srcData.length; i++) {
            desData.add(srcData[i]);
        }

        // 还是从左往右一列列的查找
        for (int j = 1; j < attrNames.length; j++) {
            if (attrNames[j].equals(attrName)) {
                for (int i = 1; i < desData.size(); i++) {
                    if (desData.get(i)[j].equals(valueType)) {
                        // 如果匹配这个数据,则移除其他的数据
                        selectData.add(desData.get(i));
                    }
                }
            }
        }

        if (beLongValue) {
            desDataArray = new String[selectData.size()][];
            selectData.toArray(desDataArray);
        } else {
            // 属性名称行不移除
            selectData.remove(attrNames);
            // 如果是划分不属于此类型的数据时,进行移除
            desData.removeAll(selectData);
            desDataArray = new String[desData.size()][];
            desData.toArray(desDataArray);
        }

        return desDataArray;
    }

    /**
     * 构造分类回归树,并返回根节点
     * @return
     */
    public TreeNode startBuildingTree() {
        initAttrValue();

        ArrayList<String> remainAttr = new ArrayList<>();
        // 添加属性,除了最后一个类标号属性
        for (int i = 1; i < attrNames.length - 1; i++) {
            remainAttr.add(attrNames[i]);
        }

        TreeNode rootNode = new TreeNode();
        buildDecisionTree(rootNode, "", data, remainAttr, false);
        setIndexAndAlpah(rootNode, 0, false);
        showDecisionTree(rootNode, 1);
        
        return rootNode;
    }

    /**
     * 显示决策树
     *
     * @param node
     *            待显示的节点
     * @param blankNum
     *            行空格符,用于显示树型结构
     */
    private void showDecisionTree(TreeNode node, int blankNum) {
        System.out.println();
        for (int i = 0; i < blankNum; i++) {
            System.out.print("    ");
        }
        System.out.print("--");
        // 显示分类的属性值
        if (node.getParentAttrValue() != null
                && node.getParentAttrValue().length() > 0) {
            System.out.print(node.getParentAttrValue());
        } else {
            System.out.print("--");
        }
        System.out.print("--");

        if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {
            String i = node.getDataIndex().get(0);
            System.out.print("【" + node.getNodeIndex() + "】类别:"
                    + data[Integer.parseInt(i)][attrNames.length - 1]);
            System.out.print("[");
            for (String index : node.getDataIndex()) {
                System.out.print(index + ", ");
            }
            System.out.print("]");
        } else {
            // 递归显示子节点
            System.out.print("【" + node.getNodeIndex() + ":"
                    + node.getAttrName() + "】");
            if (node.getChildAttrNode() != null) {
                for (TreeNode childNode : node.getChildAttrNode()) {
                    showDecisionTree(childNode, 2 * blankNum);
                }
            } else {
                System.out.print("【  Child Null】");
            }
        }
    }

    /**
     * 为节点设置序列号,并计算每个节点的误差率,用于后面剪枝
     *
     * @param node
     *            开始的时候传入的是根节点
     * @param index
     *            开始的索引号,从1开始
     * @param ifCutNode
     *            是否需要剪枝
     */
    private void setIndexAndAlpah(TreeNode node, int index, boolean ifCutNode) {
        TreeNode tempNode;
        // 最小误差代价节点,即将被剪枝的节点
        TreeNode minAlphaNode = null;
        double minAlpah = Integer.MAX_VALUE;
        Queue<TreeNode> nodeQueue = new LinkedList<TreeNode>();

        nodeQueue.add(node);
        while (nodeQueue.size() > 0) {
            index++;
            // 从队列头部获取首个节点
            tempNode = nodeQueue.poll();
            tempNode.setNodeIndex(index);
            if (tempNode.getChildAttrNode() != null) {
                for (TreeNode childNode : tempNode.getChildAttrNode()) {
                    nodeQueue.add(childNode);
                }
                computeAlpha(tempNode);
                if (tempNode.getAlpha() < minAlpah) {
                    minAlphaNode = tempNode;
                    minAlpah = tempNode.getAlpha();
                } else if (tempNode.getAlpha() == minAlpah) {
                    // 如果误差代价值一样,比较包含的叶子节点个数,剪枝有多叶子节点数的节点
                    if (tempNode.getLeafNum() > minAlphaNode.getLeafNum()) {
                        minAlphaNode = tempNode;
                    }
                }
            }
        }

        if (ifCutNode) {
            // 进行树的剪枝,让其左右孩子节点为null
            minAlphaNode.setChildAttrNode(null);
        }
    }

    /**
     * 为非叶子节点计算误差代价,这里的后剪枝法用的是CCP代价复杂度剪枝
     *
     * @param node
     *            待计算的非叶子节点
     */
    private void computeAlpha(TreeNode node) {
        double rt = 0;
        double Rt = 0;
        double alpha = 0;
        // 当前节点的数据总数
        int sumNum = 0;
        // 最少的偏差数
        int minNum = 0;

        ArrayList<String> dataIndex;
        ArrayList<TreeNode> leafNodes = new ArrayList<>();

        addLeafNode(node, leafNodes);
        node.setLeafNum(leafNodes.size());
        for (TreeNode attrNode : leafNodes) {
            dataIndex = attrNode.getDataIndex();

            int num = 0;
            sumNum += dataIndex.size();
            for (String s : dataIndex) {
                // 统计分类数据中的正负实例数
                if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) {
                    num++;
                }
            }
            minNum += num;

            // 取小数量的值部分
            if (1.0 * num / dataIndex.size() > 0.5) {
                num = dataIndex.size() - num;
            }

            rt += (1.0 * num / (data.length - 1));
        }
        
        //同样取出少偏差的那部分
        if (1.0 * minNum / sumNum > 0.5) {
            minNum = sumNum - minNum;
        }

        Rt = 1.0 * minNum / (data.length - 1);
        alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1);
        node.setAlpha(alpha);
    }

    /**
     * 筛选出节点所包含的叶子节点数
     *
     * @param node
     *            待筛选节点
     * @param leafNode
     *            叶子节点列表容器
     */
    private void addLeafNode(TreeNode node, ArrayList<TreeNode> leafNode) {
        ArrayList<String> dataIndex;

        if (node.getChildAttrNode() != null) {
            for (TreeNode childNode : node.getChildAttrNode()) {
                dataIndex = childNode.getDataIndex();
                if (dataIndex != null && dataIndex.size() > 0) {
                    // 说明此节点为叶子节点
                    leafNode.add(childNode);
                } else {
                    // 如果还是非叶子节点则继续递归调用
                    addLeafNode(childNode, leafNode);
                }
            }
        }
    }

}

测试类Client.java:

package DataMining_RandomForest;

import java.text.MessageFormat;

/**
 * 随机森林算法测试场景
 *
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args) {
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
        String queryStr = "Age=Youth,Income=Low,Student=No,CreditRating=Fair";
        String resultClassType = "";
        // 决策树的样本占总数的占比率
        double sampleNumRatio = 0.4;
        // 样本数据的采集特征数量占总特征的比例
        double featureNumRatio = 0.5;

        RandomForestTool tool = new RandomForestTool(filePath, sampleNumRatio,
                featureNumRatio);
        tool.constructRandomTree();

        resultClassType = tool.judgeClassType(queryStr);

        System.out.println();
        System.out
                .println(MessageFormat.format(
                        "查询属性描述{0},预测的分类结果为BuysCompute:{1}", queryStr,
                        resultClassType));
    }
}


算法的输出

下面是随机森林中的决策树:

决策树1

    --!--【1:Income】
        --Medium--【2】类别:Yes[1, 2, ]
        --!Medium--【3:Student】
                --No--【4】类别:No[3, 5, ]
                --!No--【5】类别:Yes[4, ]
决策树2

    --!--【1:Student】
        --No--【2】类别:No[1, 3, ]
        --!No--【3】类别:Yes[2, 4, 5, ]
查询属性描述Age=Youth,Income=Low,Student=No,CreditRating=Fair,预测的分类结果为BuysCompute:No

输出的结果决策树建议从左往右看,从上往下,【】符号表示一个节点,---XX---表示属性值的划分,你就应该能看懂这棵树了,在console上想展示漂亮的树形效果的确很难。。。这里说一个算法的重大不足,数据太少,导致选择的样本数据不足,所选属性太少,,构造的决策树数量过少,自然分类的准确率不见得会有多准,博友只要能领会代码中所表达的算法的思想即可。
GBDT

下面来说说随机森林的兄弟算法GBDT,梯度提升决策树,他有很多的决策树,他也有组合的思想,但是他不是随机森林算法2,GBDT的关键在于Gradient Boosting,梯度提升。这个词语理解起来就不容易了。学术的描述,每一次建立模型是在之前建立模型的损失函数的梯度下降方向。GBDT的核心在于,每一棵树学的是之前所有树结论和的残差,这个残差你可以理解为与预测值的差值。举个例子:比如预测张三的年龄,张三的真实年龄18岁,第一棵树预测张的年龄12岁,此时残差为18-12=6岁,因此在第二棵树中,我们把张的年龄作为6岁去学习,如果预测成功了,则张的真实年龄就是A树和B树的结果预测值的和,但是如果B预测成了5岁,那么残差就变成了6-5=1岁,那么此时需要构建第三树对1岁做预测,后面一样的道理。每棵树都是对之前失败预测的一个补充,用公式的表达就是如下的这个样子:


F0在这里是初始值,Ti是一棵棵的决策树,不同的问题选择不同的损失函数和初始值。在阿里内部对于此算法的叫法为TreeLink。所以下次听到什么Treelink算法了指的就是梯度提升树算法,其实我在这里省略了很大篇幅的数学推导过程,再加上自己还不是专家,无法彻底解释清数学的部分,所以就没有提及,希望以后有时间可以深入学习此方面的知识。
作者:Androidlushangderen 发表于2015/3/30 20:28:53 原文链接
阅读:1064 评论:0 查看评论
遗传算法在走迷宫游戏中的应用
2015年3月26日 21:56

我的数据挖掘算法库:https://github.com/linyiqun/DataMiningAlgorithm
我的算法库:https://github.com/linyiqun/lyq-algorithms-lib
前言

遗传(GA)算法是一个非常有意思的算法,因为他利用了生物进化理论的知识进行问题的求解。算法的核心就是把拥有更好环境适应度的基因遗传给下一代,这就是其中的关键的选择操作,遗传算法整体的阶段分为选择,交叉和变异操作,选择操作和变异操作在其中又是比较重要的步骤。本篇文章不会讲述GA算法的具体细节,之前我曾经写过一篇专门的文章介绍过此算法,链接:http://blog.csdn.net/androidlushangderen/article/details/44041499,里面介绍了一些基本的概念和算法的原理过程,如果你对GA算法掌握的还不错的话,那么对于理解后面遗传算法在走迷宫的应用来说应该不是难事。
算法在迷宫游戏中的应用

先说说走迷宫游戏要解决的问题是什么, 走迷宫游戏说白了就是给定起点,终点,中间设置一堆的障碍,然后要求可达的路径,注意这里指的是可达路径,并没有说一定是最优路径,因为最优路径一定是用步数最少的,这一点还是很不同的。而另一方面,遗传算法也是用来搜索问题最优解的,所以刚刚好可以转移到这个问题上。用一个遗传算法去解决生活中的实际问题最关键的就是如何用遗传算法中的概念表示出来,比如遗传算法中核心的几个概念,基因编码,基因长度的设置,适应度函数的定义,3个概念每个都很重要。好的,目的要求已经慢慢的明确了,下面一个个问题的解决。

为了能让大家更好的理解,下面举出一个例子,如图所示:


图是自己做的,比较简略,以左边点的形式表示,从图中可以看出,起点位置(4, 4),出口左边为绿色区域位置(1,0),X符号表示的障碍区域,不允许经过,问题就转为搜索出从起点到终点位置的最短路径,因为本身例子构造的不是很复杂,我们按照对角线的方式出发,总共的步数=4-1 + 4-0=7步,只要中间不拐弯,每一步都是靠近目标点方向的移动就是最佳的方式。下面看看如何转化成遗传算法中的概念表示。
个体基因长度

首先是基于长度,因为最后筛选出的是一个个体,就是满足条件的个体,他的基因编码就是问题的最优解,所以就能联想把角色的每一步移动操作看出是一个基因编码,总共7步就需要7个基因值表示,所以基因的长度在本例子中就是7。
基因表示

已经将角色的每一次的移动步骤转化为基因的表示,每次的移动总共有4种可能,上下左右,基因编码是标准的二进制形式,所以可以取值为00代表向上,01向下,10向左,11向右,也就是说,每个基因组用2个编码表示,所以总共的编码数字就是2*7=14个,两两一对。
适应度函数

适应度函数的设置应该是在遗传算法中最重要了吧,以为他的设置好坏直接决定着遗传质量的好坏,基因组表示的移动的操作步骤,给定起点位置,通过基因组的编码组数据,我们可以计算出最终的抵达坐标,这里可以很容易的得出结论,如果最后的抵达坐标越接近出口坐标,就越是我们想要的结果,也就是适应值越高,所以我们可以用下面的公式作为适应度函数:



(x, y)为计算出的适应值的函数值在0到1之间波动,1为最大值,就是抵达的坐标恰好是出口位置的时候,当然适应度函数的表示不是唯一的。
算法的代码实现

算法地图数据的输入mapData.txt:

0 0 0 0 0
2 0 0 -1 0
0 0 0 0 0
0 -1 0 0 -1
0 0 0 0 1

就是上面图示的那个例子.

算法的主要实现类GATool.java:

package GA_Maze;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Random;

/**
 * 遗传算法在走迷宫游戏的应用-遗传算法工具类
 *
 * @author lyq
 *
 */
public class GATool {
    // 迷宫出入口标记
    public static final int MAZE_ENTRANCE_POS = 1;
    public static final int MAZE_EXIT_POS = 2;
    // 方向对应的编码数组
    public static final int[][] MAZE_DIRECTION_CODE = new int[][] { { 0, 0 },
            { 0, 1 }, { 1, 0 }, { 1, 1 }, };
    // 坐标点方向改变
    public static final int[][] MAZE_DIRECTION_CHANGE = new int[][] {
            { -1, 0 }, { 1, 0 }, { 0, -1 }, { 0, 1 }, };
    // 方向的文字描述
    public static final String[] MAZE_DIRECTION_LABEL = new String[] { "上",
            "下", "左", "右" };

    // 地图数据文件地址
    private String filePath;
    // 走迷宫的最短步数
    private int stepNum;
    // 初始个体的数量
    private int initSetsNum;
    // 迷宫入口位置
    private int[] startPos;
    // 迷宫出口位置
    private int[] endPos;
    // 迷宫地图数据
    private int[][] mazeData;
    // 初始个体集
    private ArrayList<int[]> initSets;
    // 随机数产生器
    private Random random;

    public GATool(String filePath, int initSetsNum) {
        this.filePath = filePath;
        this.initSetsNum = initSetsNum;

        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    public void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        int rowNum = dataArray.size();
        mazeData = new int[rowNum][rowNum];
        for (int i = 0; i < rowNum; i++) {
            String[] data = dataArray.get(i);
            for (int j = 0; j < data.length; j++) {
                mazeData[i][j] = Integer.parseInt(data[j]);

                // 赋值入口和出口位置
                if (mazeData[i][j] == MAZE_ENTRANCE_POS) {
                    startPos = new int[2];
                    startPos[0] = i;
                    startPos[1] = j;
                } else if (mazeData[i][j] == MAZE_EXIT_POS) {
                    endPos = new int[2];
                    endPos[0] = i;
                    endPos[1] = j;
                }
            }
        }

        // 计算走出迷宫的最短步数
        stepNum = Math.abs(startPos[0] - endPos[0])
                + Math.abs(startPos[1] - endPos[1]);
    }

    /**
     * 产生初始数据集
     */
    private void produceInitSet() {
        // 方向编码
        int directionCode = 0;
        random = new Random();
        initSets = new ArrayList<>();
        // 每个步骤的操作需要用2位数字表示
        int[] codeNum;

        for (int i = 0; i < initSetsNum; i++) {
            codeNum = new int[stepNum * 2];
            for (int j = 0; j < stepNum; j++) {
                directionCode = random.nextInt(4);
                codeNum[2 * j] = MAZE_DIRECTION_CODE[directionCode][0];
                codeNum[2 * j + 1] = MAZE_DIRECTION_CODE[directionCode][1];
            }

            initSets.add(codeNum);
        }
    }

    /**
     * 选择操作,把适值较高的个体优先遗传到下一代
     *
     * @param initCodes
     *            初始个体编码
     * @return
     */
    private ArrayList<int[]> selectOperate(ArrayList<int[]> initCodes) {
        double randomNum = 0;
        double sumFitness = 0;
        ArrayList<int[]> resultCodes = new ArrayList<>();
        double[] adaptiveValue = new double[initSetsNum];

        for (int i = 0; i < initSetsNum; i++) {
            adaptiveValue[i] = calFitness(initCodes.get(i));
            sumFitness += adaptiveValue[i];
        }

        // 转成概率的形式,做归一化操作
        for (int i = 0; i < initSetsNum; i++) {
            adaptiveValue[i] = adaptiveValue[i] / sumFitness;
        }

        for (int i = 0; i < initSetsNum; i++) {
            randomNum = random.nextInt(100) + 1;
            randomNum = randomNum / 100;
            //因为1.0是无法判断到的,,总和会无限接近1.0取为0.99做判断
            if(randomNum == 1){
                randomNum = randomNum - 0.01;
            }
            
            sumFitness = 0;
            // 确定区间
            for (int j = 0; j < initSetsNum; j++) {
                if (randomNum > sumFitness
                        && randomNum <= sumFitness + adaptiveValue[j]) {
                    // 采用拷贝的方式避免引用重复
                    resultCodes.add(initCodes.get(j).clone());
                    break;
                } else {
                    sumFitness += adaptiveValue[j];
                }
            }
        }

        return resultCodes;
    }

    /**
     * 交叉运算
     *
     * @param selectedCodes
     *            上步骤的选择后的编码
     * @return
     */
    private ArrayList<int[]> crossOperate(ArrayList<int[]> selectedCodes) {
        int randomNum = 0;
        // 交叉点
        int crossPoint = 0;
        ArrayList<int[]> resultCodes = new ArrayList<>();
        // 随机编码队列,进行随机交叉配对
        ArrayList<int[]> randomCodeSeqs = new ArrayList<>();

        // 进行随机排序
        while (selectedCodes.size() > 0) {
            randomNum = random.nextInt(selectedCodes.size());

            randomCodeSeqs.add(selectedCodes.get(randomNum));
            selectedCodes.remove(randomNum);
        }

        int temp = 0;
        int[] array1;
        int[] array2;
        // 进行两两交叉运算
        for (int i = 1; i < randomCodeSeqs.size(); i++) {
            if (i % 2 == 1) {
                array1 = randomCodeSeqs.get(i - 1);
                array2 = randomCodeSeqs.get(i);
                crossPoint = random.nextInt(stepNum - 1) + 1;

                // 进行交叉点位置后的编码调换
                for (int j = 0; j < 2 * stepNum; j++) {
                    if (j >= 2 * crossPoint) {
                        temp = array1[j];
                        array1[j] = array2[j];
                        array2[j] = temp;
                    }
                }

                // 加入到交叉运算结果中
                resultCodes.add(array1);
                resultCodes.add(array2);
            }
        }

        return resultCodes;
    }

    /**
     * 变异操作
     *
     * @param crossCodes
     *            交叉运算后的结果
     * @return
     */
    private ArrayList<int[]> variationOperate(ArrayList<int[]> crossCodes) {
        // 变异点
        int variationPoint = 0;
        ArrayList<int[]> resultCodes = new ArrayList<>();

        for (int[] array : crossCodes) {
            variationPoint = random.nextInt(stepNum);

            for (int i = 0; i < array.length; i += 2) {
                // 变异点进行变异
                if (i % 2 == 0 && i / 2 == variationPoint) {
                    array[i] = (array[i] == 0 ? 1 : 0);
                    array[i + 1] = (array[i + 1] == 0 ? 1 : 0);
                    break;
                }
            }

            resultCodes.add(array);
        }

        return resultCodes;
    }

    /**
     * 根据编码计算适值
     *
     * @param code
     *            当前的编码
     * @return
     */
    public double calFitness(int[] code) {
        double fintness = 0;
        // 由编码计算所得的终点横坐标
        int endX = 0;
        // 由编码计算所得的终点纵坐标
        int endY = 0;
        // 基于片段所代表的行走方向
        int direction = 0;
        // 临时坐标点横坐标
        int tempX = 0;
        // 临时坐标点纵坐标
        int tempY = 0;

        endX = startPos[0];
        endY = startPos[1];
        for (int i = 0; i < stepNum; i++) {
            direction = binaryArrayToNum(new int[] { code[2 * i],
                    code[2 * i + 1] });

            // 根据方向改变数组做坐标点的改变
            tempX = endX + MAZE_DIRECTION_CHANGE[direction][0];
            tempY = endY + MAZE_DIRECTION_CHANGE[direction][1];

            // 判断坐标点是否越界
            if (tempX >= 0 && tempX < mazeData.length && tempY >= 0
                    && tempY < mazeData[0].length) {
                // 判断坐标点是否走到阻碍块
                if (mazeData[tempX][tempY] != -1) {
                    endX = tempX;
                    endY = tempY;
                }
            }
        }

        // 根据适值函数进行适值的计算
        fintness = 1.0 / (Math.abs(endX - endPos[0])
                + Math.abs(endY - endPos[1]) + 1);

        return fintness;
    }

    /**
     * 根据当前编码判断是否已经找到出口位置
     *
     * @param code
     *            经过若干次遗传的编码
     * @return
     */
    private boolean ifArriveEndPos(int[] code) {
        boolean isArrived = false;
        // 由编码计算所得的终点横坐标
        int endX = 0;
        // 由编码计算所得的终点纵坐标
        int endY = 0;
        // 基于片段所代表的行走方向
        int direction = 0;
        // 临时坐标点横坐标
        int tempX = 0;
        // 临时坐标点纵坐标
        int tempY = 0;

        endX = startPos[0];
        endY = startPos[1];
        for (int i = 0; i < stepNum; i++) {
            direction = binaryArrayToNum(new int[] { code[2 * i],
                    code[2 * i + 1] });

            // 根据方向改变数组做坐标点的改变
            tempX = endX + MAZE_DIRECTION_CHANGE[direction][0];
            tempY = endY + MAZE_DIRECTION_CHANGE[direction][1];

            // 判断坐标点是否越界
            if (tempX >= 0 && tempX < mazeData.length && tempY >= 0
                    && tempY < mazeData[0].length) {
                // 判断坐标点是否走到阻碍块
                if (mazeData[tempX][tempY] != -1) {
                    endX = tempX;
                    endY = tempY;
                }
            }
        }

        if (endX == endPos[0] && endY == endPos[1]) {
            isArrived = true;
        }

        return isArrived;
    }

    /**
     * 二进制数组转化为数字
     *
     * @param binaryArray
     *            待转化二进制数组
     */
    private int binaryArrayToNum(int[] binaryArray) {
        int result = 0;

        for (int i = binaryArray.length - 1, k = 0; i >= 0; i--, k++) {
            if (binaryArray[i] == 1) {
                result += Math.pow(2, k);
            }
        }

        return result;
    }

    /**
     * 进行遗传算法走出迷宫
     */
    public void goOutMaze() {
        // 迭代遗传次数
        int loopCount = 0;
        boolean canExit = false;
        // 结果路径
        int[] resultCode = null;
        ArrayList<int[]> initCodes;
        ArrayList<int[]> selectedCodes;
        ArrayList<int[]> crossedCodes;
        ArrayList<int[]> variationCodes;

        // 产生初始数据集
        produceInitSet();
        initCodes = initSets;

        while (true) {
            for (int[] array : initCodes) {
                // 遗传迭代的终止条件为是否找到出口位置
                if (ifArriveEndPos(array)) {
                    resultCode = array;
                    canExit = true;
                    break;
                }
            }

            if (canExit) {
                break;
            }

            selectedCodes = selectOperate(initCodes);
            crossedCodes = crossOperate(selectedCodes);
            variationCodes = variationOperate(crossedCodes);
            initCodes = variationCodes;

            loopCount++;
            
            //如果遗传次数超过100次,则退出
            if(loopCount >= 100){
                break;
            }
        }

        System.out.println("总共遗传进化了" + loopCount + "次");
        printFindedRoute(resultCode);
    }

    /**
     * 输出找到的路径
     *
     * @param code
     */
    private void printFindedRoute(int[] code) {
        if(code == null){
            System.out.println("在有限的遗传进化次数内,没有找到最优路径");
            return;
        }
        
        int tempX = startPos[0];
        int tempY = startPos[1];
        int direction = 0;

        System.out.println(MessageFormat.format(
                "起始点位置({0},{1}), 出口点位置({2}, {3})", tempX, tempY, endPos[0],
                endPos[1]));
        
        System.out.print("搜索到的结果编码:");
        for(int value: code){
            System.out.print("" + value);
        }
        System.out.println();
        
        for (int i = 0, k = 1; i < code.length; i += 2, k++) {
            direction = binaryArrayToNum(new int[] { code[i], code[i + 1] });

            tempX += MAZE_DIRECTION_CHANGE[direction][0];
            tempY += MAZE_DIRECTION_CHANGE[direction][1];

            System.out.println(MessageFormat.format(
                    "第{0}步,编码为{1}{2},向{3}移动,移动后到达({4},{5})", k, code[i], code[i+1],
                    MAZE_DIRECTION_LABEL[direction],  tempX, tempY));
        }
    }

}

算法的调用类Client.java:

package GA_Maze;

/**
 * 遗传算法在走迷宫游戏的应用
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args) {
        //迷宫地图文件数据地址
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\mapData.txt";
        //初始个体数量
        int initSetsNum = 4;
        
        GATool tool = new GATool(filePath, initSetsNum);
        tool.goOutMaze();
    }

}

算法的输出:

我测了很多次的数据,因为有可能会一时半会搜索不出来,我设置了最大遗传次数100次。

总共遗传进化了2次
起始点位置(4,4), 出口点位置(1, 0)
搜索到的结果编码:10100000100010
第1步,编码为10,向左移动,移动后到达(4,3)
第2步,编码为10,向左移动,移动后到达(4,2)
第3步,编码为00,向上移动,移动后到达(3,2)
第4步,编码为00,向上移动,移动后到达(2,2)
第5步,编码为10,向左移动,移动后到达(2,1)
第6步,编码为00,向上移动,移动后到达(1,1)
第7步,编码为10,向左移动,移动后到达(1,0)

总共遗传进化了8次
起始点位置(4,4), 出口点位置(1, 0)
搜索到的结果编码:10001000101000
第1步,编码为10,向左移动,移动后到达(4,3)
第2步,编码为00,向上移动,移动后到达(3,3)
第3步,编码为10,向左移动,移动后到达(3,2)
第4步,编码为00,向上移动,移动后到达(2,2)
第5步,编码为10,向左移动,移动后到达(2,1)
第6步,编码为10,向左移动,移动后到达(2,0)
第7步,编码为00,向上移动,移动后到达(1,0)


总共遗传进化了100次
在有限的遗传进化次数内,没有找到最优路径

算法小结

遗传算法在走迷宫中的应用总体而言还是非常有意思的如果你去认真的体会的话,至少让我更加深入的理解了GA算法,如果博友向要亲自实现这算法,我给几点建议,第一是迷宫难度的和初始个体数量的设置,为什么要注意这2点呢,一个是这关系到遗传迭代的次数,在一段时间内有的时候遗传算法是找不出来的,如果找不出来,PC机的CPU会持续高速的计算,所以不要让遗传进行无限制的进行,最好做点次数限制,也可能是我的本本配置太烂了。。在算法的调试中修复了一个之前没发现的bug,就是选择阶段的时候对于随机数的判断少考虑了一种情形,当随机数取到1.0的时候,其实是不能判断到的,因为概念和只会无限接近1,就不知道被划分到哪个区域中了。
作者:Androidlushangderen 发表于2015/3/26 21:56:15 原文链接
阅读:1128 评论:0 查看评论
Chameleon两阶段聚类算法
2015年3月23日 20:43

参考文献:http://www.cnblogs.com/zhangchaoyang/articles/2182752.html(用了很多的图和思想)
博客园(华夏35度) 作者:Orisun
数据挖掘算法-Chameleon算法.百度文库

我的算法库:https://github.com/linyiqun/lyq-algorithms-lib(里面可能有你正想要的算法)
算法介绍

本篇文章讲述的还是聚类算法,也是属于层次聚类算法领域的,不过与上篇文章讲述的分裂实现聚类的方式不同,这次所讲的Chameleon算法是合并形成最终的聚类,恰巧相反。Chamelon的英文单词的意思是变色龙,所以这个算法又称之为变色龙算法,变色龙算法的过程如标题所描绘的那样,是分为2个主要阶段的,不过他可不是像BIRCH算法那样,是树的形式。继续看下面的原理介绍。
算法原理

先来张图来大致了解整个算法的过程。


上面图的显示过程虽然说有3个阶段,但是这其中概况起来就是两个阶段,第一个是形成小簇集的过程就是从Data Set 到k最近邻图到分裂成小聚餐,第二个阶段是合并这些小聚簇形成最终的结果聚簇。理解了算法的大致过程,下面看看里面定义的一些概念,还不少的样子。

为了引出变色龙算法的一些定义,这里先说一下以往的一些聚类算法的不足之处。

1、忽略簇与簇之间的互连性。就会导致最终的结果形成如下:


2、忽略簇与簇之间的近似性。就会导致最终的聚类结果变成这样“:

为什么提这些呢,因为Chameleon算法正好弥补了这2点要求,兼具互连性和近似性。在Chameleon算法中定义了相对互连性,RI表示和相对近似性,RC表示,最后通过一个度量函数:

function value = RI( Ci, Cj)× RC( Ci, Cj)α,α在这里表示的多少次方的意思,不是乘法。

来作为2个簇是否能够合并的标准,其实这些都是第二阶段做的事情了。

在第一阶段,所做的一件关键的事情就是形成小簇集,由零星的几个数据点连成小簇,官方的作法是用hMetic算法根据最小化截断的边的权重和来分割k-最近邻图,然后我网上找了一些资料,没有确切的hMetic算法,借鉴了网上其他人的一些办法,于是用了一个很简单的思路,就是给定一个点,把他离他最近的k个点连接起来,就算是最小簇了。事实证明,效果也不会太差,最近的点的换一个意思就是与其最大权重的边,采用距离的倒数最为权重的大小。因为后面的计算,用到的会是权重而不是距离。

我们再回过头来细说第二阶段所做的事情,首先是2个略复杂的公式(直接采用截图的方式):

                                                                              相对互连性RI=

相对近似性RC=

Ci,Cj表示的是i,j聚簇内的数据点的个数,EC(Ci)表示的Ci聚簇内的边的权重和,EC(Ci,Cj)表示的是连接2个聚簇的边的权重和。

后来我在查阅书籍和一些文库的时候发现,这个公式还不是那么的标准,因为他对分母,分子进行了部分的改变,但是大意上还是一致的,标准公式上用到的是平均权重,而这里用的是和的形式,差别不大,所以就用这个公式了。

那么合并的过程如下:

1、给定度量函数如下minMetric,

2、访问每个簇,计算他与邻近的每个簇的RC和RI,通过度量函数公式计算出值tempMetric。

3、找到最大的tempMetric,如果最大的tempMetric超过阈值minMetric,将簇与此值对应的簇合并

4、如果找到的最大的tempMetric没有超过阈值,则表明此聚簇已合并完成,移除聚簇列表,加入到结果聚簇中。

4、递归步骤2,直到待合并聚簇列表最终大小为空。
算法的实现

算法的输入依旧采用的是坐标点的形式graphData.txt:

0 2 2
1 3 1
2 3 4
3 3 14
4 5 3
5 8 3
6 8 6
7 9 8
8 10 4
9 10 7
10 10 10
11 10 14
12 11 13
13 12 8
14 12 15
15 14 7
16 14 9
17 14 15
18 15 8

算法坐标点数据Point.java:

package DataMining_Chameleon;



/**
 * 坐标点类
 * @author lyq
 *
 */
public class Point{
    //坐标点id号,id号唯一
    int id;
    //坐标横坐标
    Integer x;
    //坐标纵坐标
    Integer y;
    //是否已经被访问过
    boolean isVisited;
    
    public Point(String id, String x, String y){
        this.id = Integer.parseInt(id);
        this.x = Integer.parseInt(x);
        this.y = Integer.parseInt(y);
    }
    
    /**
     * 计算当前点与制定点之间的欧式距离
     *
     * @param p
     *            待计算聚类的p点
     * @return
     */
    public double ouDistance(Point p) {
        double distance = 0;

        distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y)
                * (this.y - p.y);
        distance = Math.sqrt(distance);

        return distance;
    }
    
    /**
     * 判断2个坐标点是否为用个坐标点
     *
     * @param p
     *            待比较坐标点
     * @return
     */
    public boolean isTheSame(Point p) {
        boolean isSamed = false;

        if (this.x == p.x && this.y == p.y) {
            isSamed = true;
        }

        return isSamed;
    }
}


簇类Cluster.java:

package DataMining_Chameleon;

import java.util.ArrayList;

/**
 * 聚簇类
 *
 * @author lyq
 *
 */
public class Cluster implements Cloneable{
    //簇唯一id标识号
    int id;
    // 聚簇内的坐标点集合
    ArrayList<Point> points;
    // 聚簇内的所有边的权重和
    double weightSum = 0;

    public Cluster(int id, ArrayList<Point> points) {
        this.id = id;
        this.points = points;
    }

    /**
     * 计算聚簇的内部的边权重和
     *
     * @return
     */
    public double calEC() {
        int id1 = 0;
        int id2 = 0;
        weightSum = 0;
        
        for (Point p1 : points) {
            for (Point p2 : points) {
                id1 = p1.id;
                id2 = p2.id;

                // 为了避免重复计算,取id1小的对应大的
                if (id1 < id2 && ChameleonTool.edges[id1][id2] == 1) {
                    weightSum += ChameleonTool.weights[id1][id2];
                }
            }
        }

        return weightSum;
    }

    /**
     * 计算2个簇之间最近的n条边
     *
     * @param otherCluster
     *            待比较的簇
     * @param n
     *            最近的边的数目
     * @return
     */
    public ArrayList<int[]> calNearestEdge(Cluster otherCluster, int n){
        int count = 0;
        double distance = 0;
        double minDistance = Integer.MAX_VALUE;
        Point point1 = null;
        Point point2 = null;
        ArrayList<int[]> edgeList = new ArrayList<>();
        ArrayList<Point> pointList1 = (ArrayList<Point>) points.clone();
        ArrayList<Point> pointList2 = null;
        Cluster c2 = null;
        
        try {
            c2 = (Cluster) otherCluster.clone();
            pointList2 = c2.points;
        } catch (CloneNotSupportedException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        int[] tempEdge;
        // 循环计算出每次的最近距离
        while (count < n) {
            tempEdge = new int[2];
            minDistance = Integer.MAX_VALUE;
            
            for (Point p1 : pointList1) {
                for (Point p2 :  pointList2) {
                    distance = p1.ouDistance(p2);
                    if (distance < minDistance) {
                        point1 = p1;
                        point2 = p2;
                        tempEdge[0] = p1.id;
                        tempEdge[1] = p2.id;

                        minDistance = distance;
                    }
                }
            }

            pointList1.remove(point1);
            pointList2.remove(point2);
            edgeList.add(tempEdge);
            count++;
        }

        return edgeList;
    }

    @Override
    protected Object clone() throws CloneNotSupportedException {
        // TODO Auto-generated method stub
        
        //引用需要再次复制,实现深拷贝
        ArrayList<Point> pointList = (ArrayList<Point>) this.points.clone();
        Cluster cluster = new Cluster(id, pointList);
        
        return cluster;
    }
    
    

}


算法工具类Chameleon.java:

package DataMining_Chameleon;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;

/**
 * Chameleon 两阶段聚类算法工具类
 *
 * @author lyq
 *
 */
public class ChameleonTool {
    // 测试数据点文件地址
    private String filePath;
    // 第一阶段的k近邻的k大小
    private int k;
    // 簇度量函数阈值
    private double minMetric;
    // 总的坐标点的个数
    private int pointNum;
    // 总的连接矩阵的情况,括号表示的是坐标点的id号
    public static int[][] edges;
    // 点与点之间的边的权重
    public static double[][] weights;
    // 原始坐标点数据
    private ArrayList<Point> totalPoints;
    // 第一阶段产生的所有的连通子图作为最初始的聚类
    private ArrayList<Cluster> initClusters;
    // 结果簇结合
    private ArrayList<Cluster> resultClusters;

    public ChameleonTool(String filePath, int k, double minMetric) {
        this.filePath = filePath;
        this.k = k;
        this.minMetric = minMetric;

        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        Point p;
        totalPoints = new ArrayList<>();
        for (String[] array : dataArray) {
            p = new Point(array[0], array[1], array[2]);
            totalPoints.add(p);
        }
        pointNum = totalPoints.size();
    }

    /**
     * 递归的合并小聚簇
     */
    private void combineSubClusters() {
        Cluster cluster = null;

        resultClusters = new ArrayList<>();

        // 当最后的聚簇只剩下一个的时候,则退出循环
        while (initClusters.size() > 1) {
            cluster = initClusters.get(0);
            combineAndRemove(cluster, initClusters);
        }
    }

    /**
     * 递归的合并聚簇和移除聚簇
     *
     * @param clusterList
     */
    private ArrayList<Cluster> combineAndRemove(Cluster cluster,
            ArrayList<Cluster> clusterList) {
        ArrayList<Cluster> remainClusters;
        double metric = 0;
        double maxMetric = -Integer.MAX_VALUE;
        Cluster cluster1 = null;
        Cluster cluster2 = null;

        for (Cluster c2 : clusterList) {
            if (cluster.id == c2.id) {
                continue;
            }

            metric = calMetricfunction(cluster, c2, 1);

            if (metric > maxMetric) {
                maxMetric = metric;
                cluster1 = cluster;
                cluster2 = c2;
            }
        }

        // 如果度量函数值超过阈值,则进行合并,继续搜寻可以合并的簇
        if (maxMetric > minMetric) {
            clusterList.remove(cluster2);
            // 将边进行连接
            connectClusterToCluster(cluster1, cluster2);
            // 将簇1和簇2合并
            cluster1.points.addAll(cluster2.points);
            remainClusters = combineAndRemove(cluster1, clusterList);
        } else {
            clusterList.remove(cluster);
            remainClusters = clusterList;
            resultClusters.add(cluster);
        }

        return remainClusters;
    }

    /**
     * 将2个簇进行边的连接
     *
     * @param c1
     *            聚簇1
     * @param c2
     *            聚簇2
     */
    private void connectClusterToCluster(Cluster c1, Cluster c2) {
        ArrayList<int[]> connectedEdges;

        connectedEdges = c1.calNearestEdge(c2, 2);

        for (int[] array : connectedEdges) {
            edges[array[0]][array[1]] = 1;
            edges[array[1]][array[0]] = 1;
        }
    }

    /**
     * 算法第一阶段形成局部的连通图
     */
    private void connectedGraph() {
        double distance = 0;
        Point p1;
        Point p2;

        // 初始化权重矩阵和连接矩阵
        weights = new double[pointNum][pointNum];
        edges = new int[pointNum][pointNum];
        for (int i = 0; i < pointNum; i++) {
            for (int j = 0; j < pointNum; j++) {
                p1 = totalPoints.get(i);
                p2 = totalPoints.get(j);

                distance = p1.ouDistance(p2);
                if (distance == 0) {
                    // 如果点为自身的话,则权重设置为0
                    weights[i][j] = 0;
                } else {
                    // 边的权重采用的值为距离的倒数,距离越近,权重越大
                    weights[i][j] = 1.0 / distance;
                }
            }
        }

        double[] tempWeight;
        int[] ids;
        int id1 = 0;
        int id2 = 0;
        // 对每个id坐标点,取其权重前k个最大的点进行相连
        for (int i = 0; i < pointNum; i++) {
            tempWeight = weights[i];
            // 进行排序
            ids = sortWeightArray(tempWeight);

            // 取出前k个权重最大的边进行连接
            for (int j = 0; j < ids.length; j++) {
                if (j < k) {
                    id1 = i;
                    id2 = ids[j];

                    edges[id1][id2] = 1;
                    edges[id2][id1] = 1;
                }
            }
        }
    }

    /**
     * 权重的冒泡算法排序
     *
     * @param array
     *            待排序数组
     */
    private int[] sortWeightArray(double[] array) {
        double[] copyArray = array.clone();
        int[] ids = null;
        int k = 0;
        double maxWeight = -1;

        ids = new int[pointNum];
        for (int i = 0; i < pointNum; i++) {
            maxWeight = -1;

            for (int j = 0; j < copyArray.length; j++) {
                if (copyArray[j] > maxWeight) {
                    maxWeight = copyArray[j];
                    k = j;
                }
            }

            ids[i] = k;
            // 将当前找到的最大的值重置为-1代表已经找到过了
            copyArray[k] = -1;
        }

        return ids;
    }

    /**
     * 根据边的连通性去深度优先搜索所有的小聚簇
     */
    private void searchSmallCluster() {
        int currentId = 0;
        Point p;
        Cluster cluster;
        initClusters = new ArrayList<>();
        ArrayList<Point> pointList = null;

        // 以id的方式逐个去dfs搜索
        for (int i = 0; i < pointNum; i++) {
            p = totalPoints.get(i);

            if (p.isVisited) {
                continue;
            }

            pointList = new ArrayList<>();
            pointList.add(p);
            recusiveDfsSearch(p, -1, pointList);

            cluster = new Cluster(currentId, pointList);
            initClusters.add(cluster);

            currentId++;
        }
    }

    /**
     * 深度优先的方式找到边所连接着的所有坐标点
     *
     * @param p
     *            当前搜索的起点
     * @param lastId
     *            此点的父坐标点
     * @param pList
     *            坐标点列表
     */
    private void recusiveDfsSearch(Point p, int parentId, ArrayList<Point> pList) {
        int id1 = 0;
        int id2 = 0;
        Point newPoint;

        if (p.isVisited) {
            return;
        }

        p.isVisited = true;
        for (int j = 0; j < pointNum; j++) {
            id1 = p.id;
            id2 = j;

            if (edges[id1][id2] == 1 && id2 != parentId) {
                newPoint = totalPoints.get(j);
                pList.add(newPoint);
                // 以此点为起点,继续递归搜索
                recusiveDfsSearch(newPoint, id1, pList);
            }
        }
    }

    /**
     * 计算连接2个簇的边的权重
     *
     * @param c1
     *            聚簇1
     * @param c2
     *            聚簇2
     * @return
     */
    private double calEC(Cluster c1, Cluster c2) {
        double resultEC = 0;
        ArrayList<int[]> connectedEdges = null;

        connectedEdges = c1.calNearestEdge(c2, 2);

        // 计算连接2部分的边的权重和
        for (int[] array : connectedEdges) {
            resultEC += weights[array[0]][array[1]];
        }

        return resultEC;
    }

    /**
     * 计算2个簇的相对互连性
     *
     * @param c1
     * @param c2
     * @return
     */
    private double calRI(Cluster c1, Cluster c2) {
        double RI = 0;
        double EC1 = 0;
        double EC2 = 0;
        double EC1To2 = 0;

        EC1 = c1.calEC();
        EC2 = c2.calEC();
        EC1To2 = calEC(c1, c2);

        RI = 2 * EC1To2 / (EC1 + EC2);

        return RI;
    }

    /**
     * 计算簇的相对近似度
     *
     * @param c1
     *            簇1
     * @param c2
     *            簇2
     * @return
     */
    private double calRC(Cluster c1, Cluster c2) {
        double RC = 0;
        double EC1 = 0;
        double EC2 = 0;
        double EC1To2 = 0;
        int pNum1 = c1.points.size();
        int pNum2 = c2.points.size();

        EC1 = c1.calEC();
        EC2 = c2.calEC();
        EC1To2 = calEC(c1, c2);

        RC = EC1To2 * (pNum1 + pNum2) / (pNum2 * EC1 + pNum1 * EC2);

        return RC;
    }

    /**
     * 计算度量函数的值
     *
     * @param c1
     *            簇1
     * @param c2
     *            簇2
     * @param alpha
     *            幂的参数值
     * @return
     */
    private double calMetricfunction(Cluster c1, Cluster c2, int alpha) {
        // 度量函数值
        double metricValue = 0;
        double RI = 0;
        double RC = 0;

        RI = calRI(c1, c2);
        RC = calRC(c1, c2);
        // 如果alpha大于1,则更重视相对近似性,如果alpha逍遥于1,注重相对互连性
        metricValue = RI * Math.pow(RC, alpha);

        return metricValue;
    }

    /**
     * 输出聚簇列
     *
     * @param clusterList
     *            输出聚簇列
     */
    private void printClusters(ArrayList<Cluster> clusterList) {
        int i = 1;

        for (Cluster cluster : clusterList) {
            System.out.print("聚簇" + i + ":");
            for (Point p : cluster.points) {
                System.out.print(MessageFormat.format("({0}, {1}) ", p.x, p.y));
            }
            System.out.println();
            i++;
        }

    }

    /**
     * 创建聚簇
     */
    public void buildCluster() {
        // 第一阶段形成小聚簇
        connectedGraph();
        searchSmallCluster();
        System.out.println("第一阶段形成的小簇集合:");
        printClusters(initClusters);

        // 第二阶段根据RI和RC的值合并小聚簇形成最终结果聚簇
        combineSubClusters();
        System.out.println("最终的聚簇集合:");
        printClusters(resultClusters);
    }
}


调用类Client.java:

package DataMining_Chameleon;

/**
 * Chameleon(变色龙)两阶段聚类算法
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args){
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\graphData.txt";
        //k-近邻的k设置
        int k = 1;
        //度量函数阈值
        double minMetric = 0.1;
        
        ChameleonTool tool = new ChameleonTool(filePath, k, minMetric);
        tool.buildCluster();
    }
}

算法输出如下:

第一阶段形成的小簇集合:
聚簇1:(2, 2) (3, 1) (3, 4) (5, 3)
聚簇2:(3, 14) (10, 14) (11, 13)
聚簇3:(8, 3) (10, 4)
聚簇4:(8, 6) (9, 8) (10, 7) (12, 8) (10, 10)
聚簇5:(12, 15) (14, 15)
聚簇6:(14, 7) (15, 8) (14, 9)
最终的聚簇集合:
聚簇1:(2, 2) (3, 1) (3, 4) (5, 3) (8, 3) (10, 4)
聚簇2:(3, 14) (10, 14) (11, 13) (12, 15) (14, 15)
聚簇3:(8, 6) (9, 8) (10, 7) (12, 8) (10, 10) (14, 7) (15, 8) (14, 9)

图形展示情况如下:

首先是第一阶段形成小簇集的结果:


然后是第二阶段合并的结果:


与结果相对应,请读者细细比较。
算法总结

在算法的实现过程中遇到一个比较大的困惑点在于2个簇近和并的时候,合并边的选取,我是直接采用的是最近的2对顶点进行连接,显然这是不合理的,当簇与簇规模比较大的时候,这个连接边需要变多,我有想过做一个计算函数,帮我计算估计要连接几条边。这里再提几点变色龙算法的优缺点,首先是这个算法将互连性和近似性都考虑了进来,其次他能发现高质量的任意形状的簇,问题有,第一与KNN算法一样,这个k的取值永远是一个痛,时间复杂度高,有可能会达到O(n*n)的程度,细心的博友一定能观察到我好多地方用到了双次循环的操作了。
作者:Androidlushangderen 发表于2015/3/23 20:43:37 原文链接
阅读:829 评论:0 查看评论
dbscan基于密度的空间聚类算法
2015年3月16日 20:24

参考文献:百度百科 http://baike.baidu.com

我的算法库:https://github.com/linyiqun/lyq-algorithms-lib
算法介绍

说到聚类算法,大家如果有看过我写的一些关于机器学习的算法文章,一定都这类算法不会陌生,之前将的是划分算法(K均值算法)和层次聚类算法(BIRCH算法),各有优缺点和好坏。本文所述的算法是另外一类的聚类算法,他能够克服BIRCH算法对于形状的限制,因为BIRCH算法偏向于聚簇球形的聚类形成,而dbscan采用的是基于空间的密度的原理,所以可以适用于任何形状的数据聚类实现。
算法原理

在介绍算法原理之前,先介绍几个dbscan算法中的几个概念定义:

Ε领域:给定对象半径为Ε内的区域称为该对象的Ε领域;
核心对象:如果给定对象Ε领域内的样本点数大于等于MinPts,则称该对象为核心对象;
直接密度可达:对于样本集合D,如果样本点q在p的Ε领域内,并且p为核心对象,那么对象q从对象p直接密度可达。
密度可达:对于样本集合D,给定一串样本点p1,p2….pn,p= p1,q= pn,假如对象pi从pi-1直接密度可达,那么对象q从对象p密度可达。
密度相连:存在样本集合D中的一点o,如果对象o到对象p和对象q都是密度可达的,那么p和q密度相联。

下面是算法的过程(可能说的不是很清楚):

1、扫描原始数据,获取所有的数据点。

2、遍历数据点中的每个点,如果此点已经被访问(处理)过,则跳过,否则取出此点做聚类查找。

3、以步骤2中找到的点P为核心对象,找出在E领域内所有满足条件的点,如果个数大于等于MinPts,则此点为核心对象,加入到簇中。

4、再次P为核心对象的簇中的每个点,进行递归的扩增簇。如果P点的递归扩增结束,再次回到步骤2。

5、算法的终止条件为所有的点都被访问(处理过)。

算法可以理解为是一个DFS的深度优先扩展。
算法的实现

算法的输入Input(格式(x, y)):

2 2
3 1
3 4
3 14
5 3
8 3
8 6
9 8
10 4
10 7
10 10
10 14
11 13
12 8
12 15
14 7
14 9
14 15
15 8

坐标点类Point.java:

package DataMining_DBSCAN;

/**
 * 坐标点类
 *
 * @author lyq
 *
 */
public class Point {
    // 坐标点横坐标
    int x;
    // 坐标点纵坐标
    int y;
    // 此节点是否已经被访问过
    boolean isVisited;

    public Point(String x, String y) {
        this.x = (Integer.parseInt(x));
        this.y = (Integer.parseInt(y));
        this.isVisited = false;
    }

    /**
     * 计算当前点与制定点之间的欧式距离
     *
     * @param p
     *            待计算聚类的p点
     * @return
     */
    public double ouDistance(Point p) {
        double distance = 0;

        distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y)
                * (this.y - p.y);
        distance = Math.sqrt(distance);

        return distance;
    }

    /**
     * 判断2个坐标点是否为用个坐标点
     *
     * @param p
     *            待比较坐标点
     * @return
     */
    public boolean isTheSame(Point p) {
        boolean isSamed = false;

        if (this.x == p.x && this.y == p.y) {
            isSamed = true;
        }

        return isSamed;
    }
}

算法工具类DNSCANTool.java:

package DataMining_DBSCAN;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;

/**
 * DBSCAN基于密度聚类算法工具类
 *
 * @author lyq
 *
 */
public class DBSCANTool {
    // 测试数据文件地址
    private String filePath;
    // 簇扫描半径
    private double eps;
    // 最小包含点数阈值
    private int minPts;
    // 所有的数据坐标点
    private ArrayList<Point> totalPoints;
    // 聚簇结果
    private ArrayList<ArrayList<Point>> resultClusters;
    //噪声数据
    private ArrayList<Point> noisePoint;

    public DBSCANTool(String filePath, double eps, int minPts) {
        this.filePath = filePath;
        this.eps = eps;
        this.minPts = minPts;
        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    public void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        Point p;
        totalPoints = new ArrayList<>();
        for (String[] array : dataArray) {
            p = new Point(array[0], array[1]);
            totalPoints.add(p);
        }
    }

    /**
     * 递归的寻找聚簇
     *
     * @param pointList
     *            当前的点列表
     * @param parentCluster
     *            父聚簇
     */
    private void recursiveCluster(Point point, ArrayList<Point> parentCluster) {
        double distance = 0;
        ArrayList<Point> cluster;

        // 如果已经访问过了,则跳过
        if (point.isVisited) {
            return;
        }

        point.isVisited = true;
        cluster = new ArrayList<>();
        for (Point p2 : totalPoints) {
            // 过滤掉自身的坐标点
            if (point.isTheSame(p2)) {
                continue;
            }

            distance = point.ouDistance(p2);
            if (distance <= eps) {
                // 如果聚类小于给定的半径,则加入簇中
                cluster.add(p2);
            }
        }

        if (cluster.size() >= minPts) {
            // 将自己也加入到聚簇中
            cluster.add(point);
            // 如果附近的节点个数超过最下值,则加入到父聚簇中,同时去除重复的点
            addCluster(parentCluster, cluster);

            for (Point p : cluster) {
                recursiveCluster(p, parentCluster);
            }
        }
    }

    /**
     * 往父聚簇中添加局部簇坐标点
     *
     * @param parentCluster
     *            原始父聚簇坐标点
     * @param cluster
     *            待合并的聚簇
     */
    private void addCluster(ArrayList<Point> parentCluster,
            ArrayList<Point> cluster) {
        boolean isCotained = false;
        ArrayList<Point> addPoints = new ArrayList<>();

        for (Point p : cluster) {
            isCotained = false;
            for (Point p2 : parentCluster) {
                if (p.isTheSame(p2)) {
                    isCotained = true;
                    break;
                }
            }

            if (!isCotained) {
                addPoints.add(p);
            }
        }

        parentCluster.addAll(addPoints);
    }

    /**
     * dbScan算法基于密度的聚类
     */
    public void dbScanCluster() {
        ArrayList<Point> cluster = null;
        resultClusters = new ArrayList<>();
        noisePoint = new ArrayList<>();
        
        for (Point p : totalPoints) {
            if(p.isVisited){
                continue;
            }
            
            cluster = new ArrayList<>();
            recursiveCluster(p, cluster);

            if (cluster.size() > 0) {
                resultClusters.add(cluster);
            }else{
                noisePoint.add(p);
            }
        }
        removeFalseNoise();
        
        printClusters();
    }
    
    /**
     * 移除被错误分类的噪声点数据
     */
    private void removeFalseNoise(){
        ArrayList<Point> totalCluster = new ArrayList<>();
        ArrayList<Point> deletePoints = new ArrayList<>();
        
        //将聚簇合并
        for(ArrayList<Point> list: resultClusters){
            totalCluster.addAll(list);
        }
        
        for(Point p: noisePoint){
            for(Point p2: totalCluster){
                if(p2.isTheSame(p)){
                    deletePoints.add(p);
                }
            }
        }
        
        noisePoint.removeAll(deletePoints);
    }

    /**
     * 输出聚类结果
     */
    private void printClusters() {
        int i = 1;
        for (ArrayList<Point> pList : resultClusters) {
            System.out.print("聚簇" + (i++) + ":");
            for (Point p : pList) {
                System.out.print(MessageFormat.format("({0},{1}) ", p.x, p.y));
            }
            System.out.println();
        }
        
        System.out.println();
        System.out.print("噪声数据:");
        for (Point p : noisePoint) {
            System.out.print(MessageFormat.format("({0},{1}) ", p.x, p.y));
        }
        System.out.println();
    }
}

测试类Client.java:

package DataMining_DBSCAN;

/**
 * Dbscan基于密度的聚类算法测试类
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args){
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
        //簇扫描半径
        double eps = 3;
        //最小包含点数阈值
        int minPts = 3;
        
        DBSCANTool tool = new DBSCANTool(filePath, eps, minPts);
        tool.dbScanCluster();
    }
}

算法的输出:

聚簇1:(2,2) (3,4) (5,3) (3,1) (8,3) (8,6) (10,4) (9,8) (10,7) (10,10) (12,8) (14,7) (14,9) (15,8)
聚簇2:(10,14) (11,13) (14,15) (12,15)

噪声数据:(3,14)

图示结果如下:

算法的缺点

dbscan虽说可以用于任何形状的聚类发现,但是对于密度分布不均衡的数据,变化比较大,分类的性能就不会特别好,还有1点是不能反映高尺寸数据。
作者:Androidlushangderen 发表于2015/3/16 20:24:44 原文链接
阅读:762 评论:2 查看评论
Genetic Algorithm遗传算法学习
2015年3月3日 18:34

参考资料:http://blog.csdn.net/b2b160/article/details/4680853/#comments(冒昧的用了链接下的几张图)
百度百科:http://baike.baidu.com/link?url=FcwTBx_yPcD5DDEnN1FqvTkG4QNllkB7Yis6qFOL65wpn6EdT5LXFxUCmv4JlUfV3LUPHQGdYbGj8kHVs3GuaK
算法介绍

遗传算法是模拟达尔文生物进化论的自然选择和遗传学进化机理的计算模型。运用到了生物学中“适者生存,优胜劣汰”的原理。在每一次的进化过程中,把拥有更好环境适应性的基因传给下一代,直到最后的个体满足特定的条件,代表进化的结束,GA(后面都以GA代称为遗传算法的意思)算法是一种利用生物进化理论来搜索最优解的一种算法。
算法原理
算法的基本框架

了解算法的基本框架是理解整个算法的基础,算法的框架包含编码、适应度函数、初始群体的选择。先假设本例子的目标函数如下,求出他的最大值

f(x) = x1 * x1 + x2 * x2; 1<= x1 <=7, 1<= x2<=7

1、适应度函数。适应度函数是用来计算个体的适应值计算,顾名思义,适应值越高的个体,环境适应性越好,自然就要有更多的机会把自己的基因传给下一代,所以,其实适应度函数在这里起着一个过滤条件的作用。在本例子,目标函数总为非负值,并且以函数最大化为优化目标,所以可以将函数的值作为适应值。

2、编码。编码指的是将个体的信息表示成编码的字符串形式。如果个体变量时数字的形式,可以转为二进制的方式。
算法的群体选择过程

这个过程是遗传算法的核心过程,在里面分为了3个小的步骤,选择,交叉,变异。

1、初始个体的选择过程。就是挑选哪些个体作为即将产生下一代的个体呢。过程如下:

(1).利用适值函数,计算每个个体的适值,计算每个个体的适值占总和的百分比。

(2).根据百分比为每个个体划定一定的所属区间。

(3).产生一个[0, 1]的小数,判断这个小数点落在哪个个体的区间内,就表明要选出这个个体。这里其实就已经蕴含着把高适值的个体优先传入下一代,因为适值高,有更高的几率小数是落在自己的区间内的。

用图示范的形式表现如下:


2、交叉运算。个体的交叉运算过程的步骤细节如下:

(1).首先对于上个选择步骤选择来的个体进行随机的两两配对。

(2).取出其中配对的一对个体,随机设定一个交叉点,2个个体的编码的交叉点后的编码值进行对调,生成新的2个个体编码。

(3).所有的配对的个体都执行步骤(2)操作,最后加入到一个结果集合中。

交叉运算的方式又很多,上面用的方法是其中比较常用的单点交叉方式。

用图示范的形式表现如下:


3.变异运算。变异运算过程的步骤细节如下:

(1).遍历从交叉运算所得结果的结果集,取出集中一个个体编码,准备做变异操作

(2).产生随机的一个变异点位置。所选个体的变异点位置的值做变异操作,将他的值取为反向的值。

(3).将所有的交叉运算所得的结果集中的元素都执行步骤(2)操作。

用图示范的形式如下:


整个遗传算法的原理过程,用一个流程图的表现形式如下:


算法代码实现

算法代码的测试用例正如算法原理所举的一样,遗传进化的阈值条件为:个体中产生了使目标函数最大化值的个体,就是基因为111111。

GATool.java:

package GA;

import java.util.ArrayList;
import java.util.Random;

/**
 * 遗传算法工具类
 *
 * @author lyq
 *
 */
public class GATool {
    // 变量最小值
    private int minNum;
    // 变量最大值
    private int maxNum;
    // 单个变量的编码位数
    private int codeNum;
    // 初始种群的数量
    private int initSetsNum;
    // 随机数生成器
    private Random random;
    // 初始群体
    private ArrayList<int[]> initSets;

    public GATool(int minNum, int maxNum, int initSetsNum) {
        this.minNum = minNum;
        this.maxNum = maxNum;
        this.initSetsNum = initSetsNum;

        this.random = new Random();
        produceInitSets();
    }

    /**
     * 产生初始化群体
     */
    private void produceInitSets() {
        this.codeNum = 0;
        int num = maxNum;
        int[] array;

        initSets = new ArrayList<>();

        // 确定编码位数
        while (num != 0) {
            codeNum++;
            num /= 2;
        }

        for (int i = 0; i < initSetsNum; i++) {
            array = produceInitCode();
            initSets.add(array);
        }
    }

    /**
     * 产生初始个体的编码
     *
     * @return
     */
    private int[] produceInitCode() {
        int num = 0;
        int num2 = 0;
        int[] tempArray;
        int[] array1;
        int[] array2;

        tempArray = new int[2 * codeNum];
        array1 = new int[codeNum];
        array2 = new int[codeNum];

        num = 0;
        while (num < minNum || num > maxNum) {
            num = random.nextInt(maxNum) + 1;
        }
        numToBinaryArray(array1, num);

        while (num2 < minNum || num2 > maxNum) {
            num2 = random.nextInt(maxNum) + 1;
        }
        numToBinaryArray(array2, num2);

        // 组成总的编码
        for (int i = 0, k = 0; i < tempArray.length; i++, k++) {
            if (k < codeNum) {
                tempArray[i] = array1[k];
            } else {
                tempArray[i] = array2[k - codeNum];
            }
        }

        return tempArray;
    }

    /**
     * 选择操作,把适值较高的个体优先遗传到下一代
     *
     * @param initCodes
     *            初始个体编码
     * @return
     */
    private ArrayList<int[]> selectOperate(ArrayList<int[]> initCodes) {
        double randomNum = 0;
        double sumAdaptiveValue = 0;
        ArrayList<int[]> resultCodes = new ArrayList<>();
        double[] adaptiveValue = new double[initSetsNum];

        for (int i = 0; i < initSetsNum; i++) {
            adaptiveValue[i] = calCodeAdaptiveValue(initCodes.get(i));
            sumAdaptiveValue += adaptiveValue[i];
        }

        // 转成概率的形式,做归一化操作
        for (int i = 0; i < initSetsNum; i++) {
            adaptiveValue[i] = adaptiveValue[i] / sumAdaptiveValue;
        }

        for (int i = 0; i < initSetsNum; i++) {
            randomNum = random.nextInt(100) + 1;
            randomNum = randomNum / 100;

            sumAdaptiveValue = 0;
            // 确定区间
            for (int j = 0; j < initSetsNum; j++) {
                if (randomNum > sumAdaptiveValue
                        && randomNum <= sumAdaptiveValue + adaptiveValue[j]) {
                    //采用拷贝的方式避免引用重复
                    resultCodes.add(initCodes.get(j).clone());
                    break;
                } else {
                    sumAdaptiveValue += adaptiveValue[j];
                }
            }
        }

        return resultCodes;
    }

    /**
     * 交叉运算
     *
     * @param selectedCodes
     *            上步骤的选择后的编码
     * @return
     */
    private ArrayList<int[]> crossOperate(ArrayList<int[]> selectedCodes) {
        int randomNum = 0;
        // 交叉点
        int crossPoint = 0;
        ArrayList<int[]> resultCodes = new ArrayList<>();
        // 随机编码队列,进行随机交叉配对
        ArrayList<int[]> randomCodeSeqs = new ArrayList<>();

        // 进行随机排序
        while (selectedCodes.size() > 0) {
            randomNum = random.nextInt(selectedCodes.size());

            randomCodeSeqs.add(selectedCodes.get(randomNum));
            selectedCodes.remove(randomNum);
        }

        int temp = 0;
        int[] array1;
        int[] array2;
        // 进行两两交叉运算
        for (int i = 1; i < randomCodeSeqs.size(); i++) {
            if (i % 2 == 1) {
                array1 = randomCodeSeqs.get(i - 1);
                array2 = randomCodeSeqs.get(i);
                crossPoint = random.nextInt(2 * codeNum - 1) + 1;

                // 进行交叉点位置后的编码调换
                for (int j = 0; j < 2 * codeNum; j++) {
                    if (j >= crossPoint) {
                        temp = array1[j];
                        array1[j] = array2[j];
                        array2[j] = temp;
                    }
                }

                // 加入到交叉运算结果中
                resultCodes.add(array1);
                resultCodes.add(array2);
            }
        }

        return resultCodes;
    }

    /**
     * 变异操作
     *
     * @param crossCodes
     *            交叉运算后的结果
     * @return
     */
    private ArrayList<int[]> variationOperate(ArrayList<int[]> crossCodes) {
        // 变异点
        int variationPoint = 0;
        ArrayList<int[]> resultCodes = new ArrayList<>();

        for (int[] array : crossCodes) {
            variationPoint = random.nextInt(codeNum * 2);

            for (int i = 0; i < array.length; i++) {
                // 变异点进行变异
                if (i == variationPoint) {
                    array[i] = (array[i] == 0 ? 1 : 0);
                    break;
                }
            }

            resultCodes.add(array);
        }

        return resultCodes;
    }

    /**
     * 数字转为二进制形式
     *
     * @param binaryArray
     *            转化后的二进制数组形式
     * @param num
     *            待转化数字
     */
    private void numToBinaryArray(int[] binaryArray, int num) {
        int index = 0;
        int temp = 0;
        while (num != 0) {
            binaryArray[index] = num % 2;
            index++;
            num /= 2;
        }
        
        //进行数组前和尾部的调换
        for(int i=0; i<binaryArray.length/2; i++){
            temp = binaryArray[i];
            binaryArray[i] = binaryArray[binaryArray.length - 1 - i];
            binaryArray[binaryArray.length - 1 - i] = temp;
        }
    }

    /**
     * 二进制数组转化为数字
     *
     * @param binaryArray
     *            待转化二进制数组
     */
    private int binaryArrayToNum(int[] binaryArray) {
        int result = 0;

        for (int i = binaryArray.length-1, k=0; i >=0 ; i--, k++) {
            if (binaryArray[i] == 1) {
                result += Math.pow(2, k);
            }
        }

        return result;
    }

    /**
     * 计算个体编码的适值
     *
     * @param codeArray
     */
    private int calCodeAdaptiveValue(int[] codeArray) {
        int result = 0;
        int x1 = 0;
        int x2 = 0;
        int[] array1 = new int[codeNum];
        int[] array2 = new int[codeNum];

        for (int i = 0, k = 0; i < codeArray.length; i++, k++) {
            if (k < codeNum) {
                array1[k] = codeArray[i];
            } else {
                array2[k - codeNum] = codeArray[i];
            }
        }

        // 进行适值的叠加
        x1 = binaryArrayToNum(array1);
        x2 = binaryArrayToNum(array2);
        result = x1 * x1 + x2 * x2;

        return result;
    }

    /**
     * 进行遗传算法计算
     */
    public void geneticCal() {
        // 最大适值
        int maxFitness;
        //迭代遗传次数
        int loopCount = 0;
        boolean canExit = false;
        ArrayList<int[]> initCodes;
        ArrayList<int[]> selectedCodes;
        ArrayList<int[]> crossedCodes;
        ArrayList<int[]> variationCodes;
        
        int[] maxCode = new int[2*codeNum];
        //计算最大适值
        for(int i=0; i<2*codeNum; i++){
            maxCode[i] = 1;
        }
        maxFitness = calCodeAdaptiveValue(maxCode);

        initCodes = initSets;
        while (true) {
            for (int[] array : initCodes) {
                // 遗传迭代的终止条件为存在编码达到最大适值
                if (maxFitness == calCodeAdaptiveValue(array)) {
                    canExit = true;
                    break;
                }
            }

            if (canExit) {
                break;
            }

            selectedCodes = selectOperate(initCodes);
            crossedCodes = crossOperate(selectedCodes);
            variationCodes = variationOperate(crossedCodes);
            initCodes = variationCodes;
            
            loopCount++;
        }

        System.out.println("总共遗传进化了" + loopCount +"次" );
        printFinalCodes(initCodes);
    }

    /**
     * 输出最后的编码集
     *
     * @param finalCodes
     *            最后的结果编码
     */
    private void printFinalCodes(ArrayList<int[]> finalCodes) {
        int j = 0;

        for (int[] array : finalCodes) {
            System.out.print("个体" + (j + 1) + ":");
            for (int i = 0; i < array.length; i++) {
                System.out.print(array[i]);
            }
            System.out.println();
            j++;
        }
    }

}

算法调用类Client.java:

package GA;

/**
 * Genetic遗传算法测试类
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args){
        //变量最小值和最大值
        int minNum = 1;
        int maxNum = 7;
        //初始群体规模
        int initSetsNum = 4;
        
        GATool tool = new GATool(minNum, maxNum, initSetsNum);
        tool.geneticCal();
    }
}

算法多次测试的输出结果:

测试1:

总共遗传进化了0次
个体1:111001
个体2:101010
个体3:101110
个体4:111111

测试2:

总共遗传进化了1次
个体1:101101
个体2:111111
个体3:100111
个体4:100111

测试3:

总共遗传进化了14次
个体1:110101
个体2:111111
个体3:110101
个体4:110011

算法结果分析
可以看到,遗传进化的循环次数还是存在着不确定定性的,原因在于测试的个体数太少,如果个体数比较多的话,几轮就可以出现111111这样的个体编码组了。从结果可以看出,总的还是能够向1多的方向发展的。

说说我对遗传算法的理解

通过实现了GA算法,觉得这有点集成算法的味道,因为这其实用到了跨学科的知识,用生物进化理论的知识,去作为一个搜索最优解的解决方案,而且算法本身理解和实现也不是特别的难。
作者:Androidlushangderen 发表于2015/3/3 18:34:34 原文链接
阅读:654 评论:0 查看评论
18大经典数据挖掘算法小结
2015年2月27日 10:04

本文所有涉及到的数据挖掘代码的都放在了我的github上了。

地址链接: https://github.com/linyiqun/DataMiningAlgorithm

大概花了将近2个月的时间,自己把18大数据挖掘的经典算法进行了学习并且进行了代码实现,涉及到了决策分类,聚类,链接挖掘,关联挖掘,模式挖掘等等方面。也算是对数据挖掘领域的小小入门了吧。下面就做个小小的总结,后面都是我自己相应算法的博文链接,希望能够帮助大家学习。

1.C4.5算法。C4.5算法与ID3算法一样,都是数学分类算法,C4.5算法是ID3算法的一个改进。ID3算法采用信息增益进行决策判断,而C4.5采用的是增益率。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/42395865

2.CART算法。CART算法的全称是分类回归树算法,他是一个二元分类,采用的是类似于熵的基尼指数作为分类决策,形成决策树后之后还要进行剪枝,我自己在实现整个算法的时候采用的是代价复杂度算法,

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/42558235

3.KNN(K最近邻)算法。给定一些已经训练好的数据,输入一个新的测试数据点,计算包含于此测试数据点的最近的点的分类情况,哪个分类的类型占多数,则此测试点的分类与此相同,所以在这里,有的时候可以复制不同的分类点不同的权重。近的点的权重大点,远的点自然就小点。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/42613011

4.Naive Bayes(朴素贝叶斯)算法。朴素贝叶斯算法是贝叶斯算法里面一种比较简单的分类算法,用到了一个比较重要的贝叶斯定理,用一句简单的话概括就是条件概率的相互转换推导。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/42680161

5.SVM(支持向量机)算法。支持向量机算法是一种对线性和非线性数据进行分类的方法,非线性数据进行分类的时候可以通过核函数转为线性的情况再处理。其中的一个关键的步骤是搜索最大边缘超平面。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/42780439

6.EM(期望最大化)算法。期望最大化算法,可以拆分为2个算法,1个E-Step期望化步骤,和1个M-Step最大化步骤。他是一种算法框架,在每次计算结果之后,逼近统计模型参数的最大似然或最大后验估计。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/42921789

7.Apriori算法。Apriori算法是关联规则挖掘算法,通过连接和剪枝运算挖掘出频繁项集,然后根据频繁项集得到关联规则,关联规则的导出需要满足最小置信度的要求。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/43059211

8.FP-Tree(频繁模式树)算法。这个算法也有被称为FP-growth算法,这个算法克服了Apriori算法的产生过多侯选集的缺点,通过递归的产生频度模式树,然后对树进行挖掘,后面的过程与Apriori算法一致。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/43234309

9.PageRank(网页重要性/排名)算法。PageRank算法最早产生于Google,核心思想是通过网页的入链数作为一个网页好快的判定标准,如果1个网页内部包含了多个指向外部的链接,则PR值将会被均分,PageRank算法也会遭到Link Span攻击。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/43311943

10.HITS算法。HITS算法是另外一个链接算法,部分原理与PageRank算法是比较相似的,HITS算法引入了权威值和中心值的概念,HITS算法是受用户查询条件影响的,他一般用于小规模的数据链接分析,也更容易遭受到攻击。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/43311943

11.K-Means(K均值)算法。K-Means算法是聚类算法,k在在这里指的是分类的类型数,所以在开始设定的时候非常关键,算法的原理是首先假定k个分类点,然后根据欧式距离计算分类,然后去同分类的均值作为新的聚簇中心,循环操作直到收敛。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/43373159

12.BIRCH算法。BIRCH算法利用构建CF聚类特征树作为算法的核心,通过树的形式,BIRCH算法扫描数据库,在内存中建立一棵初始的CF-树,可以看做数据的多层压缩。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/43532111

13.AdaBoost算法。AdaBoost算法是一种提升算法,通过对数据的多次训练得到多个互补的分类器,然后组合多个分类器,构成一个更加准确的分类器。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/43635115

14.GSP算法。GSP算法是序列模式挖掘算法。GSP算法也是Apriori类算法,在算法的过程中也会进行连接和剪枝操作,不过在剪枝判断的时候还加上了一些时间上的约束等条件。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/43699083

15.PreFixSpan算法。PreFixSpan算法是另一个序列模式挖掘算法,在算法的过程中不会产生候选集,给定初始前缀模式,不断的通过后缀模式中的元素转到前缀模式中,而不断的递归挖掘下去。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/43766253

16.CBA(基于关联规则分类)算法。CBA算法是一种集成挖掘算法,因为他是建立在关联规则挖掘算法之上的,在已有的关联规则理论前提下,做分类判断,只是在算法的开始时对数据做处理,变成类似于事务的形式。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/43818787

17.RoughSets(粗糙集)算法。粗糙集理论是一个比较新颖的数据挖掘思想。这里使用的是用粗糙集进行属性约简的算法,通过上下近似集的判断删除无效的属性,进行规制的输出。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/43876001

18.gSpan算法。gSpan算法属于图挖掘算法领域。,主要用于频繁子图的挖掘,相较于其他的图算法,子图挖掘算法是他们的一个前提或基础算法。gSpan算法用到了DFS编码,和Edge五元组,最右路径子图扩展等概念,算法比较的抽象和复杂。

详细介绍链接:http://blog.csdn.net/androidlushangderen/article/details/43924273
作者:Androidlushangderen 发表于2015/2/27 10:04:01 原文链接
阅读:4367 评论:6 查看评论
gSpan频繁子图挖掘算法
2015年2月24日 9:37

参考资料:http://www.cs.ucsb.edu/~xyan/papers/gSpan.pdf
http://www.cs.ucsb.edu/~xyan/papers/gSpan-short.pdf
http://www.jos.org.cn/1000-9825/18/2469.pdf

http://blog.csdn.net/coolypf/article/details/8263176
更多挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm
介绍

gSpan算法是图挖掘邻域的一个算法,而作为子图挖掘算法,又是其他图挖掘算法的基础,所以gSpan算法在图挖掘算法中还是非常重要的。gSpan算法在挖掘频繁子图的时候,用了和FP-grown中相似的原理,就是Pattern-Grown模式增长的方式,也用到了最小支持度计数作为一个过滤条件。图算法在程序上比其他的算法更加的抽象,在实现时更加需要空间想象能力。gSpan算法的核心就是给定n个图,然后从中挖掘出频繁出现的子图部分。
算法原理

说实话,gSpan算法在我最近学习的算法之中属于非常难的那种,因为要想实现他,必须要明白他的原理,而这就要花很多时间去明白算法的一些定义,比如dfs编码,最右路径这样的概念。所以,我们应该先知道算法整体的一个结构。

1、遍历所有的图,计算出所有的边和点的频度。

2、将频度与最小支持度数做比较,移除不频繁的边和点。

3、重新将剩下的点和边按照频度进行排序,将他们的排名号给边和点进行重新标号。

4、再次计算每条边的频度,计算完后,然后初始化每条边,并且进行此边的subMining()挖掘过程。
subMining的过程

1、根据graphCode重新恢复当前的子图

2、判断当前的编码是否为最小dfs编码,如果是加入到结果集中,继续在此基础上尝试添加可能的边,进行继续挖掘

3、如果不是最小编码,则此子图的挖掘过程结束。
DFS编码

gSpan算法对图的边进行编码,采用E(v0,v1,A,B,a)的方式,v0,v1代表的标识,你可以看做就是点的id,A,B可以作为点的标号,a为之间的边的标号,而一个图就是由这样的边构成的,G{e1, e2, e3,.....},而dfs编码的方式就是比里面的五元组的元素,我这里采用的规则是,从左往右依次比较大小,如果谁先小于另一方,谁就算小,图的比较算法同样如此,具体的规则可以见我后面代码中的注释。但是这个规则并不是完全一致的,至少在我看的相关论文中有不一样的描述存在。
生成subGraph

生成子图的进行下一次挖掘的过程也是gSpan算法中的一个难点,首先你要对原图进行编码,找到与挖掘子图一致的编码,找到之后,在图的最右路径上寻找可以扩展的边,在最右路径上扩展的情况分为2种,1种为在最右节点上进行扩展,1种为在最右路径的点上进行扩展。2种情况都需要做一定的判断。
算法的技巧

算法在实现时,用的技巧比较多,有些也很不好理解,比如在dfs编码或找子边的过程中,用到了图id对于Edge中的五元组id的映射,这个会一开始没想到,还有怎么去描述一个图通过一定的数据结构。
算法的实现

此算法是借鉴了网上其他版本的实现,我是在看懂了人家代码的基础上,自己对其中的某些部分作了修改之后的。由于代码比较多,下面给出核心代码,全部代码在这里。

GSpanTool.java:

package DataMining_GSpan;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * gSpan频繁子图挖掘算法工具类
 *
 * @author lyq
 *
 */
public class GSpanTool {
    // 文件数据类型
    public final String INPUT_NEW_GRAPH = "t";
    public final String INPUT_VERTICE = "v";
    public final String INPUT_EDGE = "e";
    // Label标号的最大数量,包括点标号和边标号
    public final int LABEL_MAX = 100;

    // 测试数据文件地址
    private String filePath;
    // 最小支持度率
    private double minSupportRate;
    // 最小支持度数,通过图总数与最小支持度率的乘积计算所得
    private int minSupportCount;
    // 初始所有图的数据
    private ArrayList<GraphData> totalGraphDatas;
    // 所有的图结构数据
    private ArrayList<Graph> totalGraphs;
    // 挖掘出的频繁子图
    private ArrayList<Graph> resultGraphs;
    // 边的频度统计
    private EdgeFrequency ef;
    // 节点的频度
    private int[] freqNodeLabel;
    // 边的频度
    private int[] freqEdgeLabel;
    // 重新标号之后的点的标号数
    private int newNodeLabelNum = 0;
    // 重新标号后的边的标号数
    private int newEdgeLabelNum = 0;

    public GSpanTool(String filePath, double minSupportRate) {
        this.filePath = filePath;
        this.minSupportRate = minSupportRate;
        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        calFrequentAndRemove(dataArray);
    }

    /**
     * 统计边和点的频度,并移除不频繁的点边,以标号作为统计的变量
     *
     * @param dataArray
     *            原始数据
     */
    private void calFrequentAndRemove(ArrayList<String[]> dataArray) {
        int tempCount = 0;
        freqNodeLabel = new int[LABEL_MAX];
        freqEdgeLabel = new int[LABEL_MAX];

        // 做初始化操作
        for (int i = 0; i < LABEL_MAX; i++) {
            // 代表标号为i的节点目前的数量为0
            freqNodeLabel[i] = 0;
            freqEdgeLabel[i] = 0;
        }

        GraphData gd = null;
        totalGraphDatas = new ArrayList<>();
        for (String[] array : dataArray) {
            if (array[0].equals(INPUT_NEW_GRAPH)) {
                if (gd != null) {
                    totalGraphDatas.add(gd);
                }

                // 新建图
                gd = new GraphData();
            } else if (array[0].equals(INPUT_VERTICE)) {
                // 每个图中的每种图只统计一次
                if (!gd.getNodeLabels().contains(Integer.parseInt(array[2]))) {
                    tempCount = freqNodeLabel[Integer.parseInt(array[2])];
                    tempCount++;
                    freqNodeLabel[Integer.parseInt(array[2])] = tempCount;
                }

                gd.getNodeLabels().add(Integer.parseInt(array[2]));
                gd.getNodeVisibles().add(true);
            } else if (array[0].equals(INPUT_EDGE)) {
                // 每个图中的每种图只统计一次
                if (!gd.getEdgeLabels().contains(Integer.parseInt(array[3]))) {
                    tempCount = freqEdgeLabel[Integer.parseInt(array[3])];
                    tempCount++;
                    freqEdgeLabel[Integer.parseInt(array[3])] = tempCount;
                }

                int i = Integer.parseInt(array[1]);
                int j = Integer.parseInt(array[2]);

                gd.getEdgeLabels().add(Integer.parseInt(array[3]));
                gd.getEdgeX().add(i);
                gd.getEdgeY().add(j);
                gd.getEdgeVisibles().add(true);
            }
        }
        // 把最后一块gd数据加入
        totalGraphDatas.add(gd);
        minSupportCount = (int) (minSupportRate * totalGraphDatas.size());

        for (GraphData g : totalGraphDatas) {
            g.removeInFreqNodeAndEdge(freqNodeLabel, freqEdgeLabel,
                    minSupportCount);
        }
    }

    /**
     * 根据标号频繁度进行排序并且重新标号
     */
    private void sortAndReLabel() {
        int label1 = 0;
        int label2 = 0;
        int temp = 0;
        // 点排序名次
        int[] rankNodeLabels = new int[LABEL_MAX];
        // 边排序名次
        int[] rankEdgeLabels = new int[LABEL_MAX];
        // 标号对应排名
        int[] nodeLabel2Rank = new int[LABEL_MAX];
        int[] edgeLabel2Rank = new int[LABEL_MAX];

        for (int i = 0; i < LABEL_MAX; i++) {
            // 表示排名第i位的标号为i,[i]中的i表示排名
            rankNodeLabels[i] = i;
            rankEdgeLabels[i] = i;
        }

        for (int i = 0; i < freqNodeLabel.length - 1; i++) {
            int k = 0;
            label1 = rankNodeLabels[i];
            temp = label1;
            for (int j = i + 1; j < freqNodeLabel.length; j++) {
                label2 = rankNodeLabels[j];

                if (freqNodeLabel[temp] < freqNodeLabel[label2]) {
                    // 进行标号的互换
                    temp = label2;
                    k = j;
                }
            }

            if (temp != label1) {
                // 进行i,k排名下的标号对调
                temp = rankNodeLabels[k];
                rankNodeLabels[k] = rankNodeLabels[i];
                rankNodeLabels[i] = temp;
            }
        }

        // 对边同样进行排序
        for (int i = 0; i < freqEdgeLabel.length - 1; i++) {
            int k = 0;
            label1 = rankEdgeLabels[i];
            temp = label1;
            for (int j = i + 1; j < freqEdgeLabel.length; j++) {
                label2 = rankEdgeLabels[j];

                if (freqEdgeLabel[temp] < freqEdgeLabel[label2]) {
                    // 进行标号的互换
                    temp = label2;
                    k = j;
                }
            }

            if (temp != label1) {
                // 进行i,k排名下的标号对调
                temp = rankEdgeLabels[k];
                rankEdgeLabels[k] = rankEdgeLabels[i];
                rankEdgeLabels[i] = temp;
            }
        }

        // 将排名对标号转为标号对排名
        for (int i = 0; i < rankNodeLabels.length; i++) {
            nodeLabel2Rank[rankNodeLabels[i]] = i;
        }

        for (int i = 0; i < rankEdgeLabels.length; i++) {
            edgeLabel2Rank[rankEdgeLabels[i]] = i;
        }

        for (GraphData gd : totalGraphDatas) {
            gd.reLabelByRank(nodeLabel2Rank, edgeLabel2Rank);
        }

        // 根据排名找出小于支持度值的最大排名值
        for (int i = 0; i < rankNodeLabels.length; i++) {
            if (freqNodeLabel[rankNodeLabels[i]] > minSupportCount) {
                newNodeLabelNum = i;
            }
        }
        for (int i = 0; i < rankEdgeLabels.length; i++) {
            if (freqEdgeLabel[rankEdgeLabels[i]] > minSupportCount) {
                newEdgeLabelNum = i;
            }
        }
        //排名号比数量少1,所以要加回来
        newNodeLabelNum++;
        newEdgeLabelNum++;
    }

    /**
     * 进行频繁子图的挖掘
     */
    public void freqGraphMining() {
        long startTime =  System.currentTimeMillis();
        long endTime = 0;
        Graph g;
        sortAndReLabel();

        resultGraphs = new ArrayList<>();
        totalGraphs = new ArrayList<>();
        // 通过图数据构造图结构
        for (GraphData gd : totalGraphDatas) {
            g = new Graph();
            g = g.constructGraph(gd);
            totalGraphs.add(g);
        }

        // 根据新的点边的标号数初始化边频繁度对象
        ef = new EdgeFrequency(newNodeLabelNum, newEdgeLabelNum);
        for (int i = 0; i < newNodeLabelNum; i++) {
            for (int j = 0; j < newEdgeLabelNum; j++) {
                for (int k = 0; k < newNodeLabelNum; k++) {
                    for (Graph tempG : totalGraphs) {
                        if (tempG.hasEdge(i, j, k)) {
                            ef.edgeFreqCount[i][j][k]++;
                        }
                    }
                }
            }
        }

        Edge edge;
        GraphCode gc;
        for (int i = 0; i < newNodeLabelNum; i++) {
            for (int j = 0; j < newEdgeLabelNum; j++) {
                for (int k = 0; k < newNodeLabelNum; k++) {
                    if (ef.edgeFreqCount[i][j][k] >= minSupportCount) {
                        gc = new GraphCode();
                        edge = new Edge(0, 1, i, j, k);
                        gc.getEdgeSeq().add(edge);

                        // 将含有此边的图id加入到gc中
                        for (int y = 0; y < totalGraphs.size(); y++) {
                            if (totalGraphs.get(y).hasEdge(i, j, k)) {
                                gc.getGs().add(y);
                            }
                        }
                        // 对某条满足阈值的边进行挖掘
                        subMining(gc, 2);
                    }
                }
            }
        }
        
        endTime = System.currentTimeMillis();
        System.out.println("算法执行时间"+ (endTime-startTime) + "ms");
        printResultGraphInfo();
    }

    /**
     * 进行频繁子图的挖掘
     *
     * @param gc
     *            图编码
     * @param next
     *            图所含的点的个数
     */
    public void subMining(GraphCode gc, int next) {
        Edge e;
        Graph graph = new Graph();
        int id1;
        int id2;

        for(int i=0; i<next; i++){
            graph.nodeLabels.add(-1);
            graph.edgeLabels.add(new ArrayList<Integer>());
            graph.edgeNexts.add(new ArrayList<Integer>());
        }

        // 首先根据图编码中的边五元组构造图
        for (int i = 0; i < gc.getEdgeSeq().size(); i++) {
            e = gc.getEdgeSeq().get(i);
            id1 = e.ix;
            id2 = e.iy;

            graph.nodeLabels.set(id1, e.x);
            graph.nodeLabels.set(id2, e.y);
            graph.edgeLabels.get(id1).add(e.a);
            graph.edgeLabels.get(id2).add(e.a);
            graph.edgeNexts.get(id1).add(id2);
            graph.edgeNexts.get(id2).add(id1);
        }

        DFSCodeTraveler dTraveler = new DFSCodeTraveler(gc.getEdgeSeq(), graph);
        dTraveler.traveler();
        if (!dTraveler.isMin) {
            return;
        }

        // 如果当前是最小编码则将此图加入到结果集中
        resultGraphs.add(graph);
        Edge e1;
        ArrayList<Integer> gIds;
        SubChildTraveler sct;
        ArrayList<Edge> edgeArray;
        // 添加潜在的孩子边,每条孩子边所属的图id
        HashMap<Edge, ArrayList<Integer>> edge2GId = new HashMap<>();
        for (int i = 0; i < gc.gs.size(); i++) {
            int id = gc.gs.get(i);

            // 在此结构的条件下,在多加一条边构成子图继续挖掘
            sct = new SubChildTraveler(gc.edgeSeq, totalGraphs.get(id));
            sct.traveler();
            edgeArray = sct.getResultChildEdge();

            // 做边id的更新
            for (Edge e2 : edgeArray) {
                if (!edge2GId.containsKey(e2)) {
                    gIds = new ArrayList<>();
                } else {
                    gIds = edge2GId.get(e2);
                }

                gIds.add(id);
                edge2GId.put(e2, gIds);
            }
        }

        for (Map.Entry entry : edge2GId.entrySet()) {
            e1 = (Edge) entry.getKey();
            gIds = (ArrayList<Integer>) entry.getValue();

            // 如果此边的频度大于最小支持度值,则继续挖掘
            if (gIds.size() < minSupportCount) {
                continue;
            }

            GraphCode nGc = new GraphCode();
            nGc.edgeSeq.addAll(gc.edgeSeq);
            // 在当前图中新加入一条边,构成新的子图进行挖掘
            nGc.edgeSeq.add(e1);
            nGc.gs.addAll(gIds);

            if (e1.iy == next) {
                // 如果边的点id设置是为当前最大值的时候,则开始寻找下一个点
                subMining(nGc, next + 1);
            } else {
                // 如果此点已经存在,则next值不变
                subMining(nGc, next);
            }
        }
    }
    
    /**
     * 输出频繁子图结果信息
     */
    public void printResultGraphInfo(){
        System.out.println(MessageFormat.format("挖掘出的频繁子图的个数为:{0}个", resultGraphs.size()));
    }

}

这个算法在后来的实现时,渐渐的发现此算法的难度大大超出我预先的设想,不仅仅是其中的抽象性,还在于测试的复杂性,对于测试数据的捏造,如果用的是真实数据测的话,数据量太大,自己造数据拿捏的也不是很准确。我最后也只是自己伪造了一个图的数据,挖掘了其中的一条边的情况。大致的走了一个过程。代码并不算是完整的,仅供学习。

算法的缺点

在后来实现完算法之后,我对于其中的小的过程进行了分析,发现这个算法在2个深度优先遍历的过程中还存在问题,就是DFS判断是否最小编码和对原图进行寻找相应编码,的时候,都只是限于Edge中边是连续的情况,如果不连续了,会出现判断出错的情况,因为在最右路径上添加边,就是会出现在前面的点中多扩展一条边,就不会是连续的。而在上面的代码中是无法处理这样的情况的,个人的解决办法是用栈的方式,将节点压入栈中实现最好。
算法的体会

这个算法花了很多的时间,关关理解这个算法就已经不容易了,经常需要我在脑海中去刻画这样的图形和遍历的一些情况,带给我的挑战还是非常的大吧。
算法的特点

此算法与FP-Tree算法类似,在挖掘的过程中也是没有产生候选集的,采用深度优先的挖掘方式,一步一步进行挖掘。gSpan算法可以进行对于化学分子的结构挖掘。

作者:Androidlushangderen 发表于2015/2/24 9:37:11 原文链接
阅读:2688 评论:20 查看评论
RoughSets属性约简算法
2015年2月18日 9:24
参考资料:http://baike.baidu.com/link?url=vlCBGoGR0_97l9SQ-WNeRv7oWb-3j7c6oUnyMzQAU3PTo0fx0O5MVXxckgqUlP871xR2Le-puGfFcrA4-zIntq
更多挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm
介绍

RoughSets算法是一种比较新颖的算法,粗糙集理论对于数据的挖掘方面提供了一个新的概念和研究方法。本篇文章我不会去介绍令人厌烦的学术概念,就是简单的聊聊RoughSets算法的作用,直观上做一个了解。此算法的应用场景是,面对一个庞大的数据库系统,如何从里面分析出有效的信息,如果一database中有几十个字段,有我们好受的了,但是一般的在某些情况下有些信息在某些情况下是无用的或者说是无效的,这时候我们假设在不影响最终决策分类结果的情况下,对此属性进行约简。这就是RoughSets所干的事情了。
算法原理

算法的原理其实很简单,所有属性分为2种属性1类为条件属性,1类为决策属性,我们姑且把决策属性设置在数据列的最后一列,算法的步骤依次判断条件属性是否能被约简,如果能被约简,此输出约简属性后的规则,规则的形式大体类似于IF---THEN的规则。下面举1个例子,此例子来自于百度百科上的粗糙集理论。

给定8条记录:

元素 颜色 形状 大小 稳定性
x1 红 三角 大 稳定
x2 红 三角 大 稳定
x3 黄 圆 小 不稳定
x4 黄 圆 小 不稳定
x5 蓝 方块 大 稳定
x6 红 圆 中 不稳定
x7 蓝 圆 小 不稳定
x8 蓝 方块 中 不稳定

在这里还是得介绍几个最基本的一些概念,这里的所有的记录的集合叫做论域,那么这个论域能表达出一些什么知识或者信息呢,比如说蓝色的或者中的积木={X5,X7,X8}U{X6,X8}={X5,X6,X7,X8},同理,通过论域集合内的记录进行交并运算能够表达出不同的信息。在这里总共有3个属性,就可以分成3x3=9个小属性分类,如下:
A/R1={X1,X2,X3}={{x1,x2,x6},{x3,x4},{x5,x7,x8}} (颜色分类)
A/R2={Y1,Y2,Y3}={{x1,x2},{x5,x8},{x3,x4,x6,x7}} (形状分类)
A/R3={Z1,Z2,Z3}={{x1,x2,x5},{x6,x8},{x3,x4,x7}} (大小分类)
我们定义一个知识系统A/R=R1∩R2∩R3,就是3x3x3总共27种可能,每行各取1个做计算组后的结果为

A/R={{x1,x2},{x3,x4},{x5},{x6},{x7},{x8}},所以这个知识系统所决定的知识就是A/R中所有的集合以此这些集合的并集。给定一个集合如何用知识系统中的集合进行表示呢,这就用到了又一对概念,上近似和下近似。比如说给定集合X={X2,X5X7},在知识库中就是下近似{X2.X5},上近似{X1,X2,X5,X7},上下近似的完整定义是下近似集是在那些所有的包含于X的知识库中的集合中求交得到的,而上近似则是将那些包含X的知识库中的集合求并得到的。在后面的例子中我也是以一个集合的上下近似集是否是等于他自身来对知识系统是否是允许的做一个判断。(这只是我自己的判断原则,并不是标准的)

下面是属性约简的过程,从颜色开始,这时知识系统变为了那么知识系统变成A/(R-R1)={{x1,x2},{x3,x4,x7},,,}以及这些子集的并集,此时稳定的集合{X1,X2,X5}的集合上下近似集还是他本身,所有没有改变,说明此属性是可以约简的,然后再此基础上在约简,直到上下近似集的改变。依次3种属性进行遍历。最后得到规则,我们以约简颜色属性为例,我们可以得出的规则是大三角的稳定,圆小的不稳定等等。大体原理就是如此,也许从某些方面来说还有欠妥的地方。
算法的代码实现

同样以上面的数据未例子,不过我把他转成了英文的形式,避免中文的编码问题:

Element Color Shape Size Stability
x1 Red Triangle Large Stable
x2 Red Triangle Large Stable
x3 Yellow Circle Small UnStable
x4 Yellow Circle Small UnStable
x5 Blue Rectangle Large Stable
x6 Red Circle Middle UnStable
x7 Blue Circle Small UnStable
x8 Blue Rectangle Middle UnStable

程序写的会有些复杂,里面很多都是集合的交并运算,之所以不采用直接的数组的运算,是为了更加突出集合的概念。
Record.java:

package DataMining_RoughSets;

import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * 数据记录,包含这条记录所有属性
 *
 * @author lyq
 *
 */
public class Record {
    // 记录名称
    private String name;
    // 记录属性键值对
    private HashMap<String, String> attrValues;

    public Record(String name, HashMap<String, String> attrValues) {
        this.name = name;
        this.attrValues = attrValues;
    }

    public String getName() {
        return this.name;
    }

    /**
     * 此数据是否包含此属性值
     *
     * @param attr
     *            待判断属性值
     * @return
     */
    public boolean isContainedAttr(String attr) {
        boolean isContained = false;

        if (attrValues.containsValue(attr)) {
            isContained = true;
        }

        return isContained;
    }

    /**
     * 判断数据记录是否是同一条记录,根据数据名称来判断
     *
     * @param record
     *            目标比较对象
     * @return
     */
    public boolean isRecordSame(Record record) {
        boolean isSame = false;

        if (this.name.equals(record.name)) {
            isSame = true;
        }

        return isSame;
    }

    /**
     * 数据的决策属性分类
     *
     * @return
     */
    public String getRecordDecisionClass() {
        String value = null;

        value = attrValues.get(RoughSetsTool.DECISION_ATTR_NAME);

        return value;
    }

    /**
     * 根据约简属性输出决策规则
     *
     * @param reductAttr
     *            约简属性集合
     */
    public String getDecisionRule(ArrayList<String> reductAttr) {
        String ruleStr = "";
        String attrName = null;
        String value = null;
        String decisionValue;

        decisionValue = attrValues.get(RoughSetsTool.DECISION_ATTR_NAME);
        ruleStr += "属性";
        for (Map.Entry entry : this.attrValues.entrySet()) {
            attrName = (String) entry.getKey();
            value = (String) entry.getValue();

            if (attrName.equals(RoughSetsTool.DECISION_ATTR_NAME)
                    || reductAttr.contains(attrName) || value.equals(name)) {
                continue;
            }

            ruleStr += MessageFormat.format("{0}={1},", attrName, value);
        }
        ruleStr += "他的分类为" + decisionValue;
        
        return ruleStr;
    }
}

RecordCollection.java:

package DataMining_RoughSets;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * 数据记录集合,包含一些共同的属性
 *
 * @author lyq
 *
 */
public class RecordCollection {
    // 集合包含的属性
    private HashMap<String, String> attrValues;
    // 数据记录列表
    private ArrayList<Record> recordList;

    public RecordCollection() {
        this.attrValues = new HashMap<>();
        this.recordList = new ArrayList<>();
    }

    public RecordCollection(HashMap<String, String> attrValues,
            ArrayList<Record> recordList) {
        this.attrValues = attrValues;
        this.recordList = recordList;
    }

    public ArrayList<Record> getRecord() {
        return this.recordList;
    }

    /**
     * 返回集合的字符名称数组
     *
     * @return
     */
    public ArrayList<String> getRecordNames() {
        ArrayList<String> names = new ArrayList<>();

        for (int i = 0; i < recordList.size(); i++) {
            names.add(recordList.get(i).getName());
        }

        return names;
    }

    /**
     * 判断集合是否包含此属性名称对应的属性值
     *
     * @param attrName
     *            属性名
     * @return
     */
    public boolean isContainedAttrName(String attrName) {
        boolean isContained = false;

        if (this.attrValues.containsKey(attrName)) {
            isContained = true;
        }

        return isContained;
    }

    /**
     * 判断2个集合是否相等,比较包含的数据记录是否完全一致
     *
     * @param rc
     *            待比较集合
     * @return
     */
    public boolean isCollectionSame(RecordCollection rc) {
        boolean isSame = false;

        for (Record r : recordList) {
            isSame = false;

            for (Record r2 : rc.recordList) {
                if (r.isRecordSame(r2)) {
                    isSame = true;
                    break;
                }
            }

            // 如果有1个记录不包含,就算集合不相等
            if (!isSame) {
                break;
            }
        }

        return isSame;
    }

    /**
     * 集合之间的交运算
     *
     * @param rc
     *            交运算的参与运算的另外一集合
     * @return
     */
    public RecordCollection overlapCalculate(RecordCollection rc) {
        String key;
        String value;
        RecordCollection resultCollection = null;
        HashMap<String, String> resultAttrValues = new HashMap<>();
        ArrayList<Record> resultRecords = new ArrayList<>();

        // 进行集合的交运算,有相同的记录的则进行添加
        for (Record record : this.recordList) {
            for (Record record2 : rc.recordList) {
                if (record.isRecordSame(record2)) {
                    resultRecords.add(record);
                    break;
                }
            }
        }

        // 如果没有交集,则直接返回
        if (resultRecords.size() == 0) {
            return null;
        }

        // 将2个集合的属性进行合并
        for (Map.Entry entry : this.attrValues.entrySet()) {
            key = (String) entry.getKey();
            value = (String) entry.getValue();

            resultAttrValues.put(key, value);
        }

        for (Map.Entry entry : rc.attrValues.entrySet()) {
            key = (String) entry.getKey();
            value = (String) entry.getValue();

            resultAttrValues.put(key, value);
        }

        resultCollection = new RecordCollection(resultAttrValues, resultRecords);
        return resultCollection;
    }

    /**
     * 求集合的并集,各自保留各自的属性
     *
     * @param rc
     *            待合并的集合
     * @return
     */
    public RecordCollection unionCal(RecordCollection rc) {
        RecordCollection resultRc = null;
        ArrayList<Record> records = new ArrayList<>();

        for (Record r1 : this.recordList) {
            records.add(r1);
        }

        for (Record r2 : rc.recordList) {
            records.add(r2);
        }

        resultRc = new RecordCollection(null, records);
        return resultRc;
    }
    
    /**
     * 输出集合中包含的元素
     */
    public void printRc(){
        System.out.print("{");
        for (Record r : this.getRecord()) {
            System.out.print(r.getName() + ", ");
        }
        System.out.println("}");
    }
}

KnowledgeSystem.java:

package DataMining_RoughSets;

import java.util.ArrayList;
import java.util.HashMap;

/**
 * 知识系统
 *
 * @author lyq
 *
 */
public class KnowledgeSystem {
    // 知识系统内的集合
    ArrayList<RecordCollection> ksCollections;

    public KnowledgeSystem(ArrayList<RecordCollection> ksCollections) {
        this.ksCollections = ksCollections;
    }

    /**
     * 获取集合的上近似集合
     *
     * @param rc
     *            原始集合
     * @return
     */
    public RecordCollection getUpSimilarRC(RecordCollection rc) {
        RecordCollection resultRc = null;
        ArrayList<String> nameArray;
        ArrayList<String> targetArray;
        ArrayList<RecordCollection> copyRcs = new ArrayList<>();
        ArrayList<RecordCollection> deleteRcs = new ArrayList<>();
        targetArray = rc.getRecordNames();

        // 做一个集合拷贝
        for (RecordCollection recordCollection : ksCollections) {
            copyRcs.add(recordCollection);
        }

        for (RecordCollection recordCollection : copyRcs) {
            nameArray = recordCollection.getRecordNames();

            if (strIsContained(targetArray, nameArray)) {
                removeOverLaped(targetArray, nameArray);
                deleteRcs.add(recordCollection);

                if (resultRc == null) {
                    resultRc = recordCollection;
                } else {
                    // 进行并运算
                    resultRc = resultRc.unionCal(recordCollection);
                }

                if (targetArray.size() == 0) {
                    break;
                }
            }
        }
        //去除已经添加过的集合
        copyRcs.removeAll(deleteRcs);

        if (targetArray.size() > 0) {
            // 说明已经完全还未找全上近似的集合
            for (RecordCollection recordCollection : copyRcs) {
                nameArray = recordCollection.getRecordNames();

                if (strHasOverlap(targetArray, nameArray)) {
                    removeOverLaped(targetArray, nameArray);

                    if (resultRc == null) {
                        resultRc = recordCollection;
                    } else {
                        // 进行并运算
                        resultRc = resultRc.unionCal(recordCollection);
                    }

                    if (targetArray.size() == 0) {
                        break;
                    }
                }
            }
        }

        return resultRc;
    }

    /**
     * 获取集合的下近似集合
     *
     * @param rc
     *            原始集合
     * @return
     */
    public RecordCollection getDownSimilarRC(RecordCollection rc) {
        RecordCollection resultRc = null;
        ArrayList<String> nameArray;
        ArrayList<String> targetArray;
        targetArray = rc.getRecordNames();

        for (RecordCollection recordCollection : ksCollections) {
            nameArray = recordCollection.getRecordNames();

            if (strIsContained(targetArray, nameArray)) {
                removeOverLaped(targetArray, nameArray);

                if (resultRc == null) {
                    resultRc = recordCollection;
                } else {
                    // 进行并运算
                    resultRc = resultRc.unionCal(recordCollection);
                }

                if (targetArray.size() == 0) {
                    break;
                }
            }
        }

        return resultRc;
    }

    /**
     * 判断2个字符数组之间是否有交集
     *
     * @param str1
     *            字符列表1
     * @param str2
     *            字符列表2
     * @return
     */
    public boolean strHasOverlap(ArrayList<String> str1, ArrayList<String> str2) {
        boolean hasOverlap = false;

        for (String s1 : str1) {
            for (String s2 : str2) {
                if (s1.equals(s2)) {
                    hasOverlap = true;
                    break;
                }
            }

            if (hasOverlap) {
                break;
            }
        }

        return hasOverlap;
    }

    /**
     * 判断字符集str2是否完全包含于str1中
     *
     * @param str1
     * @param str2
     * @return
     */
    public boolean strIsContained(ArrayList<String> str1, ArrayList<String> str2) {
        boolean isContained = false;
        int count = 0;

        for (String s : str2) {
            if (str1.contains(s)) {
                count++;
            }
        }

        if (count == str2.size()) {
            isContained = true;
        }

        return isContained;
    }

    /**
     * 字符列表移除公共元素
     *
     * @param str1
     * @param str2
     */
    public void removeOverLaped(ArrayList<String> str1, ArrayList<String> str2) {
        ArrayList<String> deleteStrs = new ArrayList<>();

        for (String s1 : str1) {
            for (String s2 : str2) {
                if (s1.equals(s2)) {
                    deleteStrs.add(s1);
                    break;
                }
            }
        }

        // 进行公共元素的移除
        str1.removeAll(deleteStrs);
    }
}

RoughSetsTool.java:

package DataMining_RoughSets;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * 粗糙集属性约简算法工具类
 *
 * @author lyq
 *
 */
public class RoughSetsTool {
    // 决策属性名称
    public static String DECISION_ATTR_NAME;

    // 测试数据文件地址
    private String filePath;
    // 数据属性列名称
    private String[] attrNames;
    // 所有的数据
    private ArrayList<String[]> totalDatas;
    // 所有的数据记录,与上面的区别是记录的属性是可约简的,原始数据是不能变的
    private ArrayList<Record> totalRecords;
    // 条件属性图
    private HashMap<String, ArrayList<String>> conditionAttr;
    // 属性记录集合
    private ArrayList<RecordCollection> collectionList;

    public RoughSetsTool(String filePath) {
        this.filePath = filePath;
        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        String[] array;
        Record tempRecord;
        HashMap<String, String> attrMap;
        ArrayList<String> attrList;
        totalDatas = new ArrayList<>();
        totalRecords = new ArrayList<>();
        conditionAttr = new HashMap<>();
        // 赋值属性名称行
        attrNames = dataArray.get(0);
        DECISION_ATTR_NAME = attrNames[attrNames.length - 1];
        for (int j = 0; j < dataArray.size(); j++) {
            array = dataArray.get(j);
            totalDatas.add(array);
            if (j == 0) {
                // 过滤掉第一行列名称数据
                continue;
            }

            attrMap = new HashMap<>();
            for (int i = 0; i < attrNames.length; i++) {
                attrMap.put(attrNames[i], array[i]);

                // 寻找条件属性
                if (i > 0 && i < attrNames.length - 1) {
                    if (conditionAttr.containsKey(attrNames[i])) {
                        attrList = conditionAttr.get(attrNames[i]);
                        if (!attrList.contains(array[i])) {
                            attrList.add(array[i]);
                        }
                    } else {
                        attrList = new ArrayList<>();
                        attrList.add(array[i]);
                    }
                    conditionAttr.put(attrNames[i], attrList);
                }
            }
            tempRecord = new Record(array[0], attrMap);
            totalRecords.add(tempRecord);
        }
    }

    /**
     * 将数据记录根据属性分割到集合中
     */
    private void recordSpiltToCollection() {
        String attrName;
        ArrayList<String> attrList;
        ArrayList<Record> recordList;
        HashMap<String, String> collectionAttrValues;
        RecordCollection collection;
        collectionList = new ArrayList<>();

        for (Map.Entry entry : conditionAttr.entrySet()) {
            attrName = (String) entry.getKey();
            attrList = (ArrayList<String>) entry.getValue();

            for (String s : attrList) {
                recordList = new ArrayList<>();
                // 寻找属性为s的数据记录分入到集合中
                for (Record record : totalRecords) {
                    if (record.isContainedAttr(s)) {
                        recordList.add(record);
                    }
                }
                collectionAttrValues = new HashMap<>();
                collectionAttrValues.put(attrName, s);
                collection = new RecordCollection(collectionAttrValues,
                        recordList);

                collectionList.add(collection);
            }
        }
    }

    /**
     * 构造属性集合图
     *
     * @param reductAttr
     *            需要约简的属性
     * @return
     */
    private HashMap<String, ArrayList<RecordCollection>> constructCollectionMap(
            ArrayList<String> reductAttr) {
        String currentAtttrName;
        ArrayList<RecordCollection> cList;
        // 集合属性对应图
        HashMap<String, ArrayList<RecordCollection>> collectionMap = new HashMap<>();

        // 截取出条件属性部分
        for (int i = 1; i < attrNames.length - 1; i++) {
            currentAtttrName = attrNames[i];

            // 判断此属性列是否需要约简
            if (reductAttr != null && reductAttr.contains(currentAtttrName)) {
                continue;
            }

            cList = new ArrayList<>();

            for (RecordCollection c : collectionList) {
                if (c.isContainedAttrName(currentAtttrName)) {
                    cList.add(c);
                }
            }

            collectionMap.put(currentAtttrName, cList);
        }

        return collectionMap;
    }

    /**
     * 根据已有的分裂集合计算知识系统
     */
    private ArrayList<RecordCollection> computeKnowledgeSystem(
            HashMap<String, ArrayList<RecordCollection>> collectionMap) {
        String attrName = null;
        ArrayList<RecordCollection> cList = null;
        // 知识系统
        ArrayList<RecordCollection> ksCollections;

        ksCollections = new ArrayList<>();

        // 取出1项
        for (Map.Entry entry : collectionMap.entrySet()) {
            attrName = (String) entry.getKey();
            cList = (ArrayList<RecordCollection>) entry.getValue();
            break;
        }
        collectionMap.remove(attrName);

        for (RecordCollection rc : cList) {
            recurrenceComputeKS(ksCollections, collectionMap, rc);
        }

        return ksCollections;
    }

    /**
     * 递归计算所有的知识系统,通过计算所有集合的交集
     *
     * @param ksCollection
     *            已经求得知识系统的集合
     * @param map
     *            还未曾进行过交运算的集合
     * @param preCollection
     *            前个步骤中已经通过交运算计算出的集合
     */
    private void recurrenceComputeKS(ArrayList<RecordCollection> ksCollections,
            HashMap<String, ArrayList<RecordCollection>> map,
            RecordCollection preCollection) {
        String attrName = null;
        RecordCollection tempCollection;
        ArrayList<RecordCollection> cList = null;
        HashMap<String, ArrayList<RecordCollection>> mapCopy = new HashMap<>();
        
        //如果已经没有数据了,则直接添加
        if(map.size() == 0){
            ksCollections.add(preCollection);
            return;
        }

        for (Map.Entry entry : map.entrySet()) {
            cList = (ArrayList<RecordCollection>) entry.getValue();
            mapCopy.put((String) entry.getKey(), cList);
        }

        // 取出1项
        for (Map.Entry entry : map.entrySet()) {
            attrName = (String) entry.getKey();
            cList = (ArrayList<RecordCollection>) entry.getValue();
            break;
        }

        mapCopy.remove(attrName);
        for (RecordCollection rc : cList) {
            // 挑选此属性的一个集合进行交运算,然后再次递归
            tempCollection = preCollection.overlapCalculate(rc);

            if (tempCollection == null) {
                continue;
            }

            // 如果map中已经没有数据了,说明递归到头了
            if (mapCopy.size() == 0) {
                ksCollections.add(tempCollection);
            } else {
                recurrenceComputeKS(ksCollections, mapCopy, tempCollection);
            }
        }
    }

    /**
     * 进行粗糙集属性约简算法
     */
    public void findingReduct() {
        RecordCollection[] sameClassRcs;
        KnowledgeSystem ks;
        ArrayList<RecordCollection> ksCollections;
        // 待约简的属性
        ArrayList<String> reductAttr = null;
        ArrayList<String> attrNameList;
        // 最终可约简的属性组
        ArrayList<ArrayList<String>> canReductAttrs;
        HashMap<String, ArrayList<RecordCollection>> collectionMap;

        sameClassRcs = selectTheSameClassRC();
        // 这里讲数据按照各个分类的小属性划分了9个集合
        recordSpiltToCollection();

        collectionMap = constructCollectionMap(reductAttr);
        ksCollections = computeKnowledgeSystem(collectionMap);
        ks = new KnowledgeSystem(ksCollections);
        System.out.println("原始集合分类的上下近似集合");
        ks.getDownSimilarRC(sameClassRcs[0]).printRc();
        ks.getUpSimilarRC(sameClassRcs[0]).printRc();
        ks.getDownSimilarRC(sameClassRcs[1]).printRc();
        ks.getUpSimilarRC(sameClassRcs[1]).printRc();

        attrNameList = new ArrayList<>();
        for (int i = 1; i < attrNames.length - 1; i++) {
            attrNameList.add(attrNames[i]);
        }

        ArrayList<String> remainAttr;
        canReductAttrs = new ArrayList<>();
        reductAttr = new ArrayList<>();
        // 进行条件属性的递归约简
        for (String s : attrNameList) {
            remainAttr = (ArrayList<String>) attrNameList.clone();
            remainAttr.remove(s);
            reductAttr = new ArrayList<>();
            reductAttr.add(s);
            recurrenceFindingReduct(canReductAttrs, reductAttr, remainAttr,
                    sameClassRcs);
        }
        
        printRules(canReductAttrs);
    }

    /**
     * 递归进行属性约简
     *
     * @param resultAttr
     *            已经计算出的约简属性组
     * @param reductAttr
     *            将要约简的属性组
     * @param remainAttr
     *            剩余的属性
     * @param sameClassRc
     *            待计算上下近似集合的同类集合
     */
    private void recurrenceFindingReduct(
            ArrayList<ArrayList<String>> resultAttr,
            ArrayList<String> reductAttr, ArrayList<String> remainAttr,
            RecordCollection[] sameClassRc) {
        KnowledgeSystem ks;
        ArrayList<RecordCollection> ksCollections;
        ArrayList<String> copyRemainAttr;
        ArrayList<String> copyReductAttr;
        HashMap<String, ArrayList<RecordCollection>> collectionMap;
        RecordCollection upRc1;
        RecordCollection downRc1;
        RecordCollection upRc2;
        RecordCollection downRc2;

        collectionMap = constructCollectionMap(reductAttr);
        ksCollections = computeKnowledgeSystem(collectionMap);
        ks = new KnowledgeSystem(ksCollections);
        
        downRc1 = ks.getDownSimilarRC(sameClassRc[0]);
        upRc1 = ks.getUpSimilarRC(sameClassRc[0]);
        downRc2 = ks.getDownSimilarRC(sameClassRc[1]);
        upRc2 = ks.getUpSimilarRC(sameClassRc[1]);

        // 如果上下近似没有完全拟合原集合则认为属性不能被约简
        if (!upRc1.isCollectionSame(sameClassRc[0])
                || !downRc1.isCollectionSame(sameClassRc[0])) {
            return;
        }
        //正类和负类都需比较
        if (!upRc2.isCollectionSame(sameClassRc[1])
                || !downRc2.isCollectionSame(sameClassRc[1])) {
            return;
        }

        // 加入到结果集中
        resultAttr.add(reductAttr);
        //只剩下1个属性不能再约简
        if (remainAttr.size() == 1) {
            return;
        }

        for (String s : remainAttr) {
            copyRemainAttr = (ArrayList<String>) remainAttr.clone();
            copyReductAttr = (ArrayList<String>) reductAttr.clone();
            copyRemainAttr.remove(s);
            copyReductAttr.add(s);
            recurrenceFindingReduct(resultAttr, copyReductAttr, copyRemainAttr,
                    sameClassRc);
        }
    }

    /**
     * 选出决策属性一致的集合
     *
     * @return
     */
    private RecordCollection[] selectTheSameClassRC() {
        RecordCollection[] resultRc = new RecordCollection[2];
        resultRc[0] = new RecordCollection();
        resultRc[1] = new RecordCollection();
        String attrValue;

        // 找出第一个记录的决策属性作为一个分类
        attrValue = totalRecords.get(0).getRecordDecisionClass();
        for (Record r : totalRecords) {
            if (attrValue.equals(r.getRecordDecisionClass())) {
                resultRc[0].getRecord().add(r);
            }else{
                resultRc[1].getRecord().add(r);
            }
        }

        return resultRc;
    }
    
    /**
     * 输出决策规则
     * @param reductAttrArray
     * 约简属性组
     */
    public void printRules(ArrayList<ArrayList<String>> reductAttrArray){
        //用来保存已经描述过的规则,避免重复输出
        ArrayList<String> rulesArray;
        String rule;
        
        for(ArrayList<String> ra: reductAttrArray){
            rulesArray = new ArrayList<>();
            System.out.print("约简的属性:");
            for(String s: ra){
                System.out.print(s + ",");
            }
            System.out.println();
            
            for(Record r: totalRecords){
                rule = r.getDecisionRule(ra);
                if(!rulesArray.contains(rule)){
                    rulesArray.add(rule);
                    System.out.println(rule);
                }
            }
            System.out.println();
        }
    }

    /**
     * 输出记录集合
     *
     * @param rcList
     *            待输出记录集合
     */
    public void printRecordCollectionList(ArrayList<RecordCollection> rcList) {
        for (RecordCollection rc : rcList) {
            System.out.print("{");
            for (Record r : rc.getRecord()) {
                System.out.print(r.getName() + ", ");
            }
            System.out.println("}");
        }
    }
}

调用类Client.java:

package DataMining_RoughSets;

/**
 * 粗糙集约简算法
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args){
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
        
        RoughSetsTool tool = new RoughSetsTool(filePath);
        tool.findingReduct();
    }
}

结果输出:

原始集合分类的上下近似集合
{x1, x2, x5, }
{x1, x2, x5, }
{x3, x4, x7, x6, x8, }
{x3, x4, x7, x6, x8, }
约简的属性:Color,
属性Shape=Triangle,Size=Large,他的分类为Stable
属性Shape=Circle,Size=Small,他的分类为UnStable
属性Shape=Rectangle,Size=Large,他的分类为Stable
属性Shape=Circle,Size=Middle,他的分类为UnStable
属性Shape=Rectangle,Size=Middle,他的分类为UnStable

约简的属性:Color,Shape,
属性Size=Large,他的分类为Stable
属性Size=Small,他的分类为UnStable
属性Size=Middle,他的分类为UnStable

约简的属性:Shape,
属性Size=Large,Color=Red,他的分类为Stable
属性Size=Small,Color=Yellow,他的分类为UnStable
属性Size=Large,Color=Blue,他的分类为Stable
属性Size=Middle,Color=Red,他的分类为UnStable
属性Size=Small,Color=Blue,他的分类为UnStable
属性Size=Middle,Color=Blue,他的分类为UnStable

约简的属性:Shape,Color,
属性Size=Large,他的分类为Stable
属性Size=Small,他的分类为UnStable
属性Size=Middle,他的分类为UnStable

算法的小问题

我在算法实现时很大的问题到不是碰到很多,就是对于上下近似集的计算上自己做了一个修改,下近似集就是知识系统中的集合完全包括在目标集合的目标,而上近似则是在下近似集的基础上添加目标集合中还没有被包含进集合的元素的所属集合,跟题目原先设想的还是有一点点的不一样,但是算法整体思想还是呈现出来了。
我对算法的思考

粗糙集属性约简算法重在约简,至于用什么原则作为约简的标准,其实本身不止一种,当然你可以根本不需要用上下近似集的概念,这样确实使得验证变得非常的繁琐,你可以直接一条条的记录去约简属性,看会不会对分类的最终结果造成影响,然后做出判断,通过对决策影响的判断也仅仅是一种属性约简的情况。
算法的适用情况

RoughSets算法在属性集比较少的情况下能得到一个不错的分类的,也可以降低存储开销,但是属性集比较多的时候,可能准确率无法保证。
作者:Androidlushangderen 发表于2015/2/18 9:24:59 原文链接
阅读:1026 评论:0 查看评论
CBA算法---基于关联规则进行分类的算法
2015年2月14日 19:02

更多数据挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm
介绍

CBA算法全称是Classification base of Association,就是基于关联规则进行分类的算法,说到关联规则,我们就会想到Apriori和FP-Tree算法都是关联规则挖掘算法,而CBA算法正是利用了Apriori挖掘出的关联规则,然后做分类判断,所以在某种程度上说,CBA算法也可以说是一种集成挖掘算法。
算法原理

CBA算法作为分类算法,他的分类情况也就是给定一些预先知道的属性,然后叫你判断出他的决策属性是哪个值。判断的依据就是Apriori算法挖掘出的频繁项,如果一个项集中包含预先知道的属性,同时也包含分类属性值,然后我们计算此频繁项能否导出已知属性值推出决策属性值的关联规则,如果满足规则的最小置信度的要求,那么可以把频繁项中的决策属性值作为最后的分类结果。具体的算法细节如下:

1、输入数据记录,就是一条条的属性值。

2、对属性值做数字的替换(按照列从上往下寻找属性值),就类似于Apriori中的一条条事务记录。

3、根据这个转化后的事务记录,进行Apriori算法计算,挖掘出频繁项集。

4、输入查询的属性值,找出符合条件的频繁项集(需要包含查询属性和分类决策属性),如果能够推导出这样的关联规则,就算分类成功,输出分类结果。

这里以之前我做的CART算法的测试数据为CBA算法的测试数据,如下:

Rid Age Income Student CreditRating BuysComputer
1 13 High No Fair CLassNo
2 11 High No Excellent CLassNo
3 25 High No Fair CLassYes
4 45 Medium No Fair CLassYes
5 50 Low Yes Fair CLassYes
6 51 Low Yes Excellent CLassNo
7 30 Low Yes Excellent CLassYes
8 13 Medium No Fair CLassNo
9 9 Low Yes Fair CLassYes
10 55 Medium Yes Fair CLassYes
11 14 Medium Yes Excellent CLassYes
12 33 Medium No Excellent CLassYes
13 33 High Yes Fair CLassYes
14 41 Medium No Excellent CLassNo

属性值对应的数字替换图:

Medium=5, CLassYes=12, Excellent=10, Low=6, Fair=9, CLassNo=11, Young=1, Middle_aged=2, Yes=8, No=7, High=4, Senior=3

体会之后的数据变为了下面的事务项:

Rid Age Income Student CreditRating BuysComputer
1 1 4 7 9 11
2 1 4 7 10 11
3 2 4 7 9 12
4 3 5 7 9 12
5 3 6 8 9 12
6 3 6 8 10 11
7 2 6 8 10 12
8 1 5 7 9 11
9 1 6 8 9 12
10 3 5 8 9 12
11 1 5 8 10 12
12 2 5 7 10 12
13 2 4 8 9 12
14 3 5 7 10 11

把每条记录看出事务项,就和Apriori算法的输入格式基本一样了,后面就是进行连接运算和剪枝步骤等Apriori算法的步骤了,在这里就不详细描述了,Apriori算法的实现可以点击这里进行了解。

算法的代码实现

测试数据就是上面的内容。

CBATool.java:

package DataMining_CBA;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import DataMining_CBA.AprioriTool.AprioriTool;
import DataMining_CBA.AprioriTool.FrequentItem;

/**
 * CBA算法(关联规则分类)工具类
 *
 * @author lyq
 *
 */
public class CBATool {
    // 年龄的类别划分
    public final String AGE = "Age";
    public final String AGE_YOUNG = "Young";
    public final String AGE_MIDDLE_AGED = "Middle_aged";
    public final String AGE_Senior = "Senior";

    // 测试数据地址
    private String filePath;
    // 最小支持度阈值率
    private double minSupportRate;
    // 最小置信度阈值,用来判断是否能够成为关联规则
    private double minConf;
    // 最小支持度
    private int minSupportCount;
    // 属性列名称
    private String[] attrNames;
    // 类别属性所代表的数字集合
    private ArrayList<Integer> classTypes;
    // 用二维数组保存测试数据
    private ArrayList<String[]> totalDatas;
    // Apriori算法工具类
    private AprioriTool aprioriTool;
    // 属性到数字的映射图
    private HashMap<String, Integer> attr2Num;
    private HashMap<Integer, String> num2Attr;

    public CBATool(String filePath, double minSupportRate, double minConf) {
        this.filePath = filePath;
        this.minConf = minConf;
        this.minSupportRate = minSupportRate;
        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        totalDatas = new ArrayList<>();
        for (String[] array : dataArray) {
            totalDatas.add(array);
        }
        attrNames = totalDatas.get(0);
        minSupportCount = (int) (minSupportRate * totalDatas.size());

        attributeReplace();
    }

    /**
     * 属性值的替换,替换成数字的形式,以便进行频繁项的挖掘
     */
    private void attributeReplace() {
        int currentValue = 1;
        int num = 0;
        String s;
        // 属性名到数字的映射图
        attr2Num = new HashMap<>();
        num2Attr = new HashMap<>();
        classTypes = new ArrayList<>();

        // 按照1列列的方式来,从左往右边扫描,跳过列名称行和id列
        for (int j = 1; j < attrNames.length; j++) {
            for (int i = 1; i < totalDatas.size(); i++) {
                s = totalDatas.get(i)[j];
                // 如果是数字形式的,这里只做年龄类别转换,其他的数字情况类似
                if (attrNames[j].equals(AGE)) {
                    num = Integer.parseInt(s);
                    if (num <= 20 && num > 0) {
                        totalDatas.get(i)[j] = AGE_YOUNG;
                    } else if (num > 20 && num <= 40) {
                        totalDatas.get(i)[j] = AGE_MIDDLE_AGED;
                    } else if (num > 40) {
                        totalDatas.get(i)[j] = AGE_Senior;
                    }
                }

                if (!attr2Num.containsKey(totalDatas.get(i)[j])) {
                    attr2Num.put(totalDatas.get(i)[j], currentValue);
                    num2Attr.put(currentValue, totalDatas.get(i)[j]);
                    if (j == attrNames.length - 1) {
                        // 如果是组后一列,说明是分类类别列,记录下来
                        classTypes.add(currentValue);
                    }

                    currentValue++;
                }
            }
        }

        // 对原始的数据作属性替换,每条记录变为类似于事务数据的形式
        for (int i = 1; i < totalDatas.size(); i++) {
            for (int j = 1; j < attrNames.length; j++) {
                s = totalDatas.get(i)[j];
                if (attr2Num.containsKey(s)) {
                    totalDatas.get(i)[j] = attr2Num.get(s) + "";
                }
            }
        }
    }

    /**
     * Apriori计算全部频繁项集
     * @return
     */
    private ArrayList<FrequentItem> aprioriCalculate() {
        String[] tempArray;
        ArrayList<FrequentItem> totalFrequentItems;
        ArrayList<String[]> copyData = (ArrayList<String[]>) totalDatas.clone();
        
        // 去除属性名称行
        copyData.remove(0);
        // 去除首列ID
        for (int i = 0; i < copyData.size(); i++) {
            String[] array = copyData.get(i);
            tempArray = new String[array.length - 1];
            System.arraycopy(array, 1, tempArray, 0, tempArray.length);
            copyData.set(i, tempArray);
        }
        aprioriTool = new AprioriTool(copyData, minSupportCount);
        aprioriTool.computeLink();
        totalFrequentItems = aprioriTool.getTotalFrequentItems();

        return totalFrequentItems;
    }

    /**
     * 基于关联规则的分类
     *
     * @param attrValues
     *            预先知道的一些属性
     * @return
     */
    public String CBAJudge(String attrValues) {
        int value = 0;
        // 最终分类类别
        String classType = null;
        String[] tempArray;
        // 已知的属性值
        ArrayList<String> attrValueList = new ArrayList<>();
        ArrayList<FrequentItem> totalFrequentItems;

        totalFrequentItems = aprioriCalculate();
        // 将查询条件进行逐一属性的分割
        String[] array = attrValues.split(",");
        for (String record : array) {
            tempArray = record.split("=");
            value = attr2Num.get(tempArray[1]);
            attrValueList.add(value + "");
        }

        // 在频繁项集中寻找符合条件的项
        for (FrequentItem item : totalFrequentItems) {
            // 过滤掉不满足个数频繁项
            if (item.getIdArray().length < (attrValueList.size() + 1)) {
                continue;
            }

            // 要保证查询的属性都包含在频繁项集中
            if (itemIsSatisfied(item, attrValueList)) {
                tempArray = item.getIdArray();
                classType = classificationBaseRules(tempArray);

                if (classType != null) {
                    // 作属性替换
                    classType = num2Attr.get(Integer.parseInt(classType));
                    break;
                }
            }
        }

        return classType;
    }

    /**
     * 基于关联规则进行分类
     *
     * @param items
     *            频繁项
     * @return
     */
    private String classificationBaseRules(String[] items) {
        String classType = null;
        String[] arrayTemp;
        int count1 = 0;
        int count2 = 0;
        // 置信度
        double confidenceRate;

        String[] noClassTypeItems = new String[items.length - 1];
        for (int i = 0, k = 0; i < items.length; i++) {
            if (!classTypes.contains(Integer.parseInt(items[i]))) {
                noClassTypeItems[k] = items[i];
                k++;
            } else {
                classType = items[i];
            }
        }

        for (String[] array : totalDatas) {
            // 去除ID数字号
            arrayTemp = new String[array.length - 1];
            System.arraycopy(array, 1, arrayTemp, 0, array.length - 1);
            if (isStrArrayContain(arrayTemp, noClassTypeItems)) {
                count1++;

                if (isStrArrayContain(arrayTemp, items)) {
                    count2++;
                }
            }
        }

        // 做置信度的计算
        confidenceRate = count1 * 1.0 / count2;
        if (confidenceRate >= minConf) {
            return classType;
        } else {
            // 如果不满足最小置信度要求,则此关联规则无效
            return null;
        }
    }

    /**
     * 判断单个字符是否包含在字符数组中
     *
     * @param array
     *            字符数组
     * @param s
     *            判断的单字符
     * @return
     */
    private boolean strIsContained(String[] array, String s) {
        boolean isContained = false;

        for (String str : array) {
            if (str.equals(s)) {
                isContained = true;
                break;
            }
        }

        return isContained;
    }

    /**
     * 数组array2是否包含于array1中,不需要完全一样
     *
     * @param array1
     * @param array2
     * @return
     */
    private boolean isStrArrayContain(String[] array1, String[] array2) {
        boolean isContain = true;
        for (String s2 : array2) {
            isContain = false;
            for (String s1 : array1) {
                // 只要s2字符存在于array1中,这个字符就算包含在array1中
                if (s2.equals(s1)) {
                    isContain = true;
                    break;
                }
            }

            // 一旦发现不包含的字符,则array2数组不包含于array1中
            if (!isContain) {
                break;
            }
        }

        return isContain;
    }

    /**
     * 判断频繁项集是否满足查询
     *
     * @param item
     *            待判断的频繁项集
     * @param attrValues
     *            查询的属性值列表
     * @return
     */
    private boolean itemIsSatisfied(FrequentItem item,
            ArrayList<String> attrValues) {
        boolean isContained = false;
        String[] array = item.getIdArray();

        for (String s : attrValues) {
            isContained = true;

            if (!strIsContained(array, s)) {
                isContained = false;
                break;
            }

            if (!isContained) {
                break;
            }
        }

        if (isContained) {
            isContained = false;

            // 还要验证是否频繁项集中是否包含分类属性
            for (Integer type : classTypes) {
                if (strIsContained(array, type + "")) {
                    isContained = true;
                    break;
                }
            }
        }

        return isContained;
    }

}

调用类Client.java:

package DataMining_CBA;

import java.text.MessageFormat;

/**
 * CBA算法--基于关联规则的分类算法
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args){
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
        String attrDesc = "Age=Senior,CreditRating=Fair";
        String classification = null;
        
        //最小支持度阈值率
        double minSupportRate = 0.2;
        //最小置信度阈值
        double minConf = 0.7;
        
        CBATool tool = new CBATool(filePath, minSupportRate, minConf);
        classification = tool.CBAJudge(attrDesc);
        System.out.println(MessageFormat.format("{0}的关联分类结果为{1}", attrDesc, classification));
    }
}

代码的结果为:

频繁1项集:
{1,},{2,},{3,},{4,},{5,},{6,},{7,},{8,},{9,},{10,},{11,},{12,},
频繁2项集:
{1,7,},{1,9,},{1,11,},{2,12,},{3,5,},{3,8,},{3,9,},{3,12,},{4,7,},{4,9,},{5,7,},{5,9,},{5,10,},{5,12,},{6,8,},{6,12,},{7,9,},{7,10,},{7,11,},{7,12,},{8,9,},{8,10,},{8,12,},{9,12,},{10,11,},{10,12,},
频繁3项集:
{1,7,11,},{3,9,12,},{6,8,12,},{8,9,12,},
频繁4项集:

频繁5项集:

频繁6项集:

频繁7项集:

频繁8项集:

频繁9项集:

频繁10项集:

频繁11项集:

Age=Senior,CreditRating=Fair的关联分类结果为CLassYes

上面的有些项集为空说明没有此项集。Apriori算法类可以在这里进行查阅,这里只展示了CBA算法的部分。

算法的分析

我在准备实现CBA算法的时候就预见到了这个算法就是对Apriori算法的一个包装,在于2点,输入数据的格式进行数字的转换,还有就是输出的时候做属性对数字的替换,核心还是在于Apriori算法的项集频繁挖掘。

程序实现时遇到的问题

在这期间遇到了一个bug就是频繁1项集在排序的时候出现了问题,后来发现原因是String.CompareTo(),原本应该是1,2,....11,12,用了前面这个方法后会变成1,10,2,。。就是10会比2小的情况,后来查了String.CompareTo()的比较规则,明白了他是一位位比较Ascall码值,因为10的1比2小,最后果断的改回了用Integer的比较方法了。这个问题别看是个小问题,1项集如果没有排好序,后面的连接操作可能会出现少情况的可能,这个之前吃过这样的亏了。
我对CBA算法的理解

CBA算法和巧妙的利用了关联规则进行类别的分类,有别与其他的分类算法。他的算法好坏又会依靠Apriori算法的执行好坏。
作者:Androidlushangderen 发表于2015/2/14 19:02:02 原文链接
阅读:1060 评论:0 查看评论
PrefixSpan序列模式挖掘算法
2015年2月12日 19:06

更多数据挖掘代码:https://github.com/linyiqun/DataMiningAlgorithm
介绍

与GSP一样,PrefixSpan算法也是序列模式分析算法的一种,不过与前者不同的是PrefixSpan算法不产生任何的侯选集,在这点上可以说已经比GSP好很多了。PrefixSpan算法可以挖掘出满足阈值的所有序列模式,可以说是非常经典的算法。序列的格式就是上文中提到过的类似于<a, b, (de)>这种的。
算法原理

PrefixSpan算法的原理是采用后缀序列转前缀序列的方式来构造频繁序列的。举个例子,


比如原始序列如上图所示,4条序列,1个序列中好几个项集,项集内有1个或多个元素,首先找出前缀为a的子序列,此时序列前缀为<a>,后缀就变为了:

 

"_"下标符代表前缀为a,说明是在项集中间匹配的。这就相当于从后缀序列中提取出1项加入到前缀序列中,变化的规则就是从左往右扫描,找到第1个此元素对应的项,然后做改变。然后根据此规则继续递归直到后续序列不满足最小支持度阈值的情况。所以此算法的难点就转变为了从后缀序列变为前缀序列的过程。在这个过程要分为2种情况,第1种是单个元素项的后缀提前,比如这里的a,对单个项的提前有分为几种情况,比如:

<b  a  c  ad>,就会变为<c  ad>,如果a是嵌套在项集中的情况<b  c  dad  r>,就会变为< _d   r>,_代表的就是a.如果a在一项的最末尾,此项也会被移除<b  c  dda  r>变为<r>。但是如果是这种情况<_da  d  d>a包含在下标符中,将会做处理,应该此时的a是在前缀序列所属的项集内的。

还有1个大类的分类就是对于组合项的后缀提取,可以分为2个情况,1个是从_X中寻找,一个从后面找出连续的项集,比如在这里<a>的条件下,找出前缀<(ab)>的后缀序列


第一种在_X中寻找还有没有X=a的情况,因为_已经代表1个a了,还有一个是判断_X != _a的情况,从后面的项集中找到包含有连续的aa的那个项集,然后做变换处理,与单个项集的变换规则一致。

算法的递归顺序

想要实现整个的序列挖掘,算法的递归顺序就显得非常重要了。在探索递归顺序的路上还是犯了一些错误的,刚刚开始的递归顺序是<a>---><a  a>----><a   a   a>,假设<a  a  a>找不到对应的后缀模式时,然后回溯到<a (aa)>进行递归,后来发现这样会漏掉情况,为什么呢,因为如果 <a a >没法进行到<a  a  a>,那么就不可能会有前缀<a  (aa)>,顶多会判断到<(aa)>,从<a a>处回调的。于是我发现了这个问题,就变为了下面这个样子,经测试是对的。:

加入所有的单个元素的类似为a-f,顺序为

<a>,---><a a>.同时<(aa)>,然后<ab>同时<(ab)>,就是在a添加a-f的元素的时候,检验a所属项集添加a-f元素的情况。这样就不会漏掉情况了,用了2个递归搞定了这个问题。这个算法的整体实现可以对照代码来看会理解很多。最后提醒一点,在每次做出改变之后都会判断一下是否满足最小支持度阈值的。
PrefixSpan实例

这里举1个真实一点的例子,下面是输入的初始序列:

挖掘出的所有的序列模式为,下面是一个表格的形式


在<b>的序列模式中少了1个序列模式。可以与后面程序算法测试的结果做对比。
算法的代码实现

代码实现同样以这个为例子,这样会显得更有说服性。

测试数据:

bd c b ac
bf ce b fg
ah bf a b f
be ce d
a bd b c b ade

Sequence.java:

package DataMining_PrefixSpan;

import java.util.ArrayList;

/**
 * 序列类
 *
 * @author lyq
 *
 */
public class Sequence {
    // 序列内的项集
    private ArrayList<ItemSet> itemSetList;

    public Sequence() {
        this.itemSetList = new ArrayList<>();
    }

    public ArrayList<ItemSet> getItemSetList() {
        return itemSetList;
    }

    public void setItemSetList(ArrayList<ItemSet> itemSetList) {
        this.itemSetList = itemSetList;
    }

    /**
     * 判断单一项是否包含于此序列
     *
     * @param c
     *            待判断项
     * @return
     */
    public boolean strIsContained(String c) {
        boolean isContained = false;

        for (ItemSet itemSet : itemSetList) {
            isContained = false;

            for (String s : itemSet.getItems()) {
                if (itemSet.getItems().contains("_")) {
                    continue;
                }

                if (s.equals(c)) {
                    isContained = true;
                    break;
                }
            }

            if (isContained) {
                // 如果已经检测出包含了,直接挑出循环
                break;
            }
        }

        return isContained;
    }

    /**
     * 判断组合项集是否包含于序列中
     *
     * @param itemSet
     *            组合的项集,元素超过1个
     * @return
     */
    public boolean compoentItemIsContain(ItemSet itemSet) {
        boolean isContained = false;
        ArrayList<String> tempItems;
        String lastItem = itemSet.getLastValue();

        for (int i = 0; i < this.itemSetList.size(); i++) {
            tempItems = this.itemSetList.get(i).getItems();
            // 分2种情况查找,第一种从_X中找出x等于项集最后的元素,因为_前缀已经为原本的元素
            if (tempItems.size() > 1 && tempItems.get(0).equals("_")
                    && tempItems.get(1).equals(lastItem)) {
                isContained = true;
                break;
            } else if (!tempItems.get(0).equals("_")) {
                // 从没有_前缀的项集开始寻找,第二种为从后面的后缀中找出直接找出连续字符为ab为同一项集的项集
                if (strArrayContains(tempItems, itemSet.getItems())) {
                    isContained = true;
                    break;
                }
            }

            if (isContained) {
                break;
            }
        }

        return isContained;
    }

    /**
     * 删除单个项
     *
     * @param s
     *            待删除项
     */
    public void deleteSingleItem(String s) {
        ArrayList<String> tempItems;
        ArrayList<String> deleteItems = new ArrayList<>();

        for (ItemSet itemSet : this.itemSetList) {
            tempItems = itemSet.getItems();
            deleteItems = new ArrayList<>();

            for (int i = 0; i < tempItems.size(); i++) {
                if (tempItems.get(i).equals(s)) {
                    deleteItems.add(tempItems.get(i));
                }
            }

            tempItems.removeAll(deleteItems);
        }
    }

    /**
     * 提取项s之后所得的序列
     *
     * @param s
     *            目标提取项s
     */
    public Sequence extractItem(String s) {
        Sequence extractSeq = this.copySeqence();
        ItemSet itemSet;
        ArrayList<String> items;
        ArrayList<ItemSet> deleteItemSets = new ArrayList<>();
        ArrayList<String> tempItems = new ArrayList<>();

        for (int k = 0; k < extractSeq.itemSetList.size(); k++) {
            itemSet = extractSeq.itemSetList.get(k);
            items = itemSet.getItems();
            if (items.size() == 1 && items.get(0).equals(s)) {
                //如果找到的是单项,则完全移除,跳出循环
                extractSeq.itemSetList.remove(k);
                break;
            } else if (items.size() > 1 && !items.get(0).equals("_")) {
                //在后续的多元素项中判断是否包含此元素
                if (items.contains(s)) {
                    //如果包含把s后面的元素加入到临时字符数组中
                    int index = items.indexOf(s);
                    for (int j = index; j < items.size(); j++) {
                        tempItems.add(items.get(j));
                    }
                    //将第一位的s变成下标符"_"
                    tempItems.set(0, "_");
                    if (tempItems.size() == 1) {
                        // 如果此匹配为在最末端,同样移除
                        deleteItemSets.add(itemSet);
                    } else {
                        //将变化后的项集替换原来的
                        extractSeq.itemSetList.set(k, new ItemSet(tempItems));
                    }
                    break;
                } else {
                    deleteItemSets.add(itemSet);
                }
            } else {
                // 不符合以上2项条件的统统移除
                deleteItemSets.add(itemSet);
            }
        }
        extractSeq.itemSetList.removeAll(deleteItemSets);

        return extractSeq;
    }

    /**
     * 提取组合项之后的序列
     *
     * @param array
     *            组合数组
     * @return
     */
    public Sequence extractCompoentItem(ArrayList<String> array) {
        // 找到目标项,是否立刻停止
        boolean stopExtract = false;
        Sequence seq = this.copySeqence();
        String lastItem = array.get(array.size() - 1);
        ArrayList<String> tempItems;
        ArrayList<ItemSet> deleteItems = new ArrayList<>();

        for (int i = 0; i < seq.itemSetList.size(); i++) {
            if (stopExtract) {
                break;
            }

            tempItems = seq.itemSetList.get(i).getItems();
            // 分2种情况查找,第一种从_X中找出x等于项集最后的元素,因为_前缀已经为原本的元素
            if (tempItems.size() > 1 && tempItems.get(0).equals("_")
                    && tempItems.get(1).equals(lastItem)) {
                if (tempItems.size() == 2) {
                    seq.itemSetList.remove(i);
                } else {
                    // 把1号位置变为下标符"_",往后移1个字符的位置
                    tempItems.set(1, "_");
                    // 移除第一个的"_"下划符
                    tempItems.remove(0);
                }
                stopExtract = true;
                break;
            } else if (!tempItems.get(0).equals("_")) {
                // 从没有_前缀的项集开始寻找,第二种为从后面的后缀中找出直接找出连续字符为ab为同一项集的项集
                if (strArrayContains(tempItems, array)) {
                    // 从左往右找出第一个给定字符的位置,把后面的部分截取出来
                    int index = tempItems.indexOf(lastItem);
                    ArrayList<String> array2 = new ArrayList<String>();

                    for (int j = index; j < tempItems.size(); j++) {
                        array2.add(tempItems.get(j));
                    }
                    array2.set(0, "_");

                    if (array2.size() == 1) {
                        //如果此项在末尾的位置,则移除该项,否则进行替换
                        deleteItems.add(seq.itemSetList.get(i));
                    } else {
                        seq.itemSetList.set(i, new ItemSet(array2));
                    }
                    stopExtract = true;
                    break;
                } else {
                    deleteItems.add(seq.itemSetList.get(i));
                }
            } else {
                // 这种情况是处理_X中X不等于最后一个元素的情况
                deleteItems.add(seq.itemSetList.get(i));
            }
        }
        
        seq.itemSetList.removeAll(deleteItems);

        return seq;
    }

    /**
     * 深拷贝一个序列
     *
     * @return
     */
    public Sequence copySeqence() {
        Sequence copySeq = new Sequence();
        ItemSet tempItemSet;
        ArrayList<String> items;

        for (ItemSet itemSet : this.itemSetList) {
            items = (ArrayList<String>) itemSet.getItems().clone();
            tempItemSet = new ItemSet(items);
            copySeq.getItemSetList().add(tempItemSet);
        }

        return copySeq;
    }

    /**
     * 获取序列中最后一个项集的最后1个元素
     *
     * @return
     */
    public String getLastItemSetValue() {
        int size = this.getItemSetList().size();
        ItemSet itemSet = this.getItemSetList().get(size - 1);
        size = itemSet.getItems().size();

        return itemSet.getItems().get(size - 1);
    }

    /**
     * 判断strList2是否是strList1的子序列
     *
     * @param strList1
     * @param strList2
     * @return
     */
    public boolean strArrayContains(ArrayList<String> strList1,
            ArrayList<String> strList2) {
        boolean isContained = false;

        for (int i = 0; i < strList1.size() - strList2.size() + 1; i++) {
            isContained = true;

            for (int j = 0, k = i; j < strList2.size(); j++, k++) {
                if (!strList1.get(k).equals(strList2.get(j))) {
                    isContained = false;
                    break;
                }
            }

            if (isContained) {
                break;
            }
        }

        return isContained;
    }
}

ItemSet.java:

package DataMining_PrefixSpan;

import java.util.ArrayList;

/**
 * 字符项集类
 *
 * @author lyq
 *
 */
public class ItemSet {
    // 项集内的字符
    private ArrayList<String> items;

    public ItemSet(String[] str) {
        items = new ArrayList<>();
        for (String s : str) {
            items.add(s);
        }
    }

    public ItemSet(ArrayList<String> itemsList) {
        this.items = itemsList;
    }

    public ItemSet(String s) {
        items = new ArrayList<>();
        for (int i = 0; i < s.length(); i++) {
            items.add(s.charAt(i) + "");
        }
    }

    public ArrayList<String> getItems() {
        return items;
    }

    public void setItems(ArrayList<String> items) {
        this.items = items;
    }

    /**
     * 获取项集最后1个元素
     *
     * @return
     */
    public String getLastValue() {
        int size = this.items.size();

        return this.items.get(size - 1);
    }
}

PrefixSpanTool.java:

package DataMining_PrefixSpan;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
 * PrefixSpanTool序列模式分析算法工具类
 *
 * @author lyq
 *
 */
public class PrefixSpanTool {
    // 测试数据文件地址
    private String filePath;
    // 最小支持度阈值比例
    private double minSupportRate;
    // 最小支持度,通过序列总数乘以阈值比例计算
    private int minSupport;
    // 原始序列组
    private ArrayList<Sequence> totalSeqs;
    // 挖掘出的所有序列频繁模式
    private ArrayList<Sequence> totalFrequentSeqs;
    // 所有的单一项,用于递归枚举
    private ArrayList<String> singleItems;

    public PrefixSpanTool(String filePath, double minSupportRate) {
        this.filePath = filePath;
        this.minSupportRate = minSupportRate;
        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        minSupport = (int) (dataArray.size() * minSupportRate);
        totalSeqs = new ArrayList<>();
        totalFrequentSeqs = new ArrayList<>();
        Sequence tempSeq;
        ItemSet tempItemSet;
        for (String[] str : dataArray) {
            tempSeq = new Sequence();
            for (String s : str) {
                tempItemSet = new ItemSet(s);
                tempSeq.getItemSetList().add(tempItemSet);
            }
            totalSeqs.add(tempSeq);
        }

        System.out.println("原始序列数据:");
        outputSeqence(totalSeqs);
    }

    /**
     * 输出序列列表内容
     *
     * @param seqList
     *            待输出序列列表
     */
    private void outputSeqence(ArrayList<Sequence> seqList) {
        for (Sequence seq : seqList) {
            System.out.print("<");
            for (ItemSet itemSet : seq.getItemSetList()) {
                if (itemSet.getItems().size() > 1) {
                    System.out.print("(");
                }

                for (String s : itemSet.getItems()) {
                    System.out.print(s + " ");
                }

                if (itemSet.getItems().size() > 1) {
                    System.out.print(")");
                }
            }
            System.out.println(">");
        }
    }

    /**
     * 移除初始序列中不满足最小支持度阈值的单项
     */
    private void removeInitSeqsItem() {
        int count = 0;
        HashMap<String, Integer> itemMap = new HashMap<>();
        singleItems = new ArrayList<>();

        for (Sequence seq : totalSeqs) {
            for (ItemSet itemSet : seq.getItemSetList()) {
                for (String s : itemSet.getItems()) {
                    if (!itemMap.containsKey(s)) {
                        itemMap.put(s, 1);
                    }
                }
            }
        }

        String key;
        for (Map.Entry entry : itemMap.entrySet()) {
            count = 0;
            key = (String) entry.getKey();
            for (Sequence seq : totalSeqs) {
                if (seq.strIsContained(key)) {
                    count++;
                }
            }

            itemMap.put(key, count);

        }

        for (Map.Entry entry : itemMap.entrySet()) {
            key = (String) entry.getKey();
            count = (int) entry.getValue();

            if (count < minSupport) {
                // 如果支持度阈值小于所得的最小支持度阈值,则删除该项
                for (Sequence seq : totalSeqs) {
                    seq.deleteSingleItem(key);
                }
            } else {
                singleItems.add(key);
            }
        }

        Collections.sort(singleItems);
    }

    /**
     * 递归搜索满足条件的序列模式
     *
     * @param beforeSeq
     *            前缀序列
     * @param afterSeqList
     *            后缀序列列表
     */
    private void recursiveSearchSeqs(Sequence beforeSeq,
            ArrayList<Sequence> afterSeqList) {
        ItemSet tempItemSet;
        Sequence tempSeq2;
        Sequence tempSeq;
        ArrayList<Sequence> tempSeqList = new ArrayList<>();

        for (String s : singleItems) {
            // 分成2种形式递归,以<a>为起始项,第一种直接加入独立项集遍历<a,a>,<a,b> <a,c>..
            if (isLargerThanMinSupport(s, afterSeqList)) {
                tempSeq = beforeSeq.copySeqence();
                tempItemSet = new ItemSet(s);
                tempSeq.getItemSetList().add(tempItemSet);

                totalFrequentSeqs.add(tempSeq);

                tempSeqList = new ArrayList<>();
                for (Sequence seq : afterSeqList) {
                    if (seq.strIsContained(s)) {
                        tempSeq2 = seq.extractItem(s);
                        tempSeqList.add(tempSeq2);
                    }
                }

                recursiveSearchSeqs(tempSeq, tempSeqList);
            }

            // 第二种递归为以元素的身份加入最后的项集内以a为例<(aa)>,<(ab)>,<(ac)>...
            // a在这里可以理解为一个前缀序列,里面可能是单个元素或者已经是多元素的项集
            tempSeq = beforeSeq.copySeqence();
            int size = tempSeq.getItemSetList().size();
            tempItemSet = tempSeq.getItemSetList().get(size - 1);
            tempItemSet.getItems().add(s);

            if (isLargerThanMinSupport(tempItemSet, afterSeqList)) {
                tempSeqList = new ArrayList<>();
                for (Sequence seq : afterSeqList) {
                    if (seq.compoentItemIsContain(tempItemSet)) {
                        tempSeq2 = seq.extractCompoentItem(tempItemSet
                                .getItems());
                        tempSeqList.add(tempSeq2);
                    }
                }
                totalFrequentSeqs.add(tempSeq);

                recursiveSearchSeqs(tempSeq, tempSeqList);
            }
        }
    }

    /**
     * 所传入的项组合在所给定序列中的支持度是否超过阈值
     *
     * @param s
     *            所需匹配的项
     * @param seqList
     *            比较序列数据
     * @return
     */
    private boolean isLargerThanMinSupport(String s, ArrayList<Sequence> seqList) {
        boolean isLarge = false;
        int count = 0;

        for (Sequence seq : seqList) {
            if (seq.strIsContained(s)) {
                count++;
            }
        }

        if (count >= minSupport) {
            isLarge = true;
        }

        return isLarge;
    }

    /**
     * 所传入的组合项集在序列中的支持度是否大于阈值
     *
     * @param itemSet
     *            组合元素项集
     * @param seqList
     *            比较的序列列表
     * @return
     */
    private boolean isLargerThanMinSupport(ItemSet itemSet,
            ArrayList<Sequence> seqList) {
        boolean isLarge = false;
        int count = 0;

        if (seqList == null) {
            return false;
        }

        for (Sequence seq : seqList) {
            if (seq.compoentItemIsContain(itemSet)) {
                count++;
            }
        }

        if (count >= minSupport) {
            isLarge = true;
        }

        return isLarge;
    }

    /**
     * 序列模式分析计算
     */
    public void prefixSpanCalculate() {
        Sequence seq;
        Sequence tempSeq;
        ArrayList<Sequence> tempSeqList = new ArrayList<>();
        ItemSet itemSet;
        removeInitSeqsItem();

        for (String s : singleItems) {
            // 从最开始的a,b,d开始递归往下寻找频繁序列模式
            seq = new Sequence();
            itemSet = new ItemSet(s);
            seq.getItemSetList().add(itemSet);

            if (isLargerThanMinSupport(s, totalSeqs)) {
                tempSeqList = new ArrayList<>();
                for (Sequence s2 : totalSeqs) {
                    // 判断单一项是否包含于在序列中,包含才进行提取操作
                    if (s2.strIsContained(s)) {
                        tempSeq = s2.extractItem(s);
                        tempSeqList.add(tempSeq);
                    }
                }

                totalFrequentSeqs.add(seq);
                recursiveSearchSeqs(seq, tempSeqList);
            }
        }

        printTotalFreSeqs();
    }

    /**
     * 按模式类别输出频繁序列模式
     */
    private void printTotalFreSeqs() {
        System.out.println("序列模式挖掘结果:");
        
        ArrayList<Sequence> seqList;
        HashMap<String, ArrayList<Sequence>> seqMap = new HashMap<>();
        for (String s : singleItems) {
            seqList = new ArrayList<>();
            for (Sequence seq : totalFrequentSeqs) {
                if (seq.getItemSetList().get(0).getItems().get(0).equals(s)) {
                    seqList.add(seq);
                }
            }
            seqMap.put(s, seqList);
        }

        int count = 0;
        for (String s : singleItems) {
            count = 0;
            System.out.println();
            System.out.println();

            seqList = (ArrayList<Sequence>) seqMap.get(s);
            for (Sequence tempSeq : seqList) {
                count++;
                System.out.print("<");
                for (ItemSet itemSet : tempSeq.getItemSetList()) {
                    if (itemSet.getItems().size() > 1) {
                        System.out.print("(");
                    }

                    for (String str : itemSet.getItems()) {
                        System.out.print(str + " ");
                    }

                    if (itemSet.getItems().size() > 1) {
                        System.out.print(")");
                    }
                }
                System.out.print(">, ");

                // 每5个序列换一行
                if (count == 5) {
                    count = 0;
                    System.out.println();
                }
            }

        }
    }

}

调用类Client.java:

package DataMining_PrefixSpan;

/**
 * PrefixSpan序列模式挖掘算法
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] agrs){
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
        //最小支持度阈值率
        double minSupportRate = 0.4;
        
        PrefixSpanTool tool = new PrefixSpanTool(filePath, minSupportRate);
        tool.prefixSpanCalculate();
    }
}

输出的结果:

原始序列数据:
<(b d )c b (a c )>
<(b f )(c e )b (f g )>
<(a h )(b f )a b f >
<(b e )(c e )d >
<a (b d )b c b (a d e )>
序列模式挖掘结果:


<a >, <a a >, <a b >, <a b a >, <a b b >,


<b >, <b a >, <b b >, <b b a >, <b b c >,
<b b f >, <b c >, <b c a >, <b c b >, <b c b a >,
<b c d >, <b (c e )>, <b d >, <(b d )>, <(b d )a >,
<(b d )b >, <(b d )b a >, <(b d )b c >, <(b d )c >, <(b d )c a >,
<(b d )c b >, <(b d )c b a >, <b e >, <b f >, <(b f )>,
<(b f )b >, <(b f )b f >, <(b f )f >,

<c >, <c a >, <c b >, <c b a >, <c d >,
<(c e )>,

<d >, <d a >, <d b >, <d b a >, <d b c >,
<d c >, <d c a >, <d c b >, <d c b a >,

<e >,

<f >, <f b >, <f b f >, <f f >,

经过比对,与上述表格中的结果完全一致,从结果中可以看出他的递归顺序正是刚刚我所想要的那种。

算法实现时的难点

我在实现这个算法时确实碰到了不少的问题,下面一一列举。

1、Sequence序列在判断或者提取单项和组合项的时候,情况少考虑了,还有考虑到了处理的方式又可能错了。

2、递归的顺序在最早的时候考虑错了,后来对递归的顺序进行了调整。

3、在算法的调试时遇到了,当发现某一项出现问题时,不能够立即调试,因为里面陷入的递归层次实在太深,只能自己先手算此情况下的前缀,后缀序列,然后自己模拟出1个Seq调试,在纠正extract方法时用的比较多。
我对PrefixSpan算法的理解

实现了这个算法之后,再回味这个算法,还是很奇妙的,一个序列,通过从左往右的扫描,通过各个项集的子集,能够组合出许许多多的的序列模式,然后进行挖掘,PrefixSpan通过递归的形式全部找出,而且效率非常高,的确是个很强大的算法。
PrefixSpan算法整体的特点

首先一点,他不会产生候选序列,在产生投影数据库的时候(也就是产生后缀子序列),他的规模是不断减小的。PrefixSpan采用分治法进行序列的挖掘,十分的高效。唯一比较会有影响的开销就是在构造后缀子序列的过程,专业上的名称叫做构造投影数据库的时候。
作者:Androidlushangderen 发表于2015/2/12 19:06:27 原文链接
阅读:1840 评论:0 查看评论
GSP序列模式分析算法
2015年2月10日 9:12

参考资料:http://blog.csdn.net/zone_programming/article/details/42032309

更多数据挖掘代码:https://github.com/linyiqun/DataMiningAlgorithm

介绍
GSP算法是序列模式挖掘算法的一种,他是一种类Apriori的一种,整个过程与Apriori算法比较类似,不过在细节上会略有不同,在下面的描述中,将会有所描述。GSP在原有的频繁模式定义的概念下,增加了3个的概念。

1、加入时间约束min_gap,max_gap,要求原来的连续变为只要满足在规定的min_gap到max_gap之间即可。

2、加入time_windows_size,只要在windows_size内的item,都可以被认为是同一ItemSet。

3、加入分类标准。

以上3点新的中的第一条特征将会在后面的算法中着重展现。
算法原理

1、根据所输入的序列,找出所有的单项集,即1频繁模式,这里会经过最小支持度阈值的判断。

2、根据1频繁模式进行连接运算,产生2频繁模式,这里会有进行最小阈值的判断。

3、根据2频繁模式连接产生3频繁模式,会经过最小支持度判断和剪枝操作,剪枝操作的原理在于判断他的所有子集是否也全是频繁模式。

4、3频繁模式不断的挖掘知道不能够产生出候选集为止。
连接操作的原理

2个序列,全部变为item列表的形式,如果a序列去掉第1个元素后,b序列去掉最后1个序列,2个序列的item完全一致,则代表可以连接,由b的最后一个元素加入到a中,至于是以独立项集的身份加入还是加入到a中最后1个项集中取决于b中的最后一个元素所属项集是否为单项项集。
时间约束计算

这个是用在支持度计数使用的,GSP算法的支持度计算不是那么简单,比如序列判断<2, <3, 4>>是否在序列<(1,5), 2 , <3, 4>, 2>,这就不能仅仅判断序列中是否只包含2,<3, 4>就行了,还要满足时间间隔约束,这就要把2,和<3,4>的所有出现时间都找出来,然后再里面找出一条满足时间约束的路径就算包含。时间的定义是从左往右起1.2,3...继续,以1个项集为单位,所有2的时间有2个分别为t=2和t=4,然后同理,因为<3,4>在序列中只有1次,所以时间为t=3,所以问题就变为了下面一个数组的问题

2  4

3

从时间数组的上往下,通过对多个时间的组合,找出1条满足时间约束的方案,这里的方案只有2-3,4-3,然后判断时间间隔,如果存在这样的方式,则代表此序列支持所给定序列,支持度值加1,这个算法在程序的实现中是比较复杂的。
算法的代码实现

测试数据输入(格式:事务ID item数 item1 item2.....):

1 2 1 5
1 1 2
1 1 3
1 1 4
2 1 1
2 1 3
2 1 4
2 2 3 5
3 1 1
3 1 2
3 1 3
3 1 4
3 1 5
4 1 1
4 1 3
4 1 5
5 1 4
5 1 5

最后组成的序列为:

<(1,5) 2 3 4>

<1 3 4 (3,5)>

<1 2 3 4 5>

<1 3 5>

<4 5>

也就是说同一序列都是同事务的。下面是关键的类

Sequence.java:

package DataMining_GSP;

import java.util.ArrayList;

/**
 * 序列,每个序列内部包含多组ItemSet项集
 *
 * @author lyq
 *
 */
public class Sequence implements Comparable<Sequence>, Cloneable {
    // 序列所属事务ID
    private int trsanctionID;
    // 项集列表
    private ArrayList<ItemSet> itemSetList;

    public Sequence(int trsanctionID) {
        this.trsanctionID = trsanctionID;
        this.itemSetList = new ArrayList<>();
    }

    public Sequence() {
        this.itemSetList = new ArrayList<>();
    }

    public int getTrsanctionID() {
        return trsanctionID;
    }

    public void setTrsanctionID(int trsanctionID) {
        this.trsanctionID = trsanctionID;
    }

    public ArrayList<ItemSet> getItemSetList() {
        return itemSetList;
    }

    public void setItemSetList(ArrayList<ItemSet> itemSetList) {
        this.itemSetList = itemSetList;
    }

    /**
     * 取出序列中第一个项集的第一个元素
     *
     * @return
     */
    public Integer getFirstItemSetNum() {
        return this.getItemSetList().get(0).getItems().get(0);
    }

    /**
     * 获取序列中最后一个项集
     *
     * @return
     */
    public ItemSet getLastItemSet() {
        return getItemSetList().get(getItemSetList().size() - 1);
    }

    /**
     * 获取序列中最后一个项集的最后一个一个元素
     *
     * @return
     */
    public Integer getLastItemSetNum() {
        ItemSet lastItemSet = getItemSetList().get(getItemSetList().size() - 1);
        int lastItemNum = lastItemSet.getItems().get(
                lastItemSet.getItems().size() - 1);

        return lastItemNum;
    }

    /**
     * 判断序列中最后一个项集是否为单一的值
     *
     * @return
     */
    public boolean isLastItemSetSingleNum() {
        ItemSet lastItemSet = getItemSetList().get(getItemSetList().size() - 1);
        int size = lastItemSet.getItems().size();

        return size == 1 ? true : false;
    }

    @Override
    public int compareTo(Sequence o) {
        // TODO Auto-generated method stub
        return this.getFirstItemSetNum().compareTo(o.getFirstItemSetNum());
    }

    @Override
    protected Object clone() throws CloneNotSupportedException {
        // TODO Auto-generated method stub
        return super.clone();
    }
    
    /**
     * 拷贝一份一模一样的序列
     */
    public Sequence copySeqence(){
        Sequence copySeq = new Sequence();
        for(ItemSet itemSet: this.itemSetList){
            copySeq.getItemSetList().add(new ItemSet(itemSet.copyItems()));
        }
        
        return copySeq;
    }

    /**
     * 比较2个序列是否相等,需要判断内部的每个项集是否完全一致
     *
     * @param seq
     *            比较的序列对象
     * @return
     */
    public boolean compareIsSame(Sequence seq) {
        boolean result = true;
        ArrayList<ItemSet> itemSetList2 = seq.getItemSetList();
        ItemSet tempItemSet1;
        ItemSet tempItemSet2;

        if (itemSetList2.size() != this.itemSetList.size()) {
            return false;
        }
        for (int i = 0; i < itemSetList2.size(); i++) {
            tempItemSet1 = this.itemSetList.get(i);
            tempItemSet2 = itemSetList2.get(i);

            if (!tempItemSet1.compareIsSame(tempItemSet2)) {
                // 只要不相等,直接退出函数
                result = false;
                break;
            }
        }

        return result;
    }

    /**
     * 生成此序列的所有子序列
     *
     * @return
     */
    public ArrayList<Sequence> createChildSeqs() {
        ArrayList<Sequence> childSeqs = new ArrayList<>();
        ArrayList<Integer> tempItems;
        Sequence tempSeq = null;
        ItemSet tempItemSet;

        for (int i = 0; i < this.itemSetList.size(); i++) {
            tempItemSet = itemSetList.get(i);
            if (tempItemSet.getItems().size() == 1) {
                tempSeq = this.copySeqence();
                
                // 如果只有项集中只有1个元素,则直接移除
                tempSeq.itemSetList.remove(i);
                childSeqs.add(tempSeq);
            } else {
                tempItems = tempItemSet.getItems();
                for (int j = 0; j < tempItems.size(); j++) {
                    tempSeq = this.copySeqence();

                    // 在拷贝的序列中移除一个数字
                    tempSeq.getItemSetList().get(i).getItems().remove(j);
                    childSeqs.add(tempSeq);
                }
            }
        }

        return childSeqs;
    }

}

ItemSet.java:

package DataMining_GSP;

import java.util.ArrayList;

/**
 * 序列中的子项集
 *
 * @author lyq
 *
 */
public class ItemSet {
    /**
     * 项集中保存的是数字项数组
     */
    private ArrayList<Integer> items;

    public ItemSet(String[] itemStr) {
        items = new ArrayList<>();
        for (String s : itemStr) {
            items.add(Integer.parseInt(s));
        }
    }

    public ItemSet(int[] itemNum) {
        items = new ArrayList<>();
        for (int num : itemNum) {
            items.add(num);
        }
    }
    
    public ItemSet(ArrayList<Integer> itemNum) {
        this.items = itemNum;
    }

    public ArrayList<Integer> getItems() {
        return items;
    }

    public void setItems(ArrayList<Integer> items) {
        this.items = items;
    }

    /**
     * 判断2个项集是否相等
     *
     * @param itemSet
     *            比较对象
     * @return
     */
    public boolean compareIsSame(ItemSet itemSet) {
        boolean result = true;

        if (this.items.size() != itemSet.items.size()) {
            return false;
        }

        for (int i = 0; i < itemSet.items.size(); i++) {
            if (this.items.get(i) != itemSet.items.get(i)) {
                // 只要有值不相等,直接算作不相等
                result = false;
                break;
            }
        }

        return result;
    }

    /**
     * 拷贝项集中同样的数据一份
     *
     * @return
     */
    public ArrayList<Integer> copyItems() {
        ArrayList<Integer> copyItems = new ArrayList<>();

        for (int num : this.items) {
            copyItems.add(num);
        }

        return copyItems;
    }
}

GSPTool.java(算法工具类):

package DataMining_GSP;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
 * GSP序列模式分析算法
 *
 * @author lyq
 *
 */
public class GSPTool {
    // 测试数据文件地址
    private String filePath;
    // 最小支持度阈值
    private int minSupportCount;
    // 时间最小间隔
    private int min_gap;
    // 时间最大间隔
    private int max_gap;
    // 原始数据序列
    private ArrayList<Sequence> totalSequences;
    // GSP算法中产生的所有的频繁项集序列
    private ArrayList<Sequence> totalFrequencySeqs;
    // 序列项数字对时间的映射图容器
    private ArrayList<ArrayList<HashMap<Integer, Integer>>> itemNum2Time;

    public GSPTool(String filePath, int minSupportCount, int min_gap,
            int max_gap) {
        this.filePath = filePath;
        this.minSupportCount = minSupportCount;
        this.min_gap = min_gap;
        this.max_gap = max_gap;
        totalFrequencySeqs = new ArrayList<>();
        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        HashMap<Integer, Sequence> mapSeq = new HashMap<>();
        Sequence seq;
        ItemSet itemSet;
        int tID;
        String[] itemStr;
        for (String[] str : dataArray) {
            tID = Integer.parseInt(str[0]);
            itemStr = new String[Integer.parseInt(str[1])];
            System.arraycopy(str, 2, itemStr, 0, itemStr.length);
            itemSet = new ItemSet(itemStr);

            if (mapSeq.containsKey(tID)) {
                seq = mapSeq.get(tID);
            } else {
                seq = new Sequence(tID);
            }
            seq.getItemSetList().add(itemSet);
            mapSeq.put(tID, seq);
        }

        // 将序列图加入到序列List中
        totalSequences = new ArrayList<>();
        for (Map.Entry entry : mapSeq.entrySet()) {
            totalSequences.add((Sequence) entry.getValue());
        }
    }

    /**
     * 生成1频繁项集
     *
     * @return
     */
    private ArrayList<Sequence> generateOneFrequencyItem() {
        int count = 0;
        int currentTransanctionID = 0;
        Sequence tempSeq;
        ItemSet tempItemSet;
        HashMap<Integer, Integer> itemNumMap = new HashMap<>();
        ArrayList<Sequence> seqList = new ArrayList<>();

        for (Sequence seq : totalSequences) {
            for (ItemSet itemSet : seq.getItemSetList()) {
                for (int num : itemSet.getItems()) {
                    // 如果没有此种类型项,则进行添加操作
                    if (!itemNumMap.containsKey(num)) {
                        itemNumMap.put(num, 1);
                    }
                }
            }
        }
        
        boolean isContain = false;
        int number = 0;
        for (Map.Entry entry : itemNumMap.entrySet()) {
            count = 0;
            number = (int) entry.getKey();
            for (Sequence seq : totalSequences) {
                isContain = false;
                
                for (ItemSet itemSet : seq.getItemSetList()) {
                    for (int num : itemSet.getItems()) {
                        if (num == number) {
                            isContain = true;
                            break;
                        }
                    }
                    
                    if(isContain){
                        break;
                    }
                }
                
                if(isContain){
                    count++;
                }
            }
            
            itemNumMap.put(number, count);
        }
        

        for (Map.Entry entry : itemNumMap.entrySet()) {
            count = (int) entry.getValue();
            if (count >= minSupportCount) {
                tempSeq = new Sequence();
                tempItemSet = new ItemSet(new int[] { (int) entry.getKey() });

                tempSeq.getItemSetList().add(tempItemSet);
                seqList.add(tempSeq);
            }

        }
        // 将序列升序排列
        Collections.sort(seqList);
        // 将频繁1项集加入总频繁项集列表中
        totalFrequencySeqs.addAll(seqList);

        return seqList;
    }

    /**
     * 通过1频繁项集连接产生2频繁项集
     *
     * @param oneSeq
     *            1频繁项集序列
     * @return
     */
    private ArrayList<Sequence> generateTwoFrequencyItem(
            ArrayList<Sequence> oneSeq) {
        Sequence tempSeq;
        ArrayList<Sequence> resultSeq = new ArrayList<>();
        ItemSet tempItemSet;
        int num1;
        int num2;

        // 假如将<a>,<b>2个1频繁项集做连接组合,可以分为<a a>,<a b>,<b a>,<b b>4个序列模式
        // 注意此时的每个序列中包含2个独立项集
        for (int i = 0; i < oneSeq.size(); i++) {
            num1 = oneSeq.get(i).getFirstItemSetNum();
            for (int j = 0; j < oneSeq.size(); j++) {
                num2 = oneSeq.get(j).getFirstItemSetNum();

                tempSeq = new Sequence();
                tempItemSet = new ItemSet(new int[] { num1 });
                tempSeq.getItemSetList().add(tempItemSet);
                tempItemSet = new ItemSet(new int[] { num2 });
                tempSeq.getItemSetList().add(tempItemSet);

                if (countSupport(tempSeq) >= minSupportCount) {
                    resultSeq.add(tempSeq);
                }
            }
        }

        // 上面连接还有1种情况是每个序列中只包含有一个项集的情况,此时a,b的划分则是<(a,a)> <(a,b)> <(b,b)>
        for (int i = 0; i < oneSeq.size(); i++) {
            num1 = oneSeq.get(i).getFirstItemSetNum();
            for (int j = i; j < oneSeq.size(); j++) {
                num2 = oneSeq.get(j).getFirstItemSetNum();

                tempSeq = new Sequence();
                tempItemSet = new ItemSet(new int[] { num1, num2 });
                tempSeq.getItemSetList().add(tempItemSet);

                if (countSupport(tempSeq) >= minSupportCount) {
                    resultSeq.add(tempSeq);
                }
            }
        }
        // 同样将2频繁项集加入到总频繁项集中
        totalFrequencySeqs.addAll(resultSeq);

        return resultSeq;
    }

    /**
     * 根据上次的频繁集连接产生新的侯选集
     *
     * @param seqList
     *            上次产生的候选集
     * @return
     */
    private ArrayList<Sequence> generateCandidateItem(
            ArrayList<Sequence> seqList) {
        Sequence tempSeq;
        ArrayList<Integer> tempNumArray;
        ArrayList<Sequence> resultSeq = new ArrayList<>();
        // 序列数字项列表
        ArrayList<ArrayList<Integer>> seqNums = new ArrayList<>();

        for (int i = 0; i < seqList.size(); i++) {
            tempNumArray = new ArrayList<>();
            tempSeq = seqList.get(i);
            for (ItemSet itemSet : tempSeq.getItemSetList()) {
                tempNumArray.addAll(itemSet.copyItems());
            }
            seqNums.add(tempNumArray);
        }

        ArrayList<Integer> array1;
        ArrayList<Integer> array2;
        // 序列i,j的拷贝
        Sequence seqi = null;
        Sequence seqj = null;
        // 判断是否能够连接,默认能连接
        boolean canConnect = true;
        // 进行连接运算,包括自己与自己连接
        for (int i = 0; i < seqNums.size(); i++) {
            for (int j = 0; j < seqNums.size(); j++) {
                array1 = (ArrayList<Integer>) seqNums.get(i).clone();
                array2 = (ArrayList<Integer>) seqNums.get(j).clone();

                // 将第一个数字组去掉第一个,第二个数字组去掉最后一个,如果剩下的部分相等,则可以连接
                array1.remove(0);
                array2.remove(array2.size() - 1);

                canConnect = true;
                for (int k = 0; k < array1.size(); k++) {
                    if (array1.get(k) != array2.get(k)) {
                        canConnect = false;
                        break;
                    }
                }

                if (canConnect) {
                    seqi = seqList.get(i).copySeqence();
                    seqj = seqList.get(j).copySeqence();

                    int lastItemNum = seqj.getLastItemSetNum();
                    if (seqj.isLastItemSetSingleNum()) {
                        // 如果j序列的最后项集为单一值,则最后一个数字以独立项集加入i序列
                        ItemSet itemSet = new ItemSet(new int[] { lastItemNum });
                        seqi.getItemSetList().add(itemSet);
                    } else {
                        // 如果j序列的最后项集为非单一值,则最后一个数字加入i序列最后一个项集中
                        ItemSet itemSet = seqi.getLastItemSet();
                        itemSet.getItems().add(lastItemNum);
                    }

                    // 判断是否超过最小支持度阈值
                    if (isChildSeqContained(seqi)
                            && countSupport(seqi) >= minSupportCount) {
                        resultSeq.add(seqi);
                    }
                }
            }
        }

        totalFrequencySeqs.addAll(resultSeq);
        return resultSeq;
    }

    /**
     * 判断此序列的所有子序列是否也是频繁序列
     *
     * @param seq
     *            待比较序列
     * @return
     */
    private boolean isChildSeqContained(Sequence seq) {
        boolean isContained = false;
        ArrayList<Sequence> childSeqs;

        childSeqs = seq.createChildSeqs();
        for (Sequence tempSeq : childSeqs) {
            isContained = false;

            for (Sequence frequencySeq : totalFrequencySeqs) {
                if (tempSeq.compareIsSame(frequencySeq)) {
                    isContained = true;
                    break;
                }
            }

            if (!isContained) {
                break;
            }
        }

        return isContained;
    }

    /**
     * 候选集判断支持度的值
     *
     * @param seq
     *            待判断序列
     * @return
     */
    private int countSupport(Sequence seq) {
        int count = 0;
        int matchNum = 0;
        Sequence tempSeq;
        ItemSet tempItemSet;
        HashMap<Integer, Integer> timeMap;
        ArrayList<ItemSet> itemSetList;
        ArrayList<ArrayList<Integer>> numArray = new ArrayList<>();
        // 每项集对应的时间链表
        ArrayList<ArrayList<Integer>> timeArray = new ArrayList<>();

        for (ItemSet itemSet : seq.getItemSetList()) {
            numArray.add(itemSet.getItems());
        }

        for (int i = 0; i < totalSequences.size(); i++) {
            timeArray = new ArrayList<>();

            for (int s = 0; s < numArray.size(); s++) {
                ArrayList<Integer> childNum = numArray.get(s);
                ArrayList<Integer> localTime = new ArrayList<>();
                tempSeq = totalSequences.get(i);
                itemSetList = tempSeq.getItemSetList();

                for (int j = 0; j < itemSetList.size(); j++) {
                    tempItemSet = itemSetList.get(j);
                    matchNum = 0;
                    int t = 0;

                    if (tempItemSet.getItems().size() == childNum.size()) {
                        timeMap = itemNum2Time.get(i).get(j);
                        // 只有当项集长度匹配时才匹配
                        for (int k = 0; k < childNum.size(); k++) {
                            if (timeMap.containsKey(childNum.get(k))) {
                                matchNum++;
                                t = timeMap.get(childNum.get(k));
                            }
                        }

                        // 如果完全匹配,则记录时间
                        if (matchNum == childNum.size()) {
                            localTime.add(t);
                        }
                    }

                }

                if (localTime.size() > 0) {
                    timeArray.add(localTime);
                }
            }

            // 判断时间是否满足时间最大最小约束,如果满足,则此条事务包含候选事务
            if (timeArray.size() == numArray.size()
                    && judgeTimeInGap(timeArray)) {
                count++;
            }
        }

        return count;
    }

    /**
     * 判断事务是否满足时间约束
     *
     * @param timeArray
     *            时间数组,每行代表各项集的在事务中的发生时间链表
     * @return
     */
    private boolean judgeTimeInGap(ArrayList<ArrayList<Integer>> timeArray) {
        boolean result = false;
        int preTime = 0;
        ArrayList<Integer> firstTimes = timeArray.get(0);
        timeArray.remove(0);

        if (timeArray.size() == 0) {
            return false;
        }

        for (int i = 0; i < firstTimes.size(); i++) {
            preTime = firstTimes.get(i);

            if (dfsJudgeTime(preTime, timeArray)) {
                result = true;
                break;
            }
        }

        return result;
    }

    /**
     * 深度优先遍历时间,判断是否有符合条件的时间间隔
     *
     * @param preTime
     * @param timeArray
     * @return
     */
    private boolean dfsJudgeTime(int preTime,
            ArrayList<ArrayList<Integer>> timeArray) {
        boolean result = false;
        ArrayList<ArrayList<Integer>> timeArrayClone = (ArrayList<ArrayList<Integer>>) timeArray
                .clone();
        ArrayList<Integer> firstItemItem = timeArrayClone.get(0);

        for (int i = 0; i < firstItemItem.size(); i++) {
            if (firstItemItem.get(i) - preTime >= min_gap
                    && firstItemItem.get(i) - preTime <= max_gap) {
                // 如果此2项间隔时间满足时间约束,则继续往下递归
                preTime = firstItemItem.get(i);
                timeArrayClone.remove(0);

                if (timeArrayClone.size() == 0) {
                    return true;
                } else {
                    result = dfsJudgeTime(preTime, timeArrayClone);
                    if (result) {
                        return true;
                    }
                }
            }
        }

        return result;
    }

    /**
     * 初始化序列项到时间的序列图,为了后面的时间约束计算
     */
    private void initItemNumToTimeMap() {
        Sequence seq;
        itemNum2Time = new ArrayList<>();
        HashMap<Integer, Integer> tempMap;
        ArrayList<HashMap<Integer, Integer>> tempMapList;

        for (int i = 0; i < totalSequences.size(); i++) {
            seq = totalSequences.get(i);
            tempMapList = new ArrayList<>();

            for (int j = 0; j < seq.getItemSetList().size(); j++) {
                ItemSet itemSet = seq.getItemSetList().get(j);
                tempMap = new HashMap<>();
                for (int itemNum : itemSet.getItems()) {
                    tempMap.put(itemNum, j + 1);
                }

                tempMapList.add(tempMap);
            }

            itemNum2Time.add(tempMapList);
        }
    }

    /**
     * 进行GSP算法计算
     */
    public void gspCalculate() {
        ArrayList<Sequence> oneSeq;
        ArrayList<Sequence> twoSeq;
        ArrayList<Sequence> candidateSeq;

        initItemNumToTimeMap();
        oneSeq = generateOneFrequencyItem();
        twoSeq = generateTwoFrequencyItem(oneSeq);
        candidateSeq = twoSeq;

        // 不断连接生产候选集,直到没有产生出侯选集
        for (;;) {
            candidateSeq = generateCandidateItem(candidateSeq);

            if (candidateSeq.size() == 0) {
                break;
            }
        }

        outputSeqence(totalFrequencySeqs);

    }

    /**
     * 输出序列列表信息
     *
     * @param outputSeqList
     *            待输出序列列表
     */
    private void outputSeqence(ArrayList<Sequence> outputSeqList) {
        for (Sequence seq : outputSeqList) {
            System.out.print("<");
            for (ItemSet itemSet : seq.getItemSetList()) {
                System.out.print("(");
                for (int num : itemSet.getItems()) {
                    System.out.print(num + ",");
                }
                System.out.print("), ");
            }
            System.out.println(">");
        }
    }

}

调用类Client.java:

package DataMining_GSP;

/**
 * GSP序列模式分析算法
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args){
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\testInput.txt";
        //最小支持度阈值
        int minSupportCount = 2;
        //时间最小间隔
        int min_gap = 1;
        //施加最大间隔
        int max_gap = 5;
        
        GSPTool tool = new GSPTool(filePath, minSupportCount, min_gap, max_gap);
        tool.gspCalculate();
    }
}

算法的输出(挖掘出的所有频繁模式):

<(1,), >
<(2,), >
<(3,), >
<(4,), >
<(5,), >
<(1,), (3,), >
<(1,), (4,), >
<(1,), (5,), >
<(2,), (3,), >
<(2,), (4,), >
<(3,), (4,), >
<(3,), (5,), >
<(4,), (5,), >
<(1,), (3,), (4,), >
<(1,), (3,), (5,), >
<(2,), (3,), (4,), >

算法实现的难点

1、算法花费了几天的时间,难点首先在于对算法原理本身的理解,网上对于此算法的资料特别少,而且不同的人所表达的意思 都有少许的不同,讲的也不是很详细,于是就通过阅读别人的代码理解GSP算法的原理,我的代码实现也是参考了参考资料的C语言的实现。

2、在实现时间约束的支持度计数统计的时候,调试了一段时间,做时间统计容易出错,因为层级实在太多容易搞晕。

3、还有1个是Sequence和ItemSet的拷贝时的引用问题,在产生新的序列时一定要深拷贝1个否则导致同一引用会把原数据给改掉的。
GSP算法和Apriori算法的比较

我是都实现过了GSP算法和Apriori算法的,后者是被称为关联规则挖掘算法,偏向于挖掘关联规则的,2个算法在连接的操作上有不一样的地方,还有在数据的构成方式上,Apriori的数据会简单一点,都是单项单项构成的,而且在做支持度统计的时候只需判断存在与否即可。不需要考虑时间约束。Apriori算法给定K项集,连接到K-1项集算法就停止了,而GSP算法是直到不能够产生候选集为止。
作者:Androidlushangderen 发表于2015/2/10 9:12:01 原文链接
阅读:1044 评论:0 查看评论
AdaBoost装袋提升算法
2015年2月8日 9:22

参开资料:http://blog.csdn.net/haidao2009/article/details/7514787
更多挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm
介绍

在介绍AdaBoost算法之前,需要了解一个类似的算法,装袋算法(bagging),bagging是一种提高分类准确率的算法,通过给定组合投票的方式,获得最优解。比如你生病了,去n个医院看了n个医生,每个医生给你开了药方,最后的结果中,哪个药方的出现的次数多,那就说明这个药方就越有可能性是最由解,这个很好理解。而bagging算法就是这个思想。
算法原理

而AdaBoost算法的核心思想还是基于bagging算法,但是他又一点点的改进,上面的每个医生的投票结果都是一样的,说明地位平等,如果在这里加上一个权重,大城市的医生权重高点,小县城的医生权重低,这样通过最终计算权重和的方式,会更加的合理,这就是AdaBoost算法。AdaBoost算法是一种迭代算法,只有最终分类误差率小于阈值算法才能停止,针对同一训练集数据训练不同的分类器,我们称弱分类器,最后按照权重和的形式组合起来,构成一个组合分类器,就是一个强分类器了。算法的只要过程:

1、对D训练集数据训练处一个分类器Ci

2、通过分类器Ci对数据进行分类,计算此时误差率

3、把上步骤中的分错的数据的权重提高,分对的权重降低,以此凸显了分错的数据。为什么这么做呢,后面会做出解释。

完整的adaboost算法如下


最后的sign函数是符号函数,如果最后的值为正,则分为+1类,否则即使-1类。

我们举个例子代入上面的过程,这样能够更好的理解。

adaboost的实现过程:

  图中,“+”和“-”分别表示两种类别,在这个过程中,我们使用水平或者垂直的直线作为分类器,来进行分类。

  第一步:

  根据分类的正确率,得到一个新的样本分布D2­,一个子分类器h1

  其中划圈的样本表示被分错的。在右边的途中,比较大的“+”表示对该样本做了加权。

算法最开始给了一个均匀分布 D 。所以h1 里的每个点的值是0.1。ok,当划分后,有三个点划分错了,根据算法误差表达式得到 误差为分错了的三个点的值之和,所以ɛ1=(0.1+0.1+0.1)=0.3,而ɑ1 根据表达式 的可以算出来为0.42. 然后就根据算法 把分错的点权值变大。如此迭代,最终完成adaboost算法。

  第二步:

  根据分类的正确率,得到一个新的样本分布D3,一个子分类器h2

  第三步:

  得到一个子分类器h3

  整合所有子分类器:

  因此可以得到整合的结果,从结果中看,及时简单的分类器,组合起来也能获得很好的分类效果,在例子中所有的。后面的代码实现时,举出的也是这个例子,可以做对比,这里有一点比较重要,就是点的权重经过大小变化之后,需要进行归一化,确保总和为1.0,这个容易遗忘。
算法的代码实现

输入测试数据,与上图的例子相对应(数据格式:x坐标 y坐标 已分类结果):

1 5 1
2 3 1
3 1 -1
4 5 -1
5 6 1
6 4 -1
6 7 1
7 6 1
8 7 -1
8 2 -1

Point.java

package DataMining_AdaBoost;

/**
 * 坐标点类
 *
 * @author lyq
 *
 */
public class Point {
    // 坐标点x坐标
    private int x;
    // 坐标点y坐标
    private int y;
    // 坐标点的分类类别
    private int classType;
    //如果此节点被划错,他的误差率,不能用个数除以总数,因为不同坐标点的权重不一定相等
    private double probably;
    
    public Point(int x, int y, int classType){
        this.x = x;
        this.y = y;
        this.classType = classType;
    }
    
    public Point(String x, String y, String classType){
        this.x = Integer.parseInt(x);
        this.y = Integer.parseInt(y);
        this.classType = Integer.parseInt(classType);
    }

    public int getX() {
        return x;
    }

    public void setX(int x) {
        this.x = x;
    }

    public int getY() {
        return y;
    }

    public void setY(int y) {
        this.y = y;
    }

    public int getClassType() {
        return classType;
    }

    public void setClassType(int classType) {
        this.classType = classType;
    }

    public double getProbably() {
        return probably;
    }

    public void setProbably(double probably) {
        this.probably = probably;
    }
}

AdaBoost.java

package DataMining_AdaBoost;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * AdaBoost提升算法工具类
 *
 * @author lyq
 *
 */
public class AdaBoostTool {
    // 分类的类别,程序默认为正类1和负类-1
    public static final int CLASS_POSITIVE = 1;
    public static final int CLASS_NEGTIVE = -1;

    // 事先假设的3个分类器(理论上应该重新对数据集进行训练得到)
    public static final String CLASSIFICATION1 = "X=2.5";
    public static final String CLASSIFICATION2 = "X=7.5";
    public static final String CLASSIFICATION3 = "Y=5.5";

    // 分类器组
    public static final String[] ClASSIFICATION = new String[] {
            CLASSIFICATION1, CLASSIFICATION2, CLASSIFICATION3 };
    // 分类权重组
    private double[] CLASSIFICATION_WEIGHT;

    // 测试数据文件地址
    private String filePath;
    // 误差率阈值
    private double errorValue;
    // 所有的数据点
    private ArrayList<Point> totalPoint;

    public AdaBoostTool(String filePath, double errorValue) {
        this.filePath = filePath;
        this.errorValue = errorValue;
        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        Point temp;
        totalPoint = new ArrayList<>();
        for (String[] array : dataArray) {
            temp = new Point(array[0], array[1], array[2]);
            temp.setProbably(1.0 / dataArray.size());
            totalPoint.add(temp);
        }
    }

    /**
     * 根据当前的误差值算出所得的权重
     *
     * @param errorValue
     *            当前划分的坐标点误差率
     * @return
     */
    private double calculateWeight(double errorValue) {
        double alpha = 0;
        double temp = 0;

        temp = (1 - errorValue) / errorValue;
        alpha = 0.5 * Math.log(temp);

        return alpha;
    }

    /**
     * 计算当前划分的误差率
     *
     * @param pointMap
     *            划分之后的点集
     * @param weight
     *            本次划分得到的分类器权重
     * @return
     */
    private double calculateErrorValue(
            HashMap<Integer, ArrayList<Point>> pointMap) {
        double resultValue = 0;
        double temp = 0;
        double weight = 0;
        int tempClassType;
        ArrayList<Point> pList;
        for (Map.Entry entry : pointMap.entrySet()) {
            tempClassType = (int) entry.getKey();

            pList = (ArrayList<Point>) entry.getValue();
            for (Point p : pList) {
                temp = p.getProbably();
                // 如果划分类型不相等,代表划错了
                if (tempClassType != p.getClassType()) {
                    resultValue += temp;
                }
            }
        }

        weight = calculateWeight(resultValue);
        for (Map.Entry entry : pointMap.entrySet()) {
            tempClassType = (int) entry.getKey();

            pList = (ArrayList<Point>) entry.getValue();
            for (Point p : pList) {
                temp = p.getProbably();
                // 如果划分类型不相等,代表划错了
                if (tempClassType != p.getClassType()) {
                    // 划错的点的权重比例变大
                    temp *= Math.exp(weight);
                    p.setProbably(temp);
                } else {
                    // 划对的点的权重比减小
                    temp *= Math.exp(-weight);
                    p.setProbably(temp);
                }
            }
        }

        // 如果误差率没有小于阈值,继续处理
        dataNormalized();

        return resultValue;
    }

    /**
     * 概率做归一化处理
     */
    private void dataNormalized() {
        double sumProbably = 0;
        double temp = 0;

        for (Point p : totalPoint) {
            sumProbably += p.getProbably();
        }

        // 归一化处理
        for (Point p : totalPoint) {
            temp = p.getProbably();
            p.setProbably(temp / sumProbably);
        }
    }

    /**
     * 用AdaBoost算法得到的组合分类器对数据进行分类
     *
     */
    public void adaBoostClassify() {
        double value = 0;
        Point p;

        calculateWeightArray();
        for (int i = 0; i < ClASSIFICATION.length; i++) {
            System.out.println(MessageFormat.format("分类器{0}权重为:{1}", (i+1), CLASSIFICATION_WEIGHT[i]));
        }
        
        for (int j = 0; j < totalPoint.size(); j++) {
            p = totalPoint.get(j);
            value = 0;

            for (int i = 0; i < ClASSIFICATION.length; i++) {
                value += 1.0 * classifyData(ClASSIFICATION[i], p)
                        * CLASSIFICATION_WEIGHT[i];
            }
            
            //进行符号判断
            if (value > 0) {
                System.out
                        .println(MessageFormat.format(
                                "点({0}, {1})的组合分类结果为:1,该点的实际分类为{2}", p.getX(), p.getY(),
                                p.getClassType()));
            } else {
                System.out.println(MessageFormat.format(
                        "点({0}, {1})的组合分类结果为:-1,该点的实际分类为{2}", p.getX(), p.getY(),
                        p.getClassType()));
            }
        }
    }

    /**
     * 计算分类器权重数组
     */
    private void calculateWeightArray() {
        int tempClassType = 0;
        double errorValue = 0;
        ArrayList<Point> posPointList;
        ArrayList<Point> negPointList;
        HashMap<Integer, ArrayList<Point>> mapList;
        CLASSIFICATION_WEIGHT = new double[ClASSIFICATION.length];

        for (int i = 0; i < CLASSIFICATION_WEIGHT.length; i++) {
            mapList = new HashMap<>();
            posPointList = new ArrayList<>();
            negPointList = new ArrayList<>();

            for (Point p : totalPoint) {
                tempClassType = classifyData(ClASSIFICATION[i], p);

                if (tempClassType == CLASS_POSITIVE) {
                    posPointList.add(p);
                } else {
                    negPointList.add(p);
                }
            }

            mapList.put(CLASS_POSITIVE, posPointList);
            mapList.put(CLASS_NEGTIVE, negPointList);

            if (i == 0) {
                // 最开始的各个点的权重一样,所以传入0,使得e的0次方等于1
                errorValue = calculateErrorValue(mapList);
            } else {
                // 每次把上次计算所得的权重代入,进行概率的扩大或缩小
                errorValue = calculateErrorValue(mapList);
            }

            // 计算当前分类器的所得权重
            CLASSIFICATION_WEIGHT[i] = calculateWeight(errorValue);
        }
    }

    /**
     * 用各个子分类器进行分类
     *
     * @param classification
     *            分类器名称
     * @param p
     *            待划分坐标点
     * @return
     */
    private int classifyData(String classification, Point p) {
        // 分割线所属坐标轴
        String position;
        // 分割线的值
        double value = 0;
        double posProbably = 0;
        double negProbably = 0;
        // 划分是否是大于一边的划分
        boolean isLarger = false;
        String[] array;
        ArrayList<Point> pList = new ArrayList<>();

        array = classification.split("=");
        position = array[0];
        value = Double.parseDouble(array[1]);

        if (position.equals("X")) {
            if (p.getX() > value) {
                isLarger = true;
            }

            // 将训练数据中所有属于这边的点加入
            for (Point point : totalPoint) {
                if (isLarger && point.getX() > value) {
                    pList.add(point);
                } else if (!isLarger && point.getX() < value) {
                    pList.add(point);
                }
            }
        } else if (position.equals("Y")) {
            if (p.getY() > value) {
                isLarger = true;
            }

            // 将训练数据中所有属于这边的点加入
            for (Point point : totalPoint) {
                if (isLarger && point.getY() > value) {
                    pList.add(point);
                } else if (!isLarger && point.getY() < value) {
                    pList.add(point);
                }
            }
        }

        for (Point p2 : pList) {
            if (p2.getClassType() == CLASS_POSITIVE) {
                posProbably++;
            } else {
                negProbably++;
            }
        }
        
        //分类按正负类数量进行划分
        if (posProbably > negProbably) {
            return CLASS_POSITIVE;
        } else {
            return CLASS_NEGTIVE;
        }
    }

}

调用类Client.java:

/**
 * AdaBoost提升算法调用类
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] agrs){
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
        //误差率阈值
        double errorValue = 0.2;
        
        AdaBoostTool tool = new AdaBoostTool(filePath, errorValue);
        tool.adaBoostClassify();
    }
}

输出结果:

分类器1权重为:0.424
分类器2权重为:0.65
分类器3权重为:0.923
点(1, 5)的组合分类结果为:1,该点的实际分类为1
点(2, 3)的组合分类结果为:1,该点的实际分类为1
点(3, 1)的组合分类结果为:-1,该点的实际分类为-1
点(4, 5)的组合分类结果为:-1,该点的实际分类为-1
点(5, 6)的组合分类结果为:1,该点的实际分类为1
点(6, 4)的组合分类结果为:-1,该点的实际分类为-1
点(6, 7)的组合分类结果为:1,该点的实际分类为1
点(7, 6)的组合分类结果为:1,该点的实际分类为1
点(8, 7)的组合分类结果为:-1,该点的实际分类为-1
点(8, 2)的组合分类结果为:-1,该点的实际分类为-1

我们可以看到,如果3个分类单独分类,都没有百分百分对,而尽管组合结果之后,全部分类正确。
我对AdaBoost算法的理解

到了算法的末尾,有必要解释一下每次分类自后需要把错的点的权重增大,正确的减少的理由了,加入上次分类之后,(1,5)已经分错了,如果这次又分错,由于上次的权重已经提升,所以误差率更大,则代入公式ln(1-误差率/误差率)所得的权重越小,也就是说,如果同个数据,你分类的次数越多,你的权重越小,所以这就造成整体好的分类器的权重会越大,内部就会同时有各种权重的分类器,形成了一种互补的结果,如果好的分类器结果分错 ,可以由若干弱一点的分类器进行弥补。
AdaBoost算法的应用

可以运用在诸如特征识别,二分类的一些应用上,与单个模型相比,组合的形式能显著的提高准确率。

作者:Androidlushangderen 发表于2015/2/8 9:22:10 原文链接
阅读:1285 评论:0 查看评论
BIRCH算法---使用聚类特征树的多阶段算法
2015年2月5日 18:58

更多数据挖掘代码:https://github.com/linyiqun/DataMiningAlgorithm
介绍

BIRCH算法本身上属于一种聚类算法,不过他克服了一些K-Means算法的缺点,比如说这个k的确定,因为这个算法事先本身就没有设定有多少个聚类。他是通过CF-Tree,(ClusterFeature-Tree)聚类特征树实现的。BIRCH的一个重要考虑是最小化I/O,通过扫描数据库,建立一棵存放于内存的初始CF-树,可以看做多数据的多层压缩。
算法原理
CF聚类特征

说到算法原理,首先就要先知道,什么是聚类特征,何为聚类特征,定义如下:

CF = <n, LS, SS>

聚类特征为一个3维向量,n为数据点总数,LS为n个点的线性和,SS为n个点的平方和。因此又可以得到

x0 = LS/n为簇的中心,以此计算簇与簇之间的距离。

簇内对象的平均距离簇直径,这个可以用阈值T限制,保证簇的一个整体的紧凑程度。簇和簇之间可以进行叠加,其实就是向量的叠加。
CF-Tree的构造过程

在介绍CF-Tree树,要先介绍3个变量,内部节点平衡因子B,叶节点平衡因子L,簇直径阈值T。B是用来限制非叶子节点的子节点数,L是用来限制叶子节点的子簇个数,T是用来限制簇的紧密程度的,比较的是D--簇内平均对象的距离。下面是主要的构造过程:

1、首先读入第一条数据,构造一个叶子节点和一个子簇,子簇包含在叶子节点中。

2、当读入后面的第2条,第3条,封装为一个簇,加入到一个叶子节点时,如果此时的待加入的簇C的簇直径已经大于T,则需要新建簇作为C的兄弟节点,如果作为兄弟节点,如果此时的叶子节点的孩子节点超过阈值L,则需对叶子节点进行分裂。分裂的规则是选出簇间距离最大的2个孩子,分别作为2个叶子,然后其他的孩子按照就近分配。非叶子节点的分裂规则同上。具体可以对照后面我写的代码。

3、最终的构造模样大致如此:


算法的优点:

1、算法只需扫描一遍就可以得到一个好的聚类效果,而且不需事先设定聚类个数。

2、聚类通过聚类特征树的形式,一定程度上保存了对数据的压缩。
算法的缺点:

1、该算法比较适合球形的簇,如果簇不是球形的,则聚簇的效果将不会很好。

算法的代码实现:

下面提供部分核心代码(如果想获取所有的代码,请点击我的数据挖掘代码):

数据的输入:

5.1     3.5     1.4     0.2
4.9     3.0     1.4     0.2
4.7     3.2     1.3     0.8
4.6     3.1     1.5     0.8
5.0     3.6     1.8     0.6
4.7     3.2     1.4     0.8

ClusteringFeature.java:

package DataMining_BIRCH;

import java.util.ArrayList;

/**
 * 聚类特征基本属性
 *
 * @author lyq
 *
 */
public abstract class ClusteringFeature {
    // 子类中节点的总数目
    protected int N;
    // 子类中N个节点的线性和
    protected double[] LS;
    // 子类中N个节点的平方和
    protected double[] SS;
    //节点深度,用于CF树的输出
    protected int level;

    public int getN() {
        return N;
    }

    public void setN(int n) {
        N = n;
    }

    public double[] getLS() {
        return LS;
    }

    public void setLS(double[] lS) {
        LS = lS;
    }

    public double[] getSS() {
        return SS;
    }

    public void setSS(double[] sS) {
        SS = sS;
    }

    protected void setN(ArrayList<double[]> dataRecords) {
        this.N = dataRecords.size();
    }
    
    public int getLevel() {
        return level;
    }

    public void setLevel(int level) {
        this.level = level;
    }

    /**
     * 根据节点数据计算线性和
     *
     * @param dataRecords
     *            节点数据记录
     */
    protected void setLS(ArrayList<double[]> dataRecords) {
        int num = dataRecords.get(0).length;
        double[] record;
        LS = new double[num];
        for (int j = 0; j < num; j++) {
            LS[j] = 0;
        }

        for (int i = 0; i < dataRecords.size(); i++) {
            record = dataRecords.get(i);
            for (int j = 0; j < record.length; j++) {
                LS[j] += record[j];
            }
        }
    }

    /**
     * 根据节点数据计算平方
     *
     * @param dataRecords
     *            节点数据
     */
    protected void setSS(ArrayList<double[]> dataRecords) {
        int num = dataRecords.get(0).length;
        double[] record;
        SS = new double[num];
        for (int j = 0; j < num; j++) {
            SS[j] = 0;
        }

        for (int i = 0; i < dataRecords.size(); i++) {
            record = dataRecords.get(i);
            for (int j = 0; j < record.length; j++) {
                SS[j] += record[j] * record[j];
            }
        }
    }

    /**
     * CF向量特征的叠加,无须考虑划分
     *
     * @param node
     */
    protected void directAddCluster(ClusteringFeature node) {
        int N = node.getN();
        double[] otherLS = node.getLS();
        double[] otherSS = node.getSS();
        
        if(LS == null){
            this.N = 0;
            LS = new double[otherLS.length];
            SS = new double[otherLS.length];
            
            for(int i=0; i<LS.length; i++){
                LS[i] = 0;
                SS[i] = 0;
            }
        }

        // 3个数量上进行叠加
        for (int i = 0; i < LS.length; i++) {
            LS[i] += otherLS[i];
            SS[i] += otherSS[i];
        }
        this.N += N;
    }

    /**
     * 计算簇与簇之间的距离即簇中心之间的距离
     *
     * @return
     */
    protected double computerClusterDistance(ClusteringFeature cluster) {
        double distance = 0;
        double[] otherLS = cluster.LS;
        int num = N;
        
        int otherNum = cluster.N;

        for (int i = 0; i < LS.length; i++) {
            distance += (LS[i] / num - otherLS[i] / otherNum)
                    * (LS[i] / num - otherLS[i] / otherNum);
        }
        distance = Math.sqrt(distance);

        return distance;
    }

    /**
     * 计算簇内对象的平均距离
     *
     * @param records
     *            簇内的数据记录
     * @return
     */
    protected double computerInClusterDistance(ArrayList<double[]> records) {
        double sumDistance = 0;
        double[] data1;
        double[] data2;
        // 数据总数
        int totalNum = records.size();

        for (int i = 0; i < totalNum - 1; i++) {
            data1 = records.get(i);
            for (int j = i + 1; j < totalNum; j++) {
                data2 = records.get(j);
                sumDistance += computeOuDistance(data1, data2);
            }
        }

        // 返回的值除以总对数,总对数应减半,会重复算一次
        return Math.sqrt(sumDistance / (totalNum * (totalNum - 1) / 2));
    }

    /**
     * 对给定的2个向量,计算欧式距离
     *
     * @param record1
     *            向量点1
     * @param record2
     *            向量点2
     */
    private double computeOuDistance(double[] record1, double[] record2) {
        double distance = 0;

        for (int i = 0; i < record1.length; i++) {
            distance += (record1[i] - record2[i]) * (record1[i] - record2[i]);
        }

        return distance;
    }

    /**
     * 聚类添加节点包括,超出阈值进行分裂的操作
     *
     * @param clusteringFeature
     *            待添加聚簇
     */
    public abstract void addingCluster(ClusteringFeature clusteringFeature);
}

BIRCHTool.java:

package DataMining_BIRCH;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.LinkedList;

/**
 * BIRCH聚类算法工具类
 *
 * @author lyq
 *
 */
public class BIRCHTool {
    // 节点类型名称
    public static final String NON_LEAFNODE = "【NonLeafNode】";
    public static final String LEAFNODE = "【LeafNode】";
    public static final String CLUSTER = "【Cluster】";

    // 测试数据文件地址
    private String filePath;
    // 内部节点平衡因子B
    public static int B;
    // 叶子节点平衡因子L
    public static int L;
    // 簇直径阈值T
    public static double T;
    // 总的测试数据记录
    private ArrayList<String[]> totalDataRecords;

    public BIRCHTool(String filePath, int B, int L, double T) {
        this.filePath = filePath;
        this.B = B;
        this.L = L;
        this.T = T;
        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split("     ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        totalDataRecords = new ArrayList<>();
        for (String[] array : dataArray) {
            totalDataRecords.add(array);
        }
    }

    /**
     * 构建CF聚类特征树
     *
     * @return
     */
    private ClusteringFeature buildCFTree() {
        NonLeafNode rootNode = null;
        LeafNode leafNode = null;
        Cluster cluster = null;

        for (String[] record : totalDataRecords) {
            cluster = new Cluster(record);

            if (rootNode == null) {
                // CF树只有1个节点的时候的情况
                if (leafNode == null) {
                    leafNode = new LeafNode();
                }
                leafNode.addingCluster(cluster);
                if (leafNode.getParentNode() != null) {
                    rootNode = leafNode.getParentNode();
                }
            } else {
                if (rootNode.getParentNode() != null) {
                    rootNode = rootNode.getParentNode();
                }

                // 从根节点开始,从上往下寻找到最近的添加目标叶子节点
                LeafNode temp = rootNode.findedClosestNode(cluster);
                temp.addingCluster(cluster);
            }
        }

        // 从下往上找出最上面的节点
        LeafNode node = cluster.getParentNode();
        NonLeafNode upNode = node.getParentNode();
        if (upNode == null) {
            return node;
        } else {
            while (upNode.getParentNode() != null) {
                upNode = upNode.getParentNode();
            }

            return upNode;
        }
    }

    /**
     * 开始构建CF聚类特征树
     */
    public void startBuilding() {
        // 树深度
        int level = 1;
        ClusteringFeature rootNode = buildCFTree();

        setTreeLevel(rootNode, level);
        showCFTree(rootNode);
    }

    /**
     * 设置节点深度
     *
     * @param clusteringFeature
     *            当前节点
     * @param level
     *            当前深度值
     */
    private void setTreeLevel(ClusteringFeature clusteringFeature, int level) {
        LeafNode leafNode = null;
        NonLeafNode nonLeafNode = null;

        if (clusteringFeature instanceof LeafNode) {
            leafNode = (LeafNode) clusteringFeature;
        } else if (clusteringFeature instanceof NonLeafNode) {
            nonLeafNode = (NonLeafNode) clusteringFeature;
        }

        if (nonLeafNode != null) {
            nonLeafNode.setLevel(level);
            level++;
            // 设置子节点
            if (nonLeafNode.getNonLeafChilds() != null) {
                for (NonLeafNode n1 : nonLeafNode.getNonLeafChilds()) {
                    setTreeLevel(n1, level);
                }
            } else {
                for (LeafNode n2 : nonLeafNode.getLeafChilds()) {
                    setTreeLevel(n2, level);
                }
            }
        } else {
            leafNode.setLevel(level);
            level++;
            // 设置子聚簇
            for (Cluster c : leafNode.getClusterChilds()) {
                c.setLevel(level);
            }
        }
    }

    /**
     * 显示CF聚类特征树
     *
     * @param rootNode
     *            CF树根节点
     */
    private void showCFTree(ClusteringFeature rootNode) {
        // 空格数,用于输出
        int blankNum = 5;
        // 当前树深度
        int currentLevel = 1;
        LinkedList<ClusteringFeature> nodeQueue = new LinkedList<>();
        ClusteringFeature cf;
        LeafNode leafNode;
        NonLeafNode nonLeafNode;
        ArrayList<Cluster> clusterList = new ArrayList<>();
        String typeName;

        nodeQueue.add(rootNode);
        while (nodeQueue.size() > 0) {
            cf = nodeQueue.poll();

            if (cf instanceof LeafNode) {
                leafNode = (LeafNode) cf;
                typeName = LEAFNODE;

                if (leafNode.getClusterChilds() != null) {
                    for (Cluster c : leafNode.getClusterChilds()) {
                        nodeQueue.add(c);
                    }
                }
            } else if (cf instanceof NonLeafNode) {
                nonLeafNode = (NonLeafNode) cf;
                typeName = NON_LEAFNODE;

                if (nonLeafNode.getNonLeafChilds() != null) {
                    for (NonLeafNode n1 : nonLeafNode.getNonLeafChilds()) {
                        nodeQueue.add(n1);
                    }
                } else {
                    for (LeafNode n2 : nonLeafNode.getLeafChilds()) {
                        nodeQueue.add(n2);
                    }
                }
            } else {
                clusterList.add((Cluster)cf);
                typeName = CLUSTER;
            }

            if (currentLevel != cf.getLevel()) {
                currentLevel = cf.getLevel();
                System.out.println();
                System.out.println("|");
                System.out.println("|");
            }else if(currentLevel == cf.getLevel() && currentLevel != 1){
                for (int i = 0; i < blankNum; i++) {
                    System.out.print("-");
                }
            }
            
            System.out.print(typeName);
            System.out.print("N:" + cf.getN() + ", LS:");
            System.out.print("[");
            for (double d : cf.getLS()) {
                System.out.print(MessageFormat.format("{0}, ",  d));
            }
            System.out.print("]");
        }
        
        System.out.println();
        System.out.println("*******最终分好的聚簇****");
        //显示已经分好类的聚簇点
        for(int i=0; i<clusterList.size(); i++){
            System.out.println("Cluster" + (i+1) + ":");
            for(double[] point: clusterList.get(i).getData()){
                System.out.print("[");
                for (double d : point) {
                    System.out.print(MessageFormat.format("{0}, ",  d));
                }
                System.out.println("]");
            }
        }
    }

}

由于代码量比较大,剩下的LeafNode.java,NonLeafNode.java, 和Cluster聚簇类可以在我的数据挖掘代码中查看。

结果输出:

【NonLeafNode】N:6, LS:[29, 19.6, 8.8, 3.4, ]
|
|
【LeafNode】N:3, LS:[14, 9.5, 4.2, 2.4, ]-----【LeafNode】N:3, LS:[15, 10.1, 4.6, 1, ]
|
|
【Cluster】N:3, LS:[14, 9.5, 4.2, 2.4, ]-----【Cluster】N:1, LS:[5, 3.6, 1.8, 0.6, ]-----【Cluster】N:2, LS:[10, 6.5, 2.8, 0.4, ]
*******最终分好的聚簇****
Cluster1:
[4.7, 3.2, 1.3, 0.8, ]
[4.6, 3.1, 1.5, 0.8, ]
[4.7, 3.2, 1.4, 0.8, ]
Cluster2:
[5, 3.6, 1.8, 0.6, ]
Cluster3:
[5.1, 3.5, 1.4, 0.2, ]
[4.9, 3, 1.4, 0.2, ]

算法实现时的难点

1、算簇间距离的时候,代了一下公式,发现不对劲,向量的运算不应该是这样的,于是就把他归与簇心之间的距离计算。还有簇内对象的平均距离也没有代入公式,网上的各种版本的向量计算,不知道哪种是对的,又按最原始的方式计算,一对对计算距离,求平均值。

2、算法在节点分裂的时候,如果父节点不为空,需要把自己从父亲中的孩子列表中移除,然后再添加分裂后的2个节点,这里的把自己移除掉容易忘记。

3、节点CF聚类特征值的更新,需要在每次节点的变化时,其所涉及的父类,父类的父类都需要更新,为此用了责任链模式,一个一个往上传,分裂的规则时也用了此模式,需要关注一下。

4、代码将CF聚类特征量进行抽象提取,定义了共有的方法,不过在实现时还是由于节点类型的不同,在实际的过程中需要转化。

5、最后的难点在与测试的复杂,因为程序经过千辛万苦的编写终于完成,但是如何测试时一个大问题,因为要把分裂的情况都测准,需要准确的把握T.,B.L,的设计,尤其是T簇直径,所以在捏造测试的时候自己也是经过很多的手动计算。
我对BIRCH算法的理解

在实现的整个完成的过程中 ,我对BIRCH算法的最大的感触就是通过聚类特征,一个新节点从根节点开始,从上往先寻找,离哪个簇近,就被分到哪个簇中,自发的形成了一个比较好的聚簇,这个过程是算法的神奇所在。
作者:Androidlushangderen 发表于2015/2/5 18:58:27 原文链接
阅读:1054 评论:0 查看评论
K-Means聚类算法
2015年2月1日 18:26

更多数据挖掘算法:https://github.com/linyiqun/DataMiningAlgorithm
算法介绍

K-Means又名为K均值算法,他是一个聚类算法,这里的K就是聚簇中心的个数,代表数据中存在多少数据簇。K-Means在聚类算法中算是非常简单的一个算法了。有点类似于KNN算法,都用到了距离矢量度量,用欧式距离作为小分类的标准。
算法步骤

(1)、设定数字k,从n个初始数据中随机的设置k个点为聚类中心点。

(2)、针对n个点的每个数据点,遍历计算到k个聚类中心点的距离,最后按照离哪个中心点最近,就划分到那个类别中。

(3)、对每个已经划分好类别的n个点,对同个类别的点求均值,作为此类别新的中心点。

(4)、循环(2),(3)直到最终中心点收敛。

以上的计算过程将会在下面我的程序实现中有所体现。
算法的代码实现

输入数据:

3 3
4 10
9 6
14 8
18 11
21 7

主实现类:

package DataMining_KMeans;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Collections;

/**
 * k均值算法工具类
 *
 * @author lyq
 *
 */
public class KMeansTool {
    // 输入数据文件地址
    private String filePath;
    // 分类类别个数
    private int classNum;
    // 类名称
    private ArrayList<String> classNames;
    // 聚类坐标点
    private ArrayList<Point> classPoints;
    // 所有的数据左边点
    private ArrayList<Point> totalPoints;

    public KMeansTool(String filePath, int classNum) {
        this.filePath = filePath;
        this.classNum = classNum;
        readDataFile();
    }

    /**
     * 从文件中读取数据
     */
    private void readDataFile() {
        File file = new File(filePath);
        ArrayList<String[]> dataArray = new ArrayList<String[]>();

        try {
            BufferedReader in = new BufferedReader(new FileReader(file));
            String str;
            String[] tempArray;
            while ((str = in.readLine()) != null) {
                tempArray = str.split(" ");
                dataArray.add(tempArray);
            }
            in.close();
        } catch (IOException e) {
            e.getStackTrace();
        }

        classPoints = new ArrayList<>();
        totalPoints = new ArrayList<>();
        classNames = new ArrayList<>();
        for (int i = 0, j = 1; i < dataArray.size(); i++) {
            if (j <= classNum) {
                classPoints.add(new Point(dataArray.get(i)[0],
                        dataArray.get(i)[1], j + ""));
                classNames.add(i + "");
                j++;
            }
            totalPoints
                    .add(new Point(dataArray.get(i)[0], dataArray.get(i)[1]));
        }
    }

    /**
     * K均值聚类算法实现
     */
    public void kMeansClustering() {
        double tempX = 0;
        double tempY = 0;
        int count = 0;
        double error = Integer.MAX_VALUE;
        Point temp;

        while (error > 0.01 * classNum) {
            for (Point p1 : totalPoints) {
                // 将所有的测试坐标点就近分类
                for (Point p2 : classPoints) {
                    p2.computerDistance(p1);
                }
                Collections.sort(classPoints);

                // 取出p1离类坐标点最近的那个点
                p1.setClassName(classPoints.get(0).getClassName());
            }

            error = 0;
            // 按照均值重新划分聚类中心点
            for (Point p1 : classPoints) {
                count = 0;
                tempX = 0;
                tempY = 0;
                for (Point p : totalPoints) {
                    if (p.getClassName().equals(p1.getClassName())) {
                        count++;
                        tempX += p.getX();
                        tempY += p.getY();
                    }
                }
                tempX /= count;
                tempY /= count;

                error += Math.abs((tempX - p1.getX()));
                error += Math.abs((tempY - p1.getY()));
                // 计算均值
                p1.setX(tempX);
                p1.setY(tempY);

            }
            
            for (int i = 0; i < classPoints.size(); i++) {
                temp = classPoints.get(i);
                System.out.println(MessageFormat.format("聚类中心点{0},x={1},y={2}",
                        (i + 1), temp.getX(), temp.getY()));
            }
            System.out.println("----------");
        }

        System.out.println("结果值收敛");
        for (int i = 0; i < classPoints.size(); i++) {
            temp = classPoints.get(i);
            System.out.println(MessageFormat.format("聚类中心点{0},x={1},y={2}",
                    (i + 1), temp.getX(), temp.getY()));
        }

    }

}

坐标点类:

package DataMining_KMeans;

/**
 * 坐标点类
 *
 * @author lyq
 *
 */
public class Point implements Comparable<Point>{
    // 坐标点横坐标
    private double x;
    // 坐标点纵坐标
    private double y;
    //以此点作为聚类中心的类的类名称
    private String className;
    // 坐标点之间的欧式距离
    private Double distance;

    public Point(double x, double y) {
        this.x = x;
        this.y = y;
    }
    
    public Point(String x, String y) {
        this.x = Double.parseDouble(x);
        this.y = Double.parseDouble(y);
    }
    
    public Point(String x, String y, String className) {
        this.x = Double.parseDouble(x);
        this.y = Double.parseDouble(y);
        this.className = className;
    }

    /**
     * 距离目标点p的欧几里得距离
     *
     * @param p
     */
    public void computerDistance(Point p) {
        if (p == null) {
            return;
        }

        this.distance = (this.x - p.x) * (this.x - p.x) + (this.y - p.y)
                * (this.y - p.y);
    }

    public double getX() {
        return x;
    }

    public void setX(double x) {
        this.x = x;
    }

    public double getY() {
        return y;
    }

    public void setY(double y) {
        this.y = y;
    }
    
    public String getClassName() {
        return className;
    }

    public void setClassName(String className) {
        this.className = className;
    }

    public double getDistance() {
        return distance;
    }

    public void setDistance(double distance) {
        this.distance = distance;
    }

    @Override
    public int compareTo(Point o) {
        // TODO Auto-generated method stub
        return this.distance.compareTo(o.distance);
    }
    
}

调用类:

/**
 * K-means(K均值)算法调用类
 * @author lyq
 *
 */
public class Client {
    public static void main(String[] args){
        String filePath = "C:\\Users\\lyq\\Desktop\\icon\\input.txt";
        //聚类中心数量设定
        int classNum = 3;
        
        KMeansTool tool = new KMeansTool(filePath, classNum);
        tool.kMeansClustering();
    }
}

测试输出结果:

聚类中心点1,x=15.5,y=8
聚类中心点2,x=4,y=10
聚类中心点3,x=3,y=3
----------
聚类中心点1,x=17.667,y=8.667
聚类中心点2,x=6.5,y=8
聚类中心点3,x=3,y=3
----------
聚类中心点1,x=17.667,y=8.667
聚类中心点2,x=6.5,y=8
聚类中心点3,x=3,y=3
----------
结果值收敛
聚类中心点1,x=17.667,y=8.667
聚类中心点2,x=6.5,y=8
聚类中心点3,x=3,y=3

K-Means算法的优缺点

1、首先优点当然是算法简单,快速,易懂,没有涉及到特别复杂的数据结构。

2、缺点1是最开始K的数量值以及K个聚类中心点的设置不好定,往往开始时不同的k个中心点的设置对后面迭代计算的走势会有比较大的影响,这时候可以考虑根据类的自动合并和分裂来确定这个k。

3、缺点2由于计算是迭代式的,而且计算距离的时候需要完全遍历一遍中心点,当数据规模比较大的时候,开销就显得比较大了。


Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐