Historically, DNNs training happens on the cloud due to the huge memory cost. Edge platforms used to only perform inference, but it is difficult to learn to adapt to the new sensory data. Can we train on the edge to make a device continually improve its prediction? In this work, we enable on-device training under 256KB SRAM and 1MB Flash, using less than 1/1000 memory of PyTorch while matching the accuracy on the visual wake words application. It enables the model to adapt to newly collected sensor data and users can enjoy customized services without uploading the data to the cloud thus protecting privacy. Details below:
Website:https://tinytraining.mit.edu/
Paper:https://arxiv.org/abs/2206.15472
Demo: https://youtu.be/XaDCO8YtmBw
Code: https://github.com/mit-han-lab/tiny-training
Background
On-device training allows the pre-trained model to adapt to new environments after deployment. By training and adapting locally on the edge, the model can learn to improve its predictions and perform user customization. For example, fine-tuning a language model can continually learn from users’ typing; adapting a vision model enables recognizing new objects from a mobile camera. By bringing training closer to the sensors, it also helps to protect user privacy when handling sensitive data (e.g., healthcare, input history).
However, on-device training on tiny edge devices is extremely challenging and fundamentally different from cloud training. Firstly, tiny IoT devices (microcontrollers) typically have a limited SRAM size like 256KB, which is hardly enough for the inference, let alone the training. Secondly, though there are low-cost efficient transfer learning algorithms like training only the final classifier layer, bias-only update, the accuracy drop is significant and existing training systems can not realize the theoretical numbers into measured savings. Finally, modern deep training frameworks (PyTorch, TensorFlow) are usually designed for cloud servers and require a large memory footprint even when training a small model with batch size 1 (MobileNetV2-w0.35). Therefore, we need to jointly design the algorithm and the system to enable tiny on-device training.
Methods & Results
We investigate tiny on-device training and find two unique challenges: (1) the model is quantized on edge devices. A real quantized graph (shown below) is difficult to optimize due to low-precision tensors and the lack of Batch Normalization layers; (2) the limited hardware resource (memory and computation) of tiny hardware does not allow full back-propagation, whose memory usage can easily exceed the SRAM of microcontrollers by more than an order of magnitude. Only updating the last layer leads to poor accuracy.
To cope with the optimization difficulty, we propose Quantization-Aware Scaling (QAS) to automatically scale the gradient of tensors with different bit-precisions (shown on left below), which effectively matches the scales and stabilizes the training. It is worth noting that QAS is hyper-parameter free and QAS matches the accuracy of the floating-point counterpart (right below).
To reduce the memory footprint of the full backward computation, we propose Sparse Update to skip the gradient computation of less important layers and sub-tensors. We developed an automated method based on contribution analysis to find the best update scheme. We compare the performance of our searched sparse update schemes with bias-only; last-k-layers. Sparse updates demonstrate 4.5x to 7.5x memory saving with even higher average accuracy on the 8 downstream datasets.
Finally, the innovation is implemented by Tiny Training Engine (TTE), which offloads the auto-diff to the compile-time and uses codegen to minimize runtime overhead. It also supports graph pruning and reordering to support sparse updates, achieving measured saving and speedup. The sparse update effectively reduces peak memory by 7-9× compared to the full update, and can be further reduced with operator reordering, leading to 20-21× total memory saving. The optimized kernels and sparse update enhance the training speed by 23-25× on microcontrollers.
Conclusions
In this paper, we propose the first solution to enable tiny on-device training on microcontrollers under a tight memory budget of 256KB and 1MB Flash without auxiliary memory. Our algorithm system co-design solution significantly reduces the training memory (more than 1000× compared with PyTorch and TensorFlow) and per-iteration latency (more than 20× speedup over TensorFlow-Lite Micro), allowing us to obtain higher downstream accuracy. Our study suggests that tiny IoT devices can not only perform inference but also continuously adapt to a world that is dynamic rather than static!