Source code for braid.critic

"""Critic module for self-verification and feedback loops in BRAID.

This module implements the Terminal Verification Loops pattern from the
BRAID paper, allowing the model to verify its own answers and retry
if verification fails.
"""

import re
from typing import Dict, List, Optional, Any, Callable, Tuple
from dataclasses import dataclass, field
from enum import Enum

from braid.parser import GRDStructure, GRDNode, GRDEdge


class CriticType(Enum):
    """Types of critic nodes."""

    VERIFICATION = "verification"  # Check/Verify nodes
    VALIDATION = "validation"  # Validate nodes
    REVIEW = "review"  # Review nodes
    CONFIRMATION = "confirmation"  # Confirm nodes


@dataclass
class CriticNode:
    """Represents a critic node in the GRD."""

    node_id: str
    critic_type: CriticType
    target_nodes: List[str]  # Nodes this critic verifies
    fallback_node: Optional[str] = None  # Node to return to on failure


@dataclass
class CriticResult:
    """Result of critic evaluation."""

    passed: bool
    feedback: str
    confidence: float = 1.0
    suggested_action: Optional[str] = None
    retry_node: Optional[str] = None


@dataclass
class FeedbackLoopResult:
    """Result of a complete feedback loop."""

    final_passed: bool
    attempts: int
    critic_results: List[CriticResult]
    final_output: str


