使用TensorFlow官方提供了一個例子,基於MNIST數據集,實現一個圖片分類的應用,本文是基於TensorFlow 2.0.0版本來學習和試驗的。

MNIST數據集是一個非常出名的手寫體數字識別數據集,它包含了60000張圖片作爲訓練集,10000張圖片作爲測試集,每張圖片中的手寫體數字是0~9中的一個,圖片是28×28像素大小,並且每個數字都是位於圖片的正中間的。

使用TensorFlow對MNIST數據集進行分類,整個實現對應的完整的Python代碼,如下所示:

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf


# 下載 MNIST 數據集
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 創建 tf.keras.Sequential 模型
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 訓練模型
model.fit(x_train, y_train, epochs=5)

# 驗證模型
model.evaluate(x_test,  y_test, verbose=2)

訓練集與測試集

上面,x_train是訓練集,它的大小是60000,其中,裏面包含的每一個圖片是28×28像素,由一個28×28的二維數組表示。x_train數據的結構如下所示:

array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=uint8)

下面,我們從x_train中拿出一個元素,即一個圖片對應的二維數組x_train[0],如下所示:

array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   3,  18,  18,  18, 126, 136, 175,  26, 166, 255, 247, 127,  0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  30,  36,  94, 154, 170, 253, 253, 253, 253, 253, 225, 172, 253, 242, 195,  64,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  49, 238, 253, 253, 253, 253, 253, 253, 253, 253, 251,  93,  82,  82,  56,  39,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  18, 219, 253, 253, 253, 253, 253, 198, 182, 247, 241,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  80, 156, 107, 253, 253, 205,  11,   0,  43, 154,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,  14,   1, 154, 253, 90,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 139, 253, 190,   2,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  11, 190, 253,  70,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  35, 241, 225, 160, 108,   1,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 81, 240, 253, 253, 119,  25,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,  45, 186, 253, 253, 150,  27,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,  16,  93, 252, 253, 187,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0, 249, 253, 249,  64,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,  46, 130, 183, 253, 253, 207,   2,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  39, 148, 229, 253, 253, 253, 250, 182,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  24, 114, 221, 253, 253, 253, 253, 201,  78,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  23,  66, 213, 253, 253, 253, 253, 198,  81,   2,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,  18, 171, 219, 253, 253, 253, 253, 195,  80,   9,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,  55, 172, 226, 253, 253, 253, 253, 244, 133,  11,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0, 136, 253, 253, 253, 212, 135, 132,  16,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0]], dtype=uint8)

由上面的矩陣可以看到,矩陣是非常稀疏的。從視覺上看,上面由非零的值組成的形狀,恰好像手寫數字5,其實它對應的分類標籤(Label)就是5,可以看到y_train[0]=5。

另外,測試集的數據格式,也與訓練集相同,它有10000個樣本。

上面x_train, x_test = x_train / 255.0, x_test / 255.0表示,對訓練接和測試集數據進行縮放,由整數歸一化轉換到0~1之間的浮點數。

模型創建與配置

tf.keras是Keras API的TensorFlow實現,它是一個用來構建和訓練模型的High-Level API,能夠快速上手並方便實現原型的設計。如果使用TensorFlow的Low-Level API實現,會非常複雜,使用起來沒有Keras API靈活方便。

Keras有兩種模型:順序模型(Sequential Model)和通用模型(Model),使用順序模型非常簡單,只需要創建並來配置好神經網絡各個Layer的實例,然後組裝起來就表徵並實現了一個模型,後續可以直接對其執行訓練和驗證的操作。例如,上述我們創建的順序模型:

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

上面設計的神經網絡模型,一共包含了4層:1個是輸入層,1個是隱藏層,1個是Dropout層,1個是Softmax輸出層。

已經組裝好神經網絡模型,接下來我們需要爲定義模型進行配置,以便訓練模型使用這些配置:

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

上面,優化器參數值設置爲Adam,它是實現了Adam(適應性動量估計法)算法,能夠對不同參數計算適應性學習率,經驗表明,Adam在實踐中表現很好。另外,還有其他的優化器可以選擇:

  • Adadelta
  • Adagrad
  • Adamax
  • Ftrl
  • Nadam
  • RMSprop
  • SGD

loss表示目標函數,需要輸入的是目標函數的名稱,這裏使用了sparse_categorical_crossentropy函數,它是一個多類別交叉熵損失函數,對輸入的格式要求是數字編碼的, 而不是one-hot編碼格式。sparse_categorical_crossentropy函數代碼如下所示:

@keras_export('keras.metrics.sparse_categorical_crossentropy',
              'keras.losses.sparse_categorical_crossentropy')
def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
  return K.sparse_categorical_crossentropy(
      y_true, y_pred, from_logits=from_logits, axis=axis)

對於compile的最後一個參數,metrics配置爲accuracy,表示普通的準確度評估方法。

訓練和驗證

上面代碼中,運行到模型訓練model.fit(x_train, y_train, epochs=5),生成結果如下所示:

Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 8s 126us/sample - loss: 0.2923 - accuracy: 0.9154
Epoch 2/5
60000/60000 [==============================] - 7s 118us/sample - loss: 0.1432 - accuracy: 0.9571
Epoch 3/5
60000/60000 [==============================] - 7s 114us/sample - loss: 0.1062 - accuracy: 0.9681
Epoch 4/5
60000/60000 [==============================] - 8s 133us/sample - loss: 0.0857 - accuracy: 0.9733
Epoch 5/5
60000/60000 [==============================] - 7s 123us/sample - loss: 0.0748 - accuracy: 0.9766
<tensorflow.python.keras.callbacks.History object at 0x118c7d7f0>

最後,根據測試集對模型進行驗證,執行model.evaluate(x_test, y_test, verbose=2),結果如下所示:

10000/1 - 1s - loss: 0.0396 - accuracy: 0.9759
[0.07745860494738445, 0.9759]

可見,分類器識別的準確度爲97.59%。

參考鏈接

本文基於 署名-非商業性使用-相同方式共享 4.0 許可協議發佈,歡迎轉載、使用、重新發布,但務必保留文章署名時延軍(包含鏈接:http://shiyanjun.cn),不得用於商業目的,基於本文修改後的作品務必以相同的許可發佈。如有任何疑問,請與我聯繫。

相關文章