diff --git a/synthetic_data_kit/config.yaml b/synthetic_data_kit/config.yaml index 2690697d..57e8842f 100644 --- a/synthetic_data_kit/config.yaml +++ b/synthetic_data_kit/config.yaml @@ -15,15 +15,30 @@ paths: # LLM Provider configuration llm: # Provider selection: "vllm" or "api-endpoint" - provider: "api-endpoint" - + provider: "api-endpoint" # Provider selection: "vllm", "api-endpoint", or "ollama" # VLLM server configuration vllm: api_base: "http://localhost:8000/v1" # Base URL for VLLM API port: 8000 # Port for VLLM server model: "meta-llama/Llama-3.3-70B-Instruct" # Default model to use max_retries: 3 # Number of retries for API calls - retry_delay: 1.0 # Initial delay between retries (seconds) + retry_delay: 1.0 + +# Ollama configuration +ollama: + api_base: "http://localhost:11434/api" # Base URL for Ollama API + model: "hf.co/unsloth/Qwen3-30B-A3B-Instruct-2507-GGUF:Q8_K_XL" # Default model to use (e.g., llama3, codellama, mistral, etc.) + max_retries: 3 # Number of retries for API calls + retry_delay: 1.0 # Initial delay between retries (seconds) + timeout: 120 # Request timeout in seconds + keep_alive: "5m" # How long to keep model loaded in memory (e.g., "5m", "1h") + num_ctx: 4096 # Size of context window + temperature: 0.7 # Model temperature (0.0-1.0) + top_p: 0.9 # Nucleus sampling parameter + # Ollama-specific settings + host: "localhost" # Ollama server host + port: 11434 # Ollama server port + pull_model: true # Whether to automatically pull the model if not available # Initial delay between retries (seconds) # API endpoint configuration api-endpoint: diff --git a/synthetic_data_kit/utils/text.py b/synthetic_data_kit/utils/text.py index 05886cdf..9695d754 100644 --- a/synthetic_data_kit/utils/text.py +++ b/synthetic_data_kit/utils/text.py @@ -9,30 +9,95 @@ from typing import List, Dict, Any def split_into_chunks(text: str, chunk_size: int = 4000, overlap: int = 200) -> List[str]: - """Split text into chunks with optional overlap""" - paragraphs = text.split("\n\n") + """Split text into chunks with optional overlap using hierarchical approach""" + + # Ensure overlap is not larger than chunk_size + overlap = min(overlap, chunk_size // 2) + + if "\n\n" in text: + # Paragraph-based splitting + segments = text.split("\n\n") + join_str = "\n\n" + overlap_split_str = '. ' + elif "\n" in text: + # Line-based splitting + segments = text.split("\n") + join_str = "\n" + overlap_split_str = '. ' + else: + # Sentence-based splitting - try to find sentences first + sentences = re.split(r'([.!?])\s+(?=[A-Z])', text) + segments = [] + for i in range(0, len(sentences) - 1, 2): + if i + 1 < len(sentences): + sentence = sentences[i] + sentences[i + 1] + segments.append(sentence.strip()) + + # Add remaining part if exists + if len(sentences) % 2 == 1: + segments.append(sentences[-1].strip()) + + # If no proper sentences found, fall back to word splitting + if len(segments) <= 1 and len(text) > chunk_size: + segments = text.split(' ') + join_str = " " + overlap_split_str = ' ' + else: + join_str = " " + overlap_split_str = '. ' + chunks = [] current_chunk = "" - for para in paragraphs: - if len(current_chunk) + len(para) > chunk_size and current_chunk: - chunks.append(current_chunk) - # Keep some overlap for context - sentences = current_chunk.split('. ') - if len(sentences) > 3: - current_chunk = '. '.join(sentences[-3:]) + "\n\n" + para + for segment in segments: + potential_length = len(current_chunk) + (len(join_str) if current_chunk else 0) + len(segment) + + if potential_length > chunk_size and current_chunk: + chunks.append(current_chunk.strip()) + + # Create overlap for next chunk + if overlap > 0: + overlap_parts = current_chunk.split(overlap_split_str) + if len(overlap_parts) > 1: + # Keep overlap amount of characters from the end + overlap_text = current_chunk[-overlap:] if len(current_chunk) > overlap else current_chunk + space_pos = overlap_text.find(' ') + if space_pos > 0: + overlap_text = overlap_text[space_pos + 1:] + current_chunk = overlap_text + join_str + segment + else: + current_chunk = segment else: - current_chunk = para + current_chunk = segment else: if current_chunk: - current_chunk += "\n\n" + para + current_chunk += join_str + segment else: - current_chunk = para + current_chunk = segment + # Add final chunk if it exists if current_chunk: - chunks.append(current_chunk) + chunks.append(current_chunk.strip()) + + # Fallback: if only one chunk and text is longer than chunk_size, force character-based splitting + if len(chunks) == 1 and len(text) > chunk_size: + chunks = [] + step_size = max(1, chunk_size - overlap) + + for i in range(0, len(text), step_size): + chunk_end = min(i + chunk_size, len(text)) + chunk = text[i:chunk_end] + + # Try to end at word boundary if not at end + if chunk_end < len(text) and ' ' in chunk: + last_space = chunk.rfind(' ') + if last_space > len(chunk) * 0.7: # Don't lose too much content + chunk = chunk[:last_space] + + if chunk.strip(): + chunks.append(chunk.strip()) - return chunks + return [chunk for chunk in chunks if chunk.strip()] def extract_json_from_text(text: str) -> Dict[str, Any]: """Extract JSON from text that might contain markdown or other content""" diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 80446b1d..1e5e4dbb 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -7,7 +7,6 @@ from synthetic_data_kit.utils import config, text -@pytest.mark.unit def test_split_into_chunks(): """Test splitting text into chunks.""" # Create multi-paragraph text @@ -33,7 +32,21 @@ def test_split_into_chunks(): # Empty text should produce an empty list, not a list with an empty string assert empty_chunks == [] + non_paragraph_tests = [ + "This is a sample example of inputs without paragraphs.\n" *16, + "This is a sample example of inputs without sentences." *16, + ] + # Using a small chunk size to ensure splitting + non_paragraph_chunks = text.split_into_chunks(non_paragraph_tests[0], chunk_size=20, overlap=10) + assert len(non_paragraph_chunks) > 1 + assert len(non_paragraph_chunks) >= 16 + + non_sentence_chunks = text.split_into_chunks(non_paragraph_tests[1], chunk_size=20, overlap=5) + assert len(non_sentence_chunks) > 1 + assert len(non_sentence_chunks) >=20 + +test_split_into_chunks() @pytest.mark.unit def test_extract_json_from_text(): """Test extracting JSON from text."""