模型训练概览
Semantic Router 依赖多个专门的分类模型来做出智能路由决策。本节全面概述了训练过程、使用的数据集以及每个模型在路由流程中的作用。
训练架构概览
Semantic Router 采用多任务学习方法,使用 ModernBERT 作为各种分类任务的基础模型。每个模型都针对路由流程中的特定目的进行训练:
为什么选择 ModernBERT?
技术优势
ModernBERT 代表了 BERT 架构的最新演进,相比传统 BERT 模型有几个关键改进:
1. 增强的架构
- 旋转位置嵌入 (RoPE):更好地处理位置信息
- GeGLU 激活:改进的梯度流和表示能力
- 注意力偏置移除:更简洁的注意力机制
- 现代层归 一化:更好的训练稳定性
2. 训练改进
- 更长上下文:在长达 8,192 个令牌的序列上训练,而 BERT 仅为 512
- 更好的数据:在更高质量、更新的数据集上训练
- 改进的分词:更高效的词汇表和分词
- 防过拟合技术:内置正则化改进
3. 性能优势
# 分类任务性能比较
model_performance = {
"bert-base": {
"accuracy": 89.2,
"inference_speed": "100ms",
"memory_usage": "400MB"
},
"modernbert-base": {
"accuracy": 92.7, # +3.5% 提升
"inference_speed": "85ms", # 快 15%
"memory_usage": "380MB" # 内存少 5%
}
}
为什么不使用基于 GPT 的模型?
| 方面 | ModernBERT | GPT-3.5/4 |
|---|---|---|
| 延迟 | ~20ms | ~200-500ms |
| 成本 | $0.0001/查询 | $0.002-0.03/查询 |
| 专业化 | 针对分类微调 | 通用目的 |
| 一致性 | 确定性输出 | 可变输出 |
| 部署 | 自托管 | API 依赖 |
| 上下文理解 | 双向 | 从左到右 |
训练方法论
统一微调框架
我们的训练方法使用统一微调框架,在所有分类任务中应用一致的方法�ology:
防过拟合策略
# 基于数据集大小的自适应训练配置
def get_training_config(dataset_size):
if dataset_size < 1000:
return TrainingConfig(
epochs=2,
batch_size=4,
learning_rate=1e-5,
weight_decay=0.15,
warmup_ratio=0.1,
eval_strategy="epoch",
early_stopping_patience=1
)
elif dataset_size < 5000:
return TrainingConfig(
epochs=3,
batch_size=8,
learning_rate=2e-5,
weight_decay=0.1,
warmup_ratio=0.06,
eval_strategy="steps",
eval_steps=100,
early_stopping_patience=2
)
else:
return TrainingConfig(
epochs=4,
batch_size=16,
learning_rate=3e-5,
weight_decay=0.05,
warmup_ratio=0.03,
eval_strategy="steps",
eval_steps=200,
early_stopping_patience=3
)
训练流程实现
class UnifiedBERTFinetuning:
def __init__(self, model_name="modernbert-base", task_type="classification"):
self.model_name = model_name
self.task_type = task_type
self.model = None
self.tokenizer = None
def train_model(self, dataset, config):
# 1. 加载预训练模型
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
num_labels=len(dataset.label_names),
problem_type="single_label_classification"
)
# 2. 设置带有防过拟合措施的训练参数
training_args = TrainingArguments(
output_dir=f"./models/{self.task_type}_classifier_{self.model_name}_model",
num_train_epochs=config.epochs,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=config.batch_size,
learning_rate=config.learning_rate,
weight_decay=config.weight_decay,
warmup_ratio=config.warmup_ratio,
# 评估和早停
evaluation_strategy=config.eval_strategy,
eval_steps=config.eval_steps if hasattr(config, 'eval_steps') else None,
save_strategy="steps",
save_steps=200,
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
# 正则化
fp16=True, # 混合精度训练
gradient_checkpointing=True,
dataloader_drop_last=True,
# 日志
logging_dir=f"./logs/{self.task_type}_{self.model_name}",
logging_steps=50,
report_to="tensorboard"
)
# 3. 使用自定义指标设置训练器
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=dataset.train_dataset,
eval_dataset=dataset.eval_dataset,
tokenizer=self.tokenizer,
data_collator=DataCollatorWithPadding(self.tokenizer),
compute_metrics=self.compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=config.early_stopping_patience)]
)
# 4. 训练模型
trainer.train()
# 5. 保存模型和评估结果
self.save_trained_model(trainer)
return trainer
def compute_metrics(self, eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return {
'accuracy': accuracy_score(labels, predictions),
'f1': f1_score(labels, predictions, average='weighted'),
'precision': precision_score(labels, predictions, average='weighted'),
'recall': recall_score(labels, predictions, average='weighted')
}
模型规格
1. 类别分类模型
目的:根据学术/专业领域将查询路由到专门的模型。
数据集:MMLU-Pro 学术领域
# 数据集组成
mmlu_categories = {
"mathematics": {
"samples": 1547,
"subcategories": ["algebra", "calculus", "geometry", "statistics"],
"example": "求 x^2 从 0 到 1 的积分"
},
"physics": {
"samples": 1231,
"subcategories": ["mechanics", "thermodynamics", "electromagnetism"],
"example": "计算使 10kg 质量以 5m/s^2 加速所需的力"
},
"computer_science": {
"samples": 1156,
"subcategories": ["algorithms", "data_structures", "programming"],
"example": "用 Python 实现二分搜索算法"
},
"biology": {
"samples": 1089,
"subcategories": ["genetics", "ecology", "anatomy"],
"example": "解释植物的光合作用过程"
},
"chemistry": {
"samples": 1034,
"subcategories": ["organic", "inorganic", "physical"],
"example": "配平化学方程式:H2 + O2 → H2O"
},
# ... 其他类别
}
训练配置
model_config:
base_model: "modernbert-base"
task_type: "sequence_classification"
num_labels: 10
training_config:
epochs: 3
batch_size: 8
learning_rate: 2e-5
weight_decay: 0.1
evaluation_metrics:
- accuracy: 94.2%
- f1_weighted: 93.8%
- per_category_precision: "所有类别 >90%"
模型性能
category_performance = {
"overall_accuracy": 0.942,
"per_category_results": {
"mathematics": {"precision": 0.956, "recall": 0.943, "f1": 0.949},
"physics": {"precision": 0.934, "recall": 0.928, "f1": 0.931},
"computer_science": {"precision": 0.948, "recall": 0.952, "f1": 0.950},
"biology": {"precision": 0.925, "recall": 0.918, "f1": 0.921},
"chemistry": {"precision": 0.941, "recall": 0.935, "f1": 0.938}
},
"confusion_matrix_insights": {
"most_confused": "physics <-> mathematics (12% 交叉分类)",
"best_separated": "biology <-> computer_science (2% 交叉分类)"
}
}
2. PII 检测模型
目的:识别个人身份信息以保护用户隐私。
数据集:Microsoft Presidio + 自定义合成数据
# PII 实体类型和示例
pii_entities = {
"PERSON": {
"count": 15420,
"examples": ["John Smith", "Dr. Sarah Johnson", "Ms. Emily Chen"],
"patterns": ["名 姓", "头衔 名 姓", "名 中间名 姓"]
},
"EMAIL_ADDRESS": {
"count": 8934,
"examples": ["user@domain.com", "john.doe@company.org"],
"patterns": ["Local@Domain", "FirstLast@Company"]
},
"PHONE_NUMBER": {
"count": 7234,
"examples": ["(555) 123-4567", "+1-800-555-0123", "555.123.4567"],
"patterns": ["美国格式", "国际格式", "点分格式"]
},
"US_SSN": {
"count": 5123,
"examples": ["123-45-6789", "123456789"],
"patterns": ["XXX-XX-XXXX", "XXXXXXXXX"]
},
"LOCATION": {
"count": 6789,
"examples": ["123 Main St, New York, NY", "San Francisco, CA"],
"patterns": ["街道地址", "城市, 州", "地理位置"]
},
"NO_PII": {
"count": 45678,
"examples": ["今天天气不错", "请帮我写代码"],
"description": "不包含个人信息的文本"
}
}
训练方法:令牌分类
class PIITokenClassifier:
def __init__(self):
self.model = AutoModelForTokenClassification.from_pretrained(
"modernbert-base",
num_labels=len(pii_entities), # 6 个实体类型
id2label={i: label for i, label in enumerate(pii_entities.keys())},
label2id={label: i for i, label in enumerate(pii_entities.keys())}
)
def preprocess_data(self, examples):
# 将 PII 标注转换为 BIO 标签
tokenized_inputs = self.tokenizer(
examples["tokens"],
truncation=True,
is_split_into_words=True
)
# 将标签与分词输入对齐
labels = []
for i, label in enumerate(examples["ner_tags"]):
word_ids = tokenized_inputs.word_ids(batch_index=i)
label_ids = self.align_labels_with_tokens(label, word_ids)
labels.append(label_ids)
tokenized_inputs["labels"] = labels
return tokenized_inputs
性能指标
pii_performance = {
"overall_f1": 0.957,
"entity_level_performance": {
"PERSON": {"precision": 0.961, "recall": 0.954, "f1": 0.957},
"EMAIL_ADDRESS": {"precision": 0.989, "recall": 0.985, "f1": 0.987},
"PHONE_NUMBER": {"precision": 0.978, "recall": 0.972, "f1": 0.975},
"US_SSN": {"precision": 0.995, "recall": 0.991, "f1": 0.993},
"LOCATION": {"precision": 0.943, "recall": 0.938, "f1": 0.940},
"NO_PII": {"precision": 0.967, "recall": 0.971, "f1": 0.969}
},
"false_positive_analysis": {
"common_errors": "企业名称与人名混淆",
"mitigation": "使用企业实体识别进行后处理"
}
}
3. 越狱检测模型
目的:识别并阻止绕过 AI 安全措施的尝试。