# Develop a Decision Tree in Python From Scratch

Learn to develop a decision tree in Python using a class-based method.

QUICK LINK: Decision tree in Python

## Data preparation and visualisation

The task is to write my own codes to learn a decision tree using two features (the souce clusters and the destination clusters) to predict the classification field. Therefore, the first thing step is to read cluster dataset with classification labels. Some samples of the dataset are shown in the table below:

 1 2 3 4 5  cluster_data = pd.read_csv('cluster_data.csv') print( 'Cluster dataset generated:\n', cluster_data.head() ) 
Cluster dataset generated:
sourceIP cluster  destIP cluster           class
0                 0               0   Misc activity
1                 3               0   Misc activity
2                 3               0   Misc activity
3                 2               0   Misc activity
4                 3               0   Misc activity


Before learning the decision tree, a similar size-encoding scatter graph is generated to demonstrate what classes that the points (different kinds of communications) will belong to. In the scatter graph, points of the same class will be drawn in the same color. See Figure

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59  classes = cluster_data['class'] # Extract the class column unique_classes = np.unique(classes) # Unique classes # Replace the string names with the indices of them in unique classes array cluster_data_digit_cls = cluster_data.copy(deep=True) for i, label in enumerate(unique_classes): cluster_data_digit_cls = cluster_data_digit_cls.replace(label, i) print( 'Cluster dataset with indices as class names generated:\n', cluster_data_digit_cls.head() ) # Generate triples with indices of sourceIP cluster, destIP cluster and class cluster_triples = [(cluster_data_digit_cls.iloc[i][0], cluster_data_digit_cls.iloc[i][1], cluster_data_digit_cls.iloc[i][2]) for i in cluster_data_digit_cls.index] # Use Counter method counter_relation = Counter(cluster_triples) # Generate the numpy array in shape (n,4) where n denotes all types of triples and the four column contains the number of records of the corresponding triples. This step may cost about 10 seconds relation = np.concatenate((np.asarray(list(counter_relation.keys())),np.asarray(list(counter_relation.values())).reshape(-1,1)), axis=1) # Save the dataset with counts # pd.DataFrame(relation, columns=['sourceIP cluster', 'destIP cluster', 'class', 'counts']).to_csv('relation.csv') # Generate data for size-encoding scatter plot x = relation[:,0] # Source IP cluster indices y = relation[:,1] # Destination IP cluster indices area = (relation[:,3])**2/10000 # Marker size with real number of records log_area = (np.log(relation[:,3]))**2*15 # Constrained size in logspace colors = relation[:,2] # Colours defined by classes # Create new subplots figure fig, axes = plt.subplots(1,2,figsize=(20,10)) fig.suptitle('Cluster Connections with Classifications', fontsize=20) plt.setp(axes.flat, xlabel='sourceIP Clusters', ylabel='destIP Clusters') # Scatter plot: use alpha to increase transparency scatter = axes[0].scatter(x, y, s=area, c=colors, alpha=0.8, cmap='Paired') axes[0].set_title('Real size encoding records') # Legend of classes handles, _ = scatter.legend_elements(prop='colors', alpha=0.6) lgd2 = axes[0].legend(handles, unique_classes, loc="best", title="Classes") # Scatter plot in logspace scatter = axes[1].scatter(x, y, s=log_area, c=colors, alpha=0.8, cmap='Paired') axes[1].set_title('Logspace size encoding records') # Legend of sizes kw = dict(prop="sizes", num=5, color=scatter.cmap(0.7), fmt="{x:.0f}", func=lambda s: s) handles, labels = scatter.legend_elements(**kw) lgd2 = axes[1].legend(handles, labels, loc='best', title='Sizes = \n$(log(num\_records))^2*15$', labelspacing=2.5) plt.savefig('Q4-relation-scatter.pdf') plt.savefig('Q4-relation-scatter.jpg') plt.show() 
Cluster dataset with indices as class names generated:
sourceIP cluster  destIP cluster  class
0                 0               0      1
1                 3               0      1
2                 3               0      1
3                 2               0      1
4                 3               0      1


## Implementation of decision trees

With the dataset that contains the indices of source clusters, destination clusters and the classifications in strings, the decision tree should be capable to implement decision process to split the data into branches over and over again until all the nodes can all be labelled, i.e. the nodes satisfy some standards in pre-pruning or post-pruning.

