Skip to content

Data Generate Template

2024-09-02
python
import os
import json
import uuid
from typing import Any
from concurrent.futures import ProcessPoolExecutor

import pandas as pd
from tqdm import tqdm
from loguru import logger
from openai import OpenAI

# 用户可配置的参数
CONFIG: dict[str, str] = {
    "INPUT_FILE": "请替换为输入文件路径",
    "OUTPUT_FILE": "请替换为输出文件路径",
    "PROCESSED_DIR": "请替换为中间处理文件夹路径",
    "API_KEY": "请替换为API密钥",  # 如果是vllm等自己部署的模型,则为"EMPTY"
    "BASE_URL": "请替换为API基础URL",
    "MODEL_NAME": "请替换为模型名称",
    "MAX_WORKERS": "替换成一个整数,如4",
}

# 初始化OpenAI客户端
client = OpenAI(api_key=CONFIG["API_KEY"], base_url=CONFIG["BASE_URL"])

# 系统消息模板
SYSTEM_MESSAGE = """
请在此处定义系统消息,用于指导模型的行为和输出格式。
"""

# 用户输入模板
USER_INPUT_TEMPLATE = """
请在此处定义用户输入的模板,可以使用{placeholder}作为占位符。
"""


def process_row(row: dict[str, Any]) -> None:
    try:
        user_input = USER_INPUT_TEMPLATE.format(**row)
        messages = [
            {"role": "system", "content": SYSTEM_MESSAGE},
            {"role": "user", "content": user_input},
        ]
        response = (
            client.chat.completions.create(
                model=CONFIG["MODEL_NAME"],
                messages=messages,  # type: ignore
            )
            .choices[0]
            .message.content
        )

        post_process(row, response)
    except Exception as e:
        logger.info(f"处理数据时出错: {e}")
        logger.info(f"跳过数据: {row.get('id', 'unknown')}")


def post_process(row: dict[str, Any], response: str | None) -> None:
    # 在此处理模型的响应,例如输出是json,可使用json.loads(response)
    # 示例:将响应直接添加到row中
    row["model_response"] = response

    # 生成唯一ID并保存处理后的数据
    unique_id = str(uuid.uuid4())
    filename = f"{CONFIG['PROCESSED_DIR']}/{unique_id}.json"
    with open(filename, "w", encoding="utf-8") as f:
        json.dump(row, f, ensure_ascii=False, indent=4)


def main() -> None:
    # 读取输入数据
    df = pd.read_json(CONFIG["INPUT_FILE"], lines=True)

    # 创建处理目录
    os.makedirs(CONFIG["PROCESSED_DIR"], exist_ok=True)

    # 并行处理数据
    with ProcessPoolExecutor(max_workers=int(CONFIG["MAX_WORKERS"])) as executor:
        list(
            tqdm(
                executor.map(process_row, df.to_dict(orient="records")),
                total=len(df),
            )
        )

    # 过滤掉None结果并创建最终数据框
    # 读取保存的文件chunk并进行拼接
    data = [
        json.load(
            open(os.path.join(CONFIG["PROCESSED_DIR"], chunk), encoding="utf-8")
        )
        for chunk in os.listdir(CONFIG["PROCESSED_DIR"])
    ]
    final_data = pd.DataFrame(data)

    # 保存结果
    final_data.to_json(
        CONFIG["OUTPUT_FILE"], force_ascii=False, lines=True, orient="records"
    )


if __name__ == "__main__":
    main()