Model Training Overview
The Semantic Router relies on multiple specialized classification models to make intelligent routing decisions. This section provides a comprehensive overview of the training process, datasets used, and the purpose of each model in the routing pipeline.
Training Architecture Overview
The Semantic Router employs a multi-task learning approach using ModernBERT as the foundation model for various classification tasks. Each model is trained for specific purposes in the routing pipeline:
Why ModernBERT?
Technical Advantages
ModernBERT represents the latest evolution in BERT architecture with several key improvements over traditional BERT models:
1. Enhanced Architecture
- Rotary Position Embedding (RoPE): Better handling of positional information
- GeGLU Activation: Improved gradient flow and representation capacity
- Attention Bias Removal: Cleaner attention mechanisms
- Modern Layer Normalization: Better training stability
2. Training Improvements
- Longer Context: Trained on sequences up to 8,192 tokens vs BERT's 512
- Better Data: Trained on higher-quality, more recent datasets
- Improved Tokenization: More efficient vocabulary and tokenization
- Anti-overfitting Techniques: Built-in regularization improvements
3. Performance Benefits
# Performance comparison on classification tasks
model_performance = {
"bert-base": {
"accuracy": 89.2,
"inference_speed": "100ms",
"memory_usage": "400MB"
},
"modernbert-base": {
"accuracy": 92.7, # +3.5% improvement
"inference_speed": "85ms", # 15% faster
"memory_usage": "380MB" # 5% less memory
}
}
Why Not GPT-based Models?
| Aspect | ModernBERT | GPT-3.5/4 |
|---|---|---|
| Latency | ~20ms | ~200-500ms |
| Cost | $0.0001/query | $0.002-0.03/query |
| Specialization | Fine-tuned for classification | General purpose |
| Consistency | Deterministic outputs | Variable outputs |
| Deployment | Self-hosted | API dependency |
| Context Understanding | Bidirectional | Left-to-right |
Training Methodology
Unified Fine-tuning Framework
Our training approach uses a unified fine-tuning framework that applies consistent methodologies across all classification tasks:
Anti-Overfitting Strategy
# Adaptive training configuration based on dataset size
def get_training_config(dataset_size):
if dataset_size < 1000:
return TrainingConfig(
epochs=2,
batch_size=4,
learning_rate=1e-5,
weight_decay=0.15,
warmup_ratio=0.1,
eval_strategy="epoch",
early_stopping_patience=1
)
elif dataset_size < 5000:
return TrainingConfig(
epochs=3,
batch_size=8,
learning_rate=2e-5,
weight_decay=0.1,
warmup_ratio=0.06,
eval_strategy="steps",
eval_steps=100,
early_stopping_patience=2
)
else:
return TrainingConfig(
epochs=4,
batch_size=16,
learning_rate=3e-5,
weight_decay=0.05,
warmup_ratio=0.03,
eval_strategy="steps",
eval_steps=200,
early_stopping_patience=3
)
Training Pipeline Implementation
class UnifiedBERTFinetuning:
def __init__(self, model_name="modernbert-base", task_type="classification"):
self.model_name = model_name
self.task_type = task_type
self.model = None
self.tokenizer = None
def train_model(self, dataset, config):
# 1. Load pre-trained model
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
num_labels=len(dataset.label_names),
problem_type="single_label_classification"
)
# 2. Setup training arguments with anti-overfitting measures
training_args = TrainingArguments(
output_dir=f"./models/{self.task_type}_classifier_{self.model_name}_model",
num_train_epochs=config.epochs,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=config.batch_size,
learning_rate=config.learning_rate,
weight_decay=config.weight_decay,
warmup_ratio=config.warmup_ratio,
# Evaluation and early stopping
evaluation_strategy=config.eval_strategy,
eval_steps=config.eval_steps if hasattr(config, 'eval_steps') else None,
save_strategy="steps",
save_steps=200,
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
# Regularization
fp16=True, # Mixed precision training
gradient_checkpointing=True,
dataloader_drop_last=True,
# Logging
logging_dir=f"./logs/{self.task_type}_{self.model_name}",
logging_steps=50,
report_to="tensorboard"
)
# 3. Setup trainer with custom metrics
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=dataset.train_dataset,
eval_dataset=dataset.eval_dataset,
tokenizer=self.tokenizer,
data_collator=DataCollatorWithPadding(self.tokenizer),
compute_metrics=self.compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=config.early_stopping_patience)]
)
# 4. Train the model
trainer.train()
# 5. Save model and evaluation results
self.save_trained_model(trainer)
return trainer
def compute_metrics(self, eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return {
'accuracy': accuracy_score(labels, predictions),
'f1': f1_score(labels, predictions, average='weighted'),
'precision': precision_score(labels, predictions, average='weighted'),
'recall': recall_score(labels, predictions, average='weighted')
}
Model Specifications
1. Category Classification Model
Purpose: Route queries to specialized models based on academic/professional domains.
Dataset: MMLU-Pro Academic Domains
# Dataset composition
mmlu_categories = {
"mathematics": {
"samples": 1547,
"subcategories": ["algebra", "calculus", "geometry", "statistics"],
"example": "Solve the integral of x^2 from 0 to 1"
},
"physics": {
"samples": 1231,
"subcategories": ["mechanics", "thermodynamics", "electromagnetism"],
"example": "Calculate the force needed to accelerate a 10kg mass at 5m/s^2"
},
"computer_science": {
"samples": 1156,
"subcategories": ["algorithms", "data_structures", "programming"],
"example": "Implement a binary search algorithm in Python"
},
"biology": {
"samples": 1089,
"subcategories": ["genetics", "ecology", "anatomy"],
"example": "Explain the process of photosynthesis in plants"
},
"chemistry": {
"samples": 1034,
"subcategories": ["organic", "inorganic", "physical"],
"example": "Balance the chemical equation: H2 + O2 → H2O"
},
# ... additional categories
}
Training Configuration
model_config:
base_model: "modernbert-base"
task_type: "sequence_classification"
num_labels: 10
training_config:
epochs: 3
batch_size: 8
learning_rate: 2e-5
weight_decay: 0.1
evaluation_metrics:
- accuracy: 94.2%
- f1_weighted: 93.8%
- per_category_precision: ">90% for all categories"