目录
- 前言
- 0 基本概念与数据集设计
-
- 0.1 InstructGPT提出的训练三段式
- 0.2 DeepSpeed-Chat的数据集设计
-
- 0.2.1 数据格式基本概念
- 0.2.2 DeepSpeed-Chat的数据读取流
- 0.2.3 关键代码详解
-
- 0.2.3.1 自定义PromptRawDataset类
- 0.2.3.2 阶段数据集处理过程
- 0.3 版块相关问题
- 后续
前言
早些时候微软发布了遵从InstructGPT训练逻辑的训练框架DeepSpeed-Chat,旨在通过良好的DeepSpeed生态降低类ChatGPT模型昂贵的训练成本,为了能更直接地理解有关技术原理,我对其中实现训练相关的代码进行了详细剖析,考虑到目前还没有太多相关文章对此进行过深入介绍,因此我将在本博客中探讨这个框架的实现细节,以帮助有需要的人能更好地理解和使用它。另外,我也非常欢迎大家在评论区分享出自己对这个框架的看法以及使用经验,或是提出对本文的建议。
框架源码地址:https://github/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat
如果你只是想要知道如何使用自定义的数据集快速上手DeepSpeed-Chat训练,该系列内容对你来说可能过于繁杂,那么完全可以期待一下我后续将要更新的快速上手引导(已经在新建文件夹了哈哈)。
如果你只对具体的训练细节感兴趣,该篇主要讲述的数据集内容可能不是你所想要了解的,请直接跳至【中篇】进行阅读。
本文将根据DeepSpeed-Chat的数据集设计以及三个训练阶段(可分别简称phase1、phase2、phase3)共四个部分,将主要内容大体划分为四个版块、三个篇章,而每个版块都会以动态的时序图视角展示一套完整的工作流,然后深入详解其中与InstructGPT所述理论相关的代码实现,最后还将针对相关阶段涉及的一些具体问题加以阐述,以此完成对一个阶段源码的解析。在阅读过程中如果发现某些环节使你产生困惑,不妨跳转至【版块相关问题】,或许可以从中获得启发,如果你无法通过该部分找到所需答案,请随时留下你的评论,以便进行共同交流。
此外,本文的重点在于源码解析,其中所涉及的ChatGPT背景知识、原理等将不再做过多推导式叙述。倘若你有原理基础,那么这篇文章肯定能够让你对各种相关原理如具体的RM结构、RM的训练方式、PPO迭代的具体顺序等实现细节拥有更加深刻的理解,并获得一定的实践启发;但假如你刚开始了解这项技术,也不必担心,我会使用尽可能简练的描述、尽可能直球的例子来对相应的部分进行说明。
本篇为上中下三篇章中的【上篇】,主要针对DeepSpeed的数据集管理进行介绍。DeepSpeed提供了良好的数据流管道对数据集进行了规范化处理和标准化操作,用户在了解其中的细节后可以更加高效地实现模型训练。
0 基本概念与数据集设计
现有的训练框架多数都是基于InstructGPT论文中所介绍的pipeline来实现,但一些更具体的细节,比如数据集的处理、奖励取值设计等,论文中没有进一步阐述,故而不同框架在某些细节的实现上会存在些许差异,因此在开始尝试使用DeepSpeed-Chat前我认为还是有必要了解一些框架内部既定的“范式”,这对后续理解某些具体细节将会有所帮助。
0.1 InstructGPT提出的训练三段式
InstructGPT提出了大型问答模型的训练范式,分别为有监督微调训练、基于人类偏好的奖励模型训练、基于人类偏好的强化学习训练,最终模型将因此具备“根据用户输入,生成符合人类偏好答复”的能力。
阶段 | 相关模型 | 赋能 |
---|---|---|
0 | 具备基本生成能力的基座模型(通常为CausalLM) | - |
1 | 有监督微调模型(SFT) | 使用“prompt-response”数据(通俗含义上的“问答对”)对基座进行训练,基座将获得“根据指令生成出对应响应”的能力。 |
2 | 奖励模型(RM) | 使用具有偏好评价的“prompt-response”数据(问答对)以及排序目标对预选模型进行训练,将得到具备“为指令数据做出人类偏好分值评估”能力的奖励模型。 |
3 | SFT、Actor、RM、Critic | 使用“prompt”数据(通俗含义上的“问句”),以第1、第2阶段训练得到的模型作为基线进行强化学习训练,最终得到具备“根据用户输入,生成符合人类偏好答复”能力的Actor模型。 |
0.2 DeepSpeed-Chat的数据集设计
上述各阶段训练的一个关键就是数据集的设计,每个阶段需要使用的数据格式都不尽相同,DeepSpeed-Chat根据其中存在的共性,将数据集设计成了相对统一的形式,由Dataset类进行统一管理,然后再根据不同的训练阶段细化处理出相应的数据格式。
0.2.1 数据格式基本概念
数据格式名称 | 说明 | 样例 |
---|---|---|
prompt | 对当前情境的描述,为模型生成提供指令输入信息,可以理解为通俗含义上的“问句”,适用于phase3。 | "Human: Please tell me about Microsoft in a few sentence? Assistant: " |
response/answer | 对上下文prompt的响应、回答、应答,可以理解为通俗含义上的“答句”。 | “Microsoft is a software company that develops, licenses, and supports software products,including Windows, Office, and Windows Phone. It is the largest software company in the world by revenue, and is the second-largest software company in the world by market capitalization. Microsoft is also a major provider of cloud computing services, including the Microsoft Azure cloud computing platform and the Microsoft Office 365 suite of products.” |
chosen | 应答中的一种,人类所偏好的应答。 | “Microsoft is a software company that develops, licenses, and supports software products,including Windows, Office, and Windows Phone. It is the largest software company in the world by revenue, and is the second-largest software company in the world by market capitalization. Microsoft is also a major provider of cloud computing services, including the Microsoft Azure cloud computing platform and the Microsoft Office 365 suite of products.” |
rejected | 应答中的一种,人类所排斥的应答。 | “I’m not sure what you mean.” |
conversation | 完整对话,由prompt衔接应答response得到。 | “Human: Please tell me about Microsoft in a few sentence? Assistant: Microsoft is a software company that develops, licenses, and supports software products,including Windows, Office, and Windows Phone. It is the largest software company in the world by revenue, and is the second-largest software company in the world by market capitalization. Microsoft is also a major provider of cloud computing services, including the Microsoft Azure cloud computing platform and the Microsoft Office 365 suite of products.” |
chosen_sentence | 人类偏好的完整对话,由prompt衔接偏好应答chosen得到,适用于phase1和phase2。 | “Human: Please tell me about Microsoft in a few sentence? Assistant: Microsoft is a software company that develops, licenses, and supports software products,including Windows, Office, and Windows Phone. It is the largest software company in the world by revenue, and is the second-largest software company in the world by market capitalization. Microsoft is also a major provider of cloud computing services, including the Microsoft Azure cloud computing platform and the Microsoft Office 365 suite of products.” |
reject_sentence | 人类排斥的完整对话,由prompt衔接排斥应答rejected得到,适用于phase2。 | “Human: Please tell me about Microsoft in a few sentence? Assistant: I’m not sure what you mean.” |
unsup | 无监督语料,符合自然语言要素的文本,适用于自回归语言模型的无监督训练。 | "Wikipedia is a free online encyclopedia that is maintained and edited collaboratively by volunteers from around the world. It contains articles on a wide range of topics, from history and science to popular culture and current events. Anyone can create, edit, or contribute to Wikipedia articles, making it an open and decentralized platform for knowledge sharing and dissemination. One of the key features of Wikipedia is its commitment to neutrality, with contributors striving to present information in an objective and unbiased manner. " |
DeepSpeed-Chat设计的数据格式是直接服务于阶段训练的:
- phase1:采用chosen_sentence作为训练数据,进行自回归语言建模训练。chosen_sentence在通俗含义上代表“有效的问答数据”,有助于模型学习到理解指令并做出正确响应的能力。而reject_sentence作为“相对无效的问答数据”,其响应部分往往是“反人类”的,并不利于模型进行学习,因此在这个阶段采用了chosen_sentence作为训练数据。
- phase2:采用chosen_sentence和reject_sentence作为训练数据,进行成对排序训练(pairwise ranking loss),chosen_sentence和reject_sentence将分别作为成对数据中的较好者和较差者被送入模型中,模型将学习到其中的排序思路,从而给出更为合理的奖励评分。这部分其实与InstructGPT中所述有些差别,InstructGPT是针对同个prompt构造了更多的conversations(如4至7个),通过排列组合的方式,这些conversations将两两组成更多的成对数据被送入模型中进行训练。总的来说,DeepSpeed-Chat与InstructGPT的训练思想是一致的。
- phase3:采用prompt作为基本数据,调用中间模型(Actor、SFT、Critic、RM)根据基本数据构造出经验数据,使用强化学习中的PPO算法进行训练。更具体的内容可见相关代码解析。
- 无监督训练:采用无监督语料数据,进行无监督的自回归语言建模训练。InstructGPT提出,进行phase3的RLHF训练时,为使得模型在学习人类偏好的过程中仍能保有预训练模型解决任务的性能,引入了传统的自回归语言建模进行联合训练。
并且假如你确定要进行完整的三阶段训练,DeepSpeed-Chat鼓励使用可以同时适应三阶段的数据集,即同时具备prompt、chosen、rejected的数据集,他们认为在不同阶段中,使用不同的数据集则面临着数据分布差异问题,尤其是第二、第三阶段使用与第一阶段不同的数据集,这将可能导致第二、第三阶段训练出的模型质量变差。
One thing to note is that: If you plan to only do step 1 SFT, adding more single-response datasets is definitely beneficial. However, if you do plan to do steps 2 and 3, then adding too many single-response datasets during SFT could backfire: these data could be different from the data used for steps 2/3, generating different distributions which could cause training instability/worse model quality during step 2/3. That is part of the reason why we focused on trying the datasets with two responses and the preference, and always split a dataset into all 3 steps.
基于上述情况(三阶段使用同一数据集)考虑,DeepSpeed-Chat在构建数据集时提供了一个叫做“data_split”的传参,当你使用一个适用三阶段的数据集时,通过对该传参进行列表赋值,可以对三阶段数据比例进行设置,如“[6,2,2]”,这表示对于当前数据集,将分配全量数据的 6 6 + 2 + 2 \frac{6}{6+2+2} 6+2+26提供给第一阶段用于训练、验证,同理,分配全量的 2 6 + 2 + 2 \frac{2}{6+2+2} 6+2+22提供给第二阶段用于训练、验证,分配全量的 2 6 + 2 + 2 \frac{2}{6+2+2} 6+2+22提供给第三阶段用于训练、验证。
0.2.2 DeepSpeed-Chat的数据读取流
在此简单讲述UML时序图的元素含义:
- 箭头表示信息传递:实线表示调用,虚线表示返回;
- alt表示假设分支,其后方“[]”中的内容表示“条件”;
- loop表示循环;
- 淡蓝色区域即为高亮部分。
发布评论