vLLM

vLLM is a fast and easy-to-use library for LLM inference and serving.
vLLM是一个快速易用的LLM推理和服务库。

使用beam search方式需要安装最新版本的vLLM(V6.0.1.post1)
最新版vLLM(0.6.4.post1)的环境安装可以看我的另一篇文章vLLM环境安装与运行实例【最新版(0.6.4.post1)】

对于Sampling方式的使用可以参考看我的另一篇文章vLLM使用教程【V5.0.4】,在此基础之上增加beam search方法,这里重点讲述beam search方式

vLLM相关参数说明

基于Meta-Llama-3.1-8B-Instruct进行推理测试

1. LLM——载入模型

V6.0.1.post1版本vLLM中的LLM参数增加了一些,但是用法基本一致,具体可以参考我的另一篇文章vLLM使用教程【V5.0.4】,中的【LLM——载入模型】一节

注意的是max_model_len参数需要改变一下由43200变为39456

from vllm import LLM
llm = LLM(
    model="Meta-Llama-3.1-8B-Instruct",
    max_model_len = 39456,
)
print("model load success!")

2. BeamSearchParams

代码中的参数:

class BeamSearchParams(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        # required for @cached_property.
        dict=True):  # type: ignore[call-arg]
    """Beam search parameters for text generation."""
    beam_width: int
    max_tokens: int
    ignore_eos: bool = False
    temperature: float = 0.0
    length_penalty: float = 1.0
    include_stop_str_in_output: bool = False

官方无参数说明:
根据对beam search方式的理解需要指定beam_width的大小
使用方式

from vllm.sampling_params import BeamSearchParams
beam_params = BeamSearchParams(beam_width=5, max_tokens=16)

注意:

  1. 需要设置max_tokens参数,需要适当设置,否则生成时间过长。

3. beam_search

官方代码:

def beam_search(
        self,
        prompts: List[Union[str, List[int]]],
        params: BeamSearchParams,
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
            params: The beam search parameters.

        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

TODO:显示还有很多功能暂未支持
beam_search返回的是BeamSearchOutput组成的list
完整测试

prompts = ["你好", "你是谁?"]
outputs = llm.beam_search(prompts, sampling_params)

for output in outputs:
     print("output:")
     for sequence in output.sequences:
          print(sequence.cum_logprob, [sequence.text])

4. BeamSearchOutput——输出结构

4.1 BeamSearchOutput

class BeamSearchOutput:
    """The output of beam search.
    It contains the list of the best beam search sequences.
    The length of the list is equal to the beam width.
    """
    sequences: List[BeamSearchSequence]

返回BeamSearchSequence组成的List,长度等于BeamSearchParams中的beam_width
返回数量测试:

prompts = ["你好", "你是谁?"]
outputs = llm.beam_search(prompts, sampling_params)

for output in outputs:
     print("output:")
     for sequence in output.sequences:
          print(sequence.cum_logprob, [sequence.text])

测试结果:
返回数量测试
可以看到[0]的概率最大,因此选择[0]结果作为最后的结果即可

4.2 BeamSearchSequence

class BeamSearchSequence:
    """A sequence for beam search.
    It keeps track of the tokens and the log probability of the sequence.
    The text field is optional and will only be filled when the sequence is
    about to be returned to the user.
    """
    # The tokens includes the prompt.
    tokens: List[int]
    logprobs: List[Dict[int, Logprob]]
    cum_logprob: float = 0.0
    text: Optional[str] = None
    finish_reason: Optional[str] = None
    stop_reason: Union[int, str, None] = None
    multi_modal_data: Optional["MultiModalDataDict"] = None
    mm_processor_kwargs: Optional[Dict[str, Any]] = None

需要用的属性:

序号 参数名 说明
1 tokens 全部的tokens(包含prompt)
2 logprobs 输出每个位置的概率,长度正好等于输出的Token数
3 cum_logprob 整个序列的输出概率)
4 text 全部的字符串(包含prompt)

vLLM使用demo

基于vLLM使用教程【V5.0.4】,中的【vLLM使用demo】一节,增加了beam search方式的推断方法

from vllm import LLM, SamplingParams
from vllm.sampling_params import BeamSearchParams
import argparse
import time
import logging
from tqdm import tqdm

