K-means算法的Java实现 聚类分析681个三国武将 - Go语言中文社区

K-means算法的Java实现 聚类分析681个三国武将


一,k-means算法介绍:

k-means算法接受输入量 k ;然后将n个数据对象划分为 k个聚类以便使得所获得的聚类满足:同一聚类中的对象相似度较高;而不同聚类中的对象相似度较小。聚类相似度是利用各聚类中对象的均值所获得一个“中心对象”(引力中心)来进行计算的。k个聚类具有以下特点:各聚类本身尽可能的紧凑,而各聚类之间尽可能的分开。 k个聚类具有以下特点:各聚类本身尽可能的紧凑,而各聚类之间尽可能的分开。

k-means算法的工作过程说明如下:首先从n个数据对象任意选择 k 个对象作为初始聚类中心;而对于所剩下其它对象,则根据它们与这些聚类中心的相似度(距离),分别将它们分配给与其最相似的(聚类中心所代表的)聚类;然后再计算每个所获新聚类的聚类中心(该聚类中所有对象的均值);不断重复这一过程直到标准测度函数开始收敛为止。一般都采用均方差作为标准测度函数。k个聚类具有以下特点:各聚类本身尽可能的紧凑,而各聚类之间尽可能的分开。

二,k-means算法基本步骤:

(1) 从 n个数据对象任意选择 k 个对象作为初始聚类中心;

(2) 根据每个聚类对象的均值(中心对象),计算每个对象与这些中心对象的距离;并根据最小距离重新对相应对象进行划分;

(3) 重新计算每个(有变化)聚类的均值(中心对象);

(4) 计算标准测度函数,当满足一定条件,如函数收敛时,则算法终止;如果条件不满足则回到步骤(2),不断重复直到标准测度函数开始收敛为止。(一般都采用均方差作为标准测度函数。)

三,k-means算法的java实现:

一共有七个类,General.java代表武将对象, Distance.java距离类计算各个武将到中心武将之间的距离, Cluster.java聚类对象包含一个中心武将和该聚类中所有武将, Kmeans.java核心的聚类算法类, Tool.java工具类用于转换武将的星级为数字等操作, TestKmeans.java测试类即入口文件, DomParser.java用于读取xml中的681个武将。

具体思路:先从general.xml文件中读取681个武将,然后随机选取初始类中心,计算各个武将到中心武将的距离,根据最小的距离进行聚类,然后重新根据平均值新的聚类的类中心,重新计算各个武将到新的中心武将的距离,直到更新后的聚类与原来的聚类包含的武将不再改变,即收敛时结束。

具体代码如下:

