Tensor Product Attention Is All You Need

Tensor Product Attention Is All You Need

缩放语言模型以处理较长的输入序列通常需要大的键值(KV)缓存,从而在推理过程中产生大量的内存开销。 在本文中,我们提出了张量乘积注意(TPA ),这是一种新的注意机制,它使用张量分解来紧凑地表示查询、键和值,从而在推理时显著缩小键值缓存大小。 通过将这些表示分解为上下文低秩组件(上下文分解)并与RoPE无缝集成,TPA在提高内存效率的同时提高了模型质量。 基于TPA,我们介绍了Tensor ProducT ATTenTion Transformer(T6),这是一种新的序列建模模型架构。通过对语言建模任务的广泛实证评估,我们证明T6在各种指标上超过了包括MHA、MQA、GQA和MLA等基线的性能,包括困惑度等一系列评估基准。 值得注意的是,TPA的内存效率使得在固定资源限制下模型鞥能够处理更长的序列,解决了现代语言模型中的关键可扩展性挑战。

Scaling language models to handle longer input sequences typically necessitates large key-value (KV) caches, resulting in substantial memory overhead during inference. In this paper, we propose Tensor Product Attention (TPA ), a novel attention mechanism that uses tensor decompositions to represent queries, keys, and values compactly, significantly shrinking KV cache size at inference time. By factorizing these representations into contextual low-rank components (contextual factorization) and seamlessly integrating with RoPE, TPA achieves improved model quality alongside memory efficiency. Based on TPA, we introduce the Tensor ProducT ATTenTion Transformer (T6), a new model architecture for sequence modeling. Through extensive empirical evaluation of language modeling tasks, we demonstrate that T6 exceeds the performance of standard Transformer baselines including MHA, MQA, GQA, and MLA across various metrics, including perplexity and a range of renowned evaluation benchmarks. Notably, TPAs memory efficiency enables the processing of significantly longer sequences under fixed resource constraints, addressing a critical scalability challenge in modern language models.