Back to ML Guide
Supervised Learning

Decision Trees

Non-parametric models that learn simple decision rules inferred from data features. Highly interpretable but prone to overfitting.

3 min read

Theory & Concept

Decision Tree is a flowchart-like structure where an internal node represents a feature (or attribute), the branch represents a decision rule, and each leaf node represents the outcome. The topmost node in a decision tree is known as the root node.

It handles both categorical and numerical data and is one of the few machine learning algorithms that produces a model that is easy to understand and interpret (white box model).

Key Components

  1. Root Node: The starting point of the tree.
  2. Splitting: Dividing a node into two or more sub-nodes.
  3. Decision Node: When a sub-node splits into further sub-nodes.
  4. Leaf/Terminal Node: Nodes that do not split (the final prediction).

How it Splits (Impurity Measures)

The tree learns by splitting data to maximize "purity" in the child nodes.

  • Gini Impurity (Classification): Measure of how often a randomly chosen element from the set would be incorrectly labeled.
  • Entropy/Information Gain (Classification): Measure of randomness/uncertainty.
  • Variance Reduction (Regression): Minimizes the sum of squared errors in child nodes.

Mathematical Intuition (Entropy)

H(S)=i=1cpilog2piH(S) = - \sum_{i=1}^{c} p_i \log_2 p_i

Where pip_i is the probability of an element belonging to class ii. The algorithm calculates the Information Gain for a split:

Gain(S,A)=H(S)vValues(A)SvSH(Sv)Gain(S, A) = H(S) - \sum_{v \in Values(A)} \frac{|S_v|}{|S|} H(S_v)

It chooses the split AA that maximizes this gain (reduces entropy the most).


Quick Readiness Check

Quick Readiness Check

Is this method a fit for your use case?

Best For

Baseline models where explainability is critical (e.g., medical diagnosis, loan approval).

Prerequisites

None for feature scaling. Handling missing values depends on implementation.

Strengths

Highly interpretable, handles non-linear relationships, requires little data preprocessing.

Weaknesses

High Variance: Small changes in data can result in a different tree. Prone to overfitting.

Pro Tip

If asked 'How do you stop a tree from overfitting?', mention Pruning (pre-pruning with max_depth or post-pruning).


Code Snippet

from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
 
# 1. Train
# max_depth=3 prevents overfitting and keeps it interpretable
clf = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)
clf.fit(X_train, y_train)
 
# 2. Predict
preds = clf.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, preds):.4f}")
 
# 3. Visualize (The "White Box" advantage)
plt.figure(figsize=(12,8))
plot_tree(clf, filled=True, feature_names=feature_names, class_names=class_names)
plt.show()

Parameter Tuning Cheat Sheet

ParameterOptions / RangeEffect & Best Practice
max_depthNone, int (3-10)Most critical. Limits how deep the tree grows. Set to 3-5 to prevent overfitting.
min_samples_splitint (2-20)Minimum samples required to split a node. Higher values = simpler model.
min_samples_leafint (1-10)Minimum samples required at a leaf node. Smooths the model.