美文网首页
基于对话数据集进行微调

基于对话数据集进行微调

作者: 国服最坑开发 | 来源:发表于2025-08-11 10:21 被阅读0次

0x00 TLDR;

实践一下:基于自定义的数据集进行模型微调,以达到模型回复内容符合预期

0x01 训练

准备参数文件: my_chat_dataset.jsonl

{"input_text": "what's your name?", "output_text": "I'm Ruby"}
{"input_text": "how old are you?", "output_text": "20 years old"}
{"input_text": "where are you?", "output_text": "Istanbul"}

编写代码:

from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset
import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["TRANSFORMERS_OFFLINE"] = "1"


def poc_load_dataset():
    dataset = load_dataset("json", data_files="my_chat_dataset.jsonl")
    train_dataset = dataset["train"]
    return train_dataset


# 加载模型
model_name = "microsoft/DialoGPT-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name)
model.config.pad_token_id = model.config.eos_token_id


# 数据预处理
def preprocess(example):
    inputs = tokenizer(example["input_text"], truncation=True, padding="max_length", max_length=64)
    outputs = tokenizer(example["output_text"], truncation=True, padding="max_length", max_length=64)
    inputs["labels"] = outputs["input_ids"]
    return inputs


# 训练
def poc_train():
    train_dataset = poc_load_dataset()
    tokenized_dataset = train_dataset.map(preprocess, batched=True)

    # 训练参数
    training_args = TrainingArguments(
        output_dir="./dialoGPT_finetuned",
        per_device_train_batch_size=4,
        num_train_epochs=5,
        logging_dir="./logs",
        save_strategy="epoch",
        fp16=True,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
    )

    # 开始训练
    trainer.train()

    # 保存
    trainer.save_model("./dialoGPT_finetuned")
    tokenizer.save_pretrained("./dialoGPT_finetuned")

    print("train finished")

if __name__ == '__main__':
    # 训练
    poc_train()


0x02 验证

from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import load_dataset
import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["TRANSFORMERS_OFFLINE"] = "1"

# 测试一下训练结果
import torch

def poc_test():
    _tokenizer = AutoTokenizer.from_pretrained("./dialoGPT_finetuned")
    _model = AutoModelForCausalLM.from_pretrained("./dialoGPT_finetuned")

    def chat_with_bot(prompt):
        inputs = _tokenizer.encode(prompt + _tokenizer.eos_token, return_tensors="pt")
        # 生成对应的 attention_mask
        attention_mask = torch.ones_like(inputs)

        outputs = _model.generate(
            inputs=inputs,
            attention_mask=attention_mask,
            max_length=100,
            pad_token_id=_tokenizer.eos_token_id,
            temperature=0.7,
            do_sample=True,
            top_k=50,
            top_p=0.9,
            no_repeat_ngram_size=3,
        )

        # 截取生成部分,跳过输入长度
        generated = outputs[:, inputs.shape[-1]:]
        return _tokenizer.decode(generated[0], skip_special_tokens=True)

    while True:
        user_input = input("You: ")
        if user_input.lower() in ["q", "exit", "quit"]:
            break

        reply = chat_with_bot(user_input)
        print("Bot: ", reply)


if __name__ == '__main__':
    #  验证
    poc_test()

看起来,没有达到预期,这里涉及训练参数的配置,以及参数量。

image.png

0x03 小结

虽然没有达到预期,但大致了解了微调的过程,以及生成的文件产物。

相关文章

  • Pytorch Fine-tuning

    pytorch 使用预训练过的ResNet 进行微调,训练新的数据集CIFAR100

  • 修改 MySQL 的默认编码为 UTF-8

    在Linux上安装完MySQL数据库后要对数据库进行字符集和校对集的修改(校对集是基于字符集)可以通过SHOW V...

  • 【mongoDB】MongoDB 3.x压缩选项

    数据压缩 基于 WiredTiger,MongoDB 支持对所有的数据集和索引进行压缩。压缩可以占用一点 CPU ...

  • MuTual: A Dataset for Multi-Turn

    多轮对话数据集分为两类: 无原因推理型数据集 Ubuntu Dialogue Corpus: 原因推理数据集 Mu...

  • 基于Hadoop的数据仓库Hive 基础知识

    Hive是基于Hadoop的数据仓库工具,可对存储在HDFS上的文件中的数据集进行数据整理、特殊查询和分析处理,提...

  • 基于Hadoop的数据仓库Hive知识

    Hive是基于Hadoop的数据仓库工具,可对存储在HDFS上的文件中的数据集进行数据整理、特殊查询和分析处理,提...

  • 一定要弄懂的Hive基础知识

    Hive是基于Hadoop的数据仓库工具,可对存储在HDFS上的文件中的数据集进行数据整理、特殊查询和分析处理,提...

  • hive基础知识

    Hive是基于Hadoop的数据仓库工具,可对存储在HDFS上的文件中的数据集进行数据整理、特殊查询和分析处理,提...

  • Flink简介

    Flink是一个分布式处理引擎,对无界数据流和有界数据流进行计算。 流数据(双十一)传统数据架构是基于有限数据集的...

  • 基于单细胞数据进行Bulk定量之MuSiC

    前言 Immugent最近在读文献时一直遇到各种利用deconvolution基于单细胞参考数据集进行bulk细胞...

网友评论

      本文标题:基于对话数据集进行微调

      本文链接:https://www.haomeiwen.com/subject/kdrmojtx.html