【LLM】ms-Swift大模型训练框架源码分析
ms-Swift整体调用链SFT流程:swift sft → sft_main() → SwiftSft(args).main() → run() → train(trainer) → trainer.train(…)swift 可执行程序(console-script)→ swift/cli/sft.py(CLI 子命令入口,基本就把参数解析后转给 sft_main)→ swift/llm/tr
note
- 上一篇【LLM】基于ms-Swift大模型SFT和RL的训练实践 有介绍通过ms-swift进行SFT和RL训练的流程,比如使用
swift sft
、swift rlhf
等进行训练 - 本文偏向ms-swift训练框架的源码分析
一、ms-Swift整体调用链
SFT流程:swift sft → sft_main() → SwiftSft(args).main() → run() → train(trainer) → trainer.train(…)
swift
可执行程序(console-script)
→ swift/cli/sft.py
(CLI 子命令入口,基本就把参数解析后转给 sft_main)
→ swift/llm/train/sft.py
(核心业务:构造 SwiftSft/Trainer、加载模型&数据、开训)
→ swift/trainers/...
(对 HF Transformers Trainer 的封装与混入)
→ 期间会调用 swift/llm/utils/model.py
(取模型&tokenizer)、swift/llm/template/...
(prompt 模板/打包)、swift/llm/dataset/...
(数据读取与预处理)等。
在多条 GitHub issue 的回溯里能看到这些真实文件路径与函数名,例如:
- CLI 直接就是
swift/cli/sft.py
里sft_main()
:报错堆栈出现“/site-packages/swift/cli/sft.py
line 7: sft_main()”。(GitHub) - 训练主函数是
swift/llm/train/sft.py
,内部return SwiftSft(args).main()
,再进入self.train(trainer)
与trainer.train(...)
;回溯同时出现swift/trainers/mixin.py
(对 HF Trainer 的混入封装)。(GitHub)
如微调sft的代码,ms-swift/swift/llm/train/sft.py
,大家可以具体去看class SwiftSft(SwiftPipeline, TunerMixin)
类的源码,其实就是处理训练数据、走trainer训练模型等的流程,比如其中的run
成员函数:
def run(self):
args = self.args
train_dataset, val_dataset = self._prepare_dataset()
if args.task_type == 'seq_cls':
args.problem_type = args.problem_type or getattr(self.model.config, 'problem_type', None)
logger.info(f'args.problem_type: {args.problem_type}')
args.save_args()
data_collator = self._get_data_collator()
# Some tuners require train_dataset and data_collator for preparation: LoRA-GA
self.model = self.prepare_model(self.args, self.model, template=self.template, train_dataset=train_dataset)
logger.info(f'model: {self.model}')
model_parameter_info = get_model_parameter_info(self.model)
self.train_msg['model_parameter_info'] = model_parameter_info
logger.info(f'model_parameter_info: {model_parameter_info}')
trainer_cls = TrainerFactory.get_trainer_cls(args)
trainer = trainer_cls(
model=self.model,
args=self.args.training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=val_dataset,
callbacks=self.callbacks,
template=self.template,
**self._get_trainer_kwargs(),
)
return self.train(trainer)
如果想看这里的trainer
在框架中的完整选项可以参考ms-swift/swift/trainers/trainer_factory.py
,不过像Seq2SeqTrainer
这种其实用的huggingface-transformers库的,有的trainer是继承基类后做了一些函数重写(多态)/重载:
class TrainerFactory:
TRAINER_MAPPING = {
'causal_lm': 'swift.trainers.Seq2SeqTrainer',
'seq_cls': 'swift.trainers.Trainer',
'embedding': 'swift.trainers.EmbeddingTrainer',
'reranker': 'swift.trainers.RerankerTrainer',
'generative_reranker': 'swift.trainers.RerankerTrainer',
'dpo': 'swift.trainers.DPOTrainer',
'orpo': 'swift.trainers.ORPOTrainer',
'kto': 'swift.trainers.KTOTrainer',
'cpo': 'swift.trainers.CPOTrainer',
'rm': 'swift.trainers.RewardTrainer',
'ppo': 'swift.trainers.PPOTrainer',
'grpo': 'swift.trainers.GRPOTrainer',
'gkd': 'swift.trainers.GKDTrainer',
}
TRAINING_ARGS_MAPPING = {
'causal_lm': 'swift.trainers.Seq2SeqTrainingArguments',
'seq_cls': 'swift.trainers.TrainingArguments',
'embedding': 'swift.trainers.TrainingArguments',
'reranker': 'swift.trainers.TrainingArguments',
'generative_reranker': 'swift.trainers.TrainingArguments',
'dpo': 'swift.trainers.DPOConfig',
'orpo': 'swift.trainers.ORPOConfig',
'kto': 'swift.trainers.KTOConfig',
'cpo': 'swift.trainers.CPOConfig',
'rm': 'swift.trainers.RewardConfig',
'ppo': 'swift.trainers.PPOConfig',
'grpo': 'swift.trainers.GRPOConfig',
'gkd': 'swift.trainers.GKDConfig',
}
二、代码级简化的执行流程图
-
命令入口:
swift
(console_scripts)把sft
子命令路由到swift/cli/sft.py
,该文件基本只负责解析/组合参数 → 调用sft_main(...)
。这点从多条堆栈能直接看到。(GitHub) -
主函数:
swift/llm/train/sft.py:sft_main
-
模型/分布式/LoRA 等准备:
swift.llm.utils.model.get_model_tokenizer(...)
:下载/加载 base 模型、按--train_type
注入 LoRA/QLoRA/DoRA 等、设置 dtype、并行策略(FSDP/DeepSpeed/Megatron 等)、注意力内核(flash/SDPA)等。官方 CLI 参数文档也说明了这些开关。(Swift 文档)
-
数据与模板:
swift.llm.template.template.*
:按--template
或模型默认模板把多轮对话打包(packing、特殊 token、role 映射等)。swift.llm.dataset.*
:按--dataset
或本地 JSONL/Parquet 加载并过滤、tokenize、collate。
(packing 报错的多条 issue 也能侧面印证这个模块链路。)(GitHub)
-
训练执行:
swift.trainers.*
:在 HFTrainer
基础上加了混入(如swift/trainers/mixin.py
),封装了 resume/eval/logging/callback 等。回溯能看到trainers/mixin.py
参与。(GitHub)
-
保存/合并/导出:
- 训练结束保存 adapter/checkpoint;按需
--merge_lora true
合并权重;导出给 vLLM/SGLang/LmDeploy 推理。
- 训练结束保存 adapter/checkpoint;按需
三、训练命令的背后逻辑
- 1)console-scripts 映射:在打包配置里会把
swift
指到某个swift.cli.main:main
或直接到各子命令模块(不同版本可能略有差异)。虽然我们此刻网页没展开到setup.py/cfg
的entry_points
行,但从回溯能确定swift/cli/sft.py
就是子命令入口文件。(GitHub, 参考setup.py
文件) - 2)函数名定位:
sft_main()
/SwiftSft(args).main()
/Trainer.train()
这些名字在回溯里都有,按上面inspect
的单行脚本就能把源码片段打印出来。(GitHub) - 3)参数文档对照:命令行参数如何被“集成参数”吸纳,哪些是基础/原子/集成,官方“Command Line Parameters”页是你读代码时的最佳对照表。(Swift 文档)
更多推荐
所有评论(0)