40  K-最邻近算法

本章使用一个模拟的default违约数据集,介绍三种应用最广泛的分类方法:逻辑斯谛回归、线性判别分析、K近邻算法。重点聚焦KNN算法。

关于分类模型评估的指标,具体可参看 判别问题的评判指标
library(tidymodels)
library(tidyverse)
library(ISLR) # For the Smarket data set
library(ISLR2) # For the Bikeshare data set
library(discrim)
library(poissonreg)
library(rio)

40.1 Knn算法

KNN算法是一种基于实例的学习方法,它的核心思想是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。简单的说,就是找K个距离待遇测样本最近的点,然后根据这几个点的类别来确定新样本的类别。

40.2 逻辑斯谛回归

  1. 逻辑斯谛回归模型的输出结果与线性回归的输出结果类似,均可以通过p值的大小判断是否接受零假设。
  2. 逻辑斯谛回归 可通过设定哑变量的方式分析定性预测变量。
  3. 与线性回归类似,逻辑斯谛回归中如果预测变量之间存在多重共线性,会导致模型的不稳定性。

40.3 判别分析

40.3.1 线性判别分析

  1. 线性判别分析是一种监督学习方法,用于分类问题。
  2. 线性判别分析的目标是找到一个线性组合,使得不同类别的数据在这个线性组合上的投影尽可能远离,同一类别的数据尽可能接近。
  3. 除了分类外,线性判别分析还可用于数据的降维。

40.3.2 二次判别分析

  1. 拟合非线性关系时,可以用到二次判别分析,它时线性判别分析的一个变体。
  • 分类问题常用灵敏度和特异度作为模型能力的测试指标。
    • 灵敏度(sensitivity):指的是模型对正例的识别能力。
    • 特异度(specificity):指的是模型对负例的识别能力。
  • 分类问题中,如何设定合理的分类阈值非常重要。
    • 通常情况下,分类阈值设定为0.5。
    • 但是在实际应用中,根据业务需求,可以根据灵敏度和特异度的需求,调整分类阈值。
  • ROC曲线可以很好的展示分类模型的性能。
    • ROC曲线的横轴是假正例的比例,即真实值为被错误判断的比例;纵轴是真正例的比例,即真实值被正确判断的比例。
    • ROC曲线下的面积AUC越大,说明模型的性能越好。通常情况下,我们认为一个分类模型的AUC至少应大于0.5,

40.4 tidymodels进行knn分析-实战案例

  • recipe可以查询tidymodels包中所有的预处理函数。
  • parsnip可以查询。tidymodels包中所有的模型函数。
  • 许多分类学习模型算法要求结果变量为factor类型。

选择WineData数据集进行实战案例,使用数据集中的总硫含量(total_sulfur_dioxide)和酸度(acidity)两个变量对红酒颜色(wine_color)进行分类预测。

40.4.1 数据准备

data_wine <- read_rds(
  "D:/Document/0.Study R/0.R4DS/data/practice-data/WineData.rds"
) |>
  janitor::clean_names() |>
  select(wine_color, acidity, sulrfur = total_sulfur_dioxide) |>
  mutate(wine_color = as.factor(wine_color))
head(data_wine)
# A tibble: 6 × 3
  wine_color acidity sulrfur
  <fct>        <dbl>   <dbl>
1 red           10.8      37
2 white          6.4     213
3 white          9.4     139
4 white          8.2      90
5 white          6.4     183
6 red            6.7      38
# 数据划分
set.seed(876)
split7030 <- initial_split(data_wine, prop = 0.7, strata = wine_color)
data_train <- training(split7030)
data_test <- testing(split7030)

40.4.2 数据预处理-recipe

recipe_knn_wine <- recipe(wine_color ~ ., data = data_train) |>
  step_naomit() |> # 标准化
  step_normalize(all_predictors()) # 缺失值处理
recipe_knn_wine

40.4.3 如何选择何时该dplyrrecipes

  1. 一般情况下,tidymodels建议使用recipes进行数据预处理,因为recipes是可重复的,可以在其他数据集上使用。
  2. 在使用recipes前,最好可以使用select()函数选择需要的变量,以避免不必要的变量干扰模型,也可以使用.符号选择所有变量。
  3. 如果需要对结果变量进行转换,建议在recipes外转换。例如上述代码中,将wine_color转换为factor类型。

40.4.4 模型建立-parsnip

spec_knn_wine <- nearest_neighbor(
  neighbors = 4,
  # weight_func = "rectangular"
) |>
  set_engine("kknn") |>
  set_mode("classification")
spec_knn_wine
K-Nearest Neighbor Model Specification (classification)

Main Arguments:
  neighbors = 4

Computational engine: kknn 

40.4.5 建立工作流-workflow

# 建立工作流
wf_knn_wine <- workflow() |>
  add_recipe(recipe_knn_wine) |>
  add_model(spec_knn_wine)
wf_knn_wine
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: nearest_neighbor()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_naomit()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
K-Nearest Neighbor Model Specification (classification)

Main Arguments:
  neighbors = 4

Computational engine: kknn 
# 训练/拟合模型
fit_knn_wine <- wf_knn_wine |>
  fit(data_train)
fit_knn_wine
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: nearest_neighbor()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_naomit()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────

Call:
kknn::train.kknn(formula = ..y ~ ., data = data, ks = min_rows(4,     data, 5))

Type of response variable: nominal
Minimal misclassification: 0.09115282
Best kernel: optimal
Best k: 4

40.4.6 预测-predict

knn_wine_pred <- predict(fit_knn_wine, new_data = data_test)
knn_wine_pred
# A tibble: 960 × 1
   .pred_class
   <fct>      
 1 white      
 2 red        
 3 white      
 4 white      
 5 white      
 6 red        
 7 white      
 8 white      
 9 red        
10 red        
# ℹ 950 more rows