|
15 | 15 | Tests for utility functions |
16 | 16 | """ |
17 | 17 |
|
| 18 | +import matplotlib |
| 19 | +import matplotlib.pyplot as plt |
18 | 20 | import numpy as np |
19 | 21 | import pandas as pd |
20 | 22 | import pytest |
|
26 | 28 | check_convex_hull_violation, |
27 | 29 | extract_lift_for_mmm, |
28 | 30 | get_interaction_terms, |
| 31 | + plot_correlations, |
29 | 32 | round_num, |
30 | 33 | ) |
31 | 34 |
|
@@ -369,3 +372,87 @@ def test_extract_lift_for_mmm_raises_for_ols(): |
369 | 372 | x=0.0, |
370 | 373 | delta_x=1000, |
371 | 374 | ) |
| 375 | + |
| 376 | + |
| 377 | +# ============================================================================ |
| 378 | +# Tests for plot_correlations |
| 379 | +# ============================================================================ |
| 380 | + |
| 381 | + |
| 382 | +@pytest.fixture |
| 383 | +def panel_data(): |
| 384 | + """Simple wide-format panel data for correlation tests.""" |
| 385 | + rng = np.random.default_rng(0) |
| 386 | + n = 50 |
| 387 | + base = np.sin(np.linspace(0, 4 * np.pi, n)) |
| 388 | + return pd.DataFrame( |
| 389 | + { |
| 390 | + "A": base + rng.normal(0, 0.1, n), |
| 391 | + "B": base + rng.normal(0, 0.1, n), |
| 392 | + "C": -base + rng.normal(0, 0.1, n), |
| 393 | + } |
| 394 | + ) |
| 395 | + |
| 396 | + |
| 397 | +def test_plot_correlations_returns_matrix_and_axes(panel_data): |
| 398 | + corr, ax = plot_correlations(panel_data) |
| 399 | + assert isinstance(corr, pd.DataFrame) |
| 400 | + assert corr.shape == (3, 3) |
| 401 | + assert isinstance(ax, matplotlib.axes.Axes) |
| 402 | + plt.close("all") |
| 403 | + |
| 404 | + |
| 405 | +def test_plot_correlations_diagonal_is_one(panel_data): |
| 406 | + corr, _ = plot_correlations(panel_data) |
| 407 | + np.testing.assert_allclose(np.diag(corr.values), 1.0) |
| 408 | + plt.close("all") |
| 409 | + |
| 410 | + |
| 411 | +def test_plot_correlations_symmetric(panel_data): |
| 412 | + corr, _ = plot_correlations(panel_data) |
| 413 | + np.testing.assert_allclose(corr.values, corr.values.T) |
| 414 | + plt.close("all") |
| 415 | + |
| 416 | + |
| 417 | +def test_plot_correlations_column_subset(panel_data): |
| 418 | + corr, _ = plot_correlations(panel_data, columns=["A", "B"]) |
| 419 | + assert corr.shape == (2, 2) |
| 420 | + assert list(corr.columns) == ["A", "B"] |
| 421 | + plt.close("all") |
| 422 | + |
| 423 | + |
| 424 | +def test_plot_correlations_custom_ax(panel_data): |
| 425 | + fig, provided_ax = plt.subplots() |
| 426 | + _, returned_ax = plot_correlations(panel_data, ax=provided_ax) |
| 427 | + assert returned_ax is provided_ax |
| 428 | + plt.close("all") |
| 429 | + |
| 430 | + |
| 431 | +def test_plot_correlations_kwargs_forwarded(panel_data): |
| 432 | + corr, _ = plot_correlations(panel_data, annot=False, vmin=0) |
| 433 | + assert isinstance(corr, pd.DataFrame) |
| 434 | + plt.close("all") |
| 435 | + |
| 436 | + |
| 437 | +# ============================================================================ |
| 438 | +# Tests for SyntheticControl._pre_treatment_correlations |
| 439 | +# ============================================================================ |
| 440 | + |
| 441 | + |
| 442 | +def test_pre_treatment_correlations_single_unit(sc_result_single_unit): |
| 443 | + corrs = sc_result_single_unit._pre_treatment_correlations() |
| 444 | + assert "actual" in corrs |
| 445 | + assert 0 < corrs["actual"] <= 1.0 |
| 446 | + |
| 447 | + |
| 448 | +def test_pre_treatment_correlations_multi_unit(sc_result_multi_unit): |
| 449 | + corrs = sc_result_multi_unit._pre_treatment_correlations() |
| 450 | + assert set(corrs.keys()) == {"t1", "t2"} |
| 451 | + for r in corrs.values(): |
| 452 | + assert -1 <= r <= 1 |
| 453 | + |
| 454 | + |
| 455 | +def test_summary_prints_correlation(sc_result_single_unit, capsys): |
| 456 | + sc_result_single_unit.summary() |
| 457 | + captured = capsys.readouterr() |
| 458 | + assert "Pre-treatment correlation" in captured.out |
0 commit comments