The approach proposed in this report is a class-based implementation. Following the structure of a binary tree, I firstly built a class called Node whose instance holds the attributes like data, depth, classificatoin, prev_condition (the condition that brings the data to this node), ... The way of connecting the nodes in two layers is by specifying the left son node and the right son node since the decision tree in the implementation is a binary tree. In addition, backuperror and mcp (misclassification probability) are also defined to help the algorithm perform post-pruning. There is only one Python class method in the class: set_splits, which is just a quick way of assigning values of the attributes related to how the node comes from its parent node.

Then a class called DecisionTree is created. This class defines how the decision tree takes in training data, how it learns to split the data, how to classify a node, how to visulise a decision tree and how to predict the classifications of the input test data, etc. Significant instance objects includes root (the root of the decision tree, and it will be assigned a Node instance), criterion (based on which criterion the impurity is calculated, such as entropy, Gini index and misclassification error). Other objects are mostly about the configurations of pre-pruning and post-pruning. Details are given in the following sections.

### Computation of impurity

A general decision tree performs its branching by finding the optimal splitting method with the maximised information gain or the minimised degree of impurity. Three methods are used to calculate the impurity: entropy, Gini index and misclassification errors. Their equations are listed below:

\begin{aligned} \text{Entropy} &= \sum P_i \times {log}_{2}{P_i}\\ \text{Gini index} &= 1-\sum (P_i)^2\\ \text{Misclassification error} &= 1-\underset{i}{max}P_i \end{aligned}

Also, a function to calculate the Laplace-based misclassification probability is also provided. This leads to a similar results of computing misclassification error. The reason I implement this method is to reproduce post-pruning given in the course learning materials.

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198  def calculate_entropy(data): """Calculate the entropy of the input data. Parameters: ------ data : numpy array Should be the data whose last column contains the class labels. Returns: ------ entropy : float The entropy of the data. N.B. ------ If the data is an empty array, entropy will be 0. """ labels = data[:,-1] _, counts = np.unique(labels, return_counts=True) probs = counts / counts.sum() entropy = sum(-probs * np.log2(probs)) return entropy def calculate_overall_entropy(data1, data2): """Calculate the overall entropy of the two input datasets. Parameters: ------ data1, data2 : numpy array Should be the datasets whose last column contains the class labels. Returns: ------ overall_entropy : float N.B. ------ If the data is an empty array, ZeroDivisionError will be raised. """ total_num = len(data1) + len(data2) prob_data1 = len(data1) / total_num prob_data2 = len(data2) / total_num overall_entropy = prob_data1 * calculate_entropy(data1) + prob_data2 * calculate_entropy(data2) return overall_entropy def calculate_gini(data): """Calculate the Gini index of the input data. Parameters: ------ data : numpy array Should be the data whose last column contains the class labels. Returns: ------ gini : float The Gini index of the data. N.B. ------ If the data is an empty array, gini will be 1. """ labels = data[:,-1] _, counts = np.unique(labels, return_counts=True) probs = counts / counts.sum() gini = 1 - sum(np.square(probs)) return gini def calculate_overall_gini(data1, data2): """Calculate the overall Gini index of the two input datasets. Parameters: ------ data1, data2 : numpy array Should be the datasets whose last column contains the class labels. Returns: ------ overall_gini : float N.B. ------ If the data is an empty array, ZeroDivisionError will be raised. """ total_num = len(data1) + len(data2) prob_data1 = len(data1) / total_num prob_data2 = len(data2) / total_num overall_gini = prob_data1 * calculate_gini(data1) + prob_data2 * calculate_gini(data2) return overall_gini def calculate_mce(data): """Calculate the misclassification error of the input data. Parameters: ------ data : numpy array Should be the data whose last column contains the class labels. Returns: ------ mce : float The misclassification error of the data. N.B. ------ If the data is an empty array, ValueError will be raised. """ labels = data[:,-1] _, counts = np.unique(labels, return_counts=True) probs = counts / counts.sum() mce = 1 - np.max(probs) return mce def calculate_overall_mce(data1, data2): """Calculate the overall misclassification error of the two input datasets. Parameters: ------ data1, data2 : numpy array Should be the datasets whose last column contains the class labels. Returns: ------ overall_mce : float N.B. ------ If the data is an empty array, ZeroDivisionError will be raised. """ total_num = len(data1) + len(data2) prob_data1 = len(data1) / total_num prob_data2 = len(data2) / total_num overall_mce = prob_data1 * calculate_mce(data1) + prob_data2 * calculate_mce(data2) return overall_mce def calculate_overall_impurity(data1, data2, method): """Calculate the overall impurity. Parameters: ------ data1, data2 : numpy array Should be the datasets whose last column contains the class labels. --- method : string -> 'entropy', 'gini', 'mce' Impurity computing method. Returns: ------ The value of impurity or ValueError if given wrong input. """ if method is 'entropy': return calculate_overall_entropy(data1, data2) elif method is 'gini': return calculate_overall_gini(data1, data2) elif method is 'mce': return calculate_overall_mce(data1, data2) else: raise ValueError def calculate_laplace_mcp(data): """Calculate the misclassification probability of the input data using Laplace's Law. Parameters: ------ data : numpy array Should be the data whose last column contains the class labels. Returns: ------ mce : float The misclassification error of the data. mce = (k-c+1)/(k+2), where k is the total number of samples and c is the number of majority class. N.B. ------ If the data is an empty array, ValueError will be raised. """ labels = data[:,-1] _, counts = np.unique(labels, return_counts=True) c = np.max(counts) k = counts.sum() mcp = (k-c+1)/(k+2) return mcp 

