目录

  • 前言
  • 3 phase-3: RLHF Finetuning
    • 3.1 训练数据样例
      • 3.1 基本数据
      • 3.2 经验数据
    • 3.2 训练过程
    • 3.3 关键代码详解
      • 3.3.1 读取数据集及Datacollator
        • 3.3.1.1 prompt数据集的读取
        • 3.3.1.2 DataCollatorRLHF
        • 3.3.1.3 无监督数据集的读取
      • 3.3.2 初始化各相关模型
        • 3.3.2.1 模型初始化过程
        • 3.3.2.2 DeepSpeedHybridEngine
      • 3.3.3 根据prompt获取经验数据
        • 3.3.3.1 经验数据获取过程
        • 3.3.3.2 seq的生成
        • 3.3.3.3 奖励reward_score和价值估计values的获取
        • 3.3.3.4 策略模型logits的进一步处理
      • 3.3.4 PPO训练数据管理-MiniDataset
      • 3.3.5 PPO训练过程
        • 3.3.5.1 基本流程
        • 3.3.5.2 PPO训练
        • 3.3.5.3 无监督训练
        • 3.3.5.4 EMA
    • 3.4 实例测试
    • 3.5 相关拓展
      • 3.5.1 phase3的参数设置
      • 3.5.2 PPO-ptx训练的迭代数对齐
      • 3.5.3 LMFlow的RAFT
    • 3.6 版块相关问题

前言

  本篇为上中下三篇章的【下篇】,接续自【中篇】。主要针对整个DeepSpeed-Chat框架中最为复杂的第三阶段进行详解,其中涉及到部分InstructGPT所述相关原理的实践,基于其代码将更易于理解诸多原理中一笔带过的话题,如“用于经验采集的策略网络到底是SFT还是Actor”“Critic的迭代又是如何实现的”等等。
  尽管这是讨论DeepSpeed-Chat技术为主的文章,但还是不得不借用ColossalChat(另一个实现了RLHF Pipeline的开源项目,项目地址)绘制的流程示意图(下图),因为它所描绘的第三阶段训练流程,非常详细且与DeepSpeed-Chat的实现过程基本一致,而DeepSpeed-Chat本身给出的示意图实在太过简略(见【中篇】头图)。

  相信结合这张示意图来学习phase3效果更佳。

3 phase-3: RLHF Finetuning

3.1 训练数据样例

3.1 基本数据

数据格式名称 说明 样例
prompt 对当前情境的描述,为模型生成提供指令输入信息,可以理解为通俗含义上的“问句”,适用于phase3 "Human: Please tell me about Microsoft in a few sentence? Assistant: "

3.2 经验数据

数据格式名称 说明 样例
prompt 对当前情境的描述,为模型生成提供指令输入信息,可以理解为通俗含义上的“问句”,适用于phase3。 "Human: Please tell me about Microsoft in a few sentence? Assistant: "(举文本例子是为了便于理解,实际上此处为input_ids)
seq actor基于prompt输入生成的完整对话序列。 "Human: Please tell me about Microsoft in a few sentence? Assistant: Microsoft is a world-renowned company."举文本例子是为了便于理解,实际上此处为input_ids)
logprobs actor基于seq输出的logits/策略对数。 shape: 本应为(seq_bs, max_seq_len, vocab_size),经过gather处理后仅取实际label token的log_logit值,为(seq_bs, max_seq_len, 1)。
ref_logprobs reference/SFT基于seq输出的logits/策略对数。 shape: 本应为(seq_bs, max_seq_len, vocab_size),经过gather处理后仅取实际label token的log_logit值,为(seq_bs, max_seq_len, 1)。
value critic基于seq输出的对序列每个位置的价值评估。 shape: (seq_bs, max_seq_len)
reward reward/RM基于seq输出的对整个对话的(环境)奖励。 shape: (seq_bs,)
attention_mask 用于滤掉非有效元素。 shape: (seq_bs, max_seq_len)

  各个框架对于经验数据的定义不完全相同,例如ColossalChat定义的经验数据还比此处多了项“adv”和“reward”(此reward非彼reward,ColossalChat的reward指的是“经过KL散度修正后的KL_Reward”),但本质上都是同理的,只是框定的范围不同,因为adv(优势函数Adventage)和KL_Reward完全可以由已有项logprobs、ref_logprobs、reward、value计算得到。

  从代码效率的角度来考量,ColossalChat的经验数据定义相对更严谨些,因为优势以及KL惩罚奖励完全可以由基本经验数据计算得到,在生成经验的阶段一步到位计算即可;而DeepSpeed-Chat中将其安排在训练阶段来计算,每次PPO迭代才计算,优势和KL惩罚奖励是基于基本经验数据计算得到的,而基本经验数据在生成经验阶段已经确定了,所以即使是在不同的PPO迭代中,优势和KL惩罚奖励也是不变的,因此DeepSpeed-Chat对adv以及KL惩罚奖励进行了重复计算,这个环节的计算顺序后续(编辑日期2023.05.19)相关团队应该会做出调整。

3.2 训练过程

在此简单讲述UML时序图的元素含义:
- 箭头表示信息传递:实线表示调用,虚线表示返回;
- alt表示假设分支,其后方“[]”中的内容表示“条件”;
- loop表示循环;
- 淡蓝色区域即为高亮部分。
main3.py utils.py data_utils.py rlhf_engine.py ppo_trainer.py load_hf_tokenizer() 1 tokenizer 2 create_dataset() 3 create_prompt_dataset() 4 prompt_train_dataset 5 get_unsupervised_data() 6 unsupervised_train_dataset 7 alt [unsupervised_training_enabled] DataCollatorRLHF() 8 data_collator 9 train_dataloader 10 DeepSpeedRLHFEngine() 11 rlhf_engine 12 ppo_trainer() 13 trainer 14 MiniDataset() 15 exp_mini_dataset, unsup_mini_dataset 16 unsup_mini_dataset.add() 17 unsup_dataset 18 trainer.generate_experience() 19 out 20 exp_mini_dataset.add() 21 exp_dataset 22 trainer.train_rlhf() 23 actor_loss, critic_loss