What is argmax in machine learning?

Argmax is not an in-depth machine learning concept. However, it is a mathematical concept frequently used in applied machine learning problems. It is a function that returns the index of the maximum element of a collection in a particular axis.

For example,

Let’s say that we have an array –

Temporary = [10, 30, 5, 6, 1, 20, 15]

We know that the index of an array starts from 0.

Here,

  • 10 is positioned at index 0.
  • 30 is positioned at index 1.
  • 5 is positioned at index 2.
  • 6 is positioned at index 3.
  • 1 is positioned at index 4.
  • 20 is positioned at index 5.
  • 15 is positioned at index 6.

If we apply this array into an argmax function then the output will be 1 because the maximum element of the array i.e. 30 is placed on index 1.

Here is one more example-

This function is usually used in classification problems in machine learning. Many classification algorithms return a vector of probabilities with respect to the class labels.

For example,

Let us consider that we have a multi-classification problem to predict whether an animal is a cat, dog, or rabbit. For any input variable, the model will return a vector of probabilities as output.

Assume that according to our data encoding –

  • Class 0 is Cat
  • Class 1 is Dog
  • Class 2 is Rabbit

From the vector, we can intuitively understand that-

  • The probability of the animal becoming a cat is 30% (Class 0)
  • The probability of the animal becoming a dog is 10% (Class 1)
  • The probability of the animal becoming a rabbit is 60% (Class 2)

If we apply the above output vector into an argmax function then it will return 2 as the output i.e. the index with the maximum probability. So, the prediction of the model becomes Rabbit.

Here, the model is not concerned about the values at the index and differences among them. The only thing which is needed for the prediction process is the class label (index) with the highest probability. In such scenarios, argmax function is very useful.