In order to calculate attentions with a query, this function takes the dot product of query with the keys and gets scores/weights for the values. Each score/weight the relevance between the query and each key. And you reweight the values with the scores/weights, and take the summation of the reweighted values.
This implementation is based on the code made available by Vivien Garnot https://github.com/VSainteuf/lightweight-temporal-attention-pytorch
.torch_multi_head_attention(n_heads, d_k, d_in)
An output encoder tensor.
Number of attention heads.
Dimension of key tensor.
Dimension of input values.
Charlotte Pelletier, charlotte.pelletier@univ-ubs.fr
Gilberto Camara, gilberto.camara@inpe.br
Rolf Simoes, rolf.simoes@inpe.br
Felipe Souza, lipecaso@gmail.com
Vivien Sainte Fare Garnot and Loic Landrieu, "Lightweight Temporal Self-Attention for Classifying Satellite Image Time Series", https://arxiv.org/abs/2007.00586