Submitted by wangyi_fudan t3_y2w87i in MachineLearning
The proof is simple:
attention=softmax(QKt)V
=softmax(XWq (XWk)t)XWv
=softmax(XWqWktXt)XWv
let Wk'=WkWq'
attention=softmax(X(XWk')t)XWv
=softmax(XK')V
now we see that Q=XWq is replaced by X, reduced 1/4 paramters in attention module.
I did real experiment and found that with 3/4 parameters of original attention, the difference of loss is 0.01 during the training process and does not increase. Though Wq is not necessary, but with 1/4 more parameters it seems just slightly better.
But in multihead attention, Wq is necessary. However, research has shown that stacking many small single heads attention modules to form a very deep model is better than wider multi-head attention (single head is enough).
maizeq t1_is57j81 wrote
Transformers aren’t my field of expertise so I don’t know if this has been done before but hah, neat derivation!
Though I would expect their to be no difference in loss in that case. Was the difference positive or negative? And do you think the difference can be chalked up to numerical precision errors that accumulate due to the double vs single matrix multiplication? An easy test of this would be to compare K’ and Wq (XWk)t and see how close they are throughout training for a particular sample.