"""BRAID-aware optimizer for DSPy."""
import re
from typing import Dict, List, Optional, Any, Callable
import dspy
from braid.parser import MermaidParser, GRDStructure
from braid.module import BraidReasoning, BraidResult
[docs]
class GRDMetrics:
"""
Metrics for evaluating GRD quality.
Includes both structural metrics and BRAID protocol compliance metrics:
- Structural validity
- Completeness
- Execution traceability
- Atomicity (token density)
- Masking compliance
- Procedural scaffolding
"""
[docs]
@staticmethod
def structural_validity(grd: str) -> float:
"""
Evaluate structural validity of a GRD.
Returns:
Score between 0.0 and 1.0 (1.0 = perfectly valid)
"""
parser = MermaidParser()
is_valid, _ = parser.validate(grd)
return 1.0 if is_valid else 0.0
[docs]
@staticmethod
def completeness(grd_structure: GRDStructure) -> float:
"""
Evaluate completeness of a GRD (has start, end, reasonable number of steps).
Returns:
Score between 0.0 and 1.0
"""
score = 0.0
# Has start nodes
if grd_structure.start_nodes:
score += 0.3
# Has end nodes
if grd_structure.end_nodes:
score += 0.3
# Has reasonable number of nodes (between 3 and 20)
node_count = len(grd_structure.nodes)
if 3 <= node_count <= 20:
score += 0.2
elif node_count > 20:
score += 0.1 # Too many nodes
# Has edges connecting nodes
if len(grd_structure.edges) > 0:
score += 0.2
return min(score, 1.0)
[docs]
@staticmethod
def execution_traceability(grd_structure: GRDStructure) -> float:
"""
Evaluate how traceable/executable the GRD is.
Returns:
Score between 0.0 and 1.0
"""
execution_order = grd_structure.get_execution_order()
if not execution_order:
return 0.0
# Check if all nodes are reachable
reachable_nodes = set(execution_order)
all_nodes = {node.id for node in grd_structure.nodes}
if not all_nodes:
return 0.0
reachability_score = len(reachable_nodes) / len(all_nodes)
# Check for cycles (should be minimal for good reasoning)
# Simple heuristic: if execution order length equals node count, likely no cycles
if len(execution_order) == len(grd_structure.nodes):
cycle_score = 1.0
else:
cycle_score = 0.5
return (reachability_score + cycle_score) / 2.0
[docs]
@staticmethod
def atomicity_score(grd_structure: GRDStructure, max_tokens: int = 15) -> float:
"""
Evaluate node atomicity (token density) compliance.
According to BRAID research, nano-scale models perform best when
node labels contain fewer than 15 tokens.
Args:
grd_structure: Parsed GRD structure
max_tokens: Maximum tokens allowed per node
Returns:
Score between 0.0 and 1.0 (1.0 = all nodes within limit)
"""
if not grd_structure.nodes:
return 1.0
violations = 0
for node in grd_structure.nodes:
# Simple tokenization
tokens = re.findall(r"\b\w+\b|[^\w\s]", node.label)
if len(tokens) > max_tokens:
violations += 1
violation_ratio = violations / len(grd_structure.nodes)
return max(0.0, 1.0 - violation_ratio)
[docs]
@staticmethod
def masking_compliance(grd: str) -> float:
"""
Evaluate compliance with numerical masking protocol.
Detects potential answer leakage where computed values appear in node labels.
Args:
grd: Mermaid GRD code
Returns:
Score between 0.0 and 1.0 (1.0 = no leakage detected)
"""
# Patterns indicating computed values (answer leakage)
leakage_patterns = [
r"=\s*\d+", # "= 60"
r"result\s*[:=]?\s*\d+", # "result: 42"
r"answer\s*[:=]?\s*\d+", # "answer = 100"
r"\d+\s*(?:km/h|mph|m/s)", # "60 km/h" (computed speed)
r"\d+\s*(?:dollars?|\$)", # "100 dollars"
]
leakage_count = 0
for pattern in leakage_patterns:
leakage_count += len(re.findall(pattern, grd, re.IGNORECASE))
# Penalize based on leakage count
if leakage_count == 0:
return 1.0
elif leakage_count <= 2:
return 0.7
elif leakage_count <= 5:
return 0.4
else:
return 0.1
[docs]
@staticmethod
def procedural_scaffolding_score(grd: str) -> float:
"""
Evaluate adherence to procedural scaffolding rules.
Good GRDs describe HOW to solve, not WHAT the answer is.
Args:
grd: Mermaid GRD code
Returns:
Score between 0.0 and 1.0
"""
# Action verbs that indicate good procedural scaffolding
action_patterns = [
r"\b(?:calculate|compute|find|determine)\b",
r"\b(?:identify|extract|locate|read)\b",
r"\b(?:compare|check|verify|validate)\b",
r"\b(?:apply|use|substitute|recall)\b",
r"\b(?:divide|multiply|add|subtract)\b",
r"\b(?:analyze|break|solve|derive)\b",
]
action_count = 0
for pattern in action_patterns:
action_count += len(re.findall(pattern, grd, re.IGNORECASE))
# Extract node count for normalization
node_pattern = r"\w+\s*\[[^\]]+\]"
nodes = re.findall(node_pattern, grd)
node_count = max(len(nodes), 1)
# Score based on action verb density
action_density = action_count / node_count
if action_density >= 0.8:
return 1.0
elif action_density >= 0.5:
return 0.8
elif action_density >= 0.3:
return 0.6
else:
return 0.4
[docs]
@staticmethod
def overall_quality(grd: str, grd_structure: Optional[GRDStructure] = None) -> float:
"""
Calculate overall GRD quality score.
Args:
grd: Mermaid code string
grd_structure: Optional pre-parsed structure
Returns:
Overall quality score between 0.0 and 1.0
"""
parser = MermaidParser()
# Structural validity
validity_score = GRDMetrics.structural_validity(grd)
if validity_score == 0.0:
return 0.0
# Parse if not provided
if grd_structure is None:
try:
grd_structure = parser.parse(grd)
except Exception:
return 0.0
# Completeness
completeness_score = GRDMetrics.completeness(grd_structure)
# Execution traceability
traceability_score = GRDMetrics.execution_traceability(grd_structure)
# BRAID Protocol metrics
atomicity = GRDMetrics.atomicity_score(grd_structure)
masking = GRDMetrics.masking_compliance(grd)
scaffolding = GRDMetrics.procedural_scaffolding_score(grd)
# Weighted average (BRAID metrics have higher weight)
overall = (
validity_score * 0.15
+ completeness_score * 0.10
+ traceability_score * 0.10
+ atomicity * 0.25
+ masking * 0.20
+ scaffolding * 0.20
)
return overall
[docs]
@staticmethod
def detailed_quality_report(
grd: str, grd_structure: Optional[GRDStructure] = None
) -> Dict[str, float]:
"""
Get a detailed breakdown of all quality metrics.
Args:
grd: Mermaid code string
grd_structure: Optional pre-parsed structure
Returns:
Dictionary with all individual metric scores
"""
parser = MermaidParser()
validity_score = GRDMetrics.structural_validity(grd)
if validity_score == 0.0:
return {
"structural_validity": 0.0,
"completeness": 0.0,
"execution_traceability": 0.0,
"atomicity": 0.0,
"masking_compliance": 0.0,
"procedural_scaffolding": 0.0,
"overall": 0.0,
}
if grd_structure is None:
try:
grd_structure = parser.parse(grd)
except Exception:
return {
"structural_validity": validity_score,
"completeness": 0.0,
"execution_traceability": 0.0,
"atomicity": 0.0,
"masking_compliance": 0.0,
"procedural_scaffolding": 0.0,
"overall": 0.0,
}
completeness = GRDMetrics.completeness(grd_structure)
traceability = GRDMetrics.execution_traceability(grd_structure)
atomicity = GRDMetrics.atomicity_score(grd_structure)
masking = GRDMetrics.masking_compliance(grd)
scaffolding = GRDMetrics.procedural_scaffolding_score(grd)
overall = GRDMetrics.overall_quality(grd, grd_structure)
return {
"structural_validity": validity_score,
"completeness": completeness,
"execution_traceability": traceability,
"atomicity": atomicity,
"masking_compliance": masking,
"procedural_scaffolding": scaffolding,
"overall": overall,
}
[docs]
class BraidOptimizer(dspy.Module):
"""
BRAID-aware optimizer for DSPy.
This optimizer extends DSPy's optimization capabilities by:
1. Optimizing GRD generation quality
2. Optimizing step-by-step execution
3. Providing GRD-specific metrics
"""
[docs]
def __init__(
self,
base_optimizer: Optional[dspy.Module] = None,
grd_quality_weight: float = 0.5,
execution_quality_weight: float = 0.5,
):
"""
Initialize the BRAID optimizer.
Args:
base_optimizer: Base DSPy optimizer to use (e.g., MIPROv2)
grd_quality_weight: Weight for GRD quality in optimization
execution_quality_weight: Weight for execution quality in optimization
"""
super().__init__()
self.base_optimizer = base_optimizer
self.grd_quality_weight = grd_quality_weight
self.execution_quality_weight = execution_quality_weight
self.metrics = GRDMetrics()
[docs]
def optimize(
self,
module: BraidReasoning,
trainset: List[Dict[str, Any]],
metric: Optional[Callable] = None,
num_threads: int = 1,
) -> BraidReasoning:
"""
Optimize a BraidReasoning module.
Args:
module: The BraidReasoning module to optimize
trainset: Training examples with 'problem' and optionally 'answer' keys
metric: Optional custom metric function
num_threads: Number of threads for parallel optimization
Returns:
Optimized BraidReasoning module
"""
if metric is None:
metric = self._default_metric
# If base optimizer is provided, use it
if self.base_optimizer:
# Optimize the planning phase (GRD generation)
optimized_module = self._optimize_planning(module, trainset, metric)
# Optimize the execution phase
optimized_module = self._optimize_execution(optimized_module, trainset, metric)
return optimized_module
else:
# Simple optimization: improve prompts based on metrics
return self._simple_optimize(module, trainset, metric)
def _optimize_planning(
self, module: BraidReasoning, trainset: List[Dict[str, Any]], metric: Callable
) -> BraidReasoning:
"""Optimize the planning (GRD generation) phase."""
# Collect GRD quality metrics
grd_scores = []
for example in trainset:
problem = example.get("problem", "")
if not problem:
continue
# Generate GRD
if module.use_generator and module.generator:
gen_result = module.generator.generate(problem=problem)
grd = gen_result.get("grd", "")
grd_structure = gen_result.get("parsed_structure")
else:
plan_result = module.plan(problem=problem)
grd = plan_result.grd
try:
grd_structure = module.parser.parse(grd)
except Exception:
grd_structure = None
# Evaluate quality
quality = self.metrics.overall_quality(grd, grd_structure)
grd_scores.append({"problem": problem, "grd": grd, "quality": quality})
# Use base optimizer if available
if self.base_optimizer:
# Optimize the plan signature
plan_trainset = [
dspy.Example(problem=ex["problem"], grd=ex["grd"]) for ex in grd_scores
]
module.plan = self.base_optimizer.compile(student=module.plan, trainset=plan_trainset)
return module
def _optimize_execution(
self, module: BraidReasoning, trainset: List[Dict[str, Any]], metric: Callable
) -> BraidReasoning:
"""Optimize the execution phase."""
# Collect execution results
execution_results = []
for example in trainset:
problem = example.get("problem", "")
expected_answer = example.get("answer", "")
if not problem:
continue
# Run BRAID reasoning
result = module(problem=problem)
# Evaluate execution quality
execution_score = metric(result, expected_answer)
execution_results.append(
{"problem": problem, "result": result, "score": execution_score}
)
# Use base optimizer if available
if self.base_optimizer:
# Optimize the execute step signature
step_trainset = []
for ex in execution_results:
result = ex["result"]
for step in result.reasoning_steps:
step_trainset.append(
dspy.Example(
step_description=step["label"],
context=f"Problem: {ex['problem']}",
step_output=step["result"],
)
)
if step_trainset:
module.execute_step = self.base_optimizer.compile(
student=module.execute_step, trainset=step_trainset
)
return module
def _simple_optimize(
self, module: BraidReasoning, trainset: List[Dict[str, Any]], metric: Callable
) -> BraidReasoning:
"""Simple optimization without base optimizer."""
# Just collect metrics and return module
# In a real implementation, this could use few-shot learning
# or prompt engineering improvements
return module
def _default_metric(self, result: BraidResult, expected_answer: Optional[str] = None) -> float:
"""
Default metric for evaluating BRAID results.
Args:
result: BraidResult object
expected_answer: Optional expected answer for comparison
Returns:
Score between 0.0 and 1.0
"""
score = 0.0
# GRD quality
if result.parsed_grd:
grd_quality = self.metrics.overall_quality(result.grd, result.parsed_grd)
score += grd_quality * self.grd_quality_weight
# Execution quality
if result.reasoning_steps:
# Check if we have reasonable number of steps
step_count = len(result.reasoning_steps)
if 2 <= step_count <= 15:
step_score = 1.0
elif step_count > 15:
step_score = 0.5
else:
step_score = 0.0
# Check if answer is present
answer_score = 1.0 if result.answer else 0.0
execution_score = (step_score + answer_score) / 2.0
score += execution_score * self.execution_quality_weight
# Answer correctness (if expected answer provided)
if expected_answer and result.answer:
# Simple string similarity (could be improved with semantic similarity)
expected_lower = expected_answer.lower().strip()
answer_lower = result.answer.lower().strip()
if expected_lower in answer_lower or answer_lower in expected_lower:
score += 0.2
return min(score, 1.0)
[docs]
def evaluate(
self,
module: BraidReasoning,
testset: List[Dict[str, Any]],
metric: Optional[Callable] = None,
) -> Dict[str, float]:
"""
Evaluate a BraidReasoning module on a test set.
Args:
module: The BraidReasoning module to evaluate
testset: Test examples with 'problem' and optionally 'answer' keys
metric: Optional custom metric function
Returns:
Dictionary of evaluation metrics
"""
if metric is None:
metric = self._default_metric
scores = []
grd_qualities = []
execution_scores = []
for example in testset:
problem = example.get("problem", "")
expected_answer = example.get("answer")
if not problem:
continue
result = module(problem=problem)
# Overall score
score = metric(result, expected_answer)
scores.append(score)
# GRD quality
if result.parsed_grd:
grd_quality = self.metrics.overall_quality(result.grd, result.parsed_grd)
grd_qualities.append(grd_quality)
# Execution score
if result.reasoning_steps:
exec_score = len(result.reasoning_steps) / 10.0 # Normalize
execution_scores.append(min(exec_score, 1.0))
return {
"average_score": sum(scores) / len(scores) if scores else 0.0,
"average_grd_quality": (
sum(grd_qualities) / len(grd_qualities) if grd_qualities else 0.0
),
"average_execution_score": (
sum(execution_scores) / len(execution_scores) if execution_scores else 0.0
),
"total_examples": len(testset),
"valid_results": len(scores),
}