1>>> a = torch.randn(4, 4)
2>>> a
3tensor([[-1.2360, -0.2942, -0.1222, 0.8475],
4 [ 1.1949, -1.1127, -2.2379, -0.6702],
5 [ 1.5717, -0.9207, 0.1297, -1.8768],
6 [-0.6172, 1.0036, -0.6060, -0.2432]])
7>>> torch.max(a, 1)
8torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))
9