K-Means

Given a set of $N$ data points $\left\{x_{1}, x_{2}, \ldots, x_{N}\right\}$ partition the data points into $K$ clusters. We define $\boldsymbol{\mu}_k$ as a prototype (here also the mean) of the cluster $k$, and minimize the sum of squares of the distances of each data point to its closest vector $\boldsymbol{\mu}_k$ :

$$ J=\sum_{n=1}^N \sum_{k=1}^K z_{n k}\left\|\boldsymbol{x}_n-\boldsymbol{\mu}_k\right\|^2 $$

where $\boldsymbol{z}_n$ is a one-hot vector with $z_{n k}=1$ if $k$ is closest cluster of $\boldsymbol{x}_n$

Optimization algorithm (expectation-maximization (EM) algorithm):
First, means $\boldsymbol{\mu}_k \in \mathbb{R}^D$ are initialized randomly
Then repeat until convergence ( $\mu_k$ and $z_{n k}$ do not change for any $n$ and $k$ ):
1. Expectation step: Find the assignment of the closest cluster for every data point:

$$ \frac{\partial J}{\partial z_{n k}}=0 \Rightarrow z_{n k}= \begin{cases}1 & \text { if } k=\arg \min _j\left\|\boldsymbol{x}_n-\boldsymbol{\mu}_j\right\|^2 \\ 0 & \text { otherwise }\end{cases} $$

2. Maximization step: Find the means of each cluster:
$$ \frac{\partial J}{\partial \boldsymbol{\mu}_k}=0 \Rightarrow \boldsymbol{\mu}_k=\frac{\sum_n z_{n k} \boldsymbol{x}_n}{\sum_n z_{n k}} $$

Every update step in K-means algorithm decreases the loss function or leaves it unchanged. The algorithm converges as each phase reduces the value of the objective function $J$, but they might converge to a local rather than global minimum (perform multiple random restarts and choose best minimum found)

Pros:

  • Finds cluster centers that minimize variance (good representation of data).
  • Simple to implement, widespread applications.

Cons:

  • All clusters have spherical distribution (same to all directions or isotropic). Generalization to solve this issue is Gaussian Mixture Model.
  • Hard membership/assignment (i.e. 1 or 0 membership). Soft variants are available.
  • Prone to local minima.
  • Need to choose $\mathrm{K}$. Heuristics can be used: Clustering > Determining optimal number of clusters
  • Can be very slow: each iteration is $\mathrm{O}(\mathrm{KN})$ for N-dimensional points