Source code for braid.engine

"""Stateful execution engine for dynamic GRD traversal.

This module implements the execution logic for GRDs, supporting:
- Dynamic state management
- Conditional branching based on edge labels
- Cycle support for critic feedback loops
- Runtime condition evaluation
"""

import re
from typing import Dict, List, Optional, Any, Callable, Set, Tuple
from dataclasses import dataclass, field
from enum import Enum

from braid.parser import GRDStructure, GRDNode, GRDEdge


class ExecutionStatus(Enum):
    """Status of execution."""

    PENDING = "pending"
    RUNNING = "running"
    COMPLETED = "completed"
    FAILED = "failed"
    SKIPPED = "skipped"


@dataclass
class NodeExecutionResult:
    """Result of executing a single node."""

    node_id: str
    status: ExecutionStatus
    output: str
    execution_time_ms: float = 0.0
    error: Optional[str] = None
    metadata: Dict[str, Any] = field(default_factory=dict)


@dataclass
class ExecutionState:
    """Maintains the state of GRD execution."""

    current_node: Optional[str] = None
    completed_nodes: Set[str] = field(default_factory=set)
    skipped_nodes: Set[str] = field(default_factory=set)
    step_results: Dict[str, str] = field(default_factory=dict)
    context: Dict[str, Any] = field(default_factory=dict)
    iteration_count: Dict[str, int] = field(default_factory=dict)
    execution_path: List[str] = field(default_factory=list)

    def mark_completed(self, node_id: str, result: str) -> None:
        """Mark a node as completed with its result."""
        self.completed_nodes.add(node_id)
        self.step_results[node_id] = result
        self.execution_path.append(node_id)
        self.iteration_count[node_id] = self.iteration_count.get(node_id, 0) + 1

    def get_iteration_count(self, node_id: str) -> int:
        """Get how many times a node has been executed."""
        return self.iteration_count.get(node_id, 0)

    def reset_for_retry(self, from_node: str) -> None:
        """Reset state for retrying from a specific node."""
        # Remove nodes executed after from_node from completed set
        # Keep results for potential reference
        if from_node in self.execution_path:
            idx = self.execution_path.index(from_node)
            nodes_to_reset = self.execution_path[idx:]
            for node in nodes_to_reset:
                self.completed_nodes.discard(node)
            self.execution_path = self.execution_path[:idx]


@dataclass
class ExecutionResult:
    """Complete result of GRD execution."""

    success: bool
    final_answer: str
    state: ExecutionState
    node_results: List[NodeExecutionResult]
    total_iterations: int
    error: Optional[str] = None

    def get_execution_trace(self) -> List[Dict[str, Any]]:
        """Get a detailed execution trace."""
        return [
            {
                "step": i + 1,
                "node_id": result.node_id,
                "status": result.status.value,
                "output": result.output,
                "error": result.error,
            }
            for i, result in enumerate(self.node_results)
        ]


class ConditionEvaluator:
    """
    Evaluates conditions on edge labels for branching decisions.

    Supports conditions like:
    - "success" / "failure" / "error"
    - "yes" / "no"
    - "value > 100" / "value < 50"
    - "contains X" / "not contains X"
    """

    # Pattern for comparison conditions
    COMPARISON_PATTERN = re.compile(r"(?:value|result)?\s*([<>=!]+)\s*(\d+\.?\d*)", re.IGNORECASE)

    # Keywords for boolean conditions
    SUCCESS_KEYWORDS = {"success", "yes", "true", "valid", "correct", "pass", "ok"}
    FAILURE_KEYWORDS = {"failure", "no", "false", "invalid", "incorrect", "fail", "error"}

    def evaluate(self, condition: str, context: Dict[str, Any], last_result: str) -> bool:
        """
        Evaluate a condition against the current context.

        Args:
            condition: Condition string from edge label
            context: Current execution context
            last_result: Result from the previous node

        Returns:
            True if condition is satisfied, False otherwise
        """
        condition_lower = condition.lower().strip()

        # Check for success keywords
        if condition_lower in self.SUCCESS_KEYWORDS:
            return self._is_success_result(last_result)

        # Check for failure keywords
        if condition_lower in self.FAILURE_KEYWORDS:
            return not self._is_success_result(last_result)

        # Check for comparison conditions
        match = self.COMPARISON_PATTERN.search(condition)
        if match:
            operator = match.group(1)
            threshold = float(match.group(2))
            return self._evaluate_comparison(last_result, operator, threshold)

        # Check for contains conditions
        if "contains" in condition_lower:
            search_term = condition_lower.replace("contains", "").strip()
            negate = "not" in condition_lower
            result = search_term.lower() in last_result.lower()
            return not result if negate else result

        # Default: assume condition matches if it's mentioned in result
        return condition_lower in last_result.lower()

    def _is_success_result(self, result: str) -> bool:
        """Check if a result indicates success."""
        result_lower = result.lower()

        # Check for explicit failure indicators
        failure_indicators = ["error", "failed", "invalid", "incorrect", "exception"]
        for indicator in failure_indicators:
            if indicator in result_lower:
                return False

        return True

    def _evaluate_comparison(self, result: str, operator: str, threshold: float) -> bool:
        """Evaluate a numeric comparison."""
        # Extract numeric value from result
        numbers = re.findall(r"-?\d+\.?\d*", result)
        if not numbers:
            return False

        try:
            value = float(numbers[-1])  # Use last number found
        except ValueError:
            return False

        # Evaluate comparison
        if operator == ">":
            return value > threshold
        elif operator == ">=":
            return value >= threshold
        elif operator == "<":
            return value < threshold
        elif operator == "<=":
            return value <= threshold
        elif operator in ("==", "="):
            return value == threshold
        elif operator in ("!=", "<>"):
            return value != threshold

        return False


