Keras Callbacks – trợ thủ đắc lực khi train models

Hello tuần mới anh em Mì AI. Hôm nay chúng ta cùng nhau đi tìm hiểu Keras Callbacks – trợ thủ đắc lực khi train models nhé.

Trước khi bắt đầu mình cũng xin nói trước là đây là series dành cho các bạn newbie mới bất đầu thôi. Còn với các bạn làm lâu rồi thì có lẽ khái niệm này đã khá quen thuộc rồi nên các bạn góp ý thêm cho mình nhé.

keras callbacks
Nguồn: Tại đây

Rồi bây giờ tiếp tục nha!

Phần 1 – Keras Callback là cái chi chi?

Nhiều khi train model bị overfit, model kém chất lượng mà chúng ta không biết được chính xác khi nào cần dừng train, khi nào nên lưu lại weights… Đó là lúc chúng ta cần đến các callback function. Tìm hiểu nhé!

Theo như tài liệu chính chủ của Keras thì:

A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training.
Code language: JavaScript (javascript)

Như vậy giải nghĩa ra thì Callback là một số hàm (function), chúng ta sẽ chạy các hàm này ở bước nhất định của quá trình training để xem, lưu lại các trạng thái và thông số của model trong quá trình training.

Chắc vấn trừu tượng vãi phải không các bạn? Thôi giải thích theo kiểu Mì AI nhé. Các bạn dùng callbacks để can thiệp vào quá trình training của models, ví dụ

  • Stop quá trình training khi model đạt được một kết quả accuracy/loss nhât định nào đó, gọi là Early Stop (Dừng sớm)
  • Lưu lại weights( hoặc model) của model tại một thời điểm nào đó (gọi là checkpoint) khi bạn cho đó là một bộ weights tốt.
  • Điều chỉnh learning rate (LR) của model theo số epochs đã train. Ví dụ lúc đầu có thể để LR cao, sau đó giảm dần….

Còn nhiều hàm callbacks khác nữa, nhưng mình chỉ đi 03 hàm thông dụng bên trên, các bạn có thể tự tìm hiểu thêm ở tài liệu gốc của Keras nhé.

Phần 2 – Tìm hiểu Early Stop Callback

Nói một cách ngắn gọn là Early Stop Callback dùng để dừng quá trình train sớm, sau khi mà thông số chúng ta quan sát (có thể là validation accuracy hay validaiton loss) không “khá lên” sau một vài epochs.

Đây cũng là một kỹ thuật hay được dùng để tránh Overfit (OF). Không phải cứ train với nhiều epochs là tốt, chúng ta train nhiều quá mà không thấy val_loss giảm nữa hoặc val_acc tăng nữa mà vẫn tiếp tục train thì sẽ dẫn đến có khả năng OF.

Bạn nào chưa biết OF là gì thì có thể google thêm. Đại khái là khi đó model quá khớp vào dữ liệu train, tạo ra train acc và train loss rất tốt nhưng kết quả validation thì ngược lại, ngày càng tệ (và cũng là không “khá lên”).

Đó là khi chúng ta nên áp dụng Early Stop Callback.

tf.keras.callbacks.EarlyStopping( monitor="val_loss", min_delta=0, mode="auto", restore_best_weights=False, )
Code language: PHP (php)

