机器学习实战

k-近邻算法

Posted by TkiChus on January 2, 2019

机器学习实战

1.K-近邻分类算法

  • 官方解释: KNN是通过测量不同特征值之间的距离进行分类。它的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别,其中K通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。

  • 自行理解:通过与将要进行预测的事件和数据集进行对比,然后根据事件和数据集中的相差最小的事件的分类来预测该事件的分类

    例如

​ 如上图:所有样本可以使用一个二维向量表征。图中,蓝色方形样本和红色三角形样本为已知分类样本。若使用KNN对图中绿色圆形未知分类样本进行分类,当K=3时,其三近邻中有2个红色三角形样本和1个蓝色方形样本,因此预测该待分类样本为红色三角形样本;当K=5时,其三近邻中有3个红色三角形样本和2个蓝色方形样本,因此预测该待分类样本为蓝色方形样本; 综上,KNN算法有以下几个要素:

1.数据集

2.样本的向量表示

3.样本间距离的计算方法

4.k值的选取

  • k-近邻算法的一般流程(伪代码)

    1.计算已知类别数据集中的每个点与当前点之间的距离

    计算利用的是欧式公式:

    欧式公式

    2.按照距离递增次序排序;

    3.选取与当前点距离最小的K个点;

    4.确定前K个点所在类别的出现频率;

    5.返回前K个点出现频率最高的类当作当前点的预测分类;

  • 代码实现

    
    def classify0(inX, dataSet, labels, k):
        dataSetSize = dataSet.shape[0]
        diffMat = tile(inX, (dataSetSize, 1)) - dataSet
        sqDiffMat = diffMat ** 2
        sqDistances = sqDiffMat.sum(axis=1)
        distances = sqDistances ** 0.5
        sortedDistIndicies = distances.argsort()
        classCount = {}
        for i in range(k):
            voteIlabel = labels[sortedDistIndicies[i]]
            classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
        sortedClassCount = sorted(classCount.items(), # 在python3.5之后itersitems变成items
                                  key=operator.itemgetter(1), reverse=True)
        return sortedClassCount[0][0]
    
  • 实例:手写数字的识别

    代码:

    def handWritingClassTest():
        hwLabels = []
        traningFileList = listdir('trainingDigits')   #训练集
        m = len(traningFileList)     #训练集的长度
        traningMat = np.zeros((m, 1024))
        for i in range(m):
            fileNameStr = traningFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            hwLabels.append(classNumStr)
            traningMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
        testFileList = listdir('testDigits')
        errorCount = 0.0
        mTest = len(testFileList)
        for i in range(mTest):
            fileNameStr = testFileList[i]
            fileStr = fileNameStr.split(',')[0]
            classNumStr = int(fileStr.split('_')[0])
            vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
            classifierResult =classify0(vectorUnderTest, \
                                        traningMat,hwLabels,3)
            print("the classifier came back with : %d, the real answer is :%d"\
                  %(classifierResult,classNumStr))
            if(classifierResult != classNumStr):errorCount +=1.0
        print("\nthe total number of errors is:%d" % errorCount)
        print("\nthe total error rate is %f" %(errorCount/float(mTest)))
    
    

    测试结果:

1547540400563

注:使用到的数据集可以在:机器学习实战网站上找到,更多学习也可以查阅此网站!!!