CustomerAttritionwithRandSQL
By Ira Sharenow
June 9, 2014Version 1
In this paper, I take customer attrition data from the R library C50. I start out by doing some exploratory data analysis using SQL (sqldf).
Then I apply several different models to the data. In some cases I redo the analysis using the caret wrapper package. Typically, I fit a model with a training set and then I measure performance on a test set. I create confusion matrices in order to evaluate performance for both the customers that were retained and those that were not retained.
See the summary section for a summary!
Table of Contents
- Logistic regression
- Trees
- Bagging
- Random Forests
- Boosting
- SVM
- Summary of results
- References
Take a first look at the churn data We are trying to predict churn from the other variables
library(C50)
data(churn)
str(churnTrain)
## 'data.frame': 3333 obs. of 20 variables:
## $ state : Factor w/ 51 levels "AK","AL","AR",..: 17 36 32 36 37 2 20 25 19 50 ...
## $ account_length : int 128 107 137 84 75 118 121 147 117 141 ...
## $ area_code : Factor w/ 3 levels "area_code_408",..: 2 2 2 1 2 3 3 2 1 2 ...
## $ international_plan : Factor w/ 2 levels "no","yes": 1 1 1 2 2 2 1 2 1 2 ...
## $ voice_mail_plan : Factor w/ 2 levels "no","yes": 2 2 1 1 1 1 2 1 1 2 ...
## $ number_vmail_messages : int 25 26 0 0 0 0 24 0 0 37 ...
## $ total_day_minutes : num 265 162 243 299 167 ...
## $ total_day_calls : int 110 123 114 71 113 98 88 79 97 84 ...
## $ total_day_charge : num 45.1 27.5 41.4 50.9 28.3 ...
## $ total_eve_minutes : num 197.4 195.5 121.2 61.9 148.3 ...
## $ total_eve_calls : int 99 103 110 88 122 101 108 94 80 111 ...
## $ total_eve_charge : num 16.78 16.62 10.3 5.26 12.61 ...
## $ total_night_minutes : num 245 254 163 197 187 ...
## $ total_night_calls : int 91 103 104 89 121 118 118 96 90 97 ...
## $ total_night_charge : num 11.01 11.45 7.32 8.86 8.41 ...
## $ total_intl_minutes : num 10 13.7 12.2 6.6 10.1 6.3 7.5 7.1 8.7 11.2 ...
## $ total_intl_calls : int 3 3 5 7 3 6 7 6 4 5 ...
## $ total_intl_charge : num 2.7 3.7 3.29 1.78 2.73 1.7 2.03 1.92 2.35 3.02 ...
## $ number_customer_service_calls: int 1 1 0 2 3 0 3 0 1 0 ...
## $ churn : Factor w/ 2 levels "yes","no": 2 2 2 2 2 2 2 2 2 2 ...
str(churnTest)
## 'data.frame': 1667 obs. of 20 variables:
## $ state : Factor w/ 51 levels "AK","AL","AR",..: 12 27 36 33 41 13 29 19 25 44 ...
## $ account_length : int 101 137 103 99 108 117 63 94 138 128 ...
## $ area_code : Factor w/ 3 levels "area_code_408",..: 3 3 1 2 2 2 2 1 3 2 ...
## $ international_plan : Factor w/ 2 levels "no","yes": 1 1 1 1 1 1 1 1 1 1 ...
## $ voice_mail_plan : Factor w/ 2 levels "no","yes": 1 1 2 1 1 1 2 1 1 2 ...
## $ number_vmail_messages : int 0 0 29 0 0 0 32 0 0 43 ...
## $ total_day_minutes : num 70.9 223.6 294.7 216.8 197.4 ...
## $ total_day_calls : int 123 86 95 123 78 85 124 97 117 100 ...
## $ total_day_charge : num 12.1 38 50.1 36.9 33.6 ...
## $ total_eve_minutes : num 212 245 237 126 124 ...
## $ total_eve_calls : int 73 139 105 88 101 68 125 112 46 89 ...
## $ total_eve_charge : num 18 20.8 20.2 10.7 10.5 ...
## $ total_night_minutes : num 236 94.2 300.3 220.6 204.5 ...
## $ total_night_calls : int 73 81 127 82 107 90 120 106 71 92 ...
## $ total_night_charge : num 10.62 4.24 13.51 9.93 9.2 ...
## $ total_intl_minutes : num 10.6 9.5 13.7 15.7 7.7 6.9 12.9 11.1 9.9 11.9 ...
## $ total_intl_calls : int 3 7 6 2 4 5 3 6 4 1 ...
## $ total_intl_charge : num 2.86 2.57 3.7 4.24 2.08 1.86 3.48 3 2.67 3.21 ...
## $ number_customer_service_calls: int 3 0 1 1 2 1 1 0 2 0 ...
## $ churn : Factor w/ 2 levels "yes","no": 2 2 2 2 2 2 2 2 2 2 ...
head(churnTrain)
## state account_length area_code international_plan voice_mail_plan
## 1 KS 128 area_code_415 no yes
## 2 OH 107 area_code_415 no yes
## 3 NJ 137 area_code_415 no no
## 4 OH 84 area_code_408 yes no
## 5 OK 75 area_code_415 yes no
## 6 AL 118 area_code_510 yes no
## number_vmail_messages total_day_minutes total_day_calls total_day_charge
## 1 25 265.1 110 45.07
## 2 26 161.6 123 27.47
## 3 0 243.4 114 41.38
## 4 0 299.4 71 50.90
## 5 0 166.7 113 28.34
## 6 0 223.4 98 37.98
## total_eve_minutes total_eve_calls total_eve_charge total_night_minutes
## 1 197.4 99 16.78 244.7
## 2 195.5 103 16.62 254.4
## 3 121.2 110 10.30 162.6
## 4 61.9 88 5.26 196.9
## 5 148.3 122 12.61 186.9
## 6 220.6 101 18.75 203.9
## total_night_calls total_night_charge total_intl_minutes total_intl_calls
## 1 91 11.01 10.0 3
## 2 103 11.45 13.7 3
## 3 104 7.32 12.2 5
## 4 89 8.86 6.6 7
## 5 121 8.41 10.1 3
## 6 118 9.18 6.3 6
## total_intl_charge number_customer_service_calls churn
## 1 2.70 1 no
## 2 3.70 1 no
## 3 3.29 0 no
## 4 1.78 2 no
## 5 2.73 3 no
## 6 1.70 0 no
names(churnTrain)
## [1] "state" "account_length"
## [3] "area_code" "international_plan"
## [5] "voice_mail_plan" "number_vmail_messages"
## [7] "total_day_minutes" "total_day_calls"
## [9] "total_day_charge" "total_eve_minutes"
## [11] "total_eve_calls" "total_eve_charge"
## [13] "total_night_minutes" "total_night_calls"
## [15] "total_night_charge" "total_intl_minutes"
## [17] "total_intl_calls" "total_intl_charge"
## [19] "number_customer_service_calls" "churn"
table(churnTrain$churn)
##
## yes no
## 483 2850
table(churnTest$churn)
##
## yes no
## 224 1443
SQL queries on the train dataset
library(sqldf)
## Warning: package 'sqldf' was built under R version 3.0.3
## Loading required package: gsubfn
## Loading required package: proto
## Loading required namespace: tcltk
## Loading required package: RSQLite
## Loading required package: DBI
## Loading required package: RSQLite.extfuns
# Number of rows
q1 = "SELECT COUNT(*) AS NumberRows
FROM churnTrain;"
query1 =sqldf(q1)
## Loading required package: tcltk
query1
## NumberRows
## 1 3333
# Number of rows by churn value yes/no
q2 = "SELECT churn, COUNT(*) AS NumberRows
FROM churnTrain
GROUP BY churn;"
query2 =sqldf(q2)
query2
## churn NumberRows
## 1 no 2850
## 2 yes 483
# Churn value by state
q3 = "SELECT state, churn, COUNT(*) AS NumberRows
FROM churnTrain
GROUP BY state, churn
ORDER BY state;"
query3 =sqldf(q3)
query3
## state churn NumberRows
## 1 AK no 49
## 2 AK yes 3
## 3 AL no 72
## 4 AL yes 8
## 5 AR no 44
## 6 AR yes 11
## 7 AZ no 60
## 8 AZ yes 4
## 9 CA no 25
## 10 CA yes 9
## 11 CO no 57
## 12 CO yes 9
## 13 CT no 62
## 14 CT yes 12
## 15 DC no 49
## 16 DC yes 5
## 17 DE no 52
## 18 DE yes 9
## 19 FL no 55
## 20 FL yes 8
## 21 GA no 46
## 22 GA yes 8
## 23 HI no 50
## 24 HI yes 3
## 25 IA no 41
## 26 IA yes 3
## 27 ID no 64
## 28 ID yes 9
## 29 IL no 53
## 30 IL yes 5
## 31 IN no 62
## 32 IN yes 9
## 33 KS no 57
## 34 KS yes 13
## 35 KY no 51
## 36 KY yes 8
## 37 LA no 47
## 38 LA yes 4
## 39 MA no 54
## 40 MA yes 11
## 41 MD no 53
## 42 MD yes 17
## 43 ME no 49
## 44 ME yes 13
## 45 MI no 57
## 46 MI yes 16
## 47 MN no 69
## 48 MN yes 15
## 49 MO no 56
## 50 MO yes 7
## 51 MS no 51
## 52 MS yes 14
## 53 MT no 54
## 54 MT yes 14
## 55 NC no 57
## 56 NC yes 11
## 57 ND no 56
## 58 ND yes 6
## 59 NE no 56
## 60 NE yes 5
## 61 NH no 47
## 62 NH yes 9
## 63 NJ no 50
## 64 NJ yes 18
## 65 NM no 56
## 66 NM yes 6
## 67 NV no 52
## 68 NV yes 14
## 69 NY no 68
## 70 NY yes 15
## 71 OH no 68
## 72 OH yes 10
## 73 OK no 52
## 74 OK yes 9
## 75 OR no 67
## 76 OR yes 11
## 77 PA no 37
## 78 PA yes 8
## 79 RI no 59
## 80 RI yes 6
## 81 SC no 46
## 82 SC yes 14
## 83 SD no 52
## 84 SD yes 8
## 85 TN no 48
## 86 TN yes 5
## 87 TX no 54
## 88 TX yes 18
## 89 UT no 62
## 90 UT yes 10
## 91 VA no 72
## 92 VA yes 5
## 93 VT no 65
## 94 VT yes 8
## 95 WA no 52
## 96 WA yes 14
## 97 WI no 71
## 98 WI yes 7
## 99 WV no 96
## 100 WV yes 10
## 101 WY no 68
## 102 WY yes 9
Model 1
Logistic regression
Logistic regression all variables
# Basic information on the amount of churn in the training and test sets
table(churnTrain$churn)
##
## yes no
## 483 2850
table(churnTest$churn)
##
## yes no
## 224 1443
log.fit =glm(churn ~., data =churnTrain, family =binomial)
summary(log.fit)
##
## Call:
## glm(formula = churn ~ ., family = binomial, data = churnTrain)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -3.043 0.166 0.312 0.499 1.949
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) 9.69e+00 9.80e-01 9.89 < 2e-16 ***
## stateAL -3.39e-01 7.63e-01 -0.44 0.65727
## stateAR -9.11e-01 7.52e-01 -1.21 0.22588
## stateAZ -8.97e-02 8.45e-01 -0.11 0.91545
## stateCA -1.82e+00 7.82e-01 -2.32 0.02024 *
## stateCO -6.45e-01 7.63e-01 -0.84 0.39834
## stateCT -1.02e+00 7.25e-01 -1.41 0.15917
## stateDC -6.88e-01 8.08e-01 -0.85 0.39458
## stateDE -7.46e-01 7.49e-01 -1.00 0.31923
## stateFL -5.92e-01 7.61e-01 -0.78 0.43696
## stateGA -6.60e-01 7.78e-01 -0.85 0.39607
## stateHI 2.30e-01 8.96e-01 0.26 0.79747
## stateIA -2.08e-01 9.02e-01 -0.23 0.81741
## stateID -8.71e-01 7.47e-01 -1.16 0.24410
## stateIL 2.38e-01 8.34e-01 0.29 0.77517
## stateIN -4.41e-01 7.53e-01 -0.59 0.55792
## stateKS -1.06e+00 7.30e-01 -1.46 0.14566
## stateKY -7.89e-01 7.66e-01 -1.03 0.30293
## stateLA -5.55e-01 8.35e-01 -0.66 0.50672
## stateMA -1.16e+00 7.43e-01 -1.56 0.11826
## stateMD -1.14e+00 7.17e-01 -1.60 0.11043
## stateME -1.33e+00 7.28e-01 -1.82 0.06832 .
## stateMI -1.39e+00 7.14e-01 -1.95 0.05140 .
## stateMN -1.16e+00 7.15e-01 -1.62 0.10471
## stateMO -5.98e-01 7.74e-01 -0.77 0.43991
## stateMS -1.36e+00 7.28e-01 -1.86 0.06260 .
## stateMT -1.87e+00 7.17e-01 -2.60 0.00924 **
## stateNC -5.77e-01 7.55e-01 -0.76 0.44482
## stateND -1.27e-01 7.97e-01 -0.16 0.87300
## stateNE -2.95e-01 8.06e-01 -0.37 0.71398
## stateNH -1.16e+00 7.69e-01 -1.51 0.13137
## stateNJ -1.57e+00 7.10e-01 -2.22 0.02676 *
## stateNM -4.59e-01 7.87e-01 -0.58 0.55960
## stateNV -1.25e+00 7.25e-01 -1.73 0.08420 .
## stateNY -1.16e+00 7.19e-01 -1.61 0.10650
## stateOH -6.73e-01 7.46e-01 -0.90 0.36751
## stateOK -8.66e-01 7.56e-01 -1.15 0.25181
## stateOR -7.68e-01 7.35e-01 -1.04 0.29613
## statePA -1.14e+00 7.79e-01 -1.46 0.14312
## stateRI 1.10e-01 8.20e-01 0.13 0.89334
## stateSC -1.75e+00 7.37e-01 -2.37 0.01778 *
## stateSD -8.23e-01 7.61e-01 -1.08 0.27951
## stateTN -2.60e-01 8.21e-01 -0.32 0.75107
## stateTX -1.64e+00 7.08e-01 -2.31 0.02075 *
## stateUT -1.05e+00 7.44e-01 -1.41 0.15906
## stateVA 4.42e-01 8.22e-01 0.54 0.59034
## stateVT -8.39e-02 7.80e-01 -0.11 0.91433
## stateWA -1.40e+00 7.24e-01 -1.93 0.05308 .
## stateWI -2.84e-01 7.80e-01 -0.36 0.71611
## stateWV -5.73e-01 7.33e-01 -0.78 0.43414
## stateWY -2.95e-01 7.54e-01 -0.39 0.69545
## account_length -9.65e-04 1.43e-03 -0.67 0.50121
## area_codearea_code_415 7.88e-02 1.42e-01 0.56 0.57857
## area_codearea_code_510 1.02e-01 1.63e-01 0.62 0.53362
## international_planyes -2.19e+00 1.53e-01 -14.29 < 2e-16 ***
## voice_mail_planyes 2.13e+00 5.94e-01 3.59 0.00034 ***
## number_vmail_messages -3.83e-02 1.86e-02 -2.06 0.03987 *
## total_day_minutes 3.82e-01 3.38e+00 0.11 0.90994
## total_day_calls -4.04e-03 2.86e-03 -1.41 0.15748
## total_day_charge -2.33e+00 1.99e+01 -0.12 0.90687
## total_eve_minutes -8.93e-01 1.70e+00 -0.53 0.59951
## total_eve_calls -1.02e-03 2.89e-03 -0.35 0.72464
## total_eve_charge 1.04e+01 2.00e+01 0.52 0.60270
## total_night_minutes 2.23e-01 9.04e-01 0.25 0.80540
## total_night_calls -1.81e-04 2.93e-03 -0.06 0.95072
## total_night_charge -5.04e+00 2.01e+01 -0.25 0.80204
## total_intl_minutes 4.15e+00 5.49e+00 0.76 0.45019
## total_intl_calls 9.06e-02 2.58e-02 3.52 0.00044 ***
## total_intl_charge -1.57e+01 2.03e+01 -0.77 0.44112
## number_customer_service_calls -5.37e-01 4.10e-02 -13.09 < 2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 2758.3 on 3332 degrees of freedom
## Residual deviance: 2070.8 on 3263 degrees of freedom
## AIC: 2211
##
## Number of Fisher Scoring iterations: 6
#I looked ahead to the random forests model and chose the most important variables to create
#a smaller logistic regression model. I used the mean decrease in the Gini Index
log.fit2 =glm(churn ~total_day_charge +total_day_minutes +number_customer_service_calls +total_eve_charge,
data =churnTrain, family =binomial)
summary(log.fit2)
##
## Call:
## glm(formula = churn ~ total_day_charge + total_day_minutes +
## number_customer_service_calls + total_eve_charge, family = binomial,
## data = churnTrain)
##
## Deviance Residuals:
## Min 1Q Median 3Q Max
## -2.997 0.292 0.423 0.562 1.804
##
## Coefficients:
## Estimate Std. Error z value Pr(>|z|)
## (Intercept) 6.1450 0.3325 18.48 < 2e-16 ***
## total_day_charge -0.5841 18.1434 -0.03 0.97
## total_day_minutes 0.0871 3.0844 0.03 0.98
## number_customer_service_calls -0.4391 0.0364 -12.05 < 2e-16 ***
## total_eve_charge -0.0709 0.0125 -5.67 1.4e-08 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## (Dispersion parameter for binomial family taken to be 1)
##
## Null deviance: 2758.3 on 3332 degrees of freedom
## Residual deviance: 2437.8 on 3328 degrees of freedom
## AIC: 2448
##
## Number of Fisher Scoring iterations: 5
# Take a look at the predictions and compare
log.probs =predict(log.fit, type ="response")
log.probs[1:10]
## 1 2 3 4 5 6 7 8 9 10
## 0.8942 0.9697 0.8963 0.5487 0.4452 0.7640 0.7535 0.9298 0.8385 0.5713
log.probs2 =predict(log.fit2, type ="response")
log.probs2[1:10]
## 1 2 3 4 5 6 7 8 9 10
## 0.7820 0.9280 0.9199 0.7750 0.8696 0.8896 0.5166 0.9736 0.7911 0.8390
# Determine the coding
contrasts(churnTrain$churn)
## no
## yes 0
## no 1
log.pred =ifelse(log.probs >=0.5, "no", "yes")
log.pred2 =ifelse(log.probs2 >=0.5, "no", "yes")
table(log.pred, log.pred2)
## log.pred2
## log.pred no yes
## no 3097 20
## yes 179 37
agrees =(3097+37)/(3097+37+179+20)
# Determine how the predictions did
table(log.pred, churnTrain$churn)
##
## log.pred yes no
## no 357 2760
## yes 126 90
table(log.pred2, churnTrain$churn)
##
## log.pred2 yes no
## no 457 2819
## yes 26 31
# Note that the model does an okay job of predicting "no" when the value is "no"
# but does not do a good job when the underlying value is "yes'
# And this was all on the training data!
# Now use the model to predict on the test set
log.fit =glm(churn ~., data =churnTrain, family =binomial)
log.probs =predict(log.fit, churnTest, type ="response")
log.pred =ifelse(log.probs >=0.5, "no", "yes")
table(log.pred, churnTest$churn)
##
## log.pred yes no
## no 170 1395
## yes 54 48
# The full logistic regression identified 54 out of 224 customer attritions.
# The full logistic regression also identified 1395 out of 1443 retained customers.
log.fit2 =glm(churn ~total_day_charge +total_day_minutes +number_customer_service_calls +total_eve_charge,
data =churnTrain, family =binomial)
log.probs2 =predict(log.fit2, churnTest, type ="response")
log.pred2 =ifelse(log.probs2 >=0.5, "no", "yes")
table(log.pred2, churnTest$churn)
##
## log.pred2 yes no
## no 208 1430
## yes 16 13
# The reduced logistic regression identified 16 out of 224 customer attritions. A fantastc score
# The reduced logistic regression also identified 1430 out of 1443 retained customers.
Model 2
Trees
This section explores the basic tree method.
This method is quite intuitive but often lacks in predictive accuracy.
The more accurate tree methods of bagging, random forests, and boosting follow.
library(tree)
## Warning: package 'tree' was built under R version 3.0.3
tree.churn =tree(churn ~. -state, churnTrain) # factors can have at most 32 levels
summary(tree.churn)
##
## Classification tree:
## tree(formula = churn ~ . - state, data = churnTrain)
## Variables actually used in tree construction:
## [1] "total_day_minutes" "number_customer_service_calls"
## [3] "international_plan" "total_eve_minutes"
## [5] "voice_mail_plan" "total_intl_calls"
## [7] "total_intl_minutes"
## Number of terminal nodes: 12
## Residual mean deviance: 0.377 = 1250 / 3320
## Misclassification error rate: 0.0591 = 197 / 3333
plot(tree.churn)
text(tree.churn, pretty =0)
# Now use the test set and this time use the glm function
tree.pred =predict(tree.churn, newdata =churnTest, type ="class")
table(tree.pred, churnTest$churn)
##
## tree.pred yes no
## yes 139 10
## no 85 1433
# The basic tree identified 139 out of 224 customer attritions.
# The basic tree also identified 1433 out of 1443 retained customers.
# Now try pruning
set.seed(2015)
cv.tree =cv.tree(tree.churn, FUN =prune.misclass)
names(cv.tree)
## [1] "size" "dev" "k" "method"
cv.tree
## $size
## [1] 12 11 8 5 3 2 1
##
## $dev
## [1] 216 233 242 349 467 483 483
##
## $k
## [1] -Inf 7.000 8.333 31.333 38.000 41.000 43.000
##
## $method
## [1] "misclass"
##
## attr(,"class")
## [1] "prune" "tree.sequence"
prune.churn =prune.misclass(tree.churn, best =12)
plot(prune.churn)
text(prune.churn, pretty =0)
tree.pred =predict(prune.churn, churnTest, type ="class")
table(tree.pred, churnTest$churn )
##
## tree.pred yes no
## yes 139 10
## no 85 1433
# The "pruned" tree identified 139 out of 224 customer attritions.
# The "pruned" tree also identified 1433 out of 1443 retained customers.
Model 2B
Trees
caret
This section uses the caret wrapper
Additionally the rpart library is used
library(caret)
## Warning: package 'caret' was built under R version 3.0.3
## Loading required package: lattice
## Warning: package 'lattice' was built under R version 3.0.3
## Loading required package: ggplot2
## Warning: package 'ggplot2' was built under R version 3.0.3
# combine the data
set.seed(2014)
allData =rbind(churnTrain, churnTest)
# split the data into a training set and a test set
inTrain =createDataPartition(y=allData$churn,
p =3333/5000, list=FALSE)
training =allData[inTrain,]
testing =allData[-inTrain,]
table(training$churn)
##
## yes no
## 472 2862
table(testing$churn)
##
## yes no
## 235 1431
dim(training); dim(testing)
## [1] 1666 20
# the rpart method
modFit =train(churn ~., method="rpart", data =training)
## Loading required package: rpart
## Warning: package 'rpart' was built under R version 3.0.3
## Warning: package 'e1071' was built under R version 3.0.3
print(modFit$finalModel)
## n= 3334
##
## node), split, n, loss, yval, (yprob)
## * denotes terminal node
##
## 1) root 3334 472 no (0.1416 0.8584)
## 2) total_day_minutes>=265.4 198 72 yes (0.6364 0.3636)
## 4) voice_mail_planyes< 0.5 151 30 yes (0.8013 0.1987) *
## 5) voice_mail_planyes>=0.5 47 5 no (0.1064 0.8936) *
## 3) total_day_minutes< 265.4 3136 346 no (0.1103 0.8897) *
# Plot the dendogram
plot(modFit$finalModel, uniform=TRUE,
main="Classification Tree")
text(modFit$finalModel, use.n=TRUE, all=TRUE, cex=.8)
#nicer plot using the rattle library
library(rattle)
## Warning: package 'rattle' was built under R version 3.0.3
## Rattle: A free graphical interface for data mining with R.
## Version 3.0.2 r169 Copyright (c) 2006-2013 Togaware Pty Ltd.
## Type 'rattle()' to shake, rattle, and roll your data.
library(rpart.plot)
## Warning: package 'rpart.plot' was built under R version 3.0.3
fancyRpartPlot(modFit$finalModel)
## Loading required package: RColorBrewer
#Now predict
tree.pred =predict(modFit,newdata =testing)
table(tree.pred, testing$churn)
##
## tree.pred yes no
## yes 56 27
## no 179 1404
#The basic rpart tree identified 117 out of 235 customer attritions.
#The basic rpart tree also identified 1382 out of 1431 retained customers.
Model 3
Bagging
library(randomForest)
## randomForest 4.6-7
## Type rfNews() to see new features/changes/bug fixes.
set.seed(2014)
bag.churn =randomForest(churn ~. -state, data =churnTrain,
mtry =18, importance =TRUE)
yhat.bag =predict(bag.churn, newdata =churnTest)
importance(bag.churn)
## yes no MeanDecreaseAccuracy
## account_length -1.81241 -1.3323 -1.8601
## area_code -0.08345 -0.3447 -0.3792
## international_plan 136.37351 101.6036 157.3882
## voice_mail_plan 0.00000 0.0000 0.0000
## number_vmail_messages 65.72534 51.4718 78.2445
## total_day_minutes 31.39185 31.3221 40.0184
## total_day_calls -2.15176 0.1585 -0.4989
## total_day_charge 29.84802 31.2109 39.5807
## total_eve_minutes 31.66539 18.7697 23.7269
## total_eve_calls -3.11234 -4.1131 -4.7318
## total_eve_charge 29.44055 18.6515 23.4258
## total_night_minutes 9.73547 14.8594 15.9018
## total_night_calls -1.91594 -0.6723 -1.1883
## total_night_charge 5.81169 14.5901 15.3802
## total_intl_minutes 20.76618 19.7400 23.1596
## total_intl_calls 82.87735 57.1482 93.2050
## total_intl_charge 20.47850 18.8906 22.4300
## number_customer_service_calls 138.34404 91.7643 143.9246
## MeanDecreaseGini
## account_length 20.660
## area_code 3.629
## international_plan 57.976
## voice_mail_plan 0.000
## number_vmail_messages 54.858
## total_day_minutes 112.046
## total_day_calls 19.967
## total_day_charge 111.173
## total_eve_minutes 63.105
## total_eve_calls 18.989
## total_eve_charge 59.150
## total_night_minutes 25.960
## total_night_calls 19.452
## total_night_charge 24.247
## total_intl_minutes 40.160
## total_intl_calls 66.541
## total_intl_charge 39.812
## number_customer_service_calls 89.933
varImpPlot(bag.churn)
table(yhat.bag, churnTest$churn)
##
## yhat.bag yes no
## yes 161 8
## no 63 1435
#The bagging ensemble of trees identified 161 out of 224 customer attritions.
#The bagging ensemble of trees also identified 1435 out of 1443 retained customers.
Model 4
Random Forests
library(randomForest)
set.seed(2014)
bag.churn =randomForest(churn ~. -state, data =churnTrain,
mtry =6, importance =TRUE) # intead of using allpredictors just try 6 on each split
yhat.bag =predict(bag.churn, newdata =churnTest)
names(churnTest)
## [1] "state" "account_length"
## [3] "area_code" "international_plan"
## [5] "voice_mail_plan" "number_vmail_messages"
## [7] "total_day_minutes" "total_day_calls"
## [9] "total_day_charge" "total_eve_minutes"
## [11] "total_eve_calls" "total_eve_charge"
## [13] "total_night_minutes" "total_night_calls"
## [15] "total_night_charge" "total_intl_minutes"
## [17] "total_intl_calls" "total_intl_charge"
## [19] "number_customer_service_calls" "churn"
importance(bag.churn)
## yes no MeanDecreaseAccuracy
## account_length -2.1841 -0.6488 -1.5123
## area_code 0.2848 0.5177 0.5745
## international_plan 102.3016 81.4677 109.8346
## voice_mail_plan 19.5078 18.6062 20.5956
## number_vmail_messages 28.4460 22.4501 28.1643
## total_day_minutes 33.7893 31.7377 41.9491
## total_day_calls -1.6954 0.3175 -0.3122
## total_day_charge 32.2732 30.9521 39.7184
## total_eve_minutes 27.5756 21.1907 25.6121
## total_eve_calls -3.5126 -1.4000 -2.6462
## total_eve_charge 26.8221 20.4884 24.8205
## total_night_minutes 5.7027 15.9625 17.1729
## total_night_calls -0.6658 2.5282 2.0304
## total_night_charge 4.9138 15.8682 16.8575
## total_intl_minutes 18.7201 20.8445 24.4427
## total_intl_calls 54.9384 44.1862 62.7611
## total_intl_charge 19.9616 22.5940 27.3727
## number_customer_service_calls 114.8657 83.7426 122.3994
## MeanDecreaseGini
## account_length 22.012
## area_code 4.104
## international_plan 68.599
## voice_mail_plan 17.000
## number_vmail_messages 31.596
## total_day_minutes 112.586
## total_day_calls 22.156
## total_day_charge 115.699
## total_eve_minutes 57.297
## total_eve_calls 19.752
## total_eve_charge 58.220
## total_night_minutes 28.285
## total_night_calls 20.860
## total_night_charge 27.988
## total_intl_minutes 33.696
## total_intl_calls 51.167
## total_intl_charge 34.816
## number_customer_service_calls 98.533
varImpPlot(bag.churn)
table(yhat.bag, churnTest$churn)
##
## yhat.bag yes no
## yes 165 5
## no 59 1438
#The random forests ensemble of trees identified 165 out of 224 customer attritions.
#The random forests ensemble of trees also identified 1438 out of 1443 retained customers.
#So random forests improved slightly over bagging
Model 5
Boosting
I commented out the boosting code becasue it takes so long to run.
See the summary for the boosting results
# library(caret)
# library(gbm)
# library(pROC)
#combine the data
# set.seed(2014)
# allData = rbind(churnTrain, churnTest)
#
# #split the data into a training set and a test set
# inTrain = createDataPartition(y=allData$churn,
# p = 3333/5000, list=FALSE)
# training = allData[inTrain,]
# testing = allData[-inTrain,]
# table(training$churn)
# table(testing$churn)
#
# dim(training); dim(testing)
#
# forGBM = churnTrain
# forGBM$churn = ifelse(forGBM$churn == "yes", 1, 0)
#
# #I guessed on the parameters: n.trees, interaction depth, shrinkage
# gbmFit = gbm(formula = churn ~ ., # Use all predictors
# distribution = "bernoulli", # For classification
# data = forGBM,
# n.trees = 2000, # 2000 boosting iterations
# interaction.depth = 4, # How many splits in each tree
# shrinkage = 0.01, # learning rate
# verbose = FALSE)
#
# #Now I will use more of the caret functionality to tune the parameters via cross-validation
# ctrl = trainControl(method = "repeatedcv", repeats = 5, classProbs = TRUE,
# summaryFunction = twoClassSummary) # 10-fold cross-validation; 5 repeats
#
# grid = expand.grid(interaction.depth = seq(2, 6, by = 2),
# n.trees = seq(100, 500, by = 50), #1000
# shrinkage = c(0.01, 0.1))
#
# set.seed(2014)
# The following step is very time consuming
# gbmTune = train(churn ~ ., data = churnTrain,
# method = "gbm",
# metric = "ROC",
# tuneGrid = grid,
# verbose = FALSE, # Avoid massive output
# trControl = ctrl)
#
# library(ggplot2)
# ggplot(gbmTune) + theme(legend.position = "top")
#
# gbmPred <- predict(gbmTune, churnTest)
# str(gbmPred)
#
# gbmProbs <- predict(gbmTune, churnTest, type = "prob")
# str(gbmProbs)
#
# confusionMatrix(gbmPred, churnTest$churn)
# # The boosting method identified 158 out of 224 customer attritions.
# #The ensemble of trees also identified 1433 out of 1443 retained customers.
# #So random forests improved slightly over bagging
#
#
# rocCurve <- roc(response = churnTest$churn,
# predictor = gbmProbs[, "yes"],
# levels = rev(levels(churnTest$churn)))
#
# rocCurve
# # plot(rocCurve)
#
Model 6
SVM
library(e1071)
dat =churnTrain
out =svm(churn ~., data =dat, kernel ="linear", cost =10)
# Now on the test set
dat.test =churnTest
pred.test =predict(out, newdata =dat.test)
table(pred.test, dat.test$churn)
##
## pred.test yes no
## yes 0 0
## no 224 1443
#The SVM identified 0 out of 224 customer attritions.
#The SVM also identified 1443 out of 1443 retained customers.
#SVM using a linear kernel predicted everything as a "no"!
#tune.out = tune(svm, churn ~ ., data = dat, kernel = "linear",
#ranges = list( cost=c(0.001, 0.01, 0.1, 1, 5 , 10 , 100)))
#summary(tune.out)
# examine the best model
#bestmod = tune.out$best.model
#summary(bestmod)
# Test the model using the best tuning parameter
#ypred = predict(bestmod, dat.test)
#table(predict = ypred, truth = dat.test$churn)
# Now try a radial kernel
dat =churnTrain
svmfit =svm(churn ~., data =dat, kernel ="radial", gamma =1, cost =1)
summary(svmfit)
##
## Call:
## svm(formula = churn ~ ., data = dat, kernel = "radial", gamma = 1,
## cost = 1)
##
##
## Parameters:
## SVM-Type: C-classification
## SVM-Kernel: radial
## cost: 1
## gamma: 1
##
## Number of Support Vectors: 3333
##
## ( 2850 483 )
##
##
## Number of Classes: 2
##
## Levels:
## yes no
# Now tune
set.seed(2014)
tune.out =tune(svm, churn ~., data =dat, kernel ="radial",
ranges =list( cost =c(0.1, 1, 10, 100, 1000),
gamma =c(0.5, 1, 2, 3, 4)))
summary(tune.out)
##
## Parameter tuning of 'svm':
##
## - sampling method: 10-fold cross validation
##
## - best parameters:
## cost gamma
## 10 0.5
##
## - best performance: 0.1446
##
## - Detailed performance results:
## cost gamma error dispersion
## 1 1e-01 0.5 0.1449 0.01773
## 2 1e+00 0.5 0.1449 0.01773
## 3 1e+01 0.5 0.1446 0.01761
## 4 1e+02 0.5 0.1446 0.01761
## 5 1e+03 0.5 0.1446 0.01761
## 6 1e-01 1.0 0.1449 0.01773
## 7 1e+00 1.0 0.1449 0.01773
## 8 1e+01 1.0 0.1449 0.01773
## 9 1e+02 1.0 0.1449 0.01773
## 10 1e+03 1.0 0.1449 0.01773
## 11 1e-01 2.0 0.1449 0.01773
## 12 1e+00 2.0 0.1449 0.01773
## 13 1e+01 2.0 0.1449 0.01773
## 14 1e+02 2.0 0.1449 0.01773
## 15 1e+03 2.0 0.1449 0.01773
## 16 1e-01 3.0 0.1449 0.01773
## 17 1e+00 3.0 0.1449 0.01773
## 18 1e+01 3.0 0.1449 0.01773
## 19 1e+02 3.0 0.1449 0.01773
## 20 1e+03 3.0 0.1449 0.01773
## 21 1e-01 4.0 0.1449 0.01773
## 22 1e+00 4.0 0.1449 0.01773
## 23 1e+01 4.0 0.1449 0.01773
## 24 1e+02 4.0 0.1449 0.01773
## 25 1e+03 4.0 0.1449 0.01773
# - best parameters:
# cost gamma
# 10 0.5
dat =churnTrain
svmfit =svm(churn ~., data =churnTest, kernel ="radial", gamma =0.5, cost =10)
summary(svmfit)
##
## Call:
## svm(formula = churn ~ ., data = churnTest, kernel = "radial",
## gamma = 0.5, cost = 10)
##
##
## Parameters:
## SVM-Type: C-classification
## SVM-Kernel: radial
## cost: 10
## gamma: 0.5
##
## Number of Support Vectors: 1664
##
## ( 1440 224 )
##
##
## Number of Classes: 2
##
## Levels:
## yes no
table(svmfit$fitted, churnTest$churn)
##
## yes no
## yes 224 0
## no 0 1443
length(churnTest$churn)
## [1] 1667
table(true =churnTest$churn, pred =predict(tune.out$best.model, newdata =churnTest))
## pred
## true yes no
## yes 4 220
## no 0 1443
#SVM identified 4 out of 224 customer attritions.
#SVM also identified 1443 out of 1443 retained customers.
Summary
Results For test data
Technique / Yes Correct / Yes Wrong / No Correct / No WrongThe test data / 224 / 0 / 1443 / 0
Test data (caret) / 235 / 0 / 1431 / 0
Logistic Regression / 54 / 170 / 1395 / 48
Logist. Regr (4 param) / 16 / 208 / 1430 / 13
Basic Tree (prune) / 139 / 85 / 1433 / 10
Basic Tree (caret) / 117 / 118 / 1382 / 49
Bagging / 161 / 63 / 1435 / 8
Random Forest / 165 / 59 / 1438 / 5
Boosting / 158 / 66 / 1433 / 10
SVM (linear) / 0 / 224 / 1443 / 0
SVM (radial) / 4 / 220 / 1443 / 0
So Random Forest edged out bagging which edged out boosting. Boosting ran very slowly. When attrition wasyes, RF got 165 correct and just 59 wrong. When customers were not lost, when attrition wasno, then RF got 1438 correct and just 5 wrong. The basic tree, even when pruned, could not compete with the aforementioned ensemble methods.
To start out, the test data when caret was not used had 224noand 1443yes. The caret test data had a very similar composition.
References
- Kuhn and Johnson:Applied Predictive Modeling
- James, Witten, Hastie, Tibshirani:An Introduction to Statistical Learning with Applications in R
- Hastie, Tibshirani, Friedman:The Elements of Statistical Learning: Data Mining, Inference, and Prediction