attention 实现细节

Attention基本原理

$$
w_i = softmax(k^T q)_i \\
out = \sum_i w_i k_i
$$

tensorflow 实现细节

import tensorflow as tf

def attention(keys, keys_mask, query):
    """
    keys [None, seq_len, emb_len]
    keys_mask [None, seq_len]
    query [None, emb_len]
    """ 
    keys_mask_3d = tf.expand_dims(keys_mask, -1)
    query_3d = tf.expand_dims(query, -1)
    score = tf.reduce_sum(keys * query_3d * keys_mask_3d , axis=-1)
    score -= 1.0e9 * tf.cast(keys_mask < 0.5, tf.float32)
    weight = tf.nn.softmax(score)
    weight = tf.expand_dims(weight, -1)
    return tf.reduce_sum(keys * weight, axis=1)