79 lines
2.5 KiB
Python
79 lines
2.5 KiB
Python
"""
|
||
医疗数据生成器
|
||
用于模拟构建医疗测试数据:
|
||
- id: 唯一标识符
|
||
- treatment: 是否吃药(0 或 1)
|
||
- health: 病人健康状态(0~1 浮点数,越高越好)
|
||
"""
|
||
|
||
import pandas as pd
|
||
import numpy as np
|
||
import os
|
||
|
||
|
||
def generate_medical_data(n_samples: int = 500, output_path: str = "examples/medical/data.xlsx") -> str:
|
||
"""
|
||
生成医疗测试数据并保存到 Excel 文件
|
||
|
||
Args:
|
||
n_samples: 样本数量,默认 500
|
||
output_path: 输出文件路径
|
||
|
||
Returns:
|
||
生成的文件路径
|
||
"""
|
||
# 设置随机种子以确保可重复性
|
||
np.random.seed(42)
|
||
|
||
# 生成唯一 ID
|
||
ids = list(range(1, n_samples + 1))
|
||
|
||
# 生成是否吃药(0 或 1),假设 40% 的人吃药
|
||
treatment = np.random.binomial(1, 0.4, n_samples)
|
||
|
||
# 生成健康状态(0~1 浮点数)
|
||
# 健康状态受是否吃药影响,吃药的人健康状态平均更高
|
||
# 基础健康状态 + 吃药的额外影响 + 随机噪声
|
||
base_health = np.random.beta(2, 2, n_samples) # 基础健康分布
|
||
treatment_effect = treatment * 0.2 # 吃药带来 0.2 的健康提升
|
||
noise = np.random.normal(0, 0.1, n_samples) # 随机噪声
|
||
health = np.clip(base_health + treatment_effect + noise, 0, 1)
|
||
|
||
# 创建 DataFrame
|
||
df = pd.DataFrame({
|
||
'id': ids,
|
||
'treatment': treatment,
|
||
'health': np.round(health, 4)
|
||
})
|
||
|
||
# 确保输出目录存在
|
||
output_dir = os.path.dirname(output_path)
|
||
if output_dir and not os.path.exists(output_dir):
|
||
os.makedirs(output_dir)
|
||
|
||
# 确保输出目录存在
|
||
output_dir = os.path.dirname(output_path)
|
||
if output_dir and not os.path.exists(output_dir):
|
||
os.makedirs(output_dir)
|
||
|
||
# 保存到 Excel
|
||
df.to_excel(output_path, index=False)
|
||
|
||
print(f"成功生成 {n_samples} 条医疗数据,已保存到:{output_path}")
|
||
print(f"文件已创建:{output_path}")
|
||
print(f"数据预览:")
|
||
print(df.head(10))
|
||
print(f"\n统计信息:")
|
||
print(f"吃药人数:{treatment.sum()} ({treatment.mean()*100:.1f}%)")
|
||
print(f"健康状态均值:{health.mean():.4f}")
|
||
print(f"吃药组健康均值:{health[treatment==1].mean():.4f}")
|
||
print(f"未吃药组健康均值:{health[treatment==0].mean():.4f}")
|
||
|
||
return output_path
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 默认生成 500 条数据到 data.xlsx
|
||
output_file = generate_medical_data(n_samples=500, output_path="data.xlsx")
|
||
print(f"\n文件已创建:{output_file}")
|