ChatGLM模型(服务器部署+微调)
chatglm在Linux上的模型部署与微调经验
·
一、ChatGLM-6B在Linux上部署运行
ChatGLM官方源码github地址
模型huggface地址
1.部署
1.1 部署前要先安装anaconda,我是在ubuntu上安装的,具体可参考这篇博客。
1.2 下面就开始ChatGLM-6B的正式部署了。
# 下载模型代码
git clone https://github.com/THUDM/ChatGLM-6B
# 切换到项目文件夹下
cd ChatGLM-6B
# 新建chatglm环境
conda create -n chatglm python==3.8
# 激活chatglm环境
conda activate chatglm
# 安装运行依赖
pip install -r requirement.txt
#命令行 Demo
python3 cli_demo.py
2.微调
# 除 ChatGLM-6B 的依赖之外,还需要安装以下依赖,在上面创建的chatglm环境下安装
!pip install rouge_chinese nltk jieba datasets -i https://mirror.sjtu.edu.cn/pypi/web/simple
# 切换到ChatGLM-6B 的ptuning文件夹下
cd /ChatGLM-6B/ptuning
数据集下载,这里给出官方提供的数据集,从 Google Drive 或者 Tsinghua Cloud 下载处理好的 ADGEN 数据集,将解压后的 AdvertiseGen 目录放到本目录(ptuning)下。
# 执行train脚本开始微调
bash train.sh
这里我出现了如此下错误:
解决方案为,在main.py文件中添加这两行:
def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]),local_rank=-1)
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# 这两行
training_args.local_rank = -1
print(training_args.local_rank)
# Setup logging
3.推理
bash evaluate.sh
4. 利用微调后的模型进行验证
1.新建infer_base.py文件
import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
# 原始glm模型存储地址
MODEL_PATH = "./model/chatglm-6b"
# 微调后你的模型存储地址
CHECKPOINT_PATH = "./output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-1000"
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(MODEL_PATH, config=config, trust_remote_code=True).cuda()
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
if k.startswith("transformer.prefix_encoder."):
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
print(f"Quantized to 4 bit")
model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()
print("用户:你好\n")
response, history = model.chat(tokenizer, "你好", history=[])
print("ChatGLM-6B:\n",response)
print("\n------------------------------------------------\n用户:")
line = input()
while line:
response, history = model.chat(tokenizer, line, history=history)
print("ChatGLM-6B:\n", response)
print("\n------------------------------------------------\n用户:")
line = input()
2.运行infer文件
python3 infer.py
3.运行时出现两个bug如下:
bug1:
解决方法:
在代码的第一列加上如下代码:
#coding=gbk
bug2:
解决方法:我的是因为这个文件下载时发生错误,重新下载就可以了。
4.微调后效果
5.原模型效果
6. 这里运行原模型时突然显示内存溢出错误,重新进入服务器后,问题解决。
5.参考博客
更多推荐
已为社区贡献1条内容
所有评论(0)