2017年10月12日,AWS與微軟合作發布了Gluon開源項目,該項目旨在幫助開發者更加簡單快速的構建機器學習模型,同時保留了較好的性能。根據Gluon項目官方Github頁面上的描述,Gluon API支持任意一種深度學習框架,其相關規范已經在Apache MXNet項目中實施,開發者只需安裝最新版本的MXNet(master)即可體驗。AWS用戶可以創建一個AWS Deep Learning AMI進行體驗。該頁面提供了一段簡易使用說明,摘錄如下:本教程以一個兩層神經網絡的構建和訓練為例,我們將它稱呼為多層感知機(multilayer perceptron)。(本示范建議使用Python 3.3或以上,并且使用Jupyter notebook來運行。詳細教程可參考這個頁面。)首先,進行如下引用聲明:import mxnet as mx
from mxnet import gluon, autograd, ndarray
import numpy as np
然后,使用gluon.data.DataLoader承載訓練數據和測試數據。這個DataLoader是一個iterator對象類,非常適合處理規模較大的數據集。
train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, transform=lambda data, label: (data.astype(np.float32)/255, label)), batch_size=32, shuffle=True)test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, transform=lambda data, label: (data.astype(np.float32)/255, label)), batch_size=32, shuffle=False)接下來,定義神經網絡:
# 先把模型做個初始化net = gluon.nn.Sequential()# 然后定義模型架構with net.name_scope(): net.add(gluon.nn.Dense(128, activation="relu")) # 第一層設置128個節點 net.add(gluon.nn.Dense(64, activation="relu")) # 第二層設置64個節點 net.add(gluon.nn.Dense(10)) # 輸出層然后把模型的參數設置一下:
# 先隨機設置模型參數# 數值從一個標準差為0.05正態分布曲線里面取net.collect_params().initialize(mx.init.Normal(sigma=0.05))# 使用softmax cross entropy loss算法 # 計算模型的預測能力softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()# 使用隨機梯度下降算法(sgd)進行訓練# 并且將學習率的超參數設置為 .1trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .1})之后就可以開始跑訓練了,一共分四個步驟。一、把數據放進去;二、在神經網絡模型算出輸出之后,比較其與實際結果的差距;三、用Gluon的autograd計算模型各參數對此差距的影響;四、用Gluon的trainer方法優化這些參數以降低差距。以下我們先讓它跑10輪的訓練:
epochs = 10for e in range(epochs): for i, (data, label) in enumerate(train_data): data = data.as_in_context(mx.cpu()).reshape((-1, 784)) label = label.as_in_context(mx.cpu()) with autograd.record(): # Start recording the derivatives output = net(data) # the forward iteration loss = softmax_cross_entropy(output, label) loss.backward() trainer.step(data.shape[0]) # Provide stats on the improvement of the model over each epoch curr_loss = ndarray.mean(loss).asscalar() print("Epoch {}. Current Loss: {}.".format(e, curr_loss))若想了解更多Gluon說明與用法,可以查看gluon.mxnet.io這個網站。