class Generate:
    def __init__(self, model_path, inference_type, temperature=0.6, top_p=0.9, beam_width=5, max_tokens=43200, debug=False):
        self.logger = logging.getLogger("vLLM")
        if debug:
            self.logger.setLevel(logging.DEBUG)
        else:
            self.logger.setLevel(logging.INFO)
        stream_handler = logging.StreamHandler()
        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        stream_handler.setFormatter(formatter)
        self.logger.addHandler(stream_handler)

        start_time = time.time()
        self.llm = LLM(
            model=model_path, 
            max_model_len = 39456
        )
        end_time = time.time()
        self.logger.info("load {} model use {:.2f}s".format(model_path, end_time - start_time))

        self.inference_type = inference_type
        self.params = None
        if  self.inference_type == "sampling":
            self.params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_tokens)
        elif  self.inference_type == "beam":
            self.params = BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)

        self.all_num = 0
        self.all_input_tokens_num = 0
        self.all_output_tokens_num = 0

    def sampling_generate_sentences(self, prompts):
        assert self.params is not None
        start_time = time.time()
        outputs = self.llm.generate(prompts, self.params, use_tqdm=False)
        end_time = time.time()
        self.all_time = self.all_time + (end_time - start_time)
        result_list = []
        for output in outputs:
            self.all_num += 1
            self.all_input_tokens_num += len(output.prompt_token_ids)
            self.all_output_tokens_num += len(output.outputs[0].token_ids)
            result_list.append(output.outputs[0].text)
        return result_list
    
    def beam_generate_sentences(self, prompts):
        assert self.params is not None
        start_time = time.time()
        outputs = self.llm.beam_search(prompts, self.params)
        end_time = time.time()
        self.all_time = self.all_time + (end_time - start_time)
        result_list = []
        for i in range(len(outputs)):
            output = outputs[i]

            self.all_num += 1
            sequence = output.sequences[0]
            output_len = len(sequence.logprobs)
            input_len = len(sequence.tokens) - output_len

            #去掉<|begin_of_text|>
            text = sequence.text[17:]
            prompt = prompts[i]
            #去掉prompt
            text = text[len(prompt):]
            
            self.all_input_tokens_num += input_len
            self.all_output_tokens_num += output_len
            result_list.append(text)
        return result_list

    def generate_all_data(self, input_list, batch):
        self.all_num = 0
        self.all_input_tokens_num = 0
        self.all_output_tokens_num = 0
        self.all_time = 0
        output_list = []
        for i in tqdm(range(0, len(input_list), batch)):
            if self.inference_type == "sampling":
                output = self.sampling_generate_sentences(input_list[i : i + batch])
            elif self.inference_type == "beam":
                output = self.beam_generate_sentences(input_list[i : i + batch])
            output_list.extend(output)
        
        self.logger.info("-" * 20)
        self.logger.info("process {} sentences use {:.2f}s".format(self.all_num, self.all_time))
        self.logger.info("everage input token num: {:.2f}, everage output token num: {:.2f}".format(self.all_input_tokens_num / self.all_num, self.all_output_tokens_num / self.all_num))
        self.logger.info(
            "speed: all:{:.2f} token/s, input:{:.2f} token/s, output:{:.2f} token/s".format(
                (self.all_input_tokens_num + self.all_input_tokens_num) / self.all_time, 
                self.all_input_tokens_num / self.all_time, 
                self.all_output_tokens_num / self.all_time
            )
        )
        self.logger.info("-" * 20 + "\n\n")

        return output_list
    
    def generate_file(self, input_file_path, output_file_path, batch):
        input_list = []
        with open(input_file_path, "r", encoding="utf8") as fin:
            for line in fin.readlines():
                #去掉读取文件中最后的\n
                if line[-1] == "\n":
                    data = line[:-1]
                else:
                    data = line
                data = data.replace("\\n", "\n") #替换为真实转义字符\n
                input_list.append(data)

        output_list = generate.generate_all_data(input_list, batch)

        with open(output_file_path, "w", encoding="utf8") as fout:
            for line in output_list:
                fout.write(line + "\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="vLLM Llama Demo")

    parser.add_argument('-m', '--model', type=str, required=True, help='model path')
    parser.add_argument('-i', '--input', type=str, required=True, help='input file path')
    parser.add_argument('-o', '--output', type=str, required=True, help='output file path')
    parser.add_argument('-b', '--batch', type=int, default=1)
    parser.add_argument('-t', '--type', type=str,  required=True, choices=["sampling", "beam"], help='output file path')
    parser.add_argument('--temperature', type=float, default=0.6, help='temperature is only valid in sampling')
    parser.add_argument('--top_p', type=float, default=0.9, help='top_p is only valid in sampling')
    parser.add_argument('--beam', type=int, default=5, help='beam is only valid in beam')
    parser.add_argument('--max_token', type=int, default=43200)
    parser.add_argument('--debug', action="store_true", help="whether to show debug infomation")

    args = parser.parse_args()

    generate = Generate(args.model, args.type, args.temperature, args.top_p, args.beam, args.max_token, args.debug)

    generate.generate_file(args.input, args.output, args.batch)

运行指令:

 CUDA_VISIBLE_DEVICES=0 python3 test.py -m ./Meta-Llama-3.1-8B-Instruct/ -i input.txt -o output.txt -b 2 --max_token 16 -t beam

运行效果:
输出

Logo

免费领 200 小时云算力,进群参与显卡、AI PC 幸运抽奖

更多推荐