关键字:
Iterative RPO 推理能力 大型语言模型 链式推理(CoT) 直接偏好优化(DPO) 负对数似然(NLL) 迭代优化 性能提升 Llama-2模型 AI研究 Meta AI团队 纽约大学
2024年5月1日,纽约 —— Meta的人工智能研究团队(FAIR)与纽约大学合作,宣布了一项突破性的研究进展:Iterative Reasoning Preference Optimization(Iterative RPO),这是一种新开发的算法,旨在显著提高大型语言模型在复杂推理任务上的性能。
在最新的论文《Iterative Reasoning Preference Optimization》中,研究团队详细介绍了Iterative RPO算法,并展示了其在多个标准数据集上的卓越性能。该算法通过迭代优化过程中的偏好选择,特别是针对推理步骤的优化,使得模型能够更准确地生成导致正确答案的推理链(Chain-of-Thought, CoT)。

核心计算:

这个算式结合了两种损失函数,用于训练语言模型以优化推理过程。
1.负对数似然损失 (Negative Log-Likelihood Loss, NLL):

这个部分计算了模型生成正确答案对(赢家对)的负对数似然。这里,Mθ 表示模型θ 下的序列概率,xi 是输入问题,cwi 是赢家(正确答案对应的)推理链,ywi 是赢家答案。序列的对数似然用于评估模型生成特定输出序列的概率,负号表示我们希望最大化这个概率。分母是对序列长度的归一化,确保损失不会因为序列长度不同而产生偏差。
2.直接偏好优化损失 (Direct Preference Optimization Loss, DPO):

这个部分基于赢家对和输家对(赢家答案比输家答案更受偏好)的比较来计算损失。Mt 是前一次迭代中的模型,作为参考模型。σ 是sigmoid函数,用于将输出压缩到0到1之间,表示概率。α 是一个超参数,用于平衡NLL损失和DPO损失之间的权重。β 是用于稳定对数的系数。
整个损失函数L(DPO)+NLL的目的是训练模型,使其更倾向于生成高质量的推理链和正确答案。通过结合NLL损失和DPO损失,模型能够在迭代过程中不断改进其推理能力,最终达到更好的性能。
实验结果:在GSM8K、MATH和ARC-Challenge等数据集上,Iterative RPO在经过多次迭代训练后,准确率分别从55.6%、12.5%和77.8%提升至81.6%、20.8%和86.7%,这一成就在无需额外数据集支持的情况下,超越了其他基于Llama-2模型的方法。
该研究的核心在于使用修改后的直接偏好优化(DPO)损失函数,加上一个关键的负对数似然(NLL)项,这一创新的训练方法使得模型能够更有效地从每次迭代中学习并提升性能。


主要技术元素:
Iterative RPO算法:一种用于提升语言模型推理任务性能的迭代方法。
Chain-of-Thought (CoT):推理过程中生成的逻辑步骤,用于引导至正确答案。
修改后的DPO损失函数:结合了额外的NLL项,对模型进行优化。
负对数似然(NLL):损失函数的一部分,对性能提升至关重要。
迭代训练:通过重复迭代生成新的偏好对并训练模型,直至性能饱和。
性能提升:在GSM8K、MATH和ARC-Challenge数据集上的显著准确率提升。