### Check the purity

If the data of a node has only one class, the node should be pure and be prepared to be classified. If not, further branching may be required according to the configuration of pruning.

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25  def check_purity(data): """Check the purity of the input data. Parameters: ------ data : numpy array Should be the data whose last column contains the class labels. Returns: ------ bool True: The data is pure False: The data is not pure N.B. ------ If the data is an empty array, False will also be returned. """ labels = data[:,-1] unique_classes = np.unique(labels) if len(unique_classes) == 1: return True else: return False 

### Classify the node

When the node is pure (holding only one class) as introduced above, it is necessary to classify the node with the class it has. However, in some cases, the node should be classified even if purity is not satisfied. For example, pre-pruning in my method defines a minimum number of samples of a node, indicating that even if multiple classes exist in the node, classification is required since it has reached the lower limit of sample amount. The way of classifying is to assign the class with the largest number of records to the node.

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23  def classify_data(data): """Classify the input data. Parameters: ------ data : numpy array Should be the data whose last column contains the class labels. Returns: ------ classification : type of the label column One of the labels in the label column with the highest count. N.B. ------ If the data is an empty array, ValueError will be raised. """ labels = data[:,-1] unique_classes, count_unique_classes = np.unique(labels, return_counts=True) index = count_unique_classes.argmax() classification = unique_classes[index] return classification 

### Data splitting

While the most crucial point of decision tree is braching, data splitting is the most significant job as it prepares data subsets for the son nodes in the deeper level. The data set of a node has several columns within which the columns except the last one are features to be differentiated and the last one contains all classes. Iteration can be implemented in these feature columns and the values of the fields. Meanwhile, the algorithm will find the best feature and the best threshold by which the data is splitted. The steps taken to find the optimal feature column and the threshold are as follows:

• Get all possible splits. Perform iterations over all feature columns and extract the averages of the adjacent entries as the thresholds of the related feature.

• Try to split the data. The algorithm iterates over all features and all thresholds, splitting the data into two subsets.

