Source code for braid.module

"""Main BRAID reasoning module for DSPy."""

from typing import Dict, List, Optional, Any
import dspy
from dataclasses import dataclass

from braid.signatures import BraidPlanSignature, BraidStepSignature
from braid.generator import GRDGenerator
from braid.parser import MermaidParser, GRDStructure


[docs] @dataclass class BraidResult: """Result object returned by BraidReasoning module.""" problem: str grd: str parsed_grd: Optional[GRDStructure] reasoning_steps: List[Dict[str, Any]] answer: str execution_trace: List[Dict[str, Any]] valid: bool error: Optional[str] = None
[docs] class BraidReasoning(dspy.Module): """ BRAID reasoning module for DSPy. This module implements the BRAID (Bounded Reasoning for Autonomous Inference and Decisions) architecture: 1. Planning Phase: Generate a Guided Reasoning Diagram (GRD) in Mermaid format 2. Execution Phase: Execute the GRD step by step to solve the problem Example: >>> import dspy >>> from braid import BraidReasoning >>> >>> lm = dspy.OpenAI(model="gpt-4") >>> dspy.configure(lm=lm) >>> >>> braid = BraidReasoning() >>> result = braid(problem="If a train travels 120 km in 2 hours, what is its speed?") >>> print(result.answer) >>> print(result.grd) """
[docs] def __init__( self, use_generator: bool = True, max_execution_steps: int = 20, validate_grd: bool = True ): """ Initialize the BRAID reasoning module. Args: use_generator: Whether to use GRDGenerator for planning (True) or direct LLM call (False) max_execution_steps: Maximum number of steps to execute validate_grd: Whether to validate GRD syntax before execution """ super().__init__() self.use_generator = use_generator self.max_execution_steps = max_execution_steps self.validate_grd = validate_grd # Initialize sub-modules self.plan = dspy.Predict(BraidPlanSignature) self.execute_step = dspy.Predict(BraidStepSignature) self.parser = MermaidParser() if use_generator: self.generator = GRDGenerator() else: self.generator = None
[docs] def forward( self, problem: str, grd: Optional[str] = None, problem_type: Optional[str] = None ) -> BraidResult: """ Execute BRAID reasoning on a problem. Args: problem: The problem to solve grd: Optional pre-generated GRD (if None, will be generated) problem_type: Optional problem type hint for generation Returns: BraidResult object containing GRD, reasoning steps, and answer """ execution_trace = [] reasoning_steps = [] # Phase 1: Planning - Generate or use provided GRD if grd is None: if self.use_generator and self.generator: generation_result = self.generator.generate( problem=problem, problem_type=problem_type ) grd = generation_result.get("grd") if not grd: return BraidResult( problem=problem, grd="", parsed_grd=None, reasoning_steps=[], answer="", execution_trace=execution_trace, valid=False, error=generation_result.get("error", "Failed to generate GRD"), ) else: # Use DSPy signature directly plan_result = self.plan(problem=problem) grd = plan_result.grd execution_trace.append( {"phase": "planning", "method": "dspy_signature", "grd": grd} ) else: execution_trace.append({"phase": "planning", "method": "provided", "grd": grd}) # Validate GRD if requested parsed_grd = None if self.validate_grd: is_valid, error_msg = self.parser.validate(grd) if not is_valid: return BraidResult( problem=problem, grd=grd, parsed_grd=None, reasoning_steps=[], answer="", execution_trace=execution_trace, valid=False, error=f"Invalid GRD: {error_msg}", ) # Parse GRD structure try: parsed_grd = self.parser.parse(grd) except Exception as e: return BraidResult( problem=problem, grd=grd, parsed_grd=None, reasoning_steps=[], answer="", execution_trace=execution_trace, valid=False, error=f"Failed to parse GRD: {str(e)}", ) # Phase 2: Execution - Execute GRD step by step execution_order = parsed_grd.get_execution_order() if not execution_order: return BraidResult( problem=problem, grd=grd, parsed_grd=parsed_grd, reasoning_steps=[], answer="", execution_trace=execution_trace, valid=False, error="GRD has no execution order", ) # Execute steps in order step_results = {} context = {"problem": problem, "previous_steps": []} for step_idx, node_id in enumerate(execution_order[: self.max_execution_steps]): node = parsed_grd.get_node_by_id(node_id) if not node: continue # Build context from previous steps previous_results = "\n".join( [ f"Step {i+1} ({s['step_id']}): {s['result']}" for i, s in enumerate(reasoning_steps) ] ) # Execute step step_context = f"Problem: {problem}\n\nPrevious Steps:\n{previous_results}" try: step_result = self.execute_step(step_description=node.label, context=step_context) step_output = step_result.step_output reasoning_steps.append( { "step_id": node_id, "step_number": step_idx + 1, "label": node.label, "result": step_output, "node_type": node.node_type.value, } ) step_results[node_id] = step_output context["previous_steps"].append({"step": node_id, "result": step_output}) execution_trace.append( { "phase": "execution", "step": step_idx + 1, "node_id": node_id, "result": step_output, } ) except Exception as e: execution_trace.append( { "phase": "execution", "step": step_idx + 1, "node_id": node_id, "error": str(e), } ) # Continue execution even if a step fails step_results[node_id] = f"Error: {str(e)}" # Extract final answer from end nodes answer = self._extract_answer(parsed_grd, step_results, reasoning_steps) return BraidResult( problem=problem, grd=grd, parsed_grd=parsed_grd, reasoning_steps=reasoning_steps, answer=answer, execution_trace=execution_trace, valid=True, )
def _extract_answer( self, grd: GRDStructure, step_results: Dict[str, str], reasoning_steps: List[Dict[str, Any]] ) -> str: """ Extract the final answer from execution results. Args: grd: Parsed GRD structure step_results: Dictionary mapping node IDs to their results reasoning_steps: List of reasoning steps Returns: Final answer string """ # Try to get answer from end nodes if grd.end_nodes: for end_node_id in grd.end_nodes: if end_node_id in step_results: result = step_results[end_node_id] # Check if result looks like a final answer if result and len(result) < 500: # Reasonable answer length return result # If no end node result, use the last step result if reasoning_steps: last_step = reasoning_steps[-1] return last_step.get("result", "") # Fallback: construct answer from all steps if step_results: return "\n".join([f"{node_id}: {result}" for node_id, result in step_results.items()]) return ""
[docs] def __call__(self, problem: str, **kwargs) -> BraidResult: """Make the module callable.""" return self.forward(problem, **kwargs)