Keras Callbacks – trợ thủ đắc lực khi train models

Hello tuần mới anh em Mì AI. Hôm nay chúng ta cùng nhau đi tìm hiểu Keras Callbacks – trợ thủ đắc lực khi train models nhé.

Trước khi bắt đầu mình cũng xin nói trước là đây là series dành cho các bạn newbie mới bất đầu thôi. Còn với các bạn làm lâu rồi thì có lẽ khái niệm này đã khá quen thuộc rồi nên các bạn góp ý thêm cho mình nhé.

keras callbacks
Nguồn: Tại đây

Rồi bây giờ tiếp tục nha!

Phần 1 – Keras Callback là cái chi chi?

Nhiều khi train model bị overfit, model kém chất lượng mà chúng ta không biết được chính xác khi nào cần dừng train, khi nào nên lưu lại weights… Đó là lúc chúng ta cần đến các callback function. Tìm hiểu nhé!

Theo như tài liệu chính chủ của Keras thì:

Như vậy giải nghĩa ra thì Callback là một số hàm (function), chúng ta sẽ chạy các hàm này ở bước nhất định của quá trình training để xem, lưu lại các trạng thái và thông số của model trong quá trình training.

Chắc vấn trừu tượng vãi phải không các bạn? Thôi giải thích theo kiểu Mì AI nhé. Các bạn dùng callbacks để can thiệp vào quá trình training của models, ví dụ

  • Stop quá trình training khi model đạt được một kết quả accuracy/loss nhât định nào đó, gọi là Early Stop (Dừng sớm)
  • Lưu lại weights( hoặc model) của model tại một thời điểm nào đó (gọi là checkpoint) khi bạn cho đó là một bộ weights tốt.
  • Điều chỉnh learning rate (LR) của model theo số epochs đã train. Ví dụ lúc đầu có thể để LR cao, sau đó giảm dần….

Còn nhiều hàm callbacks khác nữa, nhưng mình chỉ đi 03 hàm thông dụng bên trên, các bạn có thể tự tìm hiểu thêm ở tài liệu gốc của Keras nhé.

Phần 2 – Tìm hiểu Early Stop Callback

Nói một cách ngắn gọn là Early Stop Callback dùng để dừng quá trình train sớm, sau khi mà thông số chúng ta quan sát (có thể là validation accuracy hay validaiton loss) không “khá lên” sau một vài epochs.

Đây cũng là một kỹ thuật hay được dùng để tránh Overfit (OF). Không phải cứ train với nhiều epochs là tốt, chúng ta train nhiều quá mà không thấy val_loss giảm nữa hoặc val_acc tăng nữa mà vẫn tiếp tục train thì sẽ dẫn đến có khả năng OF.

Bạn nào chưa biết OF là gì thì có thể google thêm. Đại khái là khi đó model quá khớp vào dữ liệu train, tạo ra train acc và train loss rất tốt nhưng kết quả validation thì ngược lại, ngày càng tệ (và cũng là không “khá lên”).

Đó là khi chúng ta nên áp dụng Early Stop Callback.

Giải thích các tham số của EarlyStopping:

  • monitor: Chúng ta sẽ quan sát tham số gì? Như ở đây là val_loss, chúng ta có thể đổi thành : loss, val_accuracy, ….
  • min_delta: mức thay đổi tối thiểu để tham số quan sát ở trên được xem xet là đã “khá lên”. Nếu để min_delta=0 thì chỉ cần có thay đổi tích cực dù nhỏ nhất cũng được gọi là “khá lên” rồi. Còn nếu để min_delta=0.5 chẳng hạn thì tham số phải thay đổi tích cực một lượng 0.5 thì mới được coi là có “khá lên”. Tham số này mặc định là 0 nhé.
  • patience (mặc định = 0): dây là số nguyên (gọi là N). Model sẽ dừng train nếu tham số quan sát không “khá lên” sau N epochs.
  • mode: là chế độ xem xét tham số quan sát. Nếu để là “min” thì tham số quan sát càng giảm càng “khá” (ví dụ loss), nếu để là “max” thì tham số quan sát càng tăng càng “khá” (ví dụ accuracy). Nếu để “auto” là ngon lành nhất, mode sẽ tự chuyển thành min hoặc max tuỳ vào tham số của chúng ta là loss hay accuracy.
  • restore_best_weights: Tham số này quy định rặng model có restore lại weights tốt nhất khi stop hay không? Ví dụ, khi train đến epoch 50 thì model nhận thấy val_loss thấp nhất, và sau khi train 3 epoch 51,52,53 thì thấy val_loss bắt đầu giảm (model kém đi) nên EarlyStopping Callback ra lệnh “Dừng anh em ơi!”. Lúc này nếu ta để restore_best_weights=True thì model sẽ restore lại weights ở epoch 50 xong mới stop, còn ngược lại thì nó sẽ giữ nguyên weights ở bước 53.

Ví dụ chúng ta train một model và muốn rằng sau 5 epoch mà loss không giảm nữa thì stop và lấy weights tốt nhất để lưu ra file model. Ta sẽ viết source như sau:

Và chúng ta được kết quả:

Chỉ có 06 epoch được train và model đã stop vì loss không giảm nữa. Thực ra bài này loss tăng vê lờ luôn anh em. Thử nghiệm để anh em hiểu như nào là Early Stop thôi mà.

