Viewing a single comment thread. View all comments

StellaAthena t1_is7iss2 wrote

The proof is even more simple: (xW_q)(xW_k)^T = x(W_qW_k^T )x^T = xWx

The problem is that W_q and W_k are not square matrices. They are d_model by d_head, and so their product is d_model x d_model. In practice d_model >> d_head (e.g., they’re 4096 and 256 respectively in GPT-J). Doing it your way uses a lot more memory and compute

22