张量索引

1维索引 tf.gather

In [31]: tf.gather(tf.constant([0,1,2,3,4]), tf.constant(2)).eval()
Out[31]: 2


In [37]: x = tf.constant([[2,3,],[0,1],[3,6]])
In [38]: tf.gather(x, tf.constant(1)).eval()
Out[38]: array([0, 1], dtype=int32)

In [39]: tf.gather(x, tf.constant([1, 2])).eval()
Out[39]:
array([[0, 1],
       [3, 6]], dtype=int32)

多维索引 tf.gather_nd

x = tf.constant([[2,3,],[0,1],[3,6]])
tf.gather(x, tf.constant([1, 0])).eval()  # x[1,0] = 0
x = tf.constant([[2,3,],[0,1],[3,6]])
tf.gather(x, tf.constant([[1], [2]])).eval() # [x[1], x[2]] = [[0, 1], [3, 6]]