"""Validators for ensuring GRD quality and BRAID protocol compliance.
This module implements validation rules from the BRAID paper, including:
- Node atomicity (≤15 tokens per node)
- Structural validity
- Procedural scaffolding compliance
"""
import re
from typing import List, Optional, Set
from dataclasses import dataclass, field
from enum import Enum
from braid.parser import GRDNode, GRDStructure
class ValidationSeverity(Enum):
"""Severity levels for validation issues."""
ERROR = "error" # Must be fixed
WARNING = "warning" # Should be fixed
INFO = "info" # Suggestion
@dataclass
class ValidationIssue:
"""A single validation issue."""
severity: ValidationSeverity
code: str
message: str
node_id: Optional[str] = None
suggestion: Optional[str] = None
@dataclass
class ValidationResult:
"""Result of validation operation."""
valid: bool
issues: List[ValidationIssue] = field(default_factory=list)
score: float = 1.0 # 0.0 to 1.0
def has_errors(self) -> bool:
"""Check if there are any error-level issues."""
return any(issue.severity == ValidationSeverity.ERROR for issue in self.issues)
def has_warnings(self) -> bool:
"""Check if there are any warning-level issues."""
return any(issue.severity == ValidationSeverity.WARNING for issue in self.issues)
def get_errors(self) -> List[ValidationIssue]:
"""Get all error-level issues."""
return [i for i in self.issues if i.severity == ValidationSeverity.ERROR]
def get_warnings(self) -> List[ValidationIssue]:
"""Get all warning-level issues."""
return [i for i in self.issues if i.severity == ValidationSeverity.WARNING]
def summary(self) -> str:
"""Get a summary of the validation result."""
errors = len(self.get_errors())
warnings = len(self.get_warnings())
return (
f"Valid: {self.valid}, Errors: {errors}, Warnings: {warnings}, Score: {self.score:.2f}"
)
[docs]
class AtomicityValidator:
"""
Validates node atomicity in GRDs.
According to BRAID research, nano-scale models achieve highest accuracy
when node labels contain fewer than 15 tokens. This validator checks
and enforces this constraint.
Example:
>>> validator = AtomicityValidator()
>>> result = validator.validate_node(node)
>>> if not result.valid:
... print(result.issues[0].suggestion)
"""
DEFAULT_MAX_TOKENS = 15
[docs]
def __init__(
self,
max_tokens_per_node: int = DEFAULT_MAX_TOKENS,
strict_mode: bool = False,
):
"""
Initialize the AtomicityValidator.
Args:
max_tokens_per_node: Maximum allowed tokens per node label
strict_mode: If True, treat violations as errors; otherwise warnings
"""
self.max_tokens_per_node = max_tokens_per_node
self.strict_mode = strict_mode
[docs]
def count_tokens(self, text: str) -> int:
"""
Count tokens in a text string.
Uses a simple whitespace + punctuation tokenization that approximates
what most LLMs would produce. For more accurate counting, consider
using the tiktoken library with a specific model's tokenizer.
Args:
text: Text to tokenize
Returns:
Number of tokens
"""
if not text:
return 0
# Simple tokenization: split on whitespace and punctuation
# This approximates GPT-style tokenization for English text
tokens = re.findall(r"\b\w+\b|[^\w\s]", text)
return len(tokens)
[docs]
def validate_node(self, node: GRDNode) -> ValidationResult:
"""
Validate a single node's atomicity.
Args:
node: GRDNode to validate
Returns:
ValidationResult with any issues found
"""
issues: List[ValidationIssue] = []
token_count = self.count_tokens(node.label)
if token_count > self.max_tokens_per_node:
severity = ValidationSeverity.ERROR if self.strict_mode else ValidationSeverity.WARNING
# Generate suggestion for fixing
suggestion = self._generate_split_suggestion(node.label, token_count)
issues.append(
ValidationIssue(
severity=severity,
code="ATOMICITY_VIOLATION",
message=f"Node '{node.id}' has {token_count} tokens (max: {self.max_tokens_per_node})",
node_id=node.id,
suggestion=suggestion,
)
)
# Calculate score (1.0 if within limit, decreasing as it exceeds)
if token_count <= self.max_tokens_per_node:
score = 1.0
else:
# Penalty increases with excess tokens
excess_ratio = token_count / self.max_tokens_per_node
score = max(0.0, 1.0 - (excess_ratio - 1.0) * 0.5)
return ValidationResult(
valid=len(issues) == 0 or not self.strict_mode,
issues=issues,
score=score,
)
[docs]
def validate_grd(self, grd: GRDStructure) -> ValidationResult:
"""
Validate all nodes in a GRD for atomicity.
Args:
grd: GRDStructure to validate
Returns:
ValidationResult with all issues found
"""
all_issues: List[ValidationIssue] = []
total_score = 0.0
for node in grd.nodes:
result = self.validate_node(node)
all_issues.extend(result.issues)
total_score += result.score
# Average score across all nodes
avg_score = total_score / len(grd.nodes) if grd.nodes else 1.0
# Valid if no errors (in strict mode) or no issues at all
has_errors = any(issue.severity == ValidationSeverity.ERROR for issue in all_issues)
return ValidationResult(
valid=not has_errors,
issues=all_issues,
score=avg_score,
)
def _generate_split_suggestion(self, label: str, token_count: int) -> str:
"""Generate a suggestion for splitting a long node."""
if token_count <= self.max_tokens_per_node * 2:
return f"Consider splitting into 2 nodes with ~{token_count // 2} tokens each"
else:
num_splits = (token_count // self.max_tokens_per_node) + 1
return f"Consider splitting into {num_splits} sequential nodes"
[docs]
class ProceduralScaffoldingValidator:
"""
Validates that GRDs follow procedural scaffolding rules.
The BRAID protocol requires that GRDs describe HOW to solve a problem,
not WHAT the answer is. This validator detects answer leakage and
ensures nodes describe actions rather than computed values.
"""
# Patterns that indicate answer leakage
LEAKAGE_PATTERNS = [
(r"=\s*\d+", "EQUALS_VALUE", "Avoid computed values in node labels"),
(
r"(?:answer|result|solution)\s*[:=]?\s*\d+",
"LABELED_ANSWER",
"Don't include answers in scaffolding",
),
(
r"\d+\s*(?:km/h|mph|m/s|kg|lb)",
"UNIT_VALUE",
"Use placeholders instead of computed values with units",
),
(
r"(?:total|sum|difference|product)\s*[:=]?\s*\d+",
"COMPUTED_AGGREGATE",
"Describe the computation, not the result",
),
]
# Patterns that indicate good procedural scaffolding
SCAFFOLDING_PATTERNS = [
r"(?:calculate|compute|find|determine|solve)\s+",
r"(?:divide|multiply|add|subtract)\s+",
r"(?:compare|check|verify|validate)\s+",
r"(?:extract|identify|locate)\s+",
r"(?:apply|use|utilize)\s+",
]
[docs]
def __init__(self, strict_mode: bool = False):
"""
Initialize the ProceduralScaffoldingValidator.
Args:
strict_mode: If True, treat leakage as errors
"""
self.strict_mode = strict_mode
self._leakage_compiled = [
(re.compile(pattern, re.IGNORECASE), code, msg)
for pattern, code, msg in self.LEAKAGE_PATTERNS
]
self._scaffolding_compiled = [
re.compile(pattern, re.IGNORECASE) for pattern in self.SCAFFOLDING_PATTERNS
]
[docs]
def validate_node(self, node: GRDNode) -> ValidationResult:
"""
Validate a single node for procedural scaffolding compliance.
Args:
node: GRDNode to validate
Returns:
ValidationResult with any issues found
"""
issues: List[ValidationIssue] = []
label = node.label
# Check for leakage patterns
for pattern, code, message in self._leakage_compiled:
if pattern.search(label):
severity = (
ValidationSeverity.ERROR if self.strict_mode else ValidationSeverity.WARNING
)
issues.append(
ValidationIssue(
severity=severity,
code=code,
message=message,
node_id=node.id,
suggestion="Describe the action to take, not the computed value",
)
)
# Check for good scaffolding patterns (informational)
has_scaffolding = any(pattern.search(label) for pattern in self._scaffolding_compiled)
if not has_scaffolding and not issues:
issues.append(
ValidationIssue(
severity=ValidationSeverity.INFO,
code="WEAK_SCAFFOLDING",
message="Node could be more action-oriented",
node_id=node.id,
suggestion="Start with action verbs like 'Calculate', 'Find', 'Determine'",
)
)
# Calculate score
if issues:
error_count = sum(1 for i in issues if i.severity == ValidationSeverity.ERROR)
warning_count = sum(1 for i in issues if i.severity == ValidationSeverity.WARNING)
score = max(0.0, 1.0 - (error_count * 0.3) - (warning_count * 0.1))
else:
score = 1.0
return ValidationResult(
valid=not any(i.severity == ValidationSeverity.ERROR for i in issues),
issues=issues,
score=score,
)
[docs]
def validate_grd(self, grd: GRDStructure) -> ValidationResult:
"""
Validate all nodes in a GRD for procedural scaffolding.
Args:
grd: GRDStructure to validate
Returns:
ValidationResult with all issues found
"""
all_issues: List[ValidationIssue] = []
total_score = 0.0
for node in grd.nodes:
result = self.validate_node(node)
# Only include errors and warnings, not info
all_issues.extend(i for i in result.issues if i.severity != ValidationSeverity.INFO)
total_score += result.score
avg_score = total_score / len(grd.nodes) if grd.nodes else 1.0
return ValidationResult(
valid=not any(i.severity == ValidationSeverity.ERROR for i in all_issues),
issues=all_issues,
score=avg_score,
)
class StructuralValidator:
"""
Validates the structural integrity of GRDs.
Checks for:
- Connectivity (no orphan nodes)
- Proper start/end nodes
- Reasonable topology
"""
def __init__(
self,
min_nodes: int = 2,
max_nodes: int = 20,
require_single_start: bool = True,
require_single_end: bool = False,
):
"""
Initialize the StructuralValidator.
Args:
min_nodes: Minimum required nodes
max_nodes: Maximum allowed nodes
require_single_start: Require exactly one start node
require_single_end: Require exactly one end node
"""
self.min_nodes = min_nodes
self.max_nodes = max_nodes
self.require_single_start = require_single_start
self.require_single_end = require_single_end
def validate(self, grd: GRDStructure) -> ValidationResult:
"""
Validate the structural integrity of a GRD.
Args:
grd: GRDStructure to validate
Returns:
ValidationResult with any issues found
"""
issues: List[ValidationIssue] = []
# Check node count
node_count = len(grd.nodes)
if node_count < self.min_nodes:
issues.append(
ValidationIssue(
severity=ValidationSeverity.ERROR,
code="TOO_FEW_NODES",
message=f"GRD has {node_count} nodes (min: {self.min_nodes})",
suggestion="Add more reasoning steps",
)
)
elif node_count > self.max_nodes:
issues.append(
ValidationIssue(
severity=ValidationSeverity.WARNING,
code="TOO_MANY_NODES",
message=f"GRD has {node_count} nodes (max recommended: {self.max_nodes})",
suggestion="Consider simplifying or combining steps",
)
)
# Check start nodes
if self.require_single_start and len(grd.start_nodes) != 1:
issues.append(
ValidationIssue(
severity=ValidationSeverity.WARNING,
code="INVALID_START_NODES",
message=f"GRD has {len(grd.start_nodes)} start nodes (expected: 1)",
suggestion="Ensure there is a single entry point",
)
)
# Check end nodes
if self.require_single_end and len(grd.end_nodes) != 1:
issues.append(
ValidationIssue(
severity=ValidationSeverity.WARNING,
code="INVALID_END_NODES",
message=f"GRD has {len(grd.end_nodes)} end nodes (expected: 1)",
suggestion="Ensure there is a single conclusion node",
)
)
# Check for orphan nodes (nodes with no connections)
connected_nodes: Set[str] = set()
for edge in grd.edges:
connected_nodes.add(edge.from_node)
connected_nodes.add(edge.to_node)
orphans = [n.id for n in grd.nodes if n.id not in connected_nodes]
if orphans and len(grd.nodes) > 1:
issues.append(
ValidationIssue(
severity=ValidationSeverity.ERROR,
code="ORPHAN_NODES",
message=f"Disconnected nodes found: {', '.join(orphans)}",
suggestion="Connect all nodes in the reasoning flow",
)
)
# Check for edges referencing non-existent nodes
node_ids = {n.id for n in grd.nodes}
invalid_edges = [
(e.from_node, e.to_node)
for e in grd.edges
if e.from_node not in node_ids or e.to_node not in node_ids
]
if invalid_edges:
issues.append(
ValidationIssue(
severity=ValidationSeverity.ERROR,
code="INVALID_EDGES",
message=f"Edges reference non-existent nodes: {invalid_edges}",
suggestion="Ensure all edge endpoints are valid nodes",
)
)
# Calculate score
error_count = sum(1 for i in issues if i.severity == ValidationSeverity.ERROR)
warning_count = sum(1 for i in issues if i.severity == ValidationSeverity.WARNING)
score = max(0.0, 1.0 - (error_count * 0.25) - (warning_count * 0.1))
return ValidationResult(
valid=error_count == 0,
issues=issues,
score=score,
)
[docs]
class GRDValidator:
"""
Comprehensive GRD validator combining all validation rules.
This is the main entry point for validating GRDs according to
BRAID protocol requirements.
"""
[docs]
def __init__(
self,
max_tokens_per_node: int = 15,
strict_atomicity: bool = False,
strict_scaffolding: bool = False,
):
"""
Initialize the GRDValidator.
Args:
max_tokens_per_node: Maximum tokens allowed per node
strict_atomicity: Treat atomicity violations as errors
strict_scaffolding: Treat scaffolding violations as errors
"""
self.atomicity_validator = AtomicityValidator(
max_tokens_per_node=max_tokens_per_node,
strict_mode=strict_atomicity,
)
self.scaffolding_validator = ProceduralScaffoldingValidator(
strict_mode=strict_scaffolding,
)
self.structural_validator = StructuralValidator()
[docs]
def validate(self, grd: GRDStructure) -> ValidationResult:
"""
Perform comprehensive validation on a GRD.
Args:
grd: GRDStructure to validate
Returns:
Combined ValidationResult from all validators
"""
results = [
self.atomicity_validator.validate_grd(grd),
self.scaffolding_validator.validate_grd(grd),
self.structural_validator.validate(grd),
]
# Combine all issues
all_issues: List[ValidationIssue] = []
for result in results:
all_issues.extend(result.issues)
# Calculate combined score (weighted average)
weights = [0.4, 0.4, 0.2] # atomicity, scaffolding, structural
combined_score = sum(r.score * w for r, w in zip(results, weights))
# Valid only if all validations pass
valid = all(r.valid for r in results)
return ValidationResult(
valid=valid,
issues=all_issues,
score=combined_score,
)
[docs]
def validate_and_report(self, grd: GRDStructure) -> str:
"""
Validate a GRD and return a formatted report.
Args:
grd: GRDStructure to validate
Returns:
Markdown-formatted validation report
"""
result = self.validate(grd)
lines = ["# GRD Validation Report\n"]
lines.append(f"**Status:** {'✅ Valid' if result.valid else '❌ Invalid'}")
lines.append(f"**Score:** {result.score:.2f}/1.00\n")
if result.issues:
lines.append("## Issues\n")
errors = result.get_errors()
if errors:
lines.append("### Errors\n")
for issue in errors:
lines.append(f"- **{issue.code}**: {issue.message}")
if issue.node_id:
lines.append(f" - Node: `{issue.node_id}`")
if issue.suggestion:
lines.append(f" - Suggestion: {issue.suggestion}")
lines.append("")
warnings = result.get_warnings()
if warnings:
lines.append("### Warnings\n")
for issue in warnings:
lines.append(f"- **{issue.code}**: {issue.message}")
if issue.node_id:
lines.append(f" - Node: `{issue.node_id}`")
if issue.suggestion:
lines.append(f" - Suggestion: {issue.suggestion}")
else:
lines.append("No issues found! ✨")
return "\n".join(lines)