class: center, middle, inverse, title-slide # Prediction and overfitting ##
Introduction to Data Science ###
introds.org
###
Dr. Mine Çetinkaya-Rundel --- layout: true <div class="my-footer"> <span> <a href="https://introds.org" target="_blank">introds.org</a> </span> </div> --- class: middle # Prediction --- ## Goal: Building a spam filter - Data: Set of emails and we know if each email is spam/not and other features - Use logistic regression to predict the probability that an incoming email is spam - Use model selection to pick the model with the best predictive performance -- - Building a model to predict the probability that an email is spam is only half of the battle! We also need a decision rule about which emails get flagged as spam (e.g. what probability should we use as out cutoff?) -- - A simple approach: choose a single threshold probability and any email that exceeds that probability is flagged as spam --- ## A multiple regression approach .panelset[ .panel[.panel-name[Output] .small[ ``` ## # A tibble: 22 x 5 ## term estimate std.error statistic p.value ## <chr> <dbl> <dbl> <dbl> <dbl> ## 1 (Intercept) -9.09e+1 9.80e+3 -0.00928 9.93e- 1 ## 2 to_multiple1 -2.68e+0 3.27e-1 -8.21 2.25e-16 ## 3 from1 -2.19e+1 9.80e+3 -0.00224 9.98e- 1 ## 4 cc 1.88e-2 2.20e-2 0.855 3.93e- 1 ## 5 sent_email1 -2.07e+1 3.87e+2 -0.0536 9.57e- 1 ## 6 time 8.48e-8 2.85e-8 2.98 2.92e- 3 ## 7 image -1.78e+0 5.95e-1 -3.00 2.73e- 3 ## 8 attach 7.35e-1 1.44e-1 5.09 3.61e- 7 ## 9 dollar -6.85e-2 2.64e-2 -2.59 9.64e- 3 ## 10 winneryes 2.07e+0 3.65e-1 5.67 1.41e- 8 ## 11 inherit 3.15e-1 1.56e-1 2.02 4.32e- 2 ## 12 viagra 2.84e+0 2.22e+3 0.00128 9.99e- 1 ## 13 password -8.54e-1 2.97e-1 -2.88 4.03e- 3 ## 14 num_char 5.06e-2 2.38e-2 2.13 3.35e- 2 ## 15 line_breaks -5.49e-3 1.35e-3 -4.06 4.91e- 5 ## 16 format1 -6.14e-1 1.49e-1 -4.14 3.53e- 5 ## 17 re_subj1 -1.64e+0 3.86e-1 -4.25 2.16e- 5 ## 18 exclaim_subj 1.42e-1 2.43e-1 0.585 5.58e- 1 ## 19 urgent_subj1 3.88e+0 1.32e+0 2.95 3.18e- 3 ## 20 exclaim_mess 1.08e-2 1.81e-3 5.98 2.23e- 9 ## 21 numbersmall -1.19e+0 1.54e-1 -7.74 9.62e-15 ## 22 numberbig -2.95e-1 2.20e-1 -1.34 1.79e- 1 ``` ] ] .panel[.panel-name[Code] ```r logistic_reg() %>% set_engine("glm") %>% fit(spam ~ ., data = email, family = "binomial") %>% tidy() %>% print(n = 22) ``` ``` ## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred ``` ] ] --- ## Prediction - The mechanics of prediction is **easy**: - Plug in values of predictors to the model equation - Calculate the predicted value of the response variable, `\(\hat{y}\)` -- - Getting it right is **hard**! - There is no guarantee the model estimates you have are correct - Or that your model will perform as well with new data as it did with your sample data --- ## Underfitting and overfitting <img src="w9-d03-prediction-overfitting_files/figure-html/unnamed-chunk-3-1.png" width="70%" style="display: block; margin: auto;" /> --- ## Spending our data - Several steps to create a useful model: parameter estimation, model selection, performance assessment, etc. - Doing all of this on the entire data we have available can lead to **overfitting** - Allocate specific subsets of data for different tasks, as opposed to allocating the largest possible amount to the model parameter estimation only (which is what we've done so far) --- class: middle # Splitting data --- ## Splitting data - **Training set:** - Sandbox for model building - Spend most of your time using the training set to develop the model - Majority of the data (usually 80%) - **Testing set:** - Held in reserve to determine efficacy of one or two chosen models - Critical to look at it once, otherwise it becomes part of the modeling process - Remainder of the data (usually 20%) --- ## Performing the split ```r # Fix random numbers by setting the seed # Enables analysis to be reproducible when random numbers are used set.seed(1116) # Put 80% of the data into the training set email_split <- initial_split(email, prop = 0.80) # Create data frames for the two sets: train_data <- training(email_split) test_data <- testing(email_split) ``` --- ## Peek at the split .small[ .pull-left[ ```r glimpse(train_data) ``` ``` ## Rows: 3,137 ## Columns: 21 ## $ spam <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… ## $ to_multiple <fct> 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0,… ## $ from <fct> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,… ## $ cc <int> 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 2, 1, 0,… ## $ sent_email <fct> 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0,… ## $ time <dttm> 2012-01-01 06:16:41, 2012-01-01 07:03:59… ## $ image <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,… ## $ attach <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,… ## $ dollar <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 5,… ## $ winner <fct> no, no, no, no, no, no, no, no, no, no, n… ## $ inherit <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… ## $ viagra <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… ## $ password <dbl> 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 1,… ## $ num_char <dbl> 11.370, 10.504, 13.256, 1.231, 1.091, 4.8… ## $ line_breaks <int> 202, 202, 255, 29, 25, 193, 237, 69, 79, … ## $ format <fct> 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1,… ## $ re_subj <fct> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0,… ## $ exclaim_subj <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,… ## $ urgent_subj <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… ## $ exclaim_mess <dbl> 0, 1, 48, 1, 1, 1, 18, 1, 1, 0, 10, 4, 10… ## $ number <fct> big, small, small, none, none, big, small… ``` ] .pull-right[ ```r glimpse(test_data) ``` ``` ## Rows: 784 ## Columns: 21 ## $ spam <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… ## $ to_multiple <fct> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,… ## $ from <fct> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,… ## $ cc <int> 0, 0, 0, 2, 0, 0, 0, 0, 2, 0, 0, 7, 0, 0,… ## $ sent_email <fct> 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,… ## $ time <dttm> 2012-01-01 16:00:32, 2012-01-01 18:12:00… ## $ image <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… ## $ attach <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… ## $ dollar <dbl> 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8,… ## $ winner <fct> no, no, no, no, no, no, no, no, no, no, n… ## $ inherit <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… ## $ viagra <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… ## $ password <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… ## $ num_char <dbl> 7.773, 2.643, 0.869, 13.890, 4.560, 2.192… ## $ line_breaks <int> 192, 68, 25, 225, 64, 85, 10, 57, 97, 39,… ## $ format <fct> 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1,… ## $ re_subj <fct> 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0,… ## $ exclaim_subj <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,… ## $ urgent_subj <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… ## $ exclaim_mess <dbl> 6, 0, 2, 0, 0, 3, 0, 5, 1, 3, 3, 0, 4, 32… ## $ number <fct> small, small, small, small, none, big, sm… ``` ] ] --- class: middle # Modeling workflow --- ## Fit a model to the training dataset ```r email_fit <- logistic_reg() %>% set_engine("glm") %>% fit(spam ~ ., data = train_data, family = "binomial") ``` ``` ## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred ``` --- ## Categorical predictors <img src="w9-d03-prediction-overfitting_files/figure-html/unnamed-chunk-8-1.png" width="75%" style="display: block; margin: auto;" /> --- ## `from` and `sent_email` .pull-left[ - `from`: Whether the message was listed as from anyone (this is usually set by default for regular outgoing email) ```r train_data %>% count(spam, from) ``` ``` ## # A tibble: 3 x 3 ## spam from n ## <fct> <fct> <int> ## 1 0 1 2847 ## 2 1 0 3 ## 3 1 1 287 ``` ] .pull-right[ - `sent_email`: Indicator for whether the sender had been sent an email in the last 30 days ```r train_data %>% count(spam, sent_email) ``` ``` ## # A tibble: 3 x 3 ## spam sent_email n ## <fct> <fct> <int> ## 1 0 0 1962 ## 2 0 1 885 ## 3 1 0 290 ``` ] --- ## Numerical predictors .small[ ``` ## ## ── Variable type: numeric ────────────────────────────────────────────────────────────────────────── ## skim_variable spam n_missing complete_rate mean sd p0 p25 p50 p75 p100 ## 1 cc 0 0 1 0.416 2.77 0 0 0 0 68 ## 2 cc 1 0 1 0.345 2.02 0 0 0 0 23 ## 3 image 0 0 1 0.0562 0.510 0 0 0 0 20 ## 4 image 1 0 1 0.00690 0.0829 0 0 0 0 1 ## 5 attach 0 0 1 0.128 0.765 0 0 0 0 21 ## 6 attach 1 0 1 0.193 0.574 0 0 0 0 2 ## 7 dollar 0 0 1 1.54 5.19 0 0 0 0 64 ## 8 dollar 1 0 1 0.655 2.63 0 0 0 0 36 ## 9 inherit 0 0 1 0.0351 0.237 0 0 0 0 6 ## 10 inherit 1 0 1 0.0690 0.560 0 0 0 0 9 *## 11 viagra 0 0 1 0 0 0 0 0 0 0 *## 12 viagra 1 0 1 0 0 0 0 0 0 0 ## 13 password 0 0 1 0.126 1.09 0 0 0 0 28 ## 14 password 1 0 1 0.0138 0.143 0 0 0 0 2 ## 15 num_char 0 0 1 11.2 14.3 0.003 1.86 6.83 15.4 165. ## 16 num_char 1 0 1 4.60 13.0 0.001 0.503 1.08 3.20 174. ## 17 line_breaks 0 0 1 244. 317. 2 42 136 320. 3589 ## 18 line_breaks 1 0 1 88.8 265. 1 14 22.5 63.8 3729 ## 19 exclaim_subj 0 0 1 0.0780 0.268 0 0 0 0 1 ## 20 exclaim_subj 1 0 1 0.0862 0.281 0 0 0 0 1 ## 21 exclaim_mess 0 0 1 6.02 41.2 0 0 1 5 1203 ## 22 exclaim_mess 1 0 1 5.65 71.1 0 0 0 1 1209 ``` ] --- ## Fit a model to the training dataset ```r email_fit <- logistic_reg() %>% set_engine("glm") %>% * fit(spam ~ . - from - sent_email - viagra, data = train_data, family = "binomial") ``` .small[ ```r email_fit ``` ``` ## parsnip model object ## ## Fit time: 39ms ## ## Call: stats::glm(formula = spam ~ . - from - sent_email - viagra, family = stats::binomial, ## data = data) ## ## Coefficients: ## (Intercept) to_multiple1 cc time image attach dollar ## -8.251e+01 -3.114e+00 2.130e-02 6.173e-08 -1.412e+00 3.871e-01 -7.115e-02 ## winneryes inherit password num_char line_breaks format1 re_subj1 ## 2.134e+00 3.569e-01 -9.737e-01 5.793e-02 -6.367e-03 -7.715e-01 -3.050e+00 ## exclaim_subj urgent_subj1 exclaim_mess numbersmall numberbig ## 2.350e-01 3.866e+00 1.200e-02 -6.915e-01 1.174e-01 ## ## Degrees of Freedom: 3136 Total (i.e. Null); 3118 Residual ## Null Deviance: 1933 ## Residual Deviance: 1402 AIC: 1440 ``` ] --- ## Predict outcome on the testing dataset ```r predict(email_fit, test_data) ``` ``` ## # A tibble: 784 x 1 ## .pred_class ## <fct> ## 1 0 ## 2 0 ## 3 0 ## 4 0 ## 5 0 ## 6 0 ## # … with 778 more rows ``` --- ## Predict probabilities on the testing dataset ```r email_pred <- predict(email_fit, test_data, type = "prob") %>% bind_cols(test_data %>% select(spam, time)) email_pred ``` ``` ## # A tibble: 784 x 4 ## .pred_0 .pred_1 spam time ## <dbl> <dbl> <fct> <dttm> ## 1 0.942 0.0581 0 2012-01-01 16:00:32 ## 2 0.920 0.0804 0 2012-01-01 18:12:00 ## 3 0.904 0.0960 0 2012-01-01 18:23:44 ## 4 0.997 0.00304 0 2012-01-02 00:54:46 ## 5 0.833 0.167 0 2012-01-02 01:58:14 ## 6 0.849 0.151 0 2012-01-02 02:05:45 ## # … with 778 more rows ``` --- ## A closer look at predictions .pull-left-wide[ ```r email_pred %>% arrange(desc(.pred_1)) %>% print(n = 10) ``` ``` ## # A tibble: 784 x 4 ## .pred_0 .pred_1 spam time ## <dbl> <dbl> <fct> <dttm> ## 1 0.0381 0.962 1 2012-03-27 06:17:01 ## 2 0.205 0.795 1 2012-02-21 08:34:56 *## 3 0.408 0.592 0 2012-02-03 13:25:39 ## 4 0.412 0.588 1 2012-03-10 22:43:58 ## 5 0.448 0.552 1 2012-02-14 19:45:19 ## 6 0.462 0.538 1 2012-02-04 15:54:23 *## 7 0.469 0.531 0 2012-01-12 02:00:16 ## 8 0.472 0.528 1 2012-01-25 16:17:54 ## 9 0.477 0.523 1 2012-03-21 02:00:30 ## 10 0.486 0.514 1 2012-03-16 21:39:28 ## # … with 774 more rows ``` ] --- ## Evaluate the performance **Receiver operating characteristic (ROC) curve**<sup>+</sup> which plot true positive rate vs. false positive rate (1 - specificity) .pull-left[ ```r email_pred %>% roc_curve( truth = spam, .pred_1, event_level = "second" ) %>% autoplot() ``` ] .pull-right[ <img src="w9-d03-prediction-overfitting_files/figure-html/unnamed-chunk-17-1.png" width="100%" style="display: block; margin: auto;" /> ] .footnote[ .small[ <sup>+</sup>Originally developed for operators of military radar receivers, hence the name. ] ] --- ## Evaluate the performance Find the area under the curve: .pull-left[ ```r email_pred %>% roc_auc( truth = spam, .pred_1, event_level = "second" ) ``` ``` ## # A tibble: 1 x 3 ## .metric .estimator .estimate ## <chr> <chr> <dbl> ## 1 roc_auc binary 0.828 ``` ] .pull-right[ <img src="w9-d03-prediction-overfitting_files/figure-html/unnamed-chunk-19-1.png" width="100%" style="display: block; margin: auto;" /> ]