knowledge distillation in deep learning — A mathematical perspective

Some of the problems like object detection, speech recognition are common applications that are used in many industries. Several advancements have been made in this aspect to get better accuracy.

Bulky models have been modeled to get better accuracy which has requires high computation and increased latency. This prohibits such a large model to get deployed on mobile. This motivates the use of tiny models such mobileNet, EfficientNet to get deployed into the devices. But this does not help much in getting better accuracy.

Knowledge distillation is transferring the knowledge of a cumbersome model, which is an ensemble of various models or a very large model, to a small model which is suitable for deployment.

One way of transferring the knowledge is by using the class probabilities referred to as “soft targets” generated by a large model. These soft targets have high entropy and provide useful information as compared to hard targets.

In the MNIST dataset, One version of 6 may be given a probability of 10−6 of being a 3 and 10−9 of being a 7 whereas for another version it may be the other way around. They are valuable but it has very little influence on the cross-entropy cost function during the transfer stage because the probabilities are so close to zero. Caruana and his collaborators [2]circumvent this problem by using the logits (the inputs to the final softmax) rather than the probabilities produced by the softmax as the targets for learning the small model and they minimize the squared difference between the logits produced by the cumbersome model and the logits produced by the small model.

Distillation is a more general approach where we can increase the probability by increasing the temperature in the softmax function. We will see mathematically that approach used by Caruana and his collaborators is a special case of the distillation method.

Class probabilities are generated by the softmax function whose equation is given above which converts the logits,z_i into the class probabilities,q_i where T is the temperature that is by default set to 1. For creating the soft targets, we raise the temperature of the softmax function. If the cumbersome model has logits v_i which produces soft target probabilities p_i and the transfer training is done at a temperature of T, this gradient is given by:

So in the high-temperature limit, distillation is equivalent to minimizing the squared difference and is a special case of distillation.

Knowledge is transferred to the distilled model by training it on a transfer set and using a soft target distribution for each case in the transfer set that is produced by using the cumbersome model with a high temperature in its softmax. The same high temperature is used when training the distilled model, but after it has been trained it uses a temperature of 1.

While training the smaller model, we have two cost functions. One is the cross-entropy with soft targets and the other is the cross-entropy of the hard target(T=1) generated by the small model and the actual ground truth. the weights of the second loss function are lowered as compared to the first objective function.

References:

  1. Distilling the Knowledge in a Neural Network —
  2. C. Buciluˇa, R. Caruana, and A. Niculescu-Mizil. Model compression. —

Senior ML Scientist at Hike and PHD Researcher at IIT Delhi. Google scholar -

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store