Giải thích các tham số của EarlyStopping:

  • monitor: Chúng ta sẽ quan sát tham số gì? Như ở đây là val_loss, chúng ta có thể đổi thành : loss, val_accuracy, ….
  • min_delta: mức thay đổi tối thiểu để tham số quan sát ở trên được xem xet là đã “khá lên”. Nếu để min_delta=0 thì chỉ cần có thay đổi tích cực dù nhỏ nhất cũng được gọi là “khá lên” rồi. Còn nếu để min_delta=0.5 chẳng hạn thì tham số phải thay đổi tích cực một lượng 0.5 thì mới được coi là có “khá lên”. Tham số này mặc định là 0 nhé.
  • patience (mặc định = 0): dây là số nguyên (gọi là N). Model sẽ dừng train nếu tham số quan sát không “khá lên” sau N epochs.
  • mode: là chế độ xem xét tham số quan sát. Nếu để là “min” thì tham số quan sát càng giảm càng “khá” (ví dụ loss), nếu để là “max” thì tham số quan sát càng tăng càng “khá” (ví dụ accuracy). Nếu để “auto” là ngon lành nhất, mode sẽ tự chuyển thành min hoặc max tuỳ vào tham số của chúng ta là loss hay accuracy.
  • restore_best_weights: Tham số này quy định rặng model có restore lại weights tốt nhất khi stop hay không? Ví dụ, khi train đến epoch 50 thì model nhận thấy val_loss thấp nhất, và sau khi train 3 epoch 51,52,53 thì thấy val_loss bắt đầu giảm (model kém đi) nên EarlyStopping Callback ra lệnh “Dừng anh em ơi!”. Lúc này nếu ta để restore_best_weights=True thì model sẽ restore lại weights ở epoch 50 xong mới stop, còn ngược lại thì nó sẽ giữ nguyên weights ở bước 53.

Ví dụ chúng ta train một model và muốn rằng sau 5 epoch mà loss không giảm nữa thì stop và lấy weights tốt nhất để lưu ra file model. Ta sẽ viết source như sau:

import tensorflow as tf import numpy as np # Thiết lập hàm call back Early Stopping callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=5, restore_best_weights=True) # Tạo lập model model = tf.keras.models.Sequential([tf.keras.layers.Dense(512)]) model.compile(tf.keras.optimizers.SGD(), loss='mse') # Sinh dữ liệu fake để train thử X = np.arange(100).reshape(5, 20).astype(float) y = np.zeros(5) # Train model 10 epochs, batch_sizes = 1 history = model.fit( X, y , epochs=10, batch_size=1, callbacks=[callback]) # In history ra xem train mấy epochs print("Số epoch đã train = ", len(history.history['loss']))
Code language: PHP (php)

Và chúng ta được kết quả:

5/5 [==============================] - 0s 37ms/sample - loss: 543.4534 Epoch 2/10 5/5 [==============================] - 0s 766us/sample - loss: 52498.2139 Epoch 3/10 5/5 [==============================] - 0s 982us/sample - loss: 1131009.8172 Epoch 4/10 5/5 [==============================] - 0s 934us/sample - loss: 1418184.8703 Epoch 5/10 5/5 [==============================] - 0s 697us/sample - loss: 129574012.1000 Epoch 6/10 5/5 [==============================] - 0s 3ms/sample - loss: 179616147.2000 Số epoch đã train = 6

Chỉ có 06 epoch được train và model đã stop vì loss không giảm nữa. Thực ra bài này loss tăng vê lờ luôn anh em. Thử nghiệm để anh em hiểu như nào là Early Stop thôi mà.

Anh em có thể tham khảo rõ hơn qua đồ thị sau:

keras callbacks
Nguồn: Tại đây

Phần 3 – Checkpoint Keras Callbacks

Với Checkpoint callbacks thì khá đơn giản, nó chỉ đơn giản làm nhiệm vụ lưu lại bộ weights tốt nhất cho chúng ta. Chúng ta quan tâm 02 thứ ở đây:

  • Thứ nhất, hư nào là weights tốt? Câu trả lời là tuỳ ta chọn tham số nào để quan sát (val_loss, val_accuracy…) như ở EarlyStop. Loss thì càng thấp càng tốt, accuracy thì càng cao càng tốt.
  • Thứ hai, checkpoint callback sẽ được gọi thực thi khi nào? Câu trả lời là nó được gọi thực thi sau mỗi epoch. Khi một epoch kết thúc nó sẽ kiểm tra xem bộ weights hiện tại có được goi là “tốt nhất” không? Nếu có nó sẽ được lưu lại.
Nguồn: Tại đây

Nào, cùng xem cái hàm Checkpoint mặt mũi nó ra làm sao nào:

tf.keras.callbacks.ModelCheckpoint( filepath, monitor="val_loss", verbose=0, save_best_only=False, save_weights_only=False, mode="auto", save_freq="epoch" )
Code language: PHP (php)

Các tham số của keras callbacks này như sau:

  • filepath: Đường dẫn lưu file weights/model.
  • monitor: tham số quan sát, tương tự như ở trên Early Stop.
  • save_best_only: Nếu để là true thì model chỉ lưu lại một checkpoint tốt nhất và ngược lại. Cái này chúng ta nên để là true.
  • save_weights_only: Nếu để là True thì callback sẽ chỉ lưu lại weights, không lưu lại cấu trúc model, còn ngược lại thì sẽ lưu cả kiến trúc model trong filepath. Các bạn tuỳ ý sử dụng nhé.
  • mode: Tương tự như ở EarlyStop, các bạn kéo lên xem nha.
  • save_freq: Nếu để là “epoch” thì model sẽ kiểm tra tham số quan sát sau mỗi epoch để từ đó đánh giá xem weights hiẹn tại có “tốt nhất” (để lưu lại) hay không? Còn nếu là số nguyên dương N thì model sẽ kiểm tra sau N batches.

Ví dụ một chút cho dễ hiểu nhé. Bây giờ chúng ta có bài toán dự đoán xem một người có bị tiểu đường hay không dựa vào các tham số đo được của người đó. Mình dùng dữ liệu từ Kaggle (https://www.kaggle.com/uciml/pima-indians-diabetes-database) nhé.

from keras.models import Sequential from keras.layers import Dense from keras.callbacks import ModelCheckpoint import numpy as np # Load dữ liệu dataset = np.loadtxt("pima-indians-diabetes.data.txt", delimiter=",") # Chia ra input X và output y X = dataset[:,0:8] Y = dataset[:,8] # Tạo model model = Sequential() model.add(Dense(32, input_dim=8, activation='relu')) model.add(Dense(16, activation='relu')) model.add(Dense(1, activation='sigmoid')) model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # Tạo callback filepath="checkpoint.hdf5" callback = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, save_best_only=True, mode='auto') # Train model model.fit(X, Y, validation_split=0.2, epochs=100, batch_size=8, callbacks=[callback])
Code language: PHP (php)

Và sau khi train 100 epochs thì chúng ta sẽ thấy xuất hiện một file checkpoint.hdf5 trong thư mục hiện tại. Đó chính là file lưu lại weights tốt nhất của chúng ta. Và nếu quan sát log in ra, các bạn sẽ thấy:

Epoch 00096: val_accuracy did not improve from 0.74675 Epoch 97/100 614/614 [==============================] - 0s 112us/step - loss: 0.4925 - accuracy: 0.7655 - val_loss: 0.6825 - val_accuracy: 0.6948 Epoch 00097: val_accuracy did not improve from 0.74675 Epoch 98/100 614/614 [==============================] - 0s 107us/step - loss: 0.5555 - accuracy: 0.7508 - val_loss: 0.8247 - val_accuracy: 0.6558 Epoch 00098: val_accuracy did not improve from 0.74675 Epoch 99/100 614/614 [==============================] - 0s 112us/step - loss: 0.5101 - accuracy: 0.7573 - val_loss: 0.7197 - val_accuracy: 0.6948 Epoch 00099: val_accuracy did not improve from 0.74675 Epoch 100/100 614/614 [==============================] - 0s 110us/step - loss: 0.4815 - accuracy: 0.7687 - val_loss: 0.6195 - val_accuracy: 0.7273 Epoch 00100: val_accuracy did not improve from 0.74675
Code language: JavaScript (javascript)

Như vậy val_accuracy tốt nhất là 0.74675 và đã được lưu lại! Chúng ta cũng có thể nhét thêm các thông số như: epoch, giá trị loss, accuracy trong tên file check point nhé. Ví dụ:

filepath="checkpoint-{epoch}-{val_accuracy}.hdf5"
Code language: JavaScript (javascript)

Rồi, như vậy có thể train 100 hay 1000 epochs đi chăng nữa thì weights tốt nhất luôn được lưu lại cho chúng ta! Yên tâm nhé!

Phần 4 – LearningRateScheduler keras callbacks

Cái món keras callbacks này thì đơn giản lắm luôn, nó chỉ có một nhiệm vụ điều chỉnh Learning Rate (LR) cho quá trình train. Bạn nào chưa biết LR là gì thì đọc đây nhé!

keras callbacks learning rate
Nguồn: Tại đây

Cấu trúc của món này khá đơn giản:

tf.keras.callbacks.LearningRateScheduler(schedule, verbose=0)

Trong đó schedule là tên hàm sẽ trả về LR.

Ví dụ chính chủ của Keras luôn:

import tensorflow as tf import numpy as np tf.enable_eager_execution() # Định nghĩa hàm trả về LR def scheduler(epoch, lr): # Nếu dưới 5 epoch if epoch < 5: # Trả về lr return float(lr) else: # Còn không thì trả về return float(lr * tf.math.exp(-0.1)) # Định nghĩa model model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)]) model.compile(tf.keras.optimizers.SGD(), loss='mse') print("Learning rate ban đầu = ", round(model.optimizer.lr.numpy(), 5)) # Train model và xem learning rate callback = tf.keras.callbacks.LearningRateScheduler(scheduler) X = np.arange(100).reshape(5, 20).astype(float) y = np.zeros(5) history = model.fit( X, y , epochs=8, callbacks=[callback], verbose=1) print("Learning rate sau khi train xong 8 epochs = ", round(model.optimizer.lr.numpy(), 5))
Code language: PHP (php)

Và kết quả đây:

Learning rate ban đầu = 0.01 Train on 5 samples Epoch 1/8 5/5 [==============================] - 1s 186ms/sample - loss: 3805.0991 Epoch 2/8 5/5 [==============================] - 0s 208us/sample - loss: 64118804.0000 Epoch 3/8 5/5 [==============================] - 0s 243us/sample - loss: 1083885420544.0000 Epoch 4/8 5/5 [==============================] - 0s 197us/sample - loss: 18322360549507072.0000 Epoch 5/8 5/5 [==============================] - 0s 188us/sample - loss: 309727253441802141696.0000 Epoch 6/8 5/5 [==============================] - 0s 175us/sample - loss: 5235732452911954394087424.0000 Epoch 7/8 5/5 [==============================] - 0s 211us/sample - loss: 72345870604726808312717246464.0000 Epoch 8/8 5/5 [==============================] - 0s 215us/sample - loss: 816984177027815951967181258883072.0000 Learning rate sau khi train xong 8 epochs = 0.00741

Đó, chúng ta thấy rõ sự thay đổi của LR qua quá trình train. Ý đố ở đây là lúc đầu mới train thì ta cần LR lớn để nhanh chóng tìm đến điểm cực tiểu, nhưng khi đã gần điểm cực tiểu rồi thì ta cần giảm LR để nhanh chóng dừng lại tránh “đi quá” điểm cực tiểu cần tìm.

Rồi vậy là mình đã cùng các bạn đi qua vài keras callbacks function thông dụng trong train model. Các bạn hãy vận dụng vào thực tế để train model tốt hơn và nhanh hơn nhé.

Các bạn có thể tham khảo source full tại đây: https://github.com/thangnch/MiAI_Keras_Callback

Chúc các bạn thành công!

#MìAI

Fanpage: http://facebook.com/miaiblog
Group trao đổi, chia sẻ: https://www.facebook.com/groups/miaigroup
Website: https://miai.vn
Youtube: http://bit.ly/miaiyoutube

Bài viết tham khảo: Tại đây

Related Post

5 Replies to “Keras Callbacks – trợ thủ đắc lực khi train models”

  1. Hay anh ơi, em mới tìm hiểu xong nhưng anh lại đăng bài này nên em được củng cố thêm

  2. a ơi cho e hỏi e dùng Checkpoint Keras Callbacks thì khi muốn sử dụng lại weight thì gọi như nào ạ

Leave a Reply

Your email address will not be published. Required fields are marked *