1,General.java

  1. </pre><pre name="code" class="java"
  1. </pre><pre name="code" class="java">package kmeans;  
  2.  
  3. public class General {  
  4.       
  5.     private String name; // 姓名  
  6.     private int render; // 星级  
  7.     private int tongshai; // 统帅  
  8.     private int wuli; // 武力  
  9.     private int zhili; // 智力  
  10.     private int polic; // 政治  
  11.     private int qiangbin; // 枪兵  
  12.     private int jibin; // 戟兵  
  13.     private int nubin; // 弩兵  
  14.     private int qibin; // 骑兵  
  15.     private int binqi; // 兵器  
  16.     private int tongwu; // 统武  
  17.     private int tongzhi; // 统智  
  18.     private int tongwuzhi; // 统武智  
  19.     private int tongwuzhizheng; // 统武智政  
  20.     private int salary; // 50级工资  
  21.  
  22.     public General(int render, String name, int tongshai, int wuli, int zhili,  
  23.             int polic, int qiangbin, int jibin, int nubin, int qibin,  
  24.             int binqi, int tongwu, int tongzhi, int tongwuzhi,  
  25.             int tongwuzhizheng, int salary) {  
  26.         super();  
  27.         this.name = name;  
  28.         this.render = render;  
  29.         this.tongshai = tongshai;  
  30.         this.wuli = wuli;  
  31.         this.zhili = zhili;  
  32.         this.polic = polic;  
  33.         this.qiangbin = qiangbin;  
  34.         this.jibin = jibin;  
  35.         this.nubin = nubin;  
  36.         this.qibin = qibin;  
  37.         this.binqi = binqi;  
  38.         this.tongwu = tongwu;  
  39.         this.tongzhi = tongzhi;  
  40.         this.tongwuzhi = tongwuzhi;  
  41.         this.tongwuzhizheng = tongwuzhizheng;  
  42.         this.salary = salary;  
  43.     }  
  44.  
  45.     public General(int render, int tongshai, int wuli, int zhili, int polic,  
  46.             int qiangbin, int jibin, int nubin, int qibin, int binqi,  
  47.             int tongwu, int tongzhi, int tongwuzhi, int tongwuzhizheng,  
  48.             int salary) {  
  49.         super();  
  50.         this.name = "聚类中心";  
  51.         this.render = render;   
  52.         this.tongshai = tongshai;   
  53.         this.wuli = wuli;  
  54.         this.zhili = zhili;  
  55.         this.polic = polic;  
  56.         this.qiangbin = qiangbin;  
  57.         this.jibin = jibin;  
  58.         this.nubin = nubin;  
  59.         this.qibin = qibin;  
  60.         this.binqi = binqi;  
  61.         this.tongwu = tongwu;  
  62.         this.tongzhi = tongzhi;  
  63.         this.tongwuzhi = tongwuzhi;  
  64.         this.tongwuzhizheng = tongwuzhizheng;  
  65.         this.salary = salary;  
  66.     }  
  67.  
  68.     public General() {  
  69.     }  
  70.  
  71.     @Override 
  72.     public String toString() {  
  73.         return "武将 [name=" + name + ", render=" + Tool.dxingji(render)  
  74.                 + ", tongshai=" + tongshai + ", wuli=" + wuli + ", zhili=" 
  75.                 + zhili + ", polic=" + polic + ", qiangbin=" 
  76.                 + Tool.dchange(qiangbin) + ", jibin=" + Tool.dchange(jibin)  
  77.                 + ", nubin=" + Tool.dchange(nubin) + ", qibin=" 
  78.                 + Tool.dchange(qibin) + ", binqi=" + Tool.dchange(binqi)  
  79.                 + ", tongwu=" + tongwu + ", tongzhi=" + tongzhi  
  80.                 + ", tongwuzhi=" + tongwuzhi + ", tongwuzhizheng=" 
  81.                 + tongwuzhizheng + ", salary=" + salary + "]";  
  82.     }  
  83.  
  84.     public String getName() {  
  85.         return name;  
  86.     }  
  87.  
  88.     public void setName(String name) {  
  89.         this.name = name;  
  90.     }  
  91.  
  92.     public int getRender() {  
  93.         return render;  
  94.     }  
  95.  
  96.     public void setRender(int render) {  
  97.         this.render = render;  
  98.     }  
  99.  
  100.     public int getTongshai() {  
  101.         return tongshai;  
  102.     }  
  103.  
  104.     public void setTongshai(int tongshai) {  
  105.         this.tongshai = tongshai;  
  106.     }  
  107.  
  108.     public int getWuli() {  
  109.         return wuli;  
  110.     }  
  111.  
  112.     public void setWuli(int wuli) {  
  113.         this.wuli = wuli;  
  114.     }  
  115.  
  116.     public int getZhili() {  
  117.         return zhili;  
  118.     }  
  119.  
  120.     public void setZhili(int zhili) {  
  121.         this.zhili = zhili;  
  122.     }  
  123.  
  124.     public int getPolic() {  
  125.         return polic;  
  126.     }  
  127.  
  128.     public void setPolic(int polic) {  
  129.         this.polic = polic;  
  130.     }  
  131.  
  132.     public int getQiangbin() {  
  133.         return qiangbin;  
  134.     }  
  135.  
  136.     public void setQiangbin(int qiangbin) {  
  137.         this.qiangbin = qiangbin;  
  138.     }  
  139.  
  140.     public int getJibin() {  
  141.         return jibin;  
  142.     }  
  143.  
  144.     public void setJibin(int jibin) {  
  145.         this.jibin = jibin;  
  146.     }  
  147.  
  148.     public int getNubin() {  
  149.         return nubin;  
  150.     }  
  151.  
  152.     public void setNubin(int nubin) {  
  153.         this.nubin = nubin;  
  154.     }  
  155.  
  156.     public int getQibin() {  
  157.         return qibin;  
  158.     }  
  159.  
  160.     public void setQibin(int qibin) {  
  161.         this.qibin = qibin;  
  162.     }  
  163.  
  164.     public int getBinqi() {  
  165.         return binqi;  
  166.     }  
  167.  
  168.     public void setBinqi(int binqi) {  
  169.         this.binqi = binqi;  
  170.     }  
  171.  
  172.     public int getTongwu() {  
  173.         return tongwu;  
  174.     }  
  175.  
  176.     public void setTongwu(int tongwu) {  
  177.         this.tongwu = tongwu;  
  178.     }  
  179.  
  180.     public int getTongzhi() {  
  181.         return tongzhi;  
  182.     }  
  183.  
  184.     public void setTongzhi(int tongzhi) {  
  185.         this.tongzhi = tongzhi;  
  186.     }  
  187.  
  188.     public int getTongwuzhi() {  
  189.         return tongwuzhi;  
  190.     }  
  191.  
  192.     public void setTongwuzhi(int tongwuzhi) {  
  193.         this.tongwuzhi = tongwuzhi;  
  194.     }  
  195.  
  196.     public int getTongwuzhizheng() {  
  197.         return tongwuzhizheng;  
  198.     }  
  199.  
  200.     public void setTongwuzhizheng(int tongwuzhizheng) {  
  201.         this.tongwuzhizheng = tongwuzhizheng;  
  202.     }  
  203.  
  204.     public int getSalary() {  
  205.         return salary;  
  206.     }  
  207.  
  208.     public void setSalary(int salary) {  
  209.         this.salary = salary;  
  210.     }  
  211.  
  212. }  