• Find the best method of splitting. Compute the overall impurity of two data subsets. Find the splitting method with the lowest degree of imprity.

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90  def get_splits(data): """Get all potential splits the data may have. Parameters: ------ data : numpy array The last column should be a column of labels. Returns: ------ splits : dictionary keys : column indices values : a list of [split thresholds] """ splits = {} n_cols = data.shape[1] # Number of columns for i_col in range(n_cols - 1): # Disregarding the last label column splits[i_col] = [] values = data[:,i_col] unique_values = np.unique(values) # All possible values for i_thresh in range(1,len(unique_values)): prev_value = unique_values[i_thresh - 1] curr_value = unique_values[i_thresh] splits[i_col].append((prev_value + curr_value)/2) # Return the average of two neighbour values return splits def split_data(data, split_index, split_thresh): """Split the data based on the split_thresh among values with the split_index. Parameters: ------ data : numpy array Input data that needs to be splitted. split_index : int The index of the column where the splitting is implemented. split_thresh : type of numpy array entries The threshold that splits the column values. Returns: ------ data_below, data_above : numpy array Splitted data. Below will be left son node and above will be right son node. """ split_column_values = data[:, split_index] data_below = data[split_column_values <= split_thresh] data_above = data[split_column_values > split_thresh] return data_below, data_above def find_best_split(data, splits, method): """Find the best split from all splits for the input data. Parameters: ------ data : numpy array The last column should be a column of labels. --- splits : dictionary keys : int, column indices values : a list of [split thresholds] --- Returns: ------ best_index : int The best column index of the data to split. --- best_thresh : float The best threshold of the data to split. --- """ global best_index global best_thresh min_overall_impurity = float('inf') # Store the largest overall impurity value for index in splits.keys(): for split_thresh in splits[index]: data_true, data_false = split_data(data=data,split_index=index, split_thresh=split_thresh) overall_impurity = calculate_overall_impurity(data_true, data_false, method) if overall_impurity <= min_overall_impurity: # Find new minimised impurity min_overall_impurity = overall_impurity # Replace the minimum impurity best_index = index best_thresh = split_thresh return best_index, best_thresh 

### Pruning

Pruning is a method to constrain the branching of the decision tree. If no pruning is performed, all nodes are divided until the son nodes are all holding one class in its data set. The configurations of pre-pruning and post-pruning are shown below.

#### Pre-pruning

Pre-pruning comes into effect in any cases of branching, which is different from the configurations of post-pruning. Three standards are defined for pre-pruning:

• Purity. If the data set has only one class, the node is classified.

• Lower limit of sample amount. If the number of samples of the data set reachs below a specified threshold, this node should not be splitted anymore.

• Upper limit of the decision tree depth. If the number of levels of the decision tree reaches the upper limit, the tree should not be growing.

#### Post-pruning

Post-pruning is based on the back-forward calculation of errors. After the tree has been learned, the algorithm computes the backup error from the bottom of the tree and performs a propagration to the top root. But in my implementation, the process turns out to be a recursive procedure that starts from the root node and return the backup error of the two son nodes. Recursively, the left son node will be assigned with the backup error of its son nodes.

