Source code for braid.training

"""Training utilities for BRAID Architect models.

This module provides tools for:
- Generating synthetic training data for Architect models
- Preparing datasets for fine-tuning
- Creating DSPy examples for BootstrapFewShot optimization
"""

import json
import random
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass, field, asdict
from pathlib import Path

from braid.validators import AtomicityValidator, GRDValidator
from braid.masking import NumericalMasker


@dataclass
class TrainingSample:
    """A single training sample for Architect model training."""

    problem: str
    grd: str
    expected_answer: Optional[str] = None
    problem_type: str = "general"
    difficulty: str = "medium"  # easy, medium, hard
    metadata: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary for serialization."""
        return asdict(self)

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "TrainingSample":
        """Create from dictionary."""
        return cls(**data)


@dataclass
class DatasetStats:
    """Statistics for a training dataset."""

    total_samples: int
    problem_types: Dict[str, int]
    difficulty_distribution: Dict[str, int]
    avg_grd_nodes: float
    avg_tokens_per_node: float
    validation_passed: int
    validation_failed: int


[docs] class SyntheticDataGenerator: """ Generates synthetic training data for Architect models. This generator creates problem-GRD pairs following BRAID protocol: - Procedural scaffolding (describe HOW, not WHAT) - Atomic nodes (≤15 tokens per node) - No answer leakage """ # Problem templates by category MATH_TEMPLATES = [ { "template": "If a {vehicle} travels {distance} km in {time} hours, what is its speed?", "grd_template": """flowchart TD Start[Read and analyze problem] --> Extract[Extract: distance and time values] Extract --> Identify[Identify: need to find speed] Identify --> Formula[Recall speed formula] Formula --> Apply[Apply: divide distance by time] Apply --> Units[Verify units are correct] Units --> Answer[State the final speed]""", "variables": { "vehicle": ["car", "train", "bus", "bicycle", "plane"], "distance": [60, 120, 180, 240, 300, 450, 600], "time": [1, 2, 3, 4, 5, 6], }, "answer_fn": lambda v: f"{v['distance'] / v['time']} km/h", }, { "template": "Solve: {a}x + {b} = {c}", "grd_template": """flowchart TD Start[Analyze the equation] --> Goal[Goal: isolate x] Goal --> Subtract[Subtract constant from both sides] Subtract --> Simplify1[Simplify right side] Simplify1 --> Divide[Divide by coefficient] Divide --> Simplify2[Calculate x value] Simplify2 --> Check[Verify by substitution] Check --> Answer[State solution]""", "variables": { "a": [2, 3, 4, 5, 6], "b": [1, 2, 3, 4, 5, 7, 8, 10], "c": [10, 12, 14, 15, 18, 20, 22, 25], }, "answer_fn": lambda v: f"x = {(v['c'] - v['b']) / v['a']}", }, { "template": "A store sells {item} at ${price} each. If {name} buys {quantity}, how much does {pronoun} pay?", "grd_template": """flowchart TD Start[Understand the scenario] --> Values[Identify: unit price and quantity] Values --> Operation[Determine operation needed] Operation --> Calculate[Multiply price by quantity] Calculate --> Format[Format as currency] Format --> Answer[State total cost]""", "variables": { "item": ["apples", "oranges", "books", "pens", "notebooks"], "price": [2, 3, 5, 8, 10, 15], "name": ["John", "Maria", "Alex", "Sarah"], "quantity": [3, 4, 5, 6, 7, 8, 10], "pronoun": ["he", "she", "they"], }, "answer_fn": lambda v: f"${v['price'] * v['quantity']}", }, ] LOGIC_TEMPLATES = [ { "template": "If all {category_a} are {category_b}, and {item} is a {category_a}, what can we conclude?", "grd_template": """flowchart TD Start[Identify premises] --> P1[Premise 1: All A are B] P1 --> P2[Premise 2: X is A] P2 --> Apply[Apply syllogistic reasoning] Apply --> Deduce[Deduce: X must be B] Deduce --> Answer[State conclusion]""", "variables": { "category_a": ["dogs", "cats", "birds", "mammals"], "category_b": ["animals", "living things", "creatures"], "item": ["Rex", "Fluffy", "Tweety", "Max"], }, "answer_fn": lambda v: f"{v['item']} is a {v['category_b'].rstrip('s')}", }, ] REASONING_TEMPLATES = [ { "template": "{person} has {count} {items}. {person2} gives {person} {more} more. How many does {person} have now?", "grd_template": """flowchart TD Start[Understand the situation] --> Initial[Identify initial count] Initial --> Change[Identify the change] Change --> Operation[Determine: addition needed] Operation --> Calculate[Add the quantities] Calculate --> Answer[State final count]""", "variables": { "person": ["Alice", "Bob", "Charlie", "Diana"], "person2": ["Bob", "Carol", "David", "Eve"], "count": [3, 5, 7, 10, 12], "items": ["apples", "books", "coins", "marbles"], "more": [2, 3, 4, 5], }, "answer_fn": lambda v: f"{v['count'] + v['more']} {v['items']}", }, ]
[docs] def __init__( self, validate_output: bool = True, max_tokens_per_node: int = 15, ): """ Initialize the synthetic data generator. Args: validate_output: Whether to validate generated samples max_tokens_per_node: Maximum tokens per node for validation """ self.validate_output = validate_output self.validator = GRDValidator(max_tokens_per_node=max_tokens_per_node) self.masker = NumericalMasker() self.atomicity_validator = AtomicityValidator(max_tokens_per_node=max_tokens_per_node)
def _fill_template(self, template: Dict[str, Any]) -> Tuple[str, str, str]: """Fill a template with random variables.""" variables = {} for var_name, var_options in template["variables"].items(): variables[var_name] = random.choice(var_options) problem = template["template"].format(**variables) grd = template["grd_template"] answer = template["answer_fn"](variables) return problem, grd, answer
[docs] def generate_math_samples(self, count: int) -> List[TrainingSample]: """ Generate math problem samples. Args: count: Number of samples to generate Returns: List of TrainingSample objects """ samples = [] for _ in range(count): template = random.choice(self.MATH_TEMPLATES) problem, grd, answer = self._fill_template(template) sample = TrainingSample( problem=problem, grd=f"```mermaid\n{grd}\n```", expected_answer=answer, problem_type="math", difficulty=random.choice(["easy", "medium"]), ) samples.append(sample) return samples
[docs] def generate_logic_samples(self, count: int) -> List[TrainingSample]: """ Generate logic problem samples. Args: count: Number of samples to generate Returns: List of TrainingSample objects """ samples = [] for _ in range(count): template = random.choice(self.LOGIC_TEMPLATES) problem, grd, answer = self._fill_template(template) sample = TrainingSample( problem=problem, grd=f"```mermaid\n{grd}\n```", expected_answer=answer, problem_type="logic", difficulty=random.choice(["medium", "hard"]), ) samples.append(sample) return samples
[docs] def generate_reasoning_samples(self, count: int) -> List[TrainingSample]: """ Generate general reasoning samples. Args: count: Number of samples to generate Returns: List of TrainingSample objects """ samples = [] for _ in range(count): template = random.choice(self.REASONING_TEMPLATES) problem, grd, answer = self._fill_template(template) sample = TrainingSample( problem=problem, grd=f"```mermaid\n{grd}\n```", expected_answer=answer, problem_type="reasoning", difficulty="easy", ) samples.append(sample) return samples
[docs] def generate_mixed_samples( self, count: int, math_ratio: float = 0.4, logic_ratio: float = 0.3, reasoning_ratio: float = 0.3, ) -> List[TrainingSample]: """ Generate a mixed dataset of samples. Args: count: Total number of samples math_ratio: Proportion of math problems logic_ratio: Proportion of logic problems reasoning_ratio: Proportion of reasoning problems Returns: List of TrainingSample objects """ math_count = int(count * math_ratio) logic_count = int(count * logic_ratio) reasoning_count = count - math_count - logic_count samples = [] samples.extend(self.generate_math_samples(math_count)) samples.extend(self.generate_logic_samples(logic_count)) samples.extend(self.generate_reasoning_samples(reasoning_count)) random.shuffle(samples) return samples
[docs] def validate_samples( self, samples: List[TrainingSample] ) -> Tuple[List[TrainingSample], List[TrainingSample]]: """ Validate samples against BRAID protocol rules. Args: samples: Samples to validate Returns: Tuple of (valid_samples, invalid_samples) """ from braid.parser import MermaidParser from braid.utils import extract_mermaid_code parser = MermaidParser() valid = [] invalid = [] for sample in samples: try: mermaid_code = extract_mermaid_code(sample.grd) if mermaid_code: parsed = parser.parse(mermaid_code) result = self.validator.validate(parsed) if result.valid: valid.append(sample) else: sample.metadata["validation_issues"] = [ str(issue) for issue in result.issues ] invalid.append(sample) else: sample.metadata["validation_issues"] = ["Could not extract Mermaid code"] invalid.append(sample) except Exception as e: sample.metadata["validation_issues"] = [str(e)] invalid.append(sample) return valid, invalid
class DatasetExporter: """Exports training datasets in various formats.""" @staticmethod def to_jsonl(samples: List[TrainingSample], path: str) -> None: """ Export samples to JSONL format. Args: samples: Samples to export path: Output file path """ with open(path, "w", encoding="utf-8") as f: for sample in samples: f.write(json.dumps(sample.to_dict(), ensure_ascii=False) + "\n") @staticmethod def to_json(samples: List[TrainingSample], path: str) -> None: """ Export samples to JSON format. Args: samples: Samples to export path: Output file path """ with open(path, "w", encoding="utf-8") as f: json.dump( [sample.to_dict() for sample in samples], f, ensure_ascii=False, indent=2, ) @staticmethod def from_jsonl(path: str) -> List[TrainingSample]: """ Load samples from JSONL format. Args: path: Input file path Returns: List of TrainingSample objects """ samples = [] with open(path, "r", encoding="utf-8") as f: for line in f: data = json.loads(line.strip()) samples.append(TrainingSample.from_dict(data)) return samples @staticmethod def from_json(path: str) -> List[TrainingSample]: """ Load samples from JSON format. Args: path: Input file path Returns: List of TrainingSample objects """ with open(path, "r", encoding="utf-8") as f: data = json.load(f) return [TrainingSample.from_dict(item) for item in data]
[docs] class ArchitectTrainer: """ Utilities for training/fine-tuning Architect models. Supports: - Creating DSPy examples for BootstrapFewShot - Preparing fine-tuning datasets - Calculating dataset statistics """
[docs] def __init__(self): """Initialize the trainer.""" self.generator = SyntheticDataGenerator()
[docs] def create_dspy_examples(self, samples: List[TrainingSample]) -> List[Any]: """ Create DSPy Example objects from training samples. Args: samples: Training samples to convert Returns: List of dspy.Example objects """ try: import dspy examples = [] for sample in samples: example = dspy.Example( problem=sample.problem, grd=sample.grd, ).with_inputs("problem") if sample.expected_answer: example = example.with_inputs("problem") examples.append(example) return examples except ImportError: raise ImportError( "DSPy is required for creating examples. Install with: pip install dspy-ai" )
[docs] def prepare_openai_finetune_dataset( self, samples: List[TrainingSample], system_prompt: Optional[str] = None, ) -> List[Dict[str, Any]]: """ Prepare dataset in OpenAI fine-tuning format. Args: samples: Training samples system_prompt: Optional system prompt Returns: List of conversation dictionaries """ if system_prompt is None: system_prompt = """You are an expert at creating Guided Reasoning Diagrams (GRDs) in Mermaid format. BRAID Protocol Rules: 1. PROCEDURAL SCAFFOLDING: Describe HOW to solve, never WHAT the answer is 2. NO ANSWER LEAKAGE: Never include computed values in node labels 3. ATOMIC NODES: Keep each node under 15 tokens 4. ACTION-ORIENTED: Each node describes an executable action""" dataset = [] for sample in samples: conversation = { "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": f"Create a GRD for this problem: {sample.problem}"}, {"role": "assistant", "content": sample.grd}, ] } dataset.append(conversation) return dataset
[docs] def calculate_dataset_stats(self, samples: List[TrainingSample]) -> DatasetStats: """ Calculate statistics for a dataset. Args: samples: Training samples Returns: DatasetStats object """ from braid.parser import MermaidParser from braid.utils import extract_mermaid_code from braid.validators import AtomicityValidator parser = MermaidParser() validator = AtomicityValidator() problem_types: Dict[str, int] = {} difficulty_dist: Dict[str, int] = {} node_counts = [] token_counts = [] valid_count = 0 invalid_count = 0 for sample in samples: # Count problem types problem_types[sample.problem_type] = problem_types.get(sample.problem_type, 0) + 1 # Count difficulties difficulty_dist[sample.difficulty] = difficulty_dist.get(sample.difficulty, 0) + 1 # Parse and analyze GRD try: mermaid_code = extract_mermaid_code(sample.grd) if mermaid_code: parsed = parser.parse(mermaid_code) node_counts.append(len(parsed.nodes)) for node in parsed.nodes: token_counts.append(validator.count_tokens(node.label)) valid_count += 1 else: invalid_count += 1 except Exception: invalid_count += 1 return DatasetStats( total_samples=len(samples), problem_types=problem_types, difficulty_distribution=difficulty_dist, avg_grd_nodes=sum(node_counts) / len(node_counts) if node_counts else 0, avg_tokens_per_node=sum(token_counts) / len(token_counts) if token_counts else 0, validation_passed=valid_count, validation_failed=invalid_count, )
[docs] def generate_training_dataset( self, size: int = 100, output_path: Optional[str] = None, format: str = "jsonl", ) -> List[TrainingSample]: """ Generate and optionally save a training dataset. Args: size: Number of samples to generate output_path: Optional path to save the dataset format: Output format ("jsonl" or "json") Returns: List of generated samples """ samples = self.generator.generate_mixed_samples(size) valid_samples, _ = self.generator.validate_samples(samples) if output_path: if format == "jsonl": DatasetExporter.to_jsonl(valid_samples, output_path) else: DatasetExporter.to_json(valid_samples, output_path) return valid_samples