2,Distance.java

  1. </pre><pre name="code" class="java">package kmeans;  
  2. /**  
  3.  * 这个类用于计算距离的。。  
  4.  *  
  5.  */ 
  6. public class Distance {  
  7.     int dest;// 目的  
  8.     int source;// 源  
  9.     double dist;// 欧式距离  
  10.  
  11.     public int getDest() {  
  12.         return dest;  
  13.     }  
  14.  
  15.     public void setDest(int dest) {  
  16.         this.dest = dest;  
  17.     }  
  18.  
  19.     public int getSource() {  
  20.         return source;  
  21.     }  
  22.  
  23.     public void setSource(int source) {  
  24.         this.source = source;  
  25.     }  
  26.  
  27.     public double getDist() {  
  28.         return dist;  
  29.     }  
  30.  
  31.     public void setDist(double dist) {  
  32.         this.dist = dist;  
  33.     }  
  34.     /**  
  35.      * 计算源和目的的距离  
  36.      * @param dest 目的武将  
  37.      * @param source 源武将  
  38.      * @param dist 两者间的距离  
  39.      */ 
  40.     public Distance(int dest, int source, double dist) {  
  41.         this.dest = dest;  
  42.         this.source = source;  
  43.         this.dist = dist;  
  44.     }  
  45.  
  46.     public Distance() {  
  47.     }  
  48.  

3,Cluster.java

  1. </pre><pre name="code" class="java">package kmeans;  
  2.  
  3. import java.util.ArrayList;  
  4.  
  5. public class Cluster {  
  6.     private int center;// 聚类中心武将的id  
  7.     private ArrayList<General> ofCluster = new ArrayList<General>();// 属于这个聚类的武将的集合  
  8.  
  9.     public int getCenter() {  
  10.         return center;  
  11.     }  
  12.  
  13.     public void setCenter(int center) {  
  14.         this.center = center;  
  15.     }  
  16.  
  17.     public ArrayList<General> getOfCluster() {  
  18.         return ofCluster;  
  19.     }  
  20.  
  21.     public void setOfCluster(ArrayList<General> ofCluster) {  
  22.         this.ofCluster = ofCluster;  
  23.     }  
  24.  
  25.     public void addGeneral(General general) {  
  26.         if (!(this.ofCluster.contains(general)))  
  27.             this.ofCluster.add(general);  
  28.     }  
  29. }  

4,Kmeans.java

  1. </pre><pre name="code" class="java"
  1. package kmeans;  
  2.  
  3. import java.util.*;  
  4.  
  5. public class Kmeans {  
  6.     public ArrayList<General> allGenerals = null;  
  7.     public int totalNumber = 0;// 得到所有的武将数目  
  8.     public int K = 0;// 假设K=10  
  9.  
  10.     public Kmeans() {  
  11.         allGenerals = new DomParser().prepare();  
  12.         totalNumber = allGenerals.size();  
  13.         K = 3;  
  14.     }  
  15.  
  16.     // 第一次随机选取聚类中心  
  17.     public Set<Integer> firstRandom() {  
  18.         Set<Integer> center = new HashSet<Integer>();// 聚类中心的点的id,采用set保证不会有重复id  
  19.         Random ran = new Random();  
  20.         int roll = ran.nextInt(totalNumber);  
  21.         while (center.size() < K) {  
  22.             roll = ran.nextInt(totalNumber);  
  23.             center.add(roll);  
  24.         }  
  25.         return center;  
  26.     }  
  27.  
  28.     // 根据聚类中心初始化聚类信息  
  29.     public ArrayList<Cluster> init(Set<Integer> center) {  
  30.         ArrayList<Cluster> cluster = new ArrayList<Cluster>();// 聚类 的数组  
  31.         Iterator<Integer> it = center.iterator();  
  32.         while (it.hasNext()) {  
  33.             Cluster c = new Cluster();// 代表一个聚类  
  34.             c.setCenter(it.next());  
  35.             cluster.add(c);  
  36.         }  
  37.         return cluster;  
  38.     }  
  39.  
  40.     /**  
  41.      * 计算各个武将到各个聚类中心的距离,重新聚类  
  42.      *   
  43.      * @param cluster  
  44.      *            聚类数组,用来聚类的,根据最近原则把武将聚类  
  45.      * @param center  
  46.      *            中心点id,用于计算各个武将到中心点的距离 return cluster 聚类后的所有聚类组成的数组  
  47.      */ 
  48.     public ArrayList<Cluster> juLei(Set<Integer> center,  
  49.             ArrayList<Cluster> cluster) {  
  50.         ArrayList<Distance> distence = new ArrayList<Distance>();// 存放距离信息,表示每个点到各个中心点的距离组成的数组  
  51.         General source = null;  
  52.         General dest = null;  
  53.         int id = 0;// 目的节点id  
  54.         int id2 = 0;// 源节点id  
  55.         Object[] p = center.toArray();// p 为聚类中心点id数组  
  56.         boolean flag = false;  
  57.         // 分别计算各个点到各个中心点的距离,并将距离最小的加入到各个聚类中,进行聚类  
  58.         for (int i = 0; i < totalNumber; i++) {  
  59.             // 每个点计算完,并聚类到距离最小的聚类中就清空距离数组  
  60.             distence.clear();  
  61.             // 计算到j个类中心点的距离,便利各个中心点  
  62.             for (int j = 0; j < center.size(); j++) {  
  63.                 // 如果该点不在中心点内 则计算距离  
  64.                 if (!(center.contains(i))) {  
  65.                     flag = true;  
  66.                     // 计算距离  
  67.                     source = allGenerals.get(i);// 某个点  
  68.                     dest = allGenerals.get((Integer) p[j]);// 各个 中心点  
  69.                     // 计算距离并存入数组  
  70.                     distence.add(new Distance((Integer) p[j], i, Tool.juli(  
  71.                             source, dest)));  
  72.                 } else {  
  73.                     flag = false;  
  74.                 }  
  75.             }  
  76.             // 说明计算完某个武将到类中心的距离,开始比较  
  77.             if (flag == true) {  
  78.                 // 排序比较一个点到各个中心的距离的大小,找到距离最小的武将的 目的id,和源id,  
  79.                 // 目的id即类中心点id,这个就归到这个中心点所在聚类中  
  80.                 double min = distence.get(0).getDist();// 默认第一个distance距离是最小的  
  81.                 // 从1开始遍历distance数组  
  82.                 int minid = 0;  
  83.                 for (int k = 1; k < distence.size(); k++) {  
  84.                     if (min > distence.get(k).getDist()) {  
  85.                         min = distence.get(k).getDist();  
  86.                         id = distence.get(k).getDest();// 目的,即类中心点  
  87.                         id2 = distence.get(k).getSource();// 某个武将  
  88.                         minid = k;  
  89.                     } else {  
  90.                         id = distence.get(minid).getDest();  
  91.                         id2 = distence.get(minid).getSource();  
  92.                     }  
  93.                 }  
  94.                 // 遍历cluster聚类数组,找到类中心点id与最小距离目的武将id相同的聚类  
  95.                 for (int n = 0; n < cluster.size(); n++) {  
  96.                     // 如果和中心点的id相同 则setError  
  97.                     if (cluster.get(n).getCenter() == id) {  
  98.                         cluster.get(n).addGeneral(allGenerals.get(id2));// 将与该聚类中心距离最小的武将加入该聚类  
  99.                         break;  
  100.                     }  
  101.                 }  
  102.             }  
  103.         }  
  104.         return cluster;  
  105.     }  
  106.  
  107.     // 产生新的聚类中心点数组  
  108.     public Set<Integer> updateCenter() {  
  109.         Set<Integer> center = new HashSet<Integer>();  
  110.         for (int i = 0; i < K; i++) {  
  111.             center.add(i);  
  112.         }  
  113.         return center;  
  114.     }  
  115.  
  116.     // 更新聚类中心, 求平均值  
  117.     public ArrayList<Cluster> updateCluster(ArrayList<Cluster> cluster) {  
  118.         ArrayList<Cluster> result = new ArrayList<Cluster>();  
  119.         // 重新产生的新的聚类中心组成的数组  
  120.         // k个聚类进行更新聚类中心  
  121.         for (int j = 0; j < K; j++) {  
  122.             ArrayList<General> ps = cluster.get(j).getOfCluster();// 该聚类的所有 武将  
  123.                                                                     // 组成的数组  
  124.             ps.add(allGenerals.get(cluster.get(j).getCenter()));// 同时将该类中心对应的武将加入该武将数组  
  125.             int size = ps.size();// 该聚类的长度大小  
  126.             // 计算和,然后在计算平均值  
  127.             int sumrender = 0, sumtongshai = 0, sumwuli = 0, sumzhili = 0, sumjibin = 0, sumnubin = 0, sumqibin = 0, sumpolic = 0, sumqiangbin = 0, sumbinqi = 0, sumtongwu = 0, sumtongzhi = 0, sumtongwuzhi = 0, sumtongwuzhizheng = 0, sumsalary = 0;  
  128.             for (int k1 = 0; k1 < size; k1++) {  
  129.                 sumrender += ps.get(k1).getRender();  
  130.                 sumtongshai += ps.get(k1).getRender();  
  131.                 sumwuli += ps.get(k1).getWuli();  
  132.                 sumzhili += ps.get(k1).getZhili();  
  133.                 sumjibin += ps.get(k1).getJibin();  
  134.                 sumnubin += ps.get(k1).getNubin();  
  135.                 sumqibin += ps.get(k1).getQibin();  
  136.                 sumpolic += ps.get(k1).getPolic();  
  137.                 sumqiangbin += ps.get(k1).getQiangbin();  
  138.                 sumbinqi += ps.get(k1).getBinqi();  
  139.                 sumtongwu += ps.get(k1).getTongwu();  
  140.                 sumtongzhi += ps.get(k1).getTongzhi();  
  141.                 sumtongwuzhi += ps.get(k1).getTongwuzhi();  
  142.                 sumtongwuzhizheng += ps.get(k1).getTongwuzhizheng();  
  143.                 sumsalary += ps.get(k1).getSalary();  
  144.             }  
  145.             // 产生新的聚类,然后加入到聚类数组中  
  146.             Cluster newCluster = new Cluster();  
  147.             newCluster.setCenter(j);  
  148.             // 计算平均值并构造新的武将对象  
  149.             newCluster.addGeneral(new General(sumrender / size, sumtongshai  
  150.                     / size, sumwuli / size, sumzhili / size, sumjibin / size,  
  151.                     sumnubin / size, sumqibin / size, sumpolic = 0,  
  152.                     sumqiangbin = 0, sumbinqi / size, sumtongwu / size,  
  153.                     sumtongzhi / size, sumtongwuzhi / size, sumtongwuzhizheng  
  154.                             / size, sumsalary / size));  
  155.             result.add(newCluster);  
  156.         }  
  157.         return result;  
  158.  
  159.     }  
  160.  
  161.     /**  
  162.      * 计算各个武将到各个更新后的聚类中心的距离,重新聚类  
  163.      * @param update 更新后的聚类中心  
  164.      * @param cluster 要存储的聚类中心  
  165.      */ 
  166.     public ArrayList<Cluster> updateJuLei(ArrayList<Cluster> update,  
  167.             ArrayList<Cluster> cluster) {  
  168.         ArrayList<Distance> distence = new ArrayList<Distance>();// 存放距离信息,表示每个点到各个中心点的距离组成的数组  
  169.         General source = null;  
  170.         General dest = null;  
  171.         int id = 0;// 目的节点id  
  172.         int id2 = 0;// 源节点id  
  173.         //Object[] p = center.toArray();// p 为聚类中心点id数组  
  174.         boolean flag = false;  
  175.         // 分别计算各个点到各个中心点的距离,并将距离最小的加入到各个聚类中,进行聚类  
  176.         for (int i = 0; i < totalNumber; i++) {  
  177.             // 每个点计算完,并聚类到距离最小的聚类中就清空距离数组  
  178.             distence.clear();  
  179.             // 计算到j个类中心点的距离,便利各个中心点  
  180.             //for (int j = 0; j < center.size(); j++) {  
  181.             for (int j = 0; j < update.size(); j++) {  
  182.                 // 如果该点不在中心点内 则计算距离  
  183.                 //if (!(center.contains(i))) {  
  184.                     flag = true;  
  185.                     // 计算距离  
  186.                     source = allGenerals.get(i);// 某个点  
  187.                     // dest = allGenerals.get((Integer) p[j]);// 各个 中心点  
  188.                     dest = update.get(j).getOfCluster().get(0);// 各个 中心点  
  189.                     // 计算距离并存入数组  
  190.                     //distence.add(new Distance((Integer) p[j], i, Tool.juli(  
  191.                     distence.add(new Distance(update.get(j).getCenter(), i, Tool.juli(  
  192.                             source, dest)));  
  193.                     /*} else {  
  194.                     flag = false;  
  195.                 }*/ 
  196.             }  
  197.             // 说明计算完某个武将到类中心的距离,开始比较  
  198.             if (flag == true) {  
  199.                 // 排序比较一个点到各个中心的距离的大小,找到距离最小的武将的 目的id,和源id,  
  200.                 // 目的id即类中心点id,这个就归到这个中心点所在聚类中  
  201.                 double min = distence.get(0).getDist();// 默认第一个distance距离是最小的  
  202.                 // 从1开始遍历distance数组  
  203.                 int mid = 0;  
  204.                 for (int k = 1; k < distence.size(); k++) {  
  205.                     if (min > distence.get(k).getDist()) {  
  206.                         min = distence.get(k).getDist();  
  207.                         id = distence.get(k).getDest();// 目的,即类中心点  
  208.                         id2 = distence.get(k).getSource();// 某个武将  
  209.                         mid = k;  
  210.                     } else {  
  211.                         id = distence.get(mid).getDest();  
  212.                         id2 = distence.get(mid).getSource();  
  213.                     }  
  214.                 }  
  215.                 // 遍历cluster聚类数组,找到类中心点id与最小距离目的武将id相同的聚类  
  216.                 for (int n = 0; n < cluster.size(); n++) {  
  217.                     // 如果和中心点的id相同 则setError  
  218.                     if (cluster.get(n).getCenter() == id) {  
  219.                         cluster.get(n).addGeneral(allGenerals.get(id2));// 将与该聚类中心距离最小的武将加入该聚类  
  220.                     }  
  221.                 }  
  222.             }  
  223.         }  
  224.         return cluster;  
  225.     }  
  226.  
  227.     // 不断循环聚类直到各个聚类没有重新分配  
  228.     public ArrayList<Cluster> getResu
版权声明:本文来源51CTO,感谢博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
原文链接:http://developer.51cto.com/art/201205/334592.htm
站方申明:本站部分内容来自社区用户分享,若涉及侵权,请联系站方删除。
  • 发表于 2021-05-16 09:57:55
  • 阅读 ( 819 )
  • 分类:算法

0 条评论

请先 登录 后评论

官方社群

GO教程

猜你喜欢