Dynamic programming turns out to be quite useful and effective in my implementation but there is one more thing to do: keeping all nodes that have been visited in memory. The way I implemented in codes is to built a First-In-Last-Out (FILO) stack to contain all nodes the recursive process is visiting. After the backuperror is calculated for one node, this node is poped out from the stack for the subsequent processing of the remained nodes in the stack. The combination of dynamic programming and stack iteration is also used to merge son nodes with the same class and to visulise the decision tree.

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42  # Node class class Node: def __init__(self, data_df, depth=0): """Initialise the node. Parameters: ------ data_df : pandas DataFrame Its last column should be labels. --- depth : int, default=0 The current depth of the node. --- """ self.left = None # Left son node self.right = None # Right son node self.data = data_df # Data of the node self.depth = depth # The depth level of the node in the tree self.classification = None # The class of the node self.prev_condition = None # Condition that brings the data to the node self.prev_feature = None # The splitting feature self.prev_thresh = None # The splitting threshold self.backuperror = None # Backuperror for post-pruning self.mcp = None # Misclassification probability def set_splits(self, prev_condition, prev_feature, prev_thresh): """Assign the configuration of the splitting method. Parameters: ------ prev_condition : string The condition in the form like 'sourceIP cluster < 2.5'. --- prev_feature : feature name. --- prev_thresh : float The splitting threshold. --- """ self.prev_condition = prev_condition self.prev_feature = prev_feature self.prev_thresh = prev_thresh 
  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410  from tabulate import tabulate class DesicionTree: def __init__(self, criterion='entropy', post_prune=False, min_samples=2, max_depth=5): """Initialise a decision tree. Parameters: ------ root : Node Instance of class Node. --- criterion : string - 'criterion' (default): Entropy = -sum(Pi*log2Pi) - 'gini': Gini index = 1-sum(Pi^2) - 'mce': Misclassification Error = 1-max(Pi) The criterion based on which the data is splitted. For example, it criterion is 'entroy', then the best split method should have the lowest overall entropy. --- post_prune : bool Whether the decision tree should be post-pruned. --- min_samples : int, default = 2 The minimum number of samples a node should contain. --- max_depth : int, default = 5 The maximum number of depth the tree can have. --- features : DataFrames.columns The attributes of the root data. --- """ self.root = None self.criterion = criterion self.post_prune = post_prune self.min_samples = min_samples self.max_depth = max_depth self.features = None def feed(self, data_df): """Feed the decision tree with data. Parameters: ------ data_df : pandas DataFrame """ self.root = Node(data_df, 0) self._fit(self.root) def _fit(self, node): """Fit the data, check impurity and make splits. Parameters: ------ node : Node instance """ # Prepare data data = node.data # pandas DataFrame depth = node.depth if depth is 0: self.features = data.columns data = data.values # numpy array # Pre-pruning if (check_purity(data)) or (len(data) < self.min_samples) or (depth is self.max_depth): # Stop splitting? classification = classify_data(data) node.classification = classification # Recursive else: # Keep splitting # Splitting splits = get_splits(data) split_index, split_thresh = find_best_split(data, splits, self.criterion) data_left, data_right = split_data(data, split_index, split_thresh) # Pre-pruning: Prevent empty split if (data_left.size is 0) or (data_right.size is 0): classification = classify_data(data) node.classification = classification else: depth += 1 # Deeper depth # Transform the numpy array into pandas DataFrame for the node data_left_df = pd.DataFrame(data_left,columns=list(self.features)) data_right_df = pd.DataFrame(data_right,columns=list(self.features)) # Get condition description feature_name = self.features[split_index] true_condition = "{} <= {}".format(feature_name, split_thresh) false_condition = "{} > {}".format(feature_name, split_thresh) # Set values of the node node.left = Node(data_left_df,depth=depth) node.right = Node(data_right_df, depth=depth) node.left.set_splits(true_condition, feature_name, split_thresh) node.right.set_splits(false_condition, feature_name, split_thresh) # Recursive process self._fit(node.left) self._fit(node.right) self._merge() # Merge the son nodes with the same class if self.post_prune: # Post-pruning self._post_prune() def _merge(self): """Merge the son nodes if they are both classifified as the same class. """ # First the root stack = [] # LIFO, Build a stack to store the Nodes stack.append(self.root) while True: if len(stack): pop_node = stack.pop() if pop_node.left: if pop_node.left.classification: # Already classified if pop_node.left.classification == pop_node.right.classification: # Same classification pop_node.classification = pop_node.left.classification pop_node.left = None pop_node.right = None else: # Different classifications stack.append(pop_node.right) stack.append(pop_node.left) else: # Not classified stack.append(pop_node.right) stack.append(pop_node.left) else: break def _calculate_error(self, node): # Misclassification probability using Laplace's Law if node.left: # There are son nodes, the backuperror of this node is the weighted sum of the backuperrors of sons backuperror_left = self._calculate_error(node.left) backuperror_right = self._calculate_error(node.right) node.backuperror = len(node.left.data)/len(node.data)*backuperror_left + len(node.right.data)/len(node.data)*backuperror_right node.mcp = calculate_laplace_mcp(node.data.to_numpy()) # And we still need mcp else: # No son nodes, backuperror = mcp node.backuperror = node.mcp = calculate_laplace_mcp(node.data.to_numpy()) return node.backuperror def _post_prune(self): """Post pruning. """ self._calculate_error(self.root) # LIFO processing stack = [] stack.append(self.root) while True: if len(stack): pop_node = stack.pop() if pop_node.left: # We only prune nodes with sons if pop_node.backuperror > pop_node.mcp: node = None else: stack.append(pop_node.right) stack.append(pop_node.left) else: break def view(self, method, saveflag=False, savename='Decision Tree'): """Visulise the decision tree. Parameters: ------ method : string - 'text', 't' or 0: Print the tree in text. - 'graph', 'g' or 1: Print the tree graphically. --- saveflag : bool Whether or not to save the visualisation. --- savename : string, default: 'Decision Tree' The saved file name if saveflag is True. --- """ # Object type check and analysis to avoid invalid input if isinstance(method, str) is True: if method is 'text' or method is 't': method = 0 elif method is 'graph' or method is 'g': method = 1 else: raise ValueError elif isinstance(method, int) is True: if method is 0 or method is 1: pass else: raise ValueError else: raise TypeError # Visualise by calling specific functions if method is 0: print('Visulising the decision tree in {}.'.format('text')) self._view_text(saveflag, savename) else: print('Visulising the decision tree {}.'.format('graphically')) self._view_graph(saveflag, savename) def _get_prefix(self, depth): """Get the prefix of the node description string. Parameters: ------ depth : int The depth of the node. --- For example, if depth is 1, the prefix is '|---' """ default_prefix = '|---' depth_prefix = '|\t' prefix = depth_prefix * (depth - 1) + default_prefix return prefix def _view_node_text(self, node, fw): """Print the desription of a node. Parameters: ------ node : Node instance. --- fw : the file that has been opened. --- """ if node.prev_condition: # If there is a condition rather than None line = self._get_prefix(node.depth) + node.prev_condition # save to .txt if fw: fw.write(line+'\n') print(line) if node.classification: # If there is a classification rather than None line = self._get_prefix(node.depth+1) + node.classification if fw: fw.write(line+'\n') print(line) def _view_text(self, saveflag=False, savename='Decision Tree'): """View the tree in text. Parameters: ------ saveflag : bool Whether or not to save the visualisation. --- savename : string, default: 'Decision Tree' The saved file name if saveflag is True. --- """ # First the root stack = [] # LIFO, Build a stack to store the Nodes stack.append(self.root) fw = None # Open file if saveflag: fw = open(savename+'.txt','w') while True: if len(stack): pop_node = stack.pop() # Pop out the visiting node self._view_node_text(pop_node, fw) # Recursice process if pop_node.left: stack.append(pop_node.right) stack.append(pop_node.left) else: break if fw: fw.close() def _view_node_graph(self, node, coords): """Visulise a node in graph. Parameters: ------ node : Node instance. --- coords : tuple of floats (x,y) where the node is plotted in the graph. --- """ data_df = node.data # Condition str_condition = node.prev_condition + '\n' if node.prev_condition else '' # Impurity str_method = self.criterion if str_method is 'entropy': impurity = calculate_entropy(data_df.values) elif str_method is 'gini': impurity = calculate_gini(data_df.values) elif str_method is 'mce': impurity = calculate_mce(data_df.values) else: raise ValueError # Number of samples str_samples = str(len(data_df)) # Classes str_predicted_class = node.classification + '\n' if node.classification else '' np_classes = np.unique(data_df[data_df.columns[-1]].to_numpy()) str_actual_classes = ',\n'.join(list(np.unique(np_classes))) # Plot the text with bound (x, y) = coords node_text = str_condition + str_method + ' = ' + str(round(impurity,4)) + '\n' + 'samples = ' + str_samples + '\n' + 'class = ' + str_predicted_class + 'Actual classes = ' + str_actual_classes plt.text(x, y, node_text, color='black', ha='center', va='center') # If there are son nodes x_offset = 0.5 y_offset = 0.1 line_y_offset = 0.015 if node.left: coords_left = (x-x_offset, y-y_offset) # Coordinates of the left son node coords_right = (x+x_offset, y-y_offset) # Coordinates of the right son node line_to_sons = ([x-x_offset, x, x+x_offset], [y-y_offset+line_y_offset, y-line_y_offset, y-y_offset+line_y_offset]) # Plot connection lines plt.plot(line_to_sons[0], line_to_sons[1], color='black', linewidth=0.5) # Recursive part self._view_node_graph(node.left, coords_left) self._view_node_graph(node.right, coords_right) def _view_graph(self, saveflag=False, savename='Decision Tree'): """View the tree graphically. Parameters: ------ saveflag : bool Whether or not to save the visualisation. --- savename : string, default: 'Decision Tree' The saved file name if saveflag is True. --- """ plt.figure() self._view_node_graph(self.root, (0,0)) # Plot from the root at (0,0) plt.axis('off') if saveflag: plt.savefig(savename + '.pdf', bbox_inches='tight') plt.savefig(savename + '.jpg', bbox_inches='tight') plt.show() def print_info(self): """Print the information of the decision tree. """ print( tabulate( [ ['Data head', self.root.data.head() if self.root else None], ['Criterion', self.criterion], ['Minimum size of the node data', self.min_samples], ['Maximum depth of the tree', self.max_depth], ['Post_pruning', self.post_prune], ['Features', [feature for feature in self.features]], ['All classes', list(np.unique(self.root.data[self.root.data.columns[-1]].to_numpy()))] ], headers=['Attributes', 'Values'], tablefmt='fancy_grid' ) ) def predict(self, test_data_df): """Predict the classification of the input DataFrame. Parameters: ------ test_data_df : pandas DataFrame Should be in the same format of the training dataset. --- """ # Only one row of sample if len(test_data_df) == 1: class_name = self._predict_example(test_data_df, self.root) return class_name else: # Multiple rows predicted_classes = [] # Iterate over all samples and store the classes in a list for i_row in range(len(test_data_df)): test_data_example = test_data_df[i_row:i_row+1] predicted_classes.append(self._predict_example(test_data_example, self.root)) return predicted_classes def _predict_example(self, data_df, node): """Predict the class of a single sample. Parameters: ------ data_df : pandas DataFrame One-row DataFrame. --- node : Node instance This is for a recursive procedure of deciding the classification of the expandable node, i.e. the deepest node the data will reach to. """ # If there are son nodes for further expanding if node.left: # Yes feature_name = node.left.prev_feature split_thresh = node.left.prev_thresh # Recursive part if data_df.iloc[0][feature_name] <= split_thresh: # Go to left son return self._predict_example(data_df, node.left) else: # Go to right son return self._predict_example(data_df, node.right) else: # No expanding return node.classification 

