当人工智能与人类进行多轮连续对话时,驱动ChatGPT等聊天机器人的强大大型语言机器学习模型有时会崩溃,导致机器人的性能迅速下降。
麻省理工学院和其他地方的研究人员发现了这个问题的一个令人惊讶的原因,并开发了一个简单的解决方案,使聊天机器人能够保持不间断的对话而不崩溃或减慢速度。
他们的方法涉及对许多大型语言模型核心的键值缓存(类似于对话记忆)进行微调。在某些方法中,当此缓存需要保存的信息超过其容量时,首先的数据将被替换。这可能导致模型失败。
通过确保这些最初的几个数据点保留在内存中,研究人员的方法允许聊天机器人在对话时间多长都可以继续聊天。
这种名为StreamingLLM的方法使模型即使在对话超过400万个单词时仍能保持高效。与另一种通过不断重新计算过去对话的部分来避免崩溃的方法相比,StreamingLLM的性能提高了22倍以上。
这可以使聊天机器人在工作日内进行长时间对话而无需不断重新启动,为文案撰写、编辑或生成代码等任务提供高效的人工智能助手。
“现在,有了这种方法,我们可以持续部署这些大型语言模型。通过制作一个我们可以随时聊天并且可以根据我们最近的对话回应我们的聊天机器人,我们可以在一些新的应用中使用这些聊天机器人,”电气工程和计算机科学(EECS)研究生、StreamingLLM论文的第一作者肖光轩说。
肖光轩的合著者包括他的导师、EECS副教授、MIT-IBM Watson AI实验室成员和NVIDIA杰出科学家韩松,以及Meta AI的研究科学家田源东、卡内基梅隆大学助理教授陈北迪和Meta AI的研究科学家迈克·刘易斯。这项工作将在国际学习表示会议上进行展示。
一个令人困惑的现象
大型语言模型将数据(如用户查询中的单词)编码为称为标记的表示。许多模型使用所谓的注意机制,使用这些标记生成新的文本。
通常,AI聊天机器人根据刚刚看到的文本编写新的文本,因此它将最近的标记存储在内存中,称为KV缓存,以供以后使用。注意机制构建了一个包含缓存中所有标记的网格,即“注意力图”,它描述了每个标记(或单词)与其他标记之间的关系强度。
理解这些关系是使大型语言模型生成类似人类的文本的一个特征。
但是当缓存变得非常大时,注意力图可能变得更加庞大,从而减慢计算速度。
此外,如果编码内容所需的标记数量超过缓存的容量,模型的性能会下降。例如,一个流行的模型可以存储4096个标记,但学术论文中大约有10000个标记。
为了解决这些问题,研究人员使用“滑动缓存”,将最旧的标记替换为新的标记。然而,一旦第一个标记被驱逐,模型的性能往往会迅速下降,新生成的单词质量也会迅速降低。
在这篇新论文中,研究人员意识到,如果他们保留滑动缓存中的第一个标记,即使超过缓存大小,模型的性能也会保持不变。
但这没有任何意义。一本小说中的第一个词可能与最后一个词无关,那么为什么第一个词对于模型生成最新词语如此重要呢?
在他们的新论文中,研究人员还发现了这种现象的原因。
注意力汇聚
一些模型在其注意机制中使用Softmax操作,该操作为每个标记分配一个分数,表示它与其他标记的关联程度。Softmax操作要求所有注意分数总和为1。由于大多数标记之间关联不强,它们的注意分数非常低。模型将剩余的注意分数倾倒在第一个标记中。
研究人员将这个第一个标记称为“注意力汇聚点”。
“我们需要一个注意力汇聚点,模型决定使用第一个标记作为注意力汇聚点,因为它是全局可见的,每个其他标记都可以看到它。我们发现我们必须始终将注意力汇聚点保留在缓存中以保持模型的动态性,”韩松说。
在构建StreamingLLM时,研究人员发现在滑动缓存的开头有四个注意力汇聚点可以实现最佳性能。
他们还发现,每个标记的位置编码必须保持不变,即使添加了新的标记并且其他标记被替换。如果标记5被替换,标记6必须保持编码为6,即使它现在是缓存中的第五个标记。
通过结合这两个想法,他们使StreamingLLM能够在保持连续对话的同时超越使用重新计算的流行方法。
例如,当缓存有256个标记时,重新计算方法需要63毫秒来解码一个新的标记,而StreamingLLM只需要31毫秒。然而,如果缓存大小增长到4096个标记,重新计算需要1411毫秒来解码一个新的标记,而StreamingLLM只需要65毫秒。
“StreamingLLM的创新方法围绕着注意力汇聚机制,确保了稳定的内存使用和性能,即使处理长度为400万个标记的文本,”新加坡国立大学计算机科学的青年教授杨友说,“这种能力不仅令人印象深刻,而且具有变革性,使StreamingLLM能够应用于广泛的人工智能应用。StreamingLLM的性能和多功能性使其成为一项极具前景的技术,有望彻底改变我们对基于人工智能生成应用的方法。”
卡内基梅隆大学机器学习和计算机科学系助理教授陈天琪也对这项研究表示赞同,他说:“Streaming LLM使大型语言模型能够顺利扩展对话长度。我们一直在使用它在iPhone上成功部署Mistral模型。”
研究人员还探索了在模型训练过程中使用注意力汇聚点的方法,通过在所有训练样本中添加几个占位符标记。
他们发现,使用注意力汇聚点进行训练可以使模型在其缓存中只有一个注意力汇聚点的情况下保持性能,而通常需要四个注意力汇聚点来稳定预训练模型的性能。
但是,虽然StreamingLLM使模型能够进行连续对话,但模型无法记住不在缓存中的单词。未来,研究人员计划通过研究检索已被驱逐的标记的方法或使模型记住先前对话的方法来解决这个限制。
StreamingLLM已经被纳入NVIDIA的大型语言模型优化库TensorRT-LLM。
这项工作部分由MIT-IBM Watson AI实验室、MIT Science Hub和美国国家科学基金会资助。