美文网首页
seaborn.scatterplot 绘制完整的散点图

seaborn.scatterplot 绘制完整的散点图

作者: ab02f58fd803 | 来源:发表于2020-08-17 20:02 被阅读0次

散点图能够直观地看出预测值与真实值之间的关系,同时绘制完整散点图非常重要。一般散点图包含下列数据。

  1. 显示的数值,比如回归预测的R^2RMSE,还有样本大小samples等。
  2. 显示比例线,一般是1:1预测与真实之间拟合线以及对应的拟合方程。
  3. 标题X,Y轴的含义,单位等重要的量。

注意:
python和库的版本,我的版本是

python 3.6
seaborn 0.10.0
pandas 1.0.1
numpy 1.19.1
matplotlib 2.2.5
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np


sns.set(style="white", font_scale=1.5,color_codes=True)

### 注意这个是读取绘制的文件名
for file in file_name:

    dataset = pd.read_csv(file)
    
    ### 1:1比例线
    pred_min, pred_max = dataset['y_pred'].min(),dataset['y_pred'].max()
    true_min, true_max = dataset['y_test'].min(),dataset['y_test'].max()

    y_pred  = dataset['y_test'].to_numpy()
    y_test =  dataset['y_pred'].to_numpy()
    
    xy_mse = np.sum((y_pred-y_test)**2);
    xy_mean = np.mean(y_pred);
    xx_mean = np.sum((y_pred - xy_mean)**2);
    R2 = 1 - xy_mse/xx_mean

    RMSE = np.sqrt(np.mean((y_pred - y_test)**2))


    p = np.polyfit(y_pred, y_test, 1)

    formatSpec = 'y = %.4fx+ %.4f'%(p[0], p[1])
    formatXy = 'y = x'

    x1 = np.linspace(pred_min,pred_max);  
    y1 = np.polyval(p, x1 )
    
    str_R2 = '$R^2$ = %.4f\nRMSE = %.2f \nsamples = %d'%(R2, RMSE, dataset.shape[0])

    
    f, ax= plt.subplots(figsize = (14, 10))
    
    #plt.title('%s'%file[:-4])
    plt.title('The demo of scatter map')

        ####   set the colorbar font size
        #### https://stackoverflow.com/questions/34706845/change-xticklabels-fontsize-of-seaborn-heatmap 
    
        ####   set the x y labels font size 
        ####  https://www.cnblogs.com/lemonbit/p/7419851.html
    ax.tick_params(labelsize = 16) #
        #  # ax.set_ylabel('the Number of Models',fontsize=15, color='r')
        ## cmap='BrBG'  'RdBu'
    scatter = sns.scatterplot(x= 'y_pred', y='y_test', data = dataset, alpha = 0.8, color = 'b')

    
    ax.plot(np.arange(pred_min, pred_max,0.1), np.arange(pred_min, pred_max,0.1), color='r', linewidth=3, alpha=0.6, label = formatXy )
    ax.plot(x1, y1, color='k', linewidth=3, alpha=0.6, label = formatSpec)
    ax.legend(loc = 'lower right', fontsize = 20)
    x_pos1 = int(pred_min) 
    y_pos1 = int(0.9 * true_max)
    ax.text(x_pos1,y_pos1 ,str_R2, fontsize = 20)
    file = 'temp'
    f.savefig('%s.png'%file, dpi=300, bbox_inches='tight')
    plt.xlabel( 'X axis')
    plt.ylabel( 'Y axis')


    plt.show()
temp.png

相关文章

网友评论

      本文标题:seaborn.scatterplot 绘制完整的散点图

      本文链接:https://www.haomeiwen.com/subject/hwxcjktx.html