[docs] class CriticDetector: """ Detects and classifies critic nodes in GRDs. Critic nodes are special nodes that verify previous computations and can trigger retries if verification fails. """ # Patterns for detecting critic node types CRITIC_PATTERNS = { CriticType.VERIFICATION: [ re.compile(r"^Check[:\s]", re.IGNORECASE), re.compile(r"^Verify[:\s]", re.IGNORECASE), re.compile(r"^Double[- ]?check", re.IGNORECASE), ], CriticType.VALIDATION: [ re.compile(r"^Validate[:\s]", re.IGNORECASE), re.compile(r"^Ensure[:\s]", re.IGNORECASE), re.compile(r"^Assert[:\s]", re.IGNORECASE), ], CriticType.REVIEW: [ re.compile(r"^Review[:\s]", re.IGNORECASE), re.compile(r"^Examine[:\s]", re.IGNORECASE), re.compile(r"^Inspect[:\s]", re.IGNORECASE), ], CriticType.CONFIRMATION: [ re.compile(r"^Confirm[:\s]", re.IGNORECASE), re.compile(r"^Make sure[:\s]", re.IGNORECASE), re.compile(r"^Is this correct", re.IGNORECASE), ], } # Patterns indicating failure/retry edges FAILURE_EDGE_PATTERNS = [ re.compile(r"fail", re.IGNORECASE), re.compile(r"error", re.IGNORECASE), re.compile(r"retry", re.IGNORECASE), re.compile(r"incorrect", re.IGNORECASE), re.compile(r"wrong", re.IGNORECASE), re.compile(r"no\s*$", re.IGNORECASE), ]
[docs] def is_critic_node(self, node: GRDNode) -> bool: """Check if a node is a critic node.""" for patterns in self.CRITIC_PATTERNS.values(): for pattern in patterns: if pattern.search(node.label): return True return False
[docs] def get_critic_type(self, node: GRDNode) -> Optional[CriticType]: """Get the type of critic node.""" for critic_type, patterns in self.CRITIC_PATTERNS.items(): for pattern in patterns: if pattern.search(node.label): return critic_type return None
[docs] def detect_critics(self, grd: GRDStructure) -> List[CriticNode]: """ Detect all critic nodes in a GRD. Args: grd: GRDStructure to analyze Returns: List of detected CriticNodes with their metadata """ critics: List[CriticNode] = [] for node in grd.nodes: critic_type = self.get_critic_type(node) if critic_type: # Find target nodes (nodes that lead to this critic) target_nodes = [edge.from_node for edge in grd.get_incoming_edges(node.id)] # Find fallback node (for retry on failure) fallback_node = self._find_fallback_node(grd, node.id) critics.append( CriticNode( node_id=node.id, critic_type=critic_type, target_nodes=target_nodes, fallback_node=fallback_node, ) ) return critics
def _find_fallback_node(self, grd: GRDStructure, critic_node_id: str) -> Optional[str]: """Find the node to return to on critic failure.""" outgoing_edges = grd.get_outgoing_edges(critic_node_id) for edge in outgoing_edges: # Check for failure/retry edge if edge.label: for pattern in self.FAILURE_EDGE_PATTERNS: if pattern.search(edge.label): return edge.to_node # If no explicit failure edge, check incoming nodes incoming_edges = grd.get_incoming_edges(critic_node_id) if incoming_edges: # Return to the first incoming node (most recent step) return incoming_edges[0].from_node return None
[docs] def get_feedback_loops(self, grd: GRDStructure) -> List[Tuple[CriticNode, List[str]]]: """ Identify feedback loops in the GRD. A feedback loop is a path from a critic node back to a previous node. Returns: List of (critic_node, loop_path) tuples """ critics = self.detect_critics(grd) loops: List[Tuple[CriticNode, List[str]]] = [] for critic in critics: if critic.fallback_node: # Find path from critic to fallback loop_path = self._find_path(grd, critic.node_id, critic.fallback_node) if loop_path: loops.append((critic, loop_path)) return loops
def _find_path(self, grd: GRDStructure, from_node: str, to_node: str) -> Optional[List[str]]: """Find a path between two nodes (BFS).""" from collections import deque queue = deque([(from_node, [from_node])]) visited = {from_node} while queue: current, path = queue.popleft() for edge in grd.get_outgoing_edges(current): if edge.to_node == to_node: return path + [to_node] if edge.to_node not in visited: visited.add(edge.to_node) queue.append((edge.to_node, path + [edge.to_node])) return None
class CriticEvaluator: """ Evaluates critic node outputs to determine pass/fail. Interprets the output of critic nodes and decides whether verification passed and what action to take. """ # Patterns indicating success SUCCESS_PATTERNS = [ re.compile( r"\b(?:correct|right|valid|verified|confirmed|passed|yes|true)\b", re.IGNORECASE ), re.compile(r"\b(?:looks good|seems correct|is accurate)\b", re.IGNORECASE), re.compile(r"✓|✅|👍", re.IGNORECASE), ] # Patterns indicating failure FAILURE_PATTERNS = [ re.compile(r"\b(?:incorrect|wrong|invalid|failed|no|false|error)\b", re.IGNORECASE), re.compile(r"\b(?:doesn't match|doesn't look right|is not accurate)\b", re.IGNORECASE), re.compile(r"\b(?:try again|redo|recalculate)\b", re.IGNORECASE), re.compile(r"✗|❌|👎", re.IGNORECASE), ] def evaluate(self, critic_output: str, context: Dict[str, Any]) -> CriticResult: """ Evaluate the output from a critic node. Args: critic_output: The text output from executing the critic node context: Current execution context Returns: CriticResult indicating pass/fail and next action """ output_lower = critic_output.lower() # Count success and failure indicators success_count = sum( len(pattern.findall(critic_output)) for pattern in self.SUCCESS_PATTERNS ) failure_count = sum( len(pattern.findall(critic_output)) for pattern in self.FAILURE_PATTERNS ) # Determine if passed if failure_count > 0 and failure_count >= success_count: passed = False # Try to extract suggested action action = self._extract_action(critic_output) else: passed = True action = None # Calculate confidence based on clarity of indicators total_indicators = success_count + failure_count if total_indicators > 0: confidence = max(success_count, failure_count) / total_indicators else: # No clear indicators - moderate confidence confidence = 0.5 return CriticResult( passed=passed, feedback=critic_output, confidence=confidence, suggested_action=action, ) def _extract_action(self, output: str) -> Optional[str]: """Extract suggested corrective action from critic output.""" # Look for action patterns action_patterns = [ re.compile(r"(?:should|need to|must)\s+(.+?)(?:\.|$)", re.IGNORECASE), re.compile(r"(?:try|redo|recalculate)\s+(.+?)(?:\.|$)", re.IGNORECASE), re.compile(r"(?:fix|correct)\s+(.+?)(?:\.|$)", re.IGNORECASE), ] for pattern in action_patterns: match = pattern.search(output) if match: return match.group(1).strip() return None
[docs] class CriticExecutor: """ Executes GRDs with critic feedback loops. This executor handles the complete cycle of: 1. Executing normal nodes 2. Executing critic nodes 3. Processing critic feedback 4. Retrying on failure (up to max retries) """ DEFAULT_MAX_RETRIES = 2
[docs] def __init__( self, grd: GRDStructure, max_retries: int = DEFAULT_MAX_RETRIES, ): """ Initialize the CriticExecutor. Args: grd: The GRD structure to execute max_retries: Maximum number of retries per critic failure """ self.grd = grd self.max_retries = max_retries self.detector = CriticDetector() self.evaluator = CriticEvaluator() self.critics = self.detector.detect_critics(grd)
[docs] def is_critic_node(self, node_id: str) -> bool: """Check if a node ID is a critic node.""" return any(c.node_id == node_id for c in self.critics)
[docs] def get_critic(self, node_id: str) -> Optional[CriticNode]: """Get critic node by ID.""" for critic in self.critics: if critic.node_id == node_id: return critic return None
[docs] def process_critic_output( self, critic: CriticNode, output: str, context: Dict[str, Any], retry_count: int, ) -> Tuple[bool, Optional[str], Dict[str, Any]]: """ Process the output from a critic node. Args: critic: The critic node that was executed output: Output from critic execution context: Current execution context retry_count: Number of retries already attempted Returns: Tuple of (should_continue, next_node_id, updated_context) """ result = self.evaluator.evaluate(output, context) if result.passed: # Critic passed - continue normally outgoing = self.grd.get_outgoing_edges(critic.node_id) # Find success edge for edge in outgoing: if not edge.label or edge.label.lower() in ("yes", "success", "pass", "continue"): return (True, edge.to_node, context) # Default to first edge if no labeled edge if outgoing: return (True, outgoing[0].to_node, context) return (True, None, context) else: # Critic failed if retry_count >= self.max_retries: # Max retries reached - continue anyway or fail context["critic_exceeded_retries"] = True return (True, None, context) if critic.fallback_node: # Update context with feedback context["critic_feedback"] = result.feedback if result.suggested_action: context["suggested_correction"] = result.suggested_action context["retry_count"] = retry_count + 1 return (True, critic.fallback_node, context) # No fallback - continue anyway return (True, None, context)
[docs] def execute_with_feedback( self, problem: str, executor: Callable[[GRDNode, Dict[str, Any]], str], initial_context: Optional[Dict[str, Any]] = None, ) -> FeedbackLoopResult: """ Execute the GRD with critic feedback loops. Args: problem: The problem being solved executor: Function to execute a single node initial_context: Optional initial context Returns: FeedbackLoopResult with complete execution details """ context = initial_context or {} context["problem"] = problem critic_results: List[CriticResult] = [] attempts = 0 max_total_steps = 100 # Safety limit # Simple execution tracking completed: set = set() results: Dict[str, str] = {} current_node = self.grd.start_nodes[0] if self.grd.start_nodes else None while current_node and attempts < max_total_steps: attempts += 1 node = self.grd.get_node_by_id(current_node) if not node: break # Execute the node exec_context = dict(context) exec_context["previous_results"] = results exec_context["completed_nodes"] = list(completed) try: output = executor(node, exec_context) results[current_node] = output completed.add(current_node) # Check if this is a critic node if self.is_critic_node(current_node): critic = self.get_critic(current_node) crit_result = self.evaluator.evaluate(output, context) critic_results.append(crit_result) if critic: retry_count = context.get("retry_count", 0) should_continue, next_node, context = self.process_critic_output( critic, output, context, retry_count ) if should_continue: current_node = next_node else: break else: current_node = self._get_next_node(current_node) else: # Normal node - get next current_node = self._get_next_node(current_node) except Exception as e: results[current_node] = f"Error: {str(e)}" current_node = self._get_next_node(current_node) # Determine final output final_output = "" for end_node in self.grd.end_nodes: if end_node in results: final_output = results[end_node] break if not final_output and results: # Use last result final_output = list(results.values())[-1] # Determine if all critics passed final_passed = all(r.passed for r in critic_results) if critic_results else True return FeedbackLoopResult( final_passed=final_passed, attempts=attempts, critic_results=critic_results, final_output=final_output, )
def _get_next_node(self, current: str) -> Optional[str]: """Get the next node in sequence.""" outgoing = self.grd.get_outgoing_edges(current) return outgoing[0].to_node if outgoing else None