数据挖掘决策树java 联系客服

发布时间 : 星期四 文章数据挖掘决策树java更新完毕开始阅读49b2deb5bb68a98270fefa0b

import java.util.HashMap; import java.util.HashSet;

import java.util.LinkedHashSet; import java.util.Iterator;

//选自csdn博客 //决策树的树结点类 class TreeNode {

String element; //该值为数据的属性名称

String value; //上一个分裂属性在此结点的值

LinkedHashSet childs; //结点的子结点,以有顺序的链式哈希集存储 public TreeNode() { this.element = null; this.value = null; this.childs = null; }

public TreeNode(String value) { this.element = null; this.value = value; this.childs = null; }

public String getElement() { return this.element; }

public void setElement(String e) { this.element = e; }

public String getValue() { return this.value; }

public void setValue(String v) { this.value = v; }

public LinkedHashSet getChilds() { return this.childs; }

}

public void setChilds(LinkedHashSet childs) { this.childs = childs; }

//决策树类

class DecisionTree { TreeNode root; //决策树的树根结点 public DecisionTree() { root = new TreeNode(); } public DecisionTree(TreeNode root) { this.root = root; } public TreeNode getRoot() { return root; } public void setRoot(TreeNode root) { this.root = root; } public String selectAtrribute(TreeNode node,String[][] deData, boolean flags[], LinkedHashSet atrributes, HashMap attrIndexMap) {

//Gain数组存放当前结点未分类属性的Gain值 double Gain[] = new double[atrributes.size()];

//每条数据中归类的下标,为每条数据的最后一个值 int class_index = deData[0].length - 1; //属性名,该结点在该属性上进行分类 String return_atrribute = null;

//计算每个未分类属性的 Gain值 int count = 0; //计算到第几个属性 for(String atrribute:atrributes) { //该属性有多少个值,该属性有多少个分类 int values_count, class_count; //属性值对应的下标 int index = attrIndexMap.get(atrribute); //存放属性的各个值和分类值 LinkedHashSet values = new LinkedHashSet(); LinkedHashSet classes = new LinkedHashSet();

for(int i = 0; i < deData.length; i++) { if(flags[i] == true) { values.add(deData[i][index]); classes.add(deData[i][class_index]); } }

values_count = values.size(); class_count = classes.size();

int values_vector[] = new int[values_count * class_count]; int class_vector[] = new int[class_count];

for(int i = 0; i < deData.length; i++) { if(flags[i] == true) { int j = 0; for(String v:values) { if(deData[i][index].equals(v)) { break; } else { j++; } } int k = 0; for(String c:classes) { if(deData[i][class_index].equals(c)) { break; } else { k++; } } values_vector[j*class_count+k]++; class_vector[k]++; } }

/* //输出各项统计值

for(int i = 0; i < values_count * class_count; i++) { System.out.print(values_vector[i] + \}

System.out.println();

for(int i = 0; i < class_count; i++) {

System.out.print(class_vector[i] + \}

System.out.println(); */

//计算InforD double InfoD = 0.0; double class_total = 0.0; for(int i = 0; i < class_vector.length; i++){ class_total += class_vector[i]; } for(int i = 0; i < class_vector.length; i++){ if(class_vector[i] == 0) { continue; } else { double d = Math.log(class_vector[i]/class_total) / Math.log(2.0) * class_vector[i] / class_total; InfoD = InfoD - d; } } //计算InfoA double InfoA = 0.0; for(int i = 0; i < values_count; i++) { double middle = 0.0; double attr_count = 0.0; for(int j = 0; j < class_count; j++) { attr_count += values_vector[i*class_count+j]; } for(int j = 0; j < class_count; j++) { if(values_vector[i*class_count+j] != 0) { double k = values_vector[i*class_count+j]; middle = middle - Math.log(k/attr_count) / Math.log(2.0) * k / attr_count; } } InfoA += middle * attr_count / class_total; } Gain[count] = InfoD - InfoA; count++;