Beam Search

Using the model's actual predictions of earlier time steps to predict the next output requires us to explore a search space

  • This is required when using any sequence-to-sequence model (image captioning, NMT, summarization, ...)
  • Greedy decoding suffers from sub-optimal decisions leading to even more sub-optimal decisions later on
  • Beam Decoding: explore multiple alternatives in parallel

Example:
Beam Decoding process

  • When do we stop? Limit length to $\lfloor r \cdot m\rfloor$, where $m$ is the length of the source sentence and $r$ is some length ratio, e.g., $r=1.2$
  • All translation candidates $C(Z)$ are compared and the most likely one is returned
    • $\hat{Y}_{\leq t}$ is $\in C(Z)$ if $\hat{y}_{t}=</ \mathrm{s}>$ or $t=\lfloor r \cdot m\rfloor$
    • candidate score $s\left(\hat{Y}_{\leq t} \mid Z\right)=\frac{1}{t} \log p\left(\hat{Y}_{\leq t} \mid Z\right)$ normalize to compare candidates of different lengths
    • return $\hat{Y}=\underset{\hat{Y}_{\leq t} \in C(Z)}{\operatorname{argmax}} s\left(\hat{Y}_{\leq t} \mid Z\right)$

Fast training is important, but fast beam decoding is critical for actual applications

  • Computing $\log p\left(\hat{Y}_{\leq t+1} \mid Z\right)$ directly for each $t$ is inefficient
    • but does not require access to relevant network activations
  • Better: $\log p\left(\hat{Y}_{\leq t+1} \mid Z\right)=\log p\left(\hat{y}_{t+1} \mid \hat{Y}_{\leq t}, Z\right)+\log p\left(\hat{Y}_{\leq t} \mid Z\right)$
    • requires access to the relevant layer activations from $n e t_{t}$ to compute $n e t_{t+1}$
    • e.g., to compute $\mathrm{LSTM}_{t+1}$ the activations $\left(\mathbf{h}_{t}, \mathbf{c}_{t}\right)$ have to be exposed (not all packages directly expose LSTM cells)
  • Beam decoding is always slower than training due to frequent interactions with the CPU
  • Beam decoding performs better than greedy decoding
    • beam sizes beyond 25 often lead to decreases in quality
    • commonly used beam sizes vary between 5 and 12
  • Beams can be represented as entries (rows) in a batch

References