## Test with independent data

By default, entropy criterion is selected to initialise the decision tree and the flag of post-pruning is set as True. The cluster dataset generated before is firstly splitted into training set and test set randomly. Then the training set is fed to the decision tree then the decision tree is learned automatically. The final decision tree in text is shown below and it can also be illustrated in the Figure.

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21  import random def train_test_split(df, test_size): """Split the data into train and test parts randomly. Parameters: df : pd.DataFrame, input data test_size : either a percentage or the number of the test samples Returns: train_df : pd.DataFrame, training data test_df : pd.DataFrame, test data """ if isinstance(test_size, float): test_size = round(test_size * len(df)) indices = df.index.tolist() test_indices = random.sample(population=indices, k=test_size) test_df = df.loc[test_indices] train_df = df.drop(test_indices) return train_df, test_df 
  1 2 3 4 5 6 7 8 9 10  import random random.seed(1) # For reproduction train_data, test_data = train_test_split(cluster_data, test_size=0.1) dt = DesicionTree(post_prune=True) dt.feed(train_data) dt.view(method='t', saveflag=True) # View in text dt.view('g', True, savename='q5-decision-tree') # View in graph 
Visulising the decision tree in text.
|---destIP cluster <= 2.5
|	|--- Misc activity
|---destIP cluster > 2.5
|	|--- Generic Protocol Command Decode
Visulising the decision tree graphically.


