Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions src/llama_prompt_ops/interfaces/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
and optimization using YAML configuration files.
"""

import atexit
import importlib
import importlib.util
import json
Expand All @@ -28,6 +29,7 @@
from dotenv import load_dotenv

# Import template utilities
from llama_prompt_ops.core.utils.logging import LoggingManager, get_logger
from llama_prompt_ops.templates import get_template_content, get_template_path


Expand Down Expand Up @@ -769,8 +771,7 @@ def load_config(config_path):
type=click.Choice(
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False
),
default="INFO",
help="Set the logging level",
help="Set the logging level (overrides config file; defaults to config or INFO if unset)",
)
def migrate(config, model, output_dir, save_yaml, api_key_env, dotenv_path, log_level):
"""
Expand All @@ -782,6 +783,22 @@ def migrate(config, model, output_dir, save_yaml, api_key_env, dotenv_path, log_
Example:
prompt-ops migrate --config configs/facility.yaml
"""
# Load configuration
try:
config_dict = load_config(config)
click.echo(f"Loaded configuration from {config}")
except ValueError as e:
click.echo(f"Error: {str(e)}", err=True)
sys.exit(1)

# Configure logging from file, if not overridden by CLI
export_path = None
if not log_level:
log_config = config_dict.get("logging", {})
# Fallback to INFO if not specified in config
log_level = log_config.get("level", "INFO")
export_path = log_config.get("export_path", None)

# Set up logging
numeric_level = getattr(logging, log_level.upper())
logging.basicConfig(
Expand Down Expand Up @@ -809,27 +826,16 @@ def migrate(config, model, output_dir, save_yaml, api_key_env, dotenv_path, log_
# Get API key using the extracted function
api_key = check_api_key(api_key_env, dotenv_path)

# Load configuration
try:
config_dict = load_config(config)
click.echo(f"Loaded configuration from {config}")
except ValueError as e:
click.echo(f"Error: {str(e)}", err=True)
sys.exit(1)

# Configure logging from file, if not overridden by CLI
if not log_level:
log_config = config_dict.get("logging", {})
level = log_config.get("level", "INFO")
logger.set_level(level)
export_path = log_config.get("export_path")
if export_path:
# Replace timestamp placeholder
if "${TIMESTAMP}" in export_path:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
export_path = export_path.replace("${TIMESTAMP}", timestamp)
atexit.register(logger.export_json, export_path)
logger.info(f"Will export logs to {export_path} on exit.")
# Export logs on exit if export_path is specified
if export_path:
logging_manager: LoggingManager = get_logger()
logging_manager.set_level(log_level)
# Replace timestamp placeholder
if "${TIMESTAMP}" in export_path:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
export_path = export_path.replace("${TIMESTAMP}", timestamp)
atexit.register(logging_manager.export_json, export_path)
logging.info(f"Will export logs to {export_path} on exit.")

# Set up models from config

Expand Down
64 changes: 64 additions & 0 deletions tests/integration/test_cli_integration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import tempfile
from pathlib import Path
Expand Down Expand Up @@ -316,3 +317,66 @@ def test_end_to_end_cli_flow(self, mock_api_key_check, temp_config_file):
# Clean up the temporary output file
if os.path.exists(output_path):
os.unlink(output_path)

def test_migrate_log_level_from_config(self):
"""Test the logging level which is set from config file."""
# Use Click's test runner instead of directly calling cli()
from click.testing import CliRunner

runner = CliRunner()

# Create mock objects for the migrator and its methods
mock_migrator = MagicMock()
mock_dataset_adapter = MagicMock()
mock_optimized = MagicMock()
mock_optimized.signature.instructions = "Optimized prompt"

# Set up return values for the mocked methods
mock_migrator.load_dataset_with_adapter.return_value = ([], [], [])
mock_migrator.optimize.return_value = mock_optimized

# Set up multiple patches
with (
patch(
"llama_prompt_ops.interfaces.cli.PromptMigrator",
return_value=mock_migrator,
),
patch(
"llama_prompt_ops.interfaces.cli.get_dataset_adapter_from_config",
return_value=mock_dataset_adapter,
),
patch(
"llama_prompt_ops.interfaces.cli.get_models_from_config",
return_value=(None, None, "test_task_model", "test_prompt_model"),
),
patch(
"llama_prompt_ops.interfaces.cli.get_metric", return_value=MagicMock()
),
patch(
"llama_prompt_ops.interfaces.cli.get_strategy", return_value=MagicMock()
),
patch(
"llama_prompt_ops.interfaces.cli.load_config",
return_value={"logging": {"level": "DEBUG"}},
),
patch(
"llama_prompt_ops.interfaces.cli.validate_min_records_in_dataset",
return_value=None,
),
patch("logging.basicConfig") as mock_basic_config,
):

# Run the migrate command
result = runner.invoke(cli, ["migrate"])

# Print the output for debugging
if result.exit_code != 0:
print(f"Command failed with exit code {result.exit_code}")
print(f"Output: {result.output}")
if result.exception:
print(f"Exception: {result.exception}")

mock_basic_config.assert_called_once()
_, kwargs = mock_basic_config.call_args
# Assert the log level used in basic config
assert kwargs["level"] == logging.DEBUG