Source code for braid.generator

"""GRD Generator module for creating Guided Reasoning Diagrams using LLMs."""

from typing import Dict, List, Optional, Any
import dspy
from braid.parser import MermaidParser
from braid.utils import extract_mermaid_code
from braid.signatures import BraidPlanSignature


[docs] class GRDGenerator: """Generator for Guided Reasoning Diagrams in Mermaid format.""" # Few-shot examples for GRD generation # BRAID Protocol: Examples use PROCEDURAL SCAFFOLDING (describe HOW, not WHAT) # Each node describes an ACTION, not a computed VALUE DEFAULT_EXAMPLES = [ { "problem": "If a train travels 120 km in 2 hours, what is its speed?", "grd": """```mermaid flowchart TD Start[Read and understand problem] --> Extract[Extract given values] Extract --> Identify[Identify what to find] Identify --> Formula[Recall speed formula] Formula --> Apply[Apply: divide distance by time] Apply --> Calculate[Perform the division] Calculate --> Verify[Verify units are correct] Verify --> Answer[State the final speed] ```""", }, { "problem": "Solve: 3x + 5 = 14", "grd": """```mermaid flowchart TD Start[Analyze the equation] --> Goal[Goal: isolate x] Goal --> Subtract[Subtract constant from both sides] Subtract --> Simplify1[Simplify the right side] Simplify1 --> Divide[Divide both sides by coefficient] Divide --> Simplify2[Simplify to get x value] Simplify2 --> Check[Verify: substitute back] Check --> Answer[State the solution] ```""", }, { "problem": "A store sells apples at $2 each. If John buys 5 apples, how much does he pay?", "grd": """```mermaid flowchart TD Start[Understand the scenario] --> Values[Identify: price and quantity] Values --> Operation[Determine operation: multiplication] Operation --> Calculate[Multiply price by quantity] Calculate --> Answer[State total cost] ```""", }, ]
[docs] def __init__( self, examples: Optional[List[Dict[str, str]]] = None, max_retries: int = 3, temperature: float = 0.3, use_dspy_predict: bool = True, ): """ Initialize the GRD Generator. Args: examples: Few-shot examples for GRD generation max_retries: Maximum number of retries if generation fails temperature: Temperature for LLM generation (lower = more deterministic) use_dspy_predict: Whether to use DSPy's Predict API (recommended) """ self.examples = examples or self.DEFAULT_EXAMPLES self.max_retries = max_retries self.temperature = temperature self.use_dspy_predict = use_dspy_predict self.parser = MermaidParser() if use_dspy_predict: self.predictor = dspy.Predict(BraidPlanSignature)
[docs] def generate( self, problem: str, problem_type: Optional[str] = None, custom_instructions: Optional[str] = None, ) -> Dict[str, Any]: """ Generate a GRD for a given problem. Args: problem: The problem to solve problem_type: Optional type hint (e.g., "math", "logic", "reasoning") custom_instructions: Optional custom instructions for generation Returns: Dictionary containing: - grd: Mermaid code string - raw_response: Raw LLM response - parsed_structure: Parsed GRDStructure object - valid: Whether the GRD is valid """ prompt = self._build_prompt(problem, problem_type, custom_instructions) for attempt in range(self.max_retries): try: if self.use_dspy_predict: # Use DSPy's Predict API (recommended) result = self.predictor(problem=problem) raw_response = result.grd else: # Direct LM access (fallback) lm = dspy.settings.lm if hasattr(dspy, "settings") else None if not lm: # Try alternative way to get LM try: lm = dspy.LM() except: raise ValueError( "DSPy language model not configured. Call dspy.configure(lm=...) first." ) # Configure temperature if supported original_temp = getattr(lm, "temperature", None) if hasattr(lm, "temperature"): lm.temperature = self.temperature try: response = lm(prompt) raw_response = ( response if isinstance(response, str) else getattr(response, "text", str(response)) ) finally: # Restore original temperature if hasattr(lm, "temperature") and original_temp is not None: lm.temperature = original_temp # Extract Mermaid code mermaid_code = extract_mermaid_code(raw_response) if not mermaid_code: if attempt < self.max_retries - 1: continue return { "grd": None, "raw_response": raw_response, "parsed_structure": None, "valid": False, "error": "Could not extract Mermaid code from response", } # Validate syntax is_valid, error_msg = self.parser.validate(mermaid_code) if not is_valid and attempt < self.max_retries - 1: continue # Parse structure parsed_structure = None if is_valid: try: parsed_structure = self.parser.parse(mermaid_code) except Exception as e: error_msg = f"Parsing error: {str(e)}" is_valid = False return { "grd": mermaid_code, "raw_response": raw_response, "parsed_structure": parsed_structure, "valid": is_valid, "error": error_msg if not is_valid else None, } except Exception as e: if attempt == self.max_retries - 1: return { "grd": None, "raw_response": None, "parsed_structure": None, "valid": False, "error": f"Generation failed after {self.max_retries} attempts: {str(e)}", } continue return { "grd": None, "raw_response": None, "parsed_structure": None, "valid": False, "error": "Generation failed after maximum retries", }
def _build_prompt( self, problem: str, problem_type: Optional[str] = None, custom_instructions: Optional[str] = None, ) -> str: """Build the prompt for GRD generation with BRAID protocol rules.""" prompt_parts = [ "You are an expert at creating structured reasoning diagrams.", "Your task is to create a Guided Reasoning Diagram (GRD) in Mermaid flowchart format", "that maps out the solution steps for a given problem.", "", "=== CRITICAL BRAID PROTOCOL RULES ===", "", "1. PROCEDURAL SCAFFOLDING: Describe HOW to solve, never WHAT the answer is.", " - WRONG: 'Calculate[Speed = 60 km/h]' (contains the answer!)", " - RIGHT: 'Calculate[Divide distance by time]' (describes the action)", "", "2. NO ANSWER LEAKAGE: Never include computed numerical values.", " - WRONG: 'Result[The sum is 42]'", " - RIGHT: 'Result[State the computed sum]'", "", "3. ATOMIC NODES: Keep each node label UNDER 15 tokens.", " - Break complex steps into multiple simple nodes.", " - Each node = ONE focused action.", "", "4. ACTION-ORIENTED: Each node should describe an executable action.", " - Use verbs: Calculate, Identify, Extract, Compare, Verify, Apply, etc.", "", "=== FORMAT ===", "", "```mermaid", "flowchart TD", " Start[Analyze the problem] --> Step1[Action description]", " Step1 --> Step2[Next action]", " Step2 --> Answer[State the final result]", "```", ] if problem_type: prompt_parts.append(f"\nProblem Type: {problem_type}") if custom_instructions: prompt_parts.append(f"\nAdditional Instructions: {custom_instructions}") prompt_parts.append("\n=== EXAMPLES (following BRAID protocol) ===") for i, example in enumerate(self.examples, 1): prompt_parts.append(f"\nExample {i}:") prompt_parts.append(f"Problem: {example['problem']}") prompt_parts.append(f"GRD:\n{example['grd']}") prompt_parts.append("\n\n=== YOUR TASK ===") prompt_parts.append(f"Problem: {problem}") prompt_parts.append( "\nGenerate a BRAID-compliant Mermaid flowchart (remember: describe HOW, not WHAT):" ) return "\n".join(prompt_parts)
[docs] def add_example(self, problem: str, grd: str): """ Add a custom example to the generator. Args: problem: Example problem grd: Example GRD in Mermaid format """ self.examples.append({"problem": problem, "grd": grd})
[docs] def get_template(self, problem_type: str) -> Optional[str]: """ Get a template GRD for a specific problem type. Args: problem_type: Type of problem (e.g., "math", "logic", "reasoning") Returns: Template Mermaid code or None """ templates = { "math": """```mermaid flowchart TD Start[Read Problem] --> Identify[Identify Given Values] Identify --> Formula[Recall Relevant Formula] Formula --> Substitute[Substitute Values] Substitute --> Calculate[Perform Calculation] Calculate --> Verify[Verify Answer] Verify --> Answer[Final Answer] ```""", "logic": """```mermaid flowchart TD Start[Problem Analysis] --> Premises[Identify Premises] Premises --> Rules[Apply Logical Rules] Rules --> Deduce[Deduce Conclusion] Deduce --> Answer[Final Conclusion] ```""", "reasoning": """```mermaid flowchart TD Start[Understand Problem] --> Break[Break into Sub-problems] Break --> Solve[Solve Each Sub-problem] Solve --> Combine[Combine Solutions] Combine --> Answer[Final Answer] ```""", } return templates.get(problem_type.lower())