It is noticeable that although three classes (Generic Protocol Command Decode, Misc activity, Potential Corporate Privacy Violation) exist in the original training data, the only two son nodes predicts only two classes (Generic Protocol Command Decode, Misc activity) among the three. The decision tree can give a fairly certain ansewr in two cases.

This situation can be indicated by printing the confusion matrix while testing the decision tree. The test set is input to the predict function and a list of predicted classes is generated. I made use of both the ground truth classes and the predicted classes to produce the confusion matrix and print the precision, recall of the classification, shown as below. Obviously, all samples with class Potential Corporate Privacy Violation are the unseen data.

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18  extended_test_data = test_data.copy() # Deep copy to avoid shared reference predicted_classes = dt.predict(extended_test_data) # Predict extended_test_data['predicted'] = predicted_classes # Add a column of predicted from sklearn import metrics y_true = extended_test_data['class'].to_numpy() y_predicted = extended_test_data['predicted'].to_numpy() # Classification report print( 'Classification report:\n', metrics.classification_report(y_true, y_predicted) ) # Confusion matrix print( 'Confusion matrix:\n', metrics.confusion_matrix(y_true=y_true, y_pred=y_predicted) ) 
Classification report:
precision    recall  f1-score   support

Generic Protocol Command Decode       0.97      1.00      0.99      1301
Misc activity       1.00      1.00      1.00       507
Potential Corporate Privacy Violation       0.00      0.00      0.00        35

accuracy                           0.98      1843
macro avg       0.66      0.67      0.66      1843
weighted avg       0.96      0.98      0.97      1843

Confusion matrix:
[[1301    0    0]
[   0  507    0]
[  35    0    0]]

first commit