美文网首页
TabNet-神经网络处理表格数据实战

TabNet-神经网络处理表格数据实战

作者: 雪糕遇上夏天 | 来源:发表于2021-09-23 11:20 被阅读0次

我们知道神经网络在图片、信号等领域大放异彩。但在表格数据领域,基本还是树模型的主场。今天我们介绍下TabNet的使用方式,这是一个能够很好的处理tabular数据的神经网络模型。
下面我们介绍下TabNet的使用。

1. 安装

根据官方介绍,安装tabnet之前需要Tensorflow 2.0+版本和Tensorflow-dataset(非必须)。确保Tensorflow 2.0+正确安装之后,就可以安装TabNet了。

pip install tabnet[cpu]
pip install tabnet[gpu]

就像TensorFlow有cpu版和gpu版一样,TabNet也有cpu版和gpu版,可以按需选择。

2. 使用

tabnet包提供了TabNetClassifier和TabNetRegression分别用于处理分类任务和回归任务。以TabNetClassification为例,他是在TabNet模块的基础上加入了处理分类任务的全连接层(即:激活函数为softmax)。

from tabnet import TabNetClassifier

我们使用iris数据,做个简单的分类任务

import tensorflow_datasets as tfds
def transform(ds):
    features = tf.unstack(ds['features'])
    labels = ds['label']

    x = dict(zip(col_names, features))
    y = tf.one_hot(labels, 3)
    return x, y

col_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
ds_full = tfds.load(name="iris", split=tfds.Split.TRAIN)
ds_full = ds_full.shuffle(150, seed=0)

ds_train = ds_full.take(train_size)
ds_train = ds_train.map(transform)
ds_train = ds_train.batch(32)

ds_test = ds_full.skip(train_size)
ds_test = ds_test.map(transform)
ds_test = ds_test.batch(32)

需要注意的是要把特征数据转化成map类型,因为模型的第一个参数即为特征的参数名称。
iris共有150条数据,每个数据有4个特征。所以我们设置如下:

col_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
feature_columns = []
for col_name in col_names:
    feature_columns.append(tf.feature_column.numeric_column(col_name))
model = TabNetClassifier(feature_columns, num_classes=3, feature_dim=8, output_dim=4)

至此模型就创建好了,下面就是训练的部分:

lr = tf.keras.optimizers.schedules.ExponentialDecay(0.01, decay_steps=100, decay_rate=0.9)
optimizer = tf.keras.optimizers.Adam(lr)
model.compile(optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(ds_train, epochs=100, validation_data=ds_test, verbose=2)

相关文章

  • TabNet-神经网络处理表格数据实战

    我们知道神经网络在图片、信号等领域大放异彩。但在表格数据领域,基本还是树模型的主场。今天我们介绍下TabNet的使...

  • 深度学习笔记09_机器学习数据预处理

    数据预处理、特征工程 神经网络的数据预处理 预处理的原则:是使原始数据更适于用神经网络处理,主要包括:向量化,标准...

  • 数据预处理、特征工程和特征学习

    神经网络的数据预处理 数据预处理的目的是使原始数据更适合神经网络处理,包括向量化、标准化、处理缺失值和特征提取。 ...

  • html学习第三天

    表格: 表格使用在处理与显示数据,不适用于布局。 创建表格: ...

  • 无题

    工作技能方面,数据处理是重中之重,基础的数据处理没有难度,但是多份表格处理就尤为重要。达到多表格处理的程度,就基本...

  • 斯坦福cs231n学习笔记(10)------神经网络训练细节(

    神经网络训练细节系列笔记: 神经网络训练细节(激活函数) 神经网络训练细节(数据预处理、权重初始化) 神经网络训练...

  • HTML第二天

    表格 table(会使用) 表格不是用来布局,常见处理、显示表格式数据. 表格学习要求: 能手写表格结构,并且能合...

  • tensorflow2.3实战循环神经网络

    一:理论部分 embedding和变长输入处理 序列式问题 循环神经网络 LSTM模型原理 二:实战 keras实...

  • 第47周回顾,第48周计划

    上周回顾 论文的投稿:处理数据,生成图片和表格,修改文字部分完成了数据处理和图片表格部分,但是文字部分还没有完成。...

  • 处理excel表格数据

    1、将excel中的数据提取,处理数据后保存想要的格式我们的表格样式: 需求:将表格中extend列的数据单独提取...

网友评论

      本文标题:TabNet-神经网络处理表格数据实战

      本文链接:https://www.haomeiwen.com/subject/raingltx.html