[docs] class StatefulExecutionEngine: """ Stateful execution engine for GRDs. Unlike simple topological sorting, this engine: - Maintains execution state across steps - Supports conditional branching - Handles cycles for critic/verification loops - Provides runtime condition evaluation """ DEFAULT_MAX_ITERATIONS = 3 # Per node, to prevent infinite loops DEFAULT_MAX_TOTAL_STEPS = 50
[docs] def __init__( self, grd: GRDStructure, max_iterations_per_node: int = DEFAULT_MAX_ITERATIONS, max_total_steps: int = DEFAULT_MAX_TOTAL_STEPS, ): """ Initialize the execution engine. Args: grd: The GRD structure to execute max_iterations_per_node: Max times a single node can be executed max_total_steps: Maximum total execution steps """ self.grd = grd self.max_iterations_per_node = max_iterations_per_node self.max_total_steps = max_total_steps self.condition_evaluator = ConditionEvaluator() self.state = ExecutionState()
[docs] def reset(self) -> None: """Reset the execution state.""" self.state = ExecutionState()
[docs] def execute( self, problem: str, executor: Callable[[GRDNode, Dict[str, Any]], str], initial_context: Optional[Dict[str, Any]] = None, ) -> ExecutionResult: """ Execute the GRD step by step. Args: problem: The problem being solved executor: Function that executes a single node initial_context: Optional initial context Returns: ExecutionResult with complete execution details """ self.reset() # Initialize context self.state.context = initial_context or {} self.state.context["problem"] = problem node_results: List[NodeExecutionResult] = [] total_iterations = 0 # Find start node(s) current_nodes = list(self.grd.start_nodes) if not current_nodes: # Fallback: find nodes with no incoming edges current_nodes = self._find_start_nodes() if not current_nodes: return ExecutionResult( success=False, final_answer="", state=self.state, node_results=[], total_iterations=0, error="No start nodes found in GRD", ) # Execute starting from first start node self.state.current_node = current_nodes[0] while self.state.current_node and total_iterations < self.max_total_steps: current_id = self.state.current_node node = self.grd.get_node_by_id(current_id) if not node: break # Check iteration limit if self.state.get_iteration_count(current_id) >= self.max_iterations_per_node: node_results.append( NodeExecutionResult( node_id=current_id, status=ExecutionStatus.SKIPPED, output="", error=f"Max iterations ({self.max_iterations_per_node}) reached", ) ) break # Execute the node total_iterations += 1 try: # Build execution context exec_context = self._build_execution_context(node) output = executor(node, exec_context) self.state.mark_completed(current_id, output) node_results.append( NodeExecutionResult( node_id=current_id, status=ExecutionStatus.COMPLETED, output=output, ) ) # Determine next node self.state.current_node = self._get_next_node(current_id, output) except Exception as e: node_results.append( NodeExecutionResult( node_id=current_id, status=ExecutionStatus.FAILED, output="", error=str(e), ) ) # Try to continue on error self.state.current_node = self._get_next_node_on_error(current_id) # Determine final answer final_answer = self._extract_final_answer() return ExecutionResult( success=len(node_results) > 0 and node_results[-1].status == ExecutionStatus.COMPLETED, final_answer=final_answer, state=self.state, node_results=node_results, total_iterations=total_iterations, )
def _find_start_nodes(self) -> List[str]: """Find nodes with no incoming edges.""" has_incoming = {edge.to_node for edge in self.grd.edges} return [node.id for node in self.grd.nodes if node.id not in has_incoming] def _build_execution_context(self, node: GRDNode) -> Dict[str, Any]: """Build context for node execution.""" context = dict(self.state.context) context["current_node"] = node.id context["current_label"] = node.label context["previous_results"] = dict(self.state.step_results) context["execution_path"] = list(self.state.execution_path) # Build formatted previous steps previous_steps = [] for prev_id in self.state.execution_path: prev_node = self.grd.get_node_by_id(prev_id) if prev_node and prev_id in self.state.step_results: previous_steps.append(f"{prev_node.label}: {self.state.step_results[prev_id]}") context["previous_steps_formatted"] = "\n".join(previous_steps) return context def _get_next_node(self, current_id: str, output: str) -> Optional[str]: """ Determine the next node based on output and edge conditions. Args: current_id: Current node ID output: Output from current node execution Returns: ID of next node, or None if execution should end """ outgoing_edges = self.grd.get_outgoing_edges(current_id) if not outgoing_edges: return None # End of execution # If only one edge, follow it if len(outgoing_edges) == 1: return outgoing_edges[0].to_node # Multiple edges - evaluate conditions for edge in outgoing_edges: if edge.label: # Check if this edge's condition is satisfied if self.condition_evaluator.evaluate(edge.label, self.state.context, output): return edge.to_node elif edge.condition: # Use explicit condition field if self.condition_evaluator.evaluate(edge.condition, self.state.context, output): return edge.to_node # Default: follow first edge without condition for edge in outgoing_edges: if not edge.label and not edge.condition: return edge.to_node # Fallback: first edge return outgoing_edges[0].to_node def _get_next_node_on_error(self, current_id: str) -> Optional[str]: """Get next node when current node execution failed.""" outgoing_edges = self.grd.get_outgoing_edges(current_id) # Look for error/failure edges for edge in outgoing_edges: if edge.label and edge.label.lower() in ("error", "failure", "fail"): return edge.to_node # Continue with default path if no error edge if outgoing_edges: return outgoing_edges[0].to_node return None def _extract_final_answer(self) -> str: """Extract the final answer from execution results.""" # Try end nodes first for end_node_id in self.grd.end_nodes: if end_node_id in self.state.step_results: return self.state.step_results[end_node_id] # Use last executed node's result if self.state.execution_path: last_node = self.state.execution_path[-1] if last_node in self.state.step_results: return self.state.step_results[last_node] return ""
[docs] def can_reach(self, from_node: str, to_node: str) -> bool: """Check if to_node is reachable from from_node.""" visited: Set[str] = set() queue = [from_node] while queue: current = queue.pop(0) if current == to_node: return True if current in visited: continue visited.add(current) for edge in self.grd.get_outgoing_edges(current): if edge.to_node not in visited: queue.append(edge.to_node) return False
[docs] def has_cycles(self) -> bool: """Check if the GRD contains cycles.""" visited: Set[str] = set() rec_stack: Set[str] = set() def dfs(node_id: str) -> bool: visited.add(node_id) rec_stack.add(node_id) for edge in self.grd.get_outgoing_edges(node_id): if edge.to_node not in visited: if dfs(edge.to_node): return True elif edge.to_node in rec_stack: return True rec_stack.discard(node_id) return False for node in self.grd.nodes: if node.id not in visited: if dfs(node.id): return True return False
[docs] def detect_cycles(self) -> List[List[str]]: """Detect all cycles in the GRD.""" cycles: List[List[str]] = [] def find_cycles_from(start: str, path: List[str], visited: Set[str]) -> None: if start in path: # Found a cycle cycle_start = path.index(start) cycles.append(path[cycle_start:] + [start]) return if start in visited: return visited.add(start) path = path + [start] for edge in self.grd.get_outgoing_edges(start): find_cycles_from(edge.to_node, path, visited.copy()) for node in self.grd.nodes: find_cycles_from(node.id, [], set()) return cycles