美文网首页
玩转fashion-mnist 数据集

玩转fashion-mnist 数据集

作者: 圣_狒司机 | 来源:发表于2019-07-19 00:00 被阅读0次

fashion-mnist 是mnist的升级版;

数据长这样,7000张不同类别的单色图片:

pic.png

任务是给这些图片分类,衣服鞋包包归纳整齐。

数据归类:

from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D,MaxPool2D,AveragePooling2D,Flatten,Dense
import pandas as pd

path = Path.cwd()/"mnist_fashion"
pathes = []

def to_the_end(path):
    if path.is_file():
        pathes.append(path)
    else:
        for i in path.iterdir():
            to_the_end(i)
to_the_end(path)

def show_data(row,col,x_train):
    for index in range(1,row*col+1):
        ax = plt.subplot(row,col,index)
        ax.imshow(x_train[index],"gray")
        plt.axis('off')

X = np.array([plt.imread(str(i)) for i in pathes])
Y = np.array([int(p.parent.name) for p in pathes])

X_train, X_test, y_train, y_test = train_test_split(X,Y)
X_train, X_test = np.expand_dims(X_train,-1),np.expand_dims(X_test,-1)

CNN分类数据集:

model = Sequential()
model.add(Conv2D(32,(3,3),input_shape=(28,28,1)))
model.add(MaxPool2D())
model.add(Conv2D(64,(3,3)))
model.add(MaxPool2D())
model.add(Conv2D(128,(3,3)))
model.add(MaxPool2D())
model.add(Flatten())
model.add(Dense(128,activation="relu"))
model.add(Dense(10,activation="softmax"))

model.summary()

model.compile(optimizer="adam",loss="sparse_categorical_crossentropy",metrics=["accuracy"])
history = model.fit(X_train,y_train,batch_size=500,epochs=10,validation_data=[X_test,y_test])
history = pd.DataFrame(history.history)
history.plot()

跑分结果:

Model: "sequential_33"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_83 (Conv2D)           (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d_41 (MaxPooling (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_84 (Conv2D)           (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_42 (MaxPooling (None, 5, 5, 64)          0         
_________________________________________________________________
conv2d_85 (Conv2D)           (None, 3, 3, 128)         73856     
_________________________________________________________________
max_pooling2d_43 (MaxPooling (None, 1, 1, 128)         0         
_________________________________________________________________
flatten_10 (Flatten)         (None, 128)               0         
_________________________________________________________________
dense_11 (Dense)             (None, 128)               16512     
_________________________________________________________________
dense_12 (Dense)             (None, 10)                1290      
=================================================================
Total params: 110,474
Trainable params: 110,474
Non-trainable params: 0
_________________________________________________________________
Train on 52500 samples, validate on 17500 samples
Epoch 1/10
52500/52500 [==============================] - 39s 745us/sample - loss: 0.8906 - accuracy: 0.6817 - val_loss: 0.6245 - val_accuracy: 0.7737
Epoch 2/10
52500/52500 [==============================] - 39s 741us/sample - loss: 0.5823 - accuracy: 0.7905 - val_loss: 0.5329 - val_accuracy: 0.8085
Epoch 3/10
52500/52500 [==============================] - 39s 746us/sample - loss: 0.5072 - accuracy: 0.8181 - val_loss: 0.4819 - val_accuracy: 0.8294
Epoch 4/10
52500/52500 [==============================] - 43s 815us/sample - loss: 0.4567 - accuracy: 0.8355 - val_loss: 0.4415 - val_accuracy: 0.8447
Epoch 5/10
52500/52500 [==============================] - 48s 919us/sample - loss: 0.4260 - accuracy: 0.8475 - val_loss: 0.4359 - val_accuracy: 0.8456
Epoch 6/10
52500/52500 [==============================] - 44s 836us/sample - loss: 0.3940 - accuracy: 0.8569 - val_loss: 0.4029 - val_accuracy: 0.8545
Epoch 7/10
52500/52500 [==============================] - 39s 744us/sample - loss: 0.3785 - accuracy: 0.8630 - val_loss: 0.4127 - val_accuracy: 0.8489
Epoch 8/10
52500/52500 [==============================] - 39s 741us/sample - loss: 0.3580 - accuracy: 0.8705 - val_loss: 0.3710 - val_accuracy: 0.8696
Epoch 9/10
52500/52500 [==============================] - 40s 753us/sample - loss: 0.3443 - accuracy: 0.8749 - val_loss: 0.3732 - val_accuracy: 0.8636
Epoch 10/10
52500/52500 [==============================] - 39s 745us/sample - loss: 0.3316 - accuracy: 0.8800 - val_loss: 0.3634 - val_accuracy: 0.8716
download.png

验证集正确率达到了87%,是不是很棒!

相关文章

网友评论

      本文标题:玩转fashion-mnist 数据集

      本文链接:https://www.haomeiwen.com/subject/awddlctx.html