# encoding: utf-8
'''
#!/usr/bin/env python
@author: yudian
@contact: hhuyudian@163.com
@file: k_means.py
@time: 2018/12/25 0:16
@desc:    JUST WAITING!!!
    程序功能:动态演示K_mean聚类过程。

    可调参数: k(全局变量):聚类中心点数目
               number(get_datas()):生成的点的数量, 最好为7的倍数
               bigCircleTimes: 重复整个聚类过程的次数;
               smallCircleTimes: 单次聚类最多迭代次数。        上两个参数都是函数k_mean_cluster()的局部变量。

    各函数功能:    get_datas(): 得到待分类数据点。所有的带领都均匀分布在[0,0],[1,1],[0,5],[3,5],[2,3],[8,10],[12,12]周围
                                 输入: number(可选):代表输入点的数目
                                 输出: dataArray: number*2的矩阵。类型为numpy.ndarray

                    myShow():    显示数据。根据传入的参数执行不同的显示功能。
                                只有data: 显示初始未分类数据。
                                data, centerArray, indexMatrix: 显示最终分类结果
                                data, centerArray, indexMatrix, initArray: 动态重现最优聚类过程。
                                输入: data: 待显示的数据
                                       centerArray: k个聚类中心点的坐标向量。
                                       indexMatrix: data属于那个聚类中心的指示矩阵
                                       initArray: 最优聚类过程中的初始化聚类中心矩阵。
'''
import numpy as np
import time
import matplotlib.pyplot as plt

def get_datas(number = 490):
    center = np.array([[0,0],[1,1],[0,5],[3,5],[2,3],[8,10],[12,12]])
    center.reshape(-1,2)
    centerNumber = center.shape[0]
    dataArray = np.empty((number,2))
    for i in (range(int(number/centerNumber))):
        randomData = np.random.random_sample((centerNumber,2))
        dataArray[i*centerNumber:(i+1)*centerNumber] = center + randomData
    # myShow(dataArray)
    return dataArray

def myShow(data, centerArray=None, indexMatrix = None, initArray = None):
    if type(centerArray) == type(None):
        print(type(centerArray))
        x = data[:, 0]
        y = data[:, 1]
        plt.scatter(x, y, marker='o', s = 5, c = 'black')
        plt.show()
    else:
        if type(initArray) == type(None):
            print(centerArray)
            colorList = ['red', 'blue', 'green', 'black', 'aliceblue', 'coral', 'firebrick', 'ivory', 'linen', 'mintcream'] 
    # choose colors website: https://www.cnblogs.com/darkknightzh/p/6117528.html
            plt.ion()
            dataX = data[:, 0]
            dataY = data[:, 1]
            plt.scatter(dataX, dataY, marker='o', s = 5, c = 'black')
            centerArray = centerArray.reshape(centerArray.shape[0], 2)
            centerX = centerArray[:, 0]
            centerY = centerArray[:, 1]
            plt.scatter(centerX, centerY, marker = '*', s = 100, c = 'red')
            for i in range(centerArray.shape[0]):
                centerMatrix = data[np.argwhere(indexMatrix == i)]
                print(centerMatrix.shape)
                centerMatrixX = centerMatrix[:, 0, 0]
                centerMatrixY = centerMatrix[:, 0, 1]
                plt.scatter(centerMatrixX, centerMatrixY, marker = 'o', s = 5, c = colorList[i])
            plt.pause(1)
            # plt.cla()
        else:
            plt.ion()
            print(initArray)
            centerArray = initArray
            for localTimes in range(30):
                oldCenterArray = centerArray.copy()
                distMatrix = get_distance_matrix(data, centerArray)
                indexMatrix = cluster(distMatrix)
                myShow(data, centerArray, indexMatrix)
                plt.cla()
                centerArray = find_next_center(data, indexMatrix, k)
                myShow(data, centerArray, indexMatrix)
                plt.pause(1)
                if(centerArray == oldCenterArray).all():
                    plt.pause(3)
                    break
                plt.cla()


def k_mean_cluster(data,k):
    bigCircleTimes = 50
    smallCircleTimes = 30
    totalDist = np.inf
    for times in range(bigCircleTimes):
        number = data.shape[0]
        centerArray = data[np.random.randint(0, number, size = k)]
        initArray = centerArray.copy()
        for localTimes in range(smallCircleTimes):
            oldCenterArray = centerArray.copy()
            distMatrix = get_distance_matrix(data, centerArray)
            indexMatrix = cluster(distMatrix)
            centerArray = find_next_center(data, indexMatrix, k)
            if((oldCenterArray == centerArray).all()):
                print('total have done.')
                break
        if(totalDist > np.sum(distMatrix)):
            totalDist = np.sum(distMatrix)
            holdOnCenterArray = centerArray
            holdOnIndexMatrix = indexMatrix
            holdOnInitArray = initArray
    print(totalDist)
    myShow(data, holdOnCenterArray, holdOnIndexMatrix, holdOnInitArray)


def find_next_center(data, indexMatrix, k):
    assert(data.shape[0] == indexMatrix.shape[0])
    newCenterArray = np.empty(shape=(k, 2))
    for i in range(k):
        tempArray = data[np.argwhere(indexMatrix == i)[:, 0]]
        if( not tempArray.size):
            print("have center no data.Need change center.")
            newCenterArray[i] = data[np.random.randint(0, data.shape[0])]
        else:
            newCenterArray[i] = np.sum(tempArray, axis = 0)/tempArray.shape[0]
    return newCenterArray

def get_distance_matrix(data, initCenter):
    centerX = initCenter[:, 0].reshape(1, -1)
    centerY = initCenter[:, 1].reshape(1, -1)
    dataX = data[:, 0].reshape(-1, 1)
    dataY = data[:, 1].reshape(-1, 1)
    distMatrix = np.power((centerX - dataX), 2) + np.power((centerY - dataY), 2)
    return distMatrix

def cluster(distMatrix):
    number = distMatrix.shape[0]
    minMatrix = np.min(distMatrix, axis = 1)
    indexMatrix = np.empty(shape = (number,))
    for i in range(number):
        try:
            indexMatrix[i] = np.argwhere(distMatrix[i] == minMatrix[i])[0][0]
        except IndexError as e:
            print(minMatrix)
            raise Exception
    return indexMatrix


global k
k = 3

if __name__ == "__main__":
    data = get_datas()
    k_mean_cluster(data, k)

 

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