Anh em có thể tham khảo rõ hơn qua đồ thị sau:

keras callbacks
Nguồn: Tại đây

Phần 3 – Checkpoint Keras Callbacks

Với Checkpoint callbacks thì khá đơn giản, nó chỉ đơn giản làm nhiệm vụ lưu lại bộ weights tốt nhất cho chúng ta. Chúng ta quan tâm 02 thứ ở đây:

  • Thứ nhất, hư nào là weights tốt? Câu trả lời là tuỳ ta chọn tham số nào để quan sát (val_loss, val_accuracy…) như ở EarlyStop. Loss thì càng thấp càng tốt, accuracy thì càng cao càng tốt.
  • Thứ hai, checkpoint callback sẽ được gọi thực thi khi nào? Câu trả lời là nó được gọi thực thi sau mỗi epoch. Khi một epoch kết thúc nó sẽ kiểm tra xem bộ weights hiện tại có được goi là “tốt nhất” không? Nếu có nó sẽ được lưu lại.
Nguồn: Tại đây

Nào, cùng xem cái hàm Checkpoint mặt mũi nó ra làm sao nào:

Các tham số của keras callbacks này như sau:

  • filepath: Đường dẫn lưu file weights/model.
  • monitor: tham số quan sát, tương tự như ở trên Early Stop.
  • save_best_only: Nếu để là true thì model chỉ lưu lại một checkpoint tốt nhất và ngược lại. Cái này chúng ta nên để là true.
  • save_weights_only: Nếu để là True thì callback sẽ chỉ lưu lại weights, không lưu lại cấu trúc model, còn ngược lại thì sẽ lưu cả kiến trúc model trong filepath. Các bạn tuỳ ý sử dụng nhé.
  • mode: Tương tự như ở EarlyStop, các bạn kéo lên xem nha.
  • save_freq: Nếu để là “epoch” thì model sẽ kiểm tra tham số quan sát sau mỗi epoch để từ đó đánh giá xem weights hiẹn tại có “tốt nhất” (để lưu lại) hay không? Còn nếu là số nguyên dương N thì model sẽ kiểm tra sau N batches.

Ví dụ một chút cho dễ hiểu nhé. Bây giờ chúng ta có bài toán dự đoán xem một người có bị tiểu đường hay không dựa vào các tham số đo được của người đó. Mình dùng dữ liệu từ Kaggle (https://www.kaggle.com/uciml/pima-indians-diabetes-database) nhé.

Và sau khi train 100 epochs thì chúng ta sẽ thấy xuất hiện một file checkpoint.hdf5 trong thư mục hiện tại. Đó chính là file lưu lại weights tốt nhất của chúng ta. Và nếu quan sát log in ra, các bạn sẽ thấy:

Như vậy val_accuracy tốt nhất là 0.74675 và đã được lưu lại! Chúng ta cũng có thể nhét thêm các thông số như: epoch, giá trị loss, accuracy trong tên file check point nhé. Ví dụ:

Rồi, như vậy có thể train 100 hay 1000 epochs đi chăng nữa thì weights tốt nhất luôn được lưu lại cho chúng ta! Yên tâm nhé!

Phần 4 – LearningRateScheduler keras callbacks

Cái món keras callbacks này thì đơn giản lắm luôn, nó chỉ có một nhiệm vụ điều chỉnh Learning Rate (LR) cho quá trình train. Bạn nào chưa biết LR là gì thì đọc đây nhé!

keras callbacks learning rate
Nguồn: Tại đây

Cấu trúc của món này khá đơn giản:

Trong đó schedule là tên hàm sẽ trả về LR.

Ví dụ chính chủ của Keras luôn:

Và kết quả đây:

Đó, chúng ta thấy rõ sự thay đổi của LR qua quá trình train. Ý đố ở đây là lúc đầu mới train thì ta cần LR lớn để nhanh chóng tìm đến điểm cực tiểu, nhưng khi đã gần điểm cực tiểu rồi thì ta cần giảm LR để nhanh chóng dừng lại tránh “đi quá” điểm cực tiểu cần tìm.

Rồi vậy là mình đã cùng các bạn đi qua vài keras callbacks function thông dụng trong train model. Các bạn hãy vận dụng vào thực tế để train model tốt hơn và nhanh hơn nhé.

Các bạn có thể tham khảo source full tại đây: https://github.com/thangnch/MiAI_Keras_Callback

Chúc các bạn thành công!

#MìAI

Fanpage: http://facebook.com/miaiblog
Group trao đổi, chia sẻ: https://www.facebook.com/groups/miaigroup
Website: https://miai.vn
Youtube: http://bit.ly/miaiyoutube

Bài viết tham khảo: Tại đây

Nguyễn Chiến Thắng

Một người đam mê những điều mới mẻ và công nghệ hiện đại. Uớc mơ cháy bỏng dùng AI, ML để làm cho cuộc sống tốt đẹp hơn! Liên hệ: thangnch@gmail.com hoặc facebook.com/thangnch

Related Post

2 Replies to “Keras Callbacks – trợ thủ đắc lực khi train models”

  1. Hay anh ơi, em mới tìm hiểu xong nhưng anh lại đăng bài này nên em được củng cố thêm

Leave a Reply

Your email address will not be published. Required fields are marked *