1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
| """ Decision Tree Plotter Reference: 1. http://www.cnblogs.com/fantasy01/p/4595902.html 2. http://whatbeg.com/2016/04/23/matplotlib-desiciontree.html """
import matplotlib.font_manager as font_manager import os import matplotlib.pyplot as plt
# Chinese font setting font_path = os.path.abspath('../resources/font/msyh.ttf') prop = font_manager.FontProperties(fname=font_path)
gAxis = None gDecison_node = dict(boxstyle='sawtooth', fc='0.8') gLeaf_node = dict(boxstyle='round4', fc="0.8") gArrow_args = dict(arrowstyle='<-') gNum_leaves = 0 gTree_depth = 0 gX_offset = 0 gY_offset = 0
def get_num_leafs(tree): """ Identify the number of leaves in a tree """ num_leafs = 0 first_str = list(tree.keys())[0] second_dict = tree[first_str] for key in second_dict.keys(): if isinstance(second_dict[key], dict): num_leafs += get_num_leafs(second_dict[key]) else: num_leafs += 1
return num_leafs
def get_tree_depth(tree): """ Identiy the depth of a tree """ max_depth = 0 first_str = list(tree.keys())[0] second_dict = tree[first_str] for key in second_dict.keys(): if isinstance(second_dict[key], dict): subtree_depth = 1 + get_tree_depth(second_dict[key]) else: subtree_depth = 1
if subtree_depth > max_depth: max_depth = subtree_depth
return max_depth
def plot_node(node_text, center_point, parent_point, node_type): global gAxis
gAxis.annotate(node_text, xy=parent_point, xycoords='axes fraction', xytext=center_point, textcoords='axes fraction', va='center', ha='center', bbox=node_type, arrowprops=gArrow_args, fontproperties=prop)
def plot_mid_text(current_point, parent_point, text_str): """ Plot text between child and parent """ global gAxis
x_mid = (parent_point[0] - current_point[0])/2.0 + current_point[0] y_mid = (parent_point[1] - current_point[1])/2.0 + current_point[1] gAxis.text(x_mid, y_mid, text_str, fontproperties=prop)
def plot_tree(tree, parent_point, node_text): global gAxis, gDecison_node global gNum_leaves, gTree_depth, gX_offset, gY_offset
num_leaves = get_num_leafs(tree) first_str = list(tree.keys())[0] current_point = (gX_offset+(1.0+float(num_leaves))/(2.0*gNum_leaves), gY_offset) plot_mid_text(current_point, parent_point, node_text) plot_node(first_str, current_point, parent_point, gDecison_node) second_dict = tree[first_str] gY_offset = gY_offset - 1.0/gTree_depth for key in second_dict.keys(): if isinstance(second_dict[key], dict): plot_tree(second_dict[key], current_point, str(key)) else: gX_offset = gX_offset + 1.0/gNum_leaves plot_node(second_dict[key], (gX_offset, gY_offset), current_point, gLeaf_node) plot_mid_text((gX_offset, gY_offset), current_point, str(key)) gY_offset = gY_offset + 1.0/gTree_depth
def create_plot(tree): global gAxis, gNum_leaves, gTree_depth, gX_offset, gY_offset
fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) gAxis = plt.subplot(111, frameon=False, **axprops) gNum_leaves = float(get_num_leafs(tree)) gTree_depth = float(get_tree_depth(tree)) gX_offset = -1/(2.0 * gNum_leaves) gY_offset = 1.0 plot_tree(tree, (0.5, 1.0), '') plt.show()
|