博客

  • MacBook M1 Max 笔记本上微调 DeepSeekR1 Qwen 1.5B 实战

    一、环境

    import datetime
    !pip install "transformers==4.48.2"
    !pip install pip3-autoremove
    !pip-autoremove torch torchvision torchaudio -y
    ​
    !pip install torch torchvision torchaudio transformers --index-url https://download.pytorch.org/whl
    !pip install "trl==0.14.0"
    !pip install "fsspec==2024.10.0"

    二、引入准备

    from datasets import load_dataset
    from trl import SFTConfig,SFTTrainer
    from peft import LoraConfig
    import json
    from transformers import AutoTokenizer,AutoModelForCausalLM
    import transformers
    import torch
    print(transformers.__version__)
    print(torch.__version__)

    三、加载数据

    dataset = load_dataset('json',data_files='/model/FreedomIntelligence:medical-o1-reasoning-SFT/medical_o1_sft_Chinese.json',split='train[0:1000]')
    print(type(dataset))
    print(dataset)
    print(dataset[0].keys())
    print(dataset[0]['Question'])
    print(dataset[0]['Complex_CoT'])
    print(dataset[0]['Response'])
    def preprocess_data(example):
        example['text'] = example['Question'] + " " + example['Complex_CoT'] + " " + example['Response']
        return example
    dataset = dataset.map(preprocess_data)
    ​
    # 打印样本,确认合并结果
    print(dataset[0].keys())  # 现在应该包含 'text' 字段
    print(dataset[0]['text'])  # 打印合并后的文本
    ​
    train_test_dataset = dataset.train_test_split(test_size=0.3)
    print(train_test_dataset['test'][0])
    print(train_test_dataset['train'][0])

    四、加载模型

    import time
    device = 'mps'
    timestart=time.time()
    modepath = "model/DeepSeek-R1-Distill-Qwen-1.5B"
    tokenizer = AutoTokenizer.from_pretrained(modepath,trust_remote_code=True,padding_side = 'right')
    model = AutoModelForCausalLM.from_pretrained(modepath,trust_remote_code=True).to(device)
    tokenizer.pad_token=tokenizer.eos_token
    time_end = time.time()
    print(f"加载耗时:{time_end-timestart}秒")
    ​

    五、评估函数

    from transformers import  TrainerCallback,TrainingArguments,TrainerState,TrainerControl
    import os
    import matplotlib.pyplot as plt
    import datetime
    import torch
    class EvaluationCallback(TrainerCallback):
        def __init__(self,test_dataset,tokenizer):
            self.test_dataset = test_dataset
            self.tokenizer = tokenizer
            self.epoch = 0
            self.eval_loss = []
            self.train_loss=[]
            self.epochs=[]
            # print(f"\ndataset的内容是:{test_dataset[1]}")
        def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
            print(f"\nevalation model after {self.epoch} times")
            if state.log_history:
                latest_loss = state.log_history[-1].get('loss')
                if latest_loss is not None:
                    self.train_loss.append(latest_loss)
                    self.epochs.append(self.epoch)
            model.eval()
            total_eval_loss = 0
            num_eval_samples=0
    ​
            with torch.no_grad():
                for i in range(min(1,len(self.test_dataset))):
                    # question = self.test_dataset[i]['Question']
                    # cot = self.test_dataset[i]['Complex_CoT']
                    # response = self.test_dataset[i]['Response']
                    # pmt = question+cot+response
                    print(self.test_dataset[i])
                    pmt = self.test_dataset[i]
                    message = [{"role":"user","content":pmt}]
                    chat_template_text = self.tokenizer.apply_chat_template(message,
                                                                            tokenize=False,
                                                                            add_generation_prompt=False)
                    print(f"\nchat_template_text内容是{chat_template_text}")
                    model_input = self.tokenizer([chat_template_text],return_tensors='pt').to(model.device)
    ​
                    generated_ids = model.generate(**model_input,
                                                   max_new_token=100,
                                                   do_sample=False,
                                                   temperature=0.3,
                                                   top_p=0.9,
                                                   pad_token_id=self.tokenizer.pad_token_id,
                                                   eos_token_id=self.tokenizer.eos_token_id,
                                                   repetition_penalty=1.2,
                                                   no_repeat_ngram_size=3,
                                                   early_stopping=False
                                                   )
                    outputs = model(**model_input,labels=model_input.input_ids)
                    loss = outputs.loss.item()
                    total_eval_loss += loss
                    num_eval_samples +=1
    ​
                    # generated_text = self.tokenizer.batch_decode([output_ids[len(input_ids):] for input_ids],
                    #                                              skip_special_tokens=True)[0]
                    generated_text = self.tokenizer.batch_decode(
                        [output_ids[len(input_ids):] for input_ids, output_ids in zip(model_input.input_ids, generated_ids)],
                        skip_special_tokens=True
                    )[0]
                    print(f"\ntest simple {i+1}:")
                    print(f"input:{pmt}")
                    print(f"output:{generated_text}")
                    print(f"loss:{loss}")
                    print("-"+50)
                avg_eval_loss = total_eval_loss/num_eval_samples if num_eval_samples>0 else 0
                self.eval_loss.append(avg_eval_loss)
    ​
                metrics={
                    'epochs':self.epochs,
                    'train_loss':self.train_loss,
                    'eval_loss':self.eval_loss,
                    'current_epoch':self.epoch,
                    'timestamp':datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
                }
    ​
                os.makedirs('losses', exist_ok=True)
    ​
                # Save metrics to JSON
                with open('losses/training_metrics.json', 'w') as f:
                    json.dump(metrics, f, indent=2)
                plt.figure(figsize=(10, 6))
                plt.plot(self.epochs, self.train_losses, 'b-', label='Training Loss')
                plt.plot(range(len(self.eval_losses)), self.eval_losses, 'r-', label='Evaluation Loss')
                plt.xlabel('Epoch')
                plt.ylabel('Loss')
                plt.title('Training and Evaluation Loss Over Time')
                plt.grid(True)
                plt.legend()
                plt.savefig('losses/training_progress.png')
                plt.close()
    ​
                self.epoch += 1
                model.train()
    ​
    ​

    六、超参配置

    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        modules_to_save=["lm_head", "embed_token"],
        task_type="CAUSAL_LM"
    )

    七、设置训练器

    import shutil
    import os
    output_dir = "./finetuned_model/deepseek_1.5b_wenan"
    print(output_dir)
    ​
    # Remove output directory if it exists
    if os.path.exists(output_dir):
        shutil.rmtree(output_dir)
        print(f"Removed existing output directory: {output_dir}")
    ​
    # Create fresh output directory
    os.makedirs(output_dir)
    print(f"\n dataset is {train_test_dataset['train'][0]}")
    print(f"\n train_test_dataset is {train_test_dataset['test'][0]}")
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_test_dataset['train'],
        eval_dataset=train_test_dataset['test'],
        args=SFTConfig(
            output_dir=output_dir,
            num_train_epochs=5,  # Reduced from 5
            per_device_train_batch_size=2,  # Add small batch size
            gradient_accumulation_steps=4,  # Add gradient accumulation
            learning_rate=1e-4,  # Add explicit learning rate
            weight_decay=0.01,   # Add weight decay for regularization
            logging_steps=1,     # More frequent logging for small dataset
            save_steps=5,        # More frequent saving
            evaluation_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            metric_for_best_model="loss",
            greater_is_better=False,
            warmup_steps=10,     # Add warmup steps
            ),
        peft_config=peft_config,
        callbacks=[EvaluationCallback(train_test_dataset['test'], tokenizer)]
    )

    八、显示训练量

    trainable_params = 0
    all_params = 0
    ​
    for _, param in trainer.model.named_parameters():
        all_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    ​
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"All parameters: {all_params:,}")
    print(f"Percentage of parameters being trained: {100 * trainable_params / all_params:.2f}%")

    九、训练

    train_output = trainer.train()
    ​

    十、显示训练效果

    def generate_response(model, tokenizer, user_input, system_prompt):
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_input}
        ]
    ​
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )
    ​
        print(f"\ntext: {text}")
    ​
        model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=100,
            do_sample=True,
            temperature=0.1,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.2,
            no_repeat_ngram_size=3,
            early_stopping=False
        )
    ​
        # print(f"\ngenerated_ids: {tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]}")
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
        # print(generated_ids)
    ​
        response = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
    ​
        return response

    十一、打印测试结果

    print("\nTesting with examples from test dataset:")
    for i in range(min(1, len(train_test_dataset['test']))):
        system_prompt = "你是一个文案助手,你的任务是帮用户生成文案"
        test_input = train_test_dataset['test'][i]
        print(f"\nTest input {i+1}: {test_input}")
        response = generate_response(trainer.model, trainer.processing_class, test_input, system_prompt)
        print(f"Model response: {response}")
        print("-" * 80)

    十二、保存模型

    new_model_local = "DeepSeek-R1-Medical-COT-zh"
    model.save_pretrained(new_model_local)
    tokenizer.save_pretrained(new_model_local)
    ​
    model.save_pretrained_merged(new_model_local, tokenizer, save_method = "merged_16bit",)

    十三、上传到HF

    new_model_online = "tain198127/DeepSeek-R1-Medical-COT-zh"
    model.push_to_hub(new_model_online)
    tokenizer.push_to_hub(new_model_online)
    ​
    model.push_to_hub_merged(new_model_online, tokenizer, save_method = "merged_16bit")
  • JVM Xlog 使用方法

    XLog使用说明

    注:以下内容,基于Open-JDK 11验证通过

    一、格式

    Xlog:{标记=级别,标记=级别,标记=级别}:{输出方式:附加信息,附加信息,附加信息}:{输出选项}

    标记

    也称之为selections,或者tags。若想获取所有的selections/tags可以执行

    jcmd {pid} VM.log list

    可以打印出所有可以选择的selections/tags

    有下述这么多种selections/tags

    add, age, alloc, annotation, aot, arguments, attach, barrier, biasedlocking, blocks, bot, breakpoint, bytecode, cds, census, class, classhisto, cleanup, codecache, compaction, compilation, constantpool, constraints, container, coops, cpu, cset, data, datacreation, dcmd, decoder, defaultmethods, director, dump, ergo, event, exceptions, exit, fingerprint, free, freelist, gc, handshake, hashtables, heap, humongous, ihop, iklass, init, inlining, interpreter, itables, jfr, jit, jni, jvmti, liveness, load, loader, logging, malloc, mark, marking, membername, memops, metadata, metaspace, methodcomparator, mirror, mmu, module, monitorinflation, monitormismatch, nestmates, nmethod, normalize, objecttagging, obsolete, oldobject, oom, oopmap, oops, oopstorage, os, pagesize, parser, patch, path, perf, phases, plab, preorder, preview, promotion, protectiondomain, purge, redefine, ref, refine, region, reloc, remset, resolve, safepoint, sampling, scavenge, setting, smr, stackmap, stacktrace, stackwalk, start, startuptime, state, stats, stringdedup, stringtable, subclass, survivor, sweep, system, table, task, thread, time, timer, tlab, tracking, unload, unshareable, update, verification, verify, vmoperation, vmthread, vtables, vtablestubs, workgang

    级别

    级别有

    • off
    • trace
    • debug
    • info
    • warning
    • error

    这么多种,其中标记与级别可以互相组合。

    输出方式

    输出方式有三种

    stdout

    stderr

    file=filename.%p.%l.%t

    注意,若希望将进程PID、时间戳等信息作为文件名的一部分,可以在filename中增加%p %l %t等信息。具体可参考第二节的例子。

    附加信息

    增加附加信息后,这些信息会打印到gc日志中

    附加项解释
    time or tISO-8601格式的当前日期和时间的时间戳。注:可以嵌入filename的文件名中
    utctime or utcutc格式的时间戳。注:可以嵌入filename的文件名中
    uptime or u从JVM系统启动经历了多少时间【秒和毫秒】例如: 6.567s.
    timemillis or tm相当于System.currentTimeMillis()
    uptimemillis or um从JVM启动经历了多少毫秒
    timenanos or tn相当于System.nanoTime()的值
    uptimenanos or un从JVM启动经历了多少纳秒
    hostname or hn宿主机名称。注:可以嵌入filename的文件名中
    pid or p进程ID。注:可以嵌入filename的文件名中
    tid or ti线程ID
    level or l日志级别
    tags or tg标记

    输出选项

    输出选项,只有在file时有效

    filecount=5

    filesize=1K/1M/1G

    注意filesize如果不带单位,默认代为是字节

    二、例子

    表示打印gc和class相关的日志信息,打印的日志级别是Info

    java -Xlog:gc,class=info -version

    表达打印gc和head的info级别信息到日志debug文件中,并且把其中的进程信息和时间戳打印到文件上

    java -Xlog:gc,heap=info:file=/home/debug.%p.%t.log:t,um,hn,p,ti,l,tg -version

    除上述信息外,还会对文件进行滚动切割,只保留5个文件,每个文件大小是1024个字节

    java -Xlog:gc,heap=info:file=/home/debug.%p.%t.log:t,um,hn,p,ti,l,tg:filecount=5,filesize=1024 -version

    这里表达打印os信息和heap信息

    java -Xlog:os=info,heap=trace:file=/home/debug.%p.%t.log:t,um,hn,p,ti,l,tg:filecount=5,filesize=1024 -version

    表示打印gc信息,并且打印时间、进程、日志级别、标记;并且轮转切割gc日志文件,每个日志文件1k,最多轮转10个,若超过是个,则会把最晚的日志覆盖

    java -Xlog:gc*,gc+ref=debug,gc+heap=debug,gc+age=trace,gc+ergo*=trace:file=gc-%p-%t.log:tags,uptime,time,level:filecount=10,filesize=1k -jar xxx.jar

    三、热修改GC日志

    • 注意,在动态修改时,必须输入output内容,否则会失败,在输入output内容后,其他的what, decorators,output_options可以任意搭配。
    • what内容会发生合并:例如启动时的gc配置是gc+init=debug,热修改what改为gc * =debug,会被合并为gc * =debug
    • 热修改GC日志配置后,新的日志文件马上被创建,但是内容可能还尚未写入,此时gc信息还在内存中,等待下一次GC时,才会写入。之前的日志即便没有写完,也不会再继续输出了
    • 可以在指定的新的日志文件中使用 %p%t 设定文件的进程ID,时间戳信息
    1. 首先通过jps获取进程pid.例如
    jps 
    //得到某个进程的ID是 22

    {pid}指的就是我们找到的要修改gc日志的java进程

    后续热修改gc日志的例子,均以下述启动参数作为基线

    java -Xlog:gc*,gc+ref=debug,gc+heap=debug,gc+age=trace,gc+ergo*=trace:file=gc-%p-%t.log:tags,uptime,time,level:filecount=10,filesize=10m -jar xxx.jar

    手动执行日志轮转

    //对日志进行轮转
    jcmd {pid} VM.log rotate

    修改抓取的内容

    jcmd {pid} VM.log output=change-gc.%p.%t.log what="gc*=debug"
    //动态修改GC 日志拦截的内容

    调整decorators

    jcmd {pid} VM.log output=change-gc.%p.%t.log what="gc*=debug" decorators="tags,uptime,time"
    //动态调整要加载的信息

    调整日志输出轮转信息

    jcmd {pid} VM.log output=change-gc.%p.%t.log what="gc*=debug" decorators="tags,uptime,time" output_options="filecount=50,filesize=100M"