美文网首页
ggplot 绘制 SVM 决策域

ggplot 绘制 SVM 决策域

作者: Norahd | 来源:发表于2023-01-08 13:43 被阅读0次

1.Prepare data

library(ggplot2);library(dplyr)
dat <- rbind(data.frame(x = rnorm(50,1,sd = .8),y = rnorm(50,1,sd = .5),groups = "a"),
             data.frame(x = rnorm(60,3,sd = .5),y = rnorm(60,5,sd = .3),groups = "b"),
             data.frame(x = rnorm(90,4,sd = 1.),y = rnorm(90,1,sd = .6),groups = "c"),
             data.frame(x = rnorm(90,6,sd = 1.),y = rnorm(90,3,sd = .65),groups = "d"))

2.SVM-linear Discriminant

### SVM liner Discriminant
svmfit <- e1071::svm(factor(groups) ~ ., data = dat, kernel = "linear",cost = 20, scale = F,probability=T)
grid.mat <- expand.grid(x = seq(-1,8,.01),y = seq(-1,8,.01))
svm.predit <- predict(svmfit, grid.mat, probability=TRUE)
plot.data <- grid.mat %>% mutate(cls = as.vector(svm.predit)) %>% 
    mutate(poster.p = data.frame(attr(svm.predit, "probabilities")) %>% mutate(post.prob = do.call(pmax, (.))) %>% pull(post.prob))
### Prepare boundary data
lapply(unique(plot.data$cls), function(c){
    tmp <- grid.mat %>% mutate(cls = plot.data$cls) %>% mutate(idx = !!(c),is.cls = ifelse(idx == cls,1,0))
}) %>% data.table::rbindlist() -> contour.dat
### Plot
plot.data %>% 
    ggplot() + 
    geom_raster(aes(x = x , y = y ,fill = cls,alpha = poster.p),show.legend = F) + scale_fill_brewer(palette = "Dark2")+
    stat_contour(data = contour.dat,aes(x = x, y = y,z = as.numeric(is.cls),group = idx),color = "black",linewidth = .1,breaks = c(.5)) +
    geom_point(data = dat,aes(x = x , y = y ,color = groups),show.legend = F) + 
    geom_label(data = dat %>% group_by(groups) %>% summarise(x= mean(x),y = mean(y)),
               aes(x = x,y = y,label = groups,fill = groups),
               fontface = "bold",colour = "white",size = 8,show.legend = F) +
    coord_equal(xlim = c(-1,8),ylim = c(-1,8)) + theme_bw() + theme(panel.grid = element_blank())

3.SVM nonlinear Discriminant

dat <- dat %>% filter(groups %in% c('a','b')) 
svmfit <- e1071::svm(factor(groups) ~ ., data = dat, kernel = "radial",cost = 20, scale = F,probability=T)
grid.mat <- expand.grid(x = seq(-1,8,.01),y = seq(-1,8,.01))
svm.predit <- predict(svmfit, grid.mat, probability=TRUE,decision.values = T)
plot.data <- grid.mat %>% mutate(cls = as.vector(svm.predit)) %>% 
    mutate(decision.value = as.vector(attr(svm.predit, "decision.values"))) %>% 
    mutate(poster.p = data.frame(attr(svm.predit, "probabilities")) %>% mutate(post.prob = do.call(pmax, (.))) %>% pull(post.prob))
plot.data %>% 
    ggplot() + 
    geom_raster(aes(x = x , y = y ,fill = cls,alpha = poster.p),show.legend = F) + 
    stat_contour(aes(x = x, y = y,z = decision.value),color = "black",linewidth = .1,breaks = c(0)) +
    geom_point(data = dat,aes(x = x , y = y ,color = groups),show.legend = F) + 
    geom_label(data = dat %>% group_by(groups) %>% summarise(x= mean(x),y = mean(y)),
               aes(x = x,y = y,label = as.factor(groups),fill = as.factor(groups)),
               fontface = "bold",colour = "white",size = 8,show.legend = F) +
    scale_fill_brewer(palette = "Dark2") +
    coord_equal(xlim = c(-1,8),ylim = c(-1,8)) + theme_bw() + theme(panel.grid = element_blank()) 

相关文章

网友评论

      本文标题:ggplot 绘制 SVM 决策域

      本文链接:https://www.haomeiwen.com/subject/aorucdtx.html