diff --git a/configs/config.yaml b/configs/config.yaml index b105ac75..bef078a0 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -48,6 +48,7 @@ generation: overlap: 200 # Overlap between chunks to maintain context max_tokens: 4096 # Maximum tokens in LLM responses num_pairs: 25 # Default number of QA pairs to generate + num_pairs_per_chunk: null # Alternative: number of QA pairs per chunk (null = use num_pairs instead) num_cot_examples: 5 # Default number of Chain of Thought examples to generate num_cot_enhance_examples: null # Maximum number of conversations to enhance (null = enhance all) batch_size: 32 # Number of requests to batch together (for create) diff --git a/synthetic_data_kit/cli.py b/synthetic_data_kit/cli.py index e88dbb59..ec19552d 100644 --- a/synthetic_data_kit/cli.py +++ b/synthetic_data_kit/cli.py @@ -290,7 +290,10 @@ def create( None, "--model", "-m", help="Model to use" ), num_pairs: Optional[int] = typer.Option( - None, "--num-pairs", "-n", help="Target number of QA pairs or CoT examples to generate" + None, "--num-pairs", "-n", help="Target number of QA pairs or CoT examples to generate (total per document)" + ), + num_pairs_per_chunk: Optional[int] = typer.Option( + None, "--num-pairs-per-chunk", help="Number of QA pairs to generate per chunk (scales with document size, takes precedence over --num-pairs)" ), chunk_size: Optional[int] = typer.Option( None, "--chunk-size", help="Size of text chunks for processing large documents (default: 4000)" @@ -313,9 +316,11 @@ def create( - Directory: synthetic-data-kit create ./processed-text/ --type qa Content types: - - qa: Generate question-answer pairs from .txt files (use --num-pairs to specify how many) + - qa: Generate question-answer pairs from .txt files + Use --num-pairs for total pairs per document OR --num-pairs-per-chunk to scale with document size - summary: Generate summaries from .txt files - - cot: Generate Chain of Thought reasoning examples from .txt files (use --num-pairs to specify how many) + - cot: Generate Chain of Thought reasoning examples from .txt files + Use --num-pairs for total examples OR --num-pairs-per-chunk to scale with document size - multimodal-qa: Generate question-answer pairs from .lance files (use --num-pairs to specify how many) - cot-enhance: Enhance existing conversations with Chain of Thought reasoning from .json files (use --num-pairs to limit the number of conversations to enhance, default is to enhance all) @@ -411,6 +416,7 @@ def create( model=model, content_type=content_type, num_pairs=num_pairs, + num_pairs_per_chunk=num_pairs_per_chunk, verbose=verbose, provider=provider, chunk_size=chunk_size, @@ -441,7 +447,8 @@ def create( verbose, provider=provider, chunk_size=chunk_size, - chunk_overlap=chunk_overlap + chunk_overlap=chunk_overlap, + num_pairs_per_chunk=num_pairs_per_chunk ) if output_path: console.print(f"✅ Content saved to [bold]{output_path}[/bold]", style="green") diff --git a/synthetic_data_kit/config.yaml b/synthetic_data_kit/config.yaml index 2690697d..44a31529 100644 --- a/synthetic_data_kit/config.yaml +++ b/synthetic_data_kit/config.yaml @@ -56,7 +56,8 @@ generation: max_tokens: 4096 # Maximum tokens in LLM responses # Content generation targets - num_pairs: 25 # Default number of QA pairs to generate + num_pairs: 25 # Default number of QA pairs to generate (per document) + num_pairs_per_chunk: null # Alternative: number of QA pairs per chunk (null = use num_pairs instead) num_cot_examples: 5 # Default number of Chain of Thought examples to generate num_cot_enhance_examples: null # Maximum number of conversations to enhance (null = enhance all) diff --git a/synthetic_data_kit/core/create.py b/synthetic_data_kit/core/create.py index 9c2dd78b..00f13ac2 100644 --- a/synthetic_data_kit/core/create.py +++ b/synthetic_data_kit/core/create.py @@ -38,6 +38,7 @@ def process_file( chunk_size: Optional[int] = None, chunk_overlap: Optional[int] = None, rolling_summary: Optional[bool] = False, + num_pairs_per_chunk: Optional[int] = None, ) -> str: """Process a file to generate content @@ -48,8 +49,13 @@ def process_file( api_base: VLLM API base URL model: Model to use content_type: Type of content to generate (qa, summary, cot) - num_pairs: Target number of QA pairs to generate - threshold: Quality threshold for filtering (1-10) + num_pairs: Target number of QA pairs to generate (total per document) + num_pairs_per_chunk: Number of QA pairs per chunk (takes precedence over num_pairs) + verbose: Show detailed output + provider: LLM provider to use + chunk_size: Size of text chunks + chunk_overlap: Overlap between chunks + rolling_summary: Use rolling summary for long documents Returns: Path to the output file @@ -89,15 +95,17 @@ def process_file( generator = QAGenerator(client, config_path) # Get num_pairs from args or config - if num_pairs is None: + if num_pairs is None and num_pairs_per_chunk is None: config = client.config generation_config = get_generation_config(config) num_pairs = generation_config.get("num_pairs", 25) + num_pairs_per_chunk = generation_config.get("num_pairs_per_chunk") # Process document result = generator.process_documents( documents, - num_pairs=num_pairs, + num_pairs=num_pairs if num_pairs is not None else 25, + num_pairs_per_chunk=num_pairs_per_chunk, verbose=verbose, rolling_summary=rolling_summary ) diff --git a/synthetic_data_kit/generators/qa_generator.py b/synthetic_data_kit/generators/qa_generator.py index e892cdd2..36010cfc 100644 --- a/synthetic_data_kit/generators/qa_generator.py +++ b/synthetic_data_kit/generators/qa_generator.py @@ -84,8 +84,19 @@ def generate_summary(self, def generate_qa_pairs(self, document_text: str, summary: str, - num_pairs: int = 25) -> List[Dict[str, str]]: - """Generate QA pairs from the document using batched processing""" + num_pairs: int = 25, + num_pairs_per_chunk: Optional[int] = None) -> List[Dict[str, str]]: + """Generate QA pairs from the document using batched processing + + Args: + document_text: The text to generate QA pairs from + summary: Summary of the document + num_pairs: Total number of QA pairs to generate (used if num_pairs_per_chunk is None) + num_pairs_per_chunk: Number of QA pairs to generate per chunk (takes precedence over num_pairs) + + Returns: + List of QA pair dictionaries + """ verbose = os.environ.get('SDK_VERBOSE', 'false').lower() == 'true' # Get generation config @@ -101,13 +112,25 @@ def generate_qa_pairs(self, overlap=overlap ) + # Determine generation mode and calculate targets + if num_pairs_per_chunk is not None: + # Per-chunk mode: scale with document size + pairs_per_chunk = num_pairs_per_chunk + total_target = num_pairs_per_chunk * len(chunks) + mode = "per-chunk" + else: + # Total pairs mode: divide across chunks (original behavior) + pairs_per_chunk = max(1, round(num_pairs / len(chunks))) + total_target = num_pairs + mode = "total" + if verbose: print(f"Generating QA pairs...") print(f"Document split into {len(chunks)} chunks") + print(f"Mode: {mode} (pairs per chunk: {pairs_per_chunk}, target total: {total_target})") print(f"Using batch size of {batch_size}") all_qa_pairs = [] - pairs_per_chunk = max(1, round(num_pairs / len(chunks))) # Get QA generation prompt template qa_prompt_template = get_prompt(self.config, "qa_generation") @@ -151,9 +174,9 @@ def generate_qa_pairs(self, # Process in batches for batch_start in range(0, len(chunks), batch_size): # Check if we've already generated enough pairs - if len(all_qa_pairs) >= num_pairs: + if len(all_qa_pairs) >= total_target: if verbose: - print(f"Reached target of {num_pairs} pairs. Stopping processing.") + print(f"Reached target of {total_target} pairs. Stopping processing.") break batch_end = min(batch_start + batch_size, len(chunks)) @@ -180,25 +203,25 @@ def generate_qa_pairs(self, # Process each response in the batch for j, response in enumerate(batch_responses): # Check if we've reached the target before processing more - if len(all_qa_pairs) >= num_pairs: + if len(all_qa_pairs) >= total_target: if verbose: - print(f" Reached target of {num_pairs} pairs. Stopping batch processing.") + print(f" Reached target of {total_target} pairs. Stopping batch processing.") break chunk_index = batch_start + j chunk_pairs = parse_qa_pairs(response) # Only add pairs up to the target limit - remaining_pairs = num_pairs - len(all_qa_pairs) + remaining_pairs = total_target - len(all_qa_pairs) if remaining_pairs > 0: pairs_to_add = chunk_pairs[:remaining_pairs] all_qa_pairs.extend(pairs_to_add) if verbose: - print(f" Generated {len(pairs_to_add)} pairs from chunk {chunk_index+1} (total: {len(all_qa_pairs)}/{num_pairs})") + print(f" Generated {len(pairs_to_add)} pairs from chunk {chunk_index+1} (total: {len(all_qa_pairs)}/{total_target})") # Break if we've reached the target - if len(all_qa_pairs) >= num_pairs: + if len(all_qa_pairs) >= total_target: break # Update progress bar if in verbose mode @@ -206,7 +229,7 @@ def generate_qa_pairs(self, progress_ctx.update(generate_task, advance=current_batch_size) # Break outer loop if we've reached the target - if len(all_qa_pairs) >= num_pairs: + if len(all_qa_pairs) >= total_target: break except Exception as e: @@ -227,7 +250,7 @@ def generate_qa_pairs(self, print("Batch processing complete.") # Always print summary information, even in non-verbose mode - print(f"Generated {len(all_qa_pairs)} QA pairs total (requested: {num_pairs})") + print(f"Generated {len(all_qa_pairs)} QA pairs total (target: {total_target}, mode: {mode})") return all_qa_pairs def rate_qa_pairs(self, @@ -321,9 +344,21 @@ def rate_qa_pairs(self, def process_documents(self, documents: List[Dict[str, Any]], num_pairs: int = 25, + num_pairs_per_chunk: Optional[int] = None, verbose: bool = False, rolling_summary: Optional[bool] = False) -> Dict[str, Any]: - """Process a list of documents to generate QA pairs without rating""" + """Process a list of documents to generate QA pairs without rating + + Args: + documents: List of document dictionaries with 'text' field + num_pairs: Total number of QA pairs to generate (used if num_pairs_per_chunk is None) + num_pairs_per_chunk: Number of QA pairs per chunk (takes precedence over num_pairs) + verbose: Whether to show detailed output + rolling_summary: Whether to use rolling summary for long documents + + Returns: + Dictionary with summary and qa_pairs + """ # Set the verbose environment variable if verbose: os.environ['SDK_VERBOSE'] = 'true' @@ -337,7 +372,7 @@ def process_documents(self, summary = self.generate_summary(full_text, rolling_summary=rolling_summary) # Generate QA pairs - qa_pairs = self.generate_qa_pairs(full_text, summary, num_pairs=num_pairs) + qa_pairs = self.generate_qa_pairs(full_text, summary, num_pairs=num_pairs, num_pairs_per_chunk=num_pairs_per_chunk) all_qa_pairs.extend(qa_pairs) diff --git a/synthetic_data_kit/utils/directory_processor.py b/synthetic_data_kit/utils/directory_processor.py index e0a72534..8da89bf8 100644 --- a/synthetic_data_kit/utils/directory_processor.py +++ b/synthetic_data_kit/utils/directory_processor.py @@ -219,6 +219,7 @@ def process_directory_create( model: Optional[str] = None, content_type: str = "qa", num_pairs: Optional[int] = None, + num_pairs_per_chunk: Optional[int] = None, verbose: bool = False, provider: Optional[str] = None, chunk_size: Optional[int] = None, @@ -233,9 +234,12 @@ def process_directory_create( api_base: API base URL model: Model to use content_type: Type of content to generate (qa, summary, cot, cot-enhance) - num_pairs: Target number of QA pairs or examples + num_pairs: Target number of QA pairs or examples (total per document) + num_pairs_per_chunk: Number of QA pairs per chunk (takes precedence over num_pairs) verbose: Show detailed progress provider: LLM provider to use + chunk_size: Size of text chunks + chunk_overlap: Overlap between chunks Returns: Dictionary with processing results @@ -310,7 +314,8 @@ def process_directory_create( verbose, provider=provider, chunk_size=chunk_size, - chunk_overlap=chunk_overlap + chunk_overlap=chunk_overlap, + num_pairs_per_chunk=num_pairs_per_chunk ) # Record success diff --git a/tests/unit/test_qa_generator.py b/tests/unit/test_qa_generator.py index c294d0b6..11951ddf 100644 --- a/tests/unit/test_qa_generator.py +++ b/tests/unit/test_qa_generator.py @@ -177,3 +177,129 @@ def test_process_document(patch_config): assert "qa_pairs" in result assert result["summary"] == "This is a summary of the document." assert len(result["qa_pairs"]) == 2 + + +@pytest.mark.unit +def test_generate_qa_pairs_per_chunk(patch_config): + """Test generating QA pairs with num_pairs_per_chunk parameter.""" + # Create mock LLM client + mock_client = MagicMock() + # Mock batch_completion to return 2 pairs per chunk (simulating 2 chunks) + mock_client.batch_completion.return_value = [ + json.dumps( + [ + { + "question": "Question from chunk 1?", + "answer": "Answer from chunk 1.", + }, + { + "question": "Question 2 from chunk 1?", + "answer": "Answer 2 from chunk 1.", + } + ] + ), + json.dumps( + [ + { + "question": "Question from chunk 2?", + "answer": "Answer from chunk 2.", + }, + { + "question": "Question 2 from chunk 2?", + "answer": "Answer 2 from chunk 2.", + } + ] + ), + ] + + # Initialize generator + generator = QAGenerator(client=mock_client) + + # Generate QA pairs with num_pairs_per_chunk=2 + # This should generate 2 pairs per chunk, total depends on number of chunks + qa_pairs = generator.generate_qa_pairs( + document_text="This is a document to generate QA pairs from. " * 100, # Long enough to create multiple chunks + summary="This is a summary of the document.", + num_pairs_per_chunk=2, + ) + + # Check that QA pairs were generated (should be 2 per chunk) + assert len(qa_pairs) >= 2 + # Check that client was called + assert mock_client.batch_completion.called + + +@pytest.mark.unit +def test_num_pairs_per_chunk_priority(patch_config): + """Test that num_pairs_per_chunk takes precedence over num_pairs.""" + # Create mock LLM client + mock_client = MagicMock() + mock_client.batch_completion.return_value = [ + json.dumps( + [ + { + "question": "Question 1?", + "answer": "Answer 1.", + }, + { + "question": "Question 2?", + "answer": "Answer 2.", + } + ] + ), + ] + + # Initialize generator + generator = QAGenerator(client=mock_client) + + # Generate QA pairs with both parameters + # num_pairs_per_chunk should take precedence + qa_pairs = generator.generate_qa_pairs( + document_text="Short document.", + summary="Summary.", + num_pairs=100, # This should be ignored + num_pairs_per_chunk=2, # This should be used + ) + + # Should generate based on num_pairs_per_chunk, not num_pairs + assert len(qa_pairs) == 2 + assert mock_client.batch_completion.called + + +@pytest.mark.unit +def test_process_documents_with_num_pairs_per_chunk(patch_config): + """Test process_documents with num_pairs_per_chunk parameter.""" + # Create mock LLM client + mock_client = MagicMock() + mock_client.chat_completion.return_value = "This is a summary of the document." + mock_client.batch_completion.return_value = [ + json.dumps( + [ + { + "question": "Question 1?", + "answer": "Answer 1.", + }, + { + "question": "Question 2?", + "answer": "Answer 2.", + } + ] + ), + ] + + # Initialize generator + generator = QAGenerator(client=mock_client) + + # Process document with num_pairs_per_chunk + result = generator.process_documents( + documents=[{"text": "This is a document to process."}], + num_pairs_per_chunk=2, + verbose=False + ) + + # Check that the result contains summary and QA pairs + assert "summary" in result + assert "qa_pairs" in result + assert result["summary"] == "This is a summary of the document." + assert len(result["qa_pairs"]) >= 2 +