Finding the proper drug for a new patient using Decision Tree Classification
We will use the Decision Tree classification algorithm to build a model from the historical data of patients, and their response to different medications. Then we'll use the trained decision tree to predict the class of an unknown patient or find a proper drug for a new patient.
Imagine that you are a medical researcher compiling data for a study. You have collected data about a set of patients, all of whom suffered from the same illness. During their course of treatment, each patient responded to one of 5 medications, Drug A, Drug B, Drug C, Drug X and Y.
Part of your job is to build a model to find out which drug might be appropriate for a future patient with the same illness. The feature sets of this dataset are Age, Sex, Blood Pressure, and Cholesterol of patients, and the target is the drug that each patient responded to.
In this blog post, we will use the Decision Tree classification algorithm to build a model from the historical data of patients, and their response to different medications. Then we'll use the trained decision tree to predict the class of an unknown patient or find a proper drug for a new patient.
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
df = pd.read_csv('drug200.csv', delimiter=',')
df.head()
df.shape
X = df[['Age', 'Sex', 'BP', 'Cholesterol', 'Na_to_K']].values
X[0:5]
Some features in this dataset are categorical, such as Sex and BP. Unfortunately, sklearn Decision Trees do not handle categorical variables. So convert these features to numerical values. pandas.get_dummies()
converts categorical variables into dummy/indicator variables.
from sklearn import preprocessing
le_sex = preprocessing.LabelEncoder()
le_sex.fit(['F','M'])
X[:,1] = le_sex.transform(X[:,1])
le_BP = preprocessing.LabelEncoder()
le_BP.fit([ 'LOW', 'NORMAL', 'HIGH'])
X[:,2] = le_BP.transform(X[:,2])
le_Chol = preprocessing.LabelEncoder()
le_Chol.fit([ 'NORMAL', 'HIGH'])
X[:,3] = le_Chol.transform(X[:,3])
X[0:5]
Now, fill the target variable.
y = df['Drug']
y[0:5]
from sklearn.model_selection import train_test_split
# train 70% of the data
X_trainset, X_testset, y_trainset, y_testset = train_test_split(X, y, test_size=0.3, random_state=3)
# dimensions of the sets
print('The shape of the train set predictors', X_trainset.shape)
print('The shape of the train set target', y_trainset.shape)
print('The shape of the test set predictors', X_testset.shape)
print('The shape of the test set target', y_testset.shape)
drugTree = DecisionTreeClassifier(criterion='entropy', max_depth=4)
drugTree # shows the default parameters
Fit the data with the training feature matrix X_trainset and the training response vector y_trainset
drugTree.fit(X_trainset, y_trainset)
predTree = drugTree.predict(X_testset)
Print out predTree and y_testset if you want to visually compare the prediction to the actual values.
print(predTree[0:5])
print(y_testset[0:5])
from sklearn import metrics
import matplotlib.pyplot as plt
print('DecisoinTree;s Accuracy: ', metrics.accuracy_score(y_testset, predTree))
Our Decision Tree is very accurate.
Accuracy classification score computes subset accuracy, i.e. the set of labels predicted for a sample must exactly match the corresponding set of labels in y_true.
In multilabel classification, the function returns the subset accuracy. If the entire set of predicted labels for a sample strictly match with the true set of labels, then the subset accuracy is 1.0; otherwise it is 0.0.
from sklearn.externals.six import StringIO
import pydotplus
import matplotlib.image as mpimg
from sklearn import tree
%matplotlib inline
dot_data = StringIO()
filename = "drugtree.png"
featureNames = df.columns[0:5]
targetNames = df["Drug"].unique().tolist()
out=tree.export_graphviz(drugTree,feature_names=featureNames, out_file=dot_data, class_names= np.unique(y_trainset), filled=True, special_characters=True,rotate=False)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png(filename)
img = mpimg.imread(filename)
plt.figure(figsize=(20, 40))
plt.imshow(img,interpolation='nearest')