diff --git a/docs/changes/2921.api.rst b/docs/changes/2921.api.rst new file mode 100644 index 00000000000..e9be2f25386 --- /dev/null +++ b/docs/changes/2921.api.rst @@ -0,0 +1,7 @@ +As a consequence of fixing the bug #2921, `ctapipe.core.ExpressionEngine` +converts all input tables to `astropy.table.QTable` internally, which has a +small side effect on what is allowed in expressions: all columns with units are +now of type `astropy.units.Quantity`, instead of `astropy.table.Column`. Before, +an expression like ``"some_column.quantity.to(u.m)"`` would work if a ``Table`` +was passed (but would fail for a ``QTable``). Now, that expression should be +``some_column.to(u.m)`` diff --git a/docs/changes/2921.bugfix.rst b/docs/changes/2921.bugfix.rst new file mode 100644 index 00000000000..a8b0e8c1d27 --- /dev/null +++ b/docs/changes/2921.bugfix.rst @@ -0,0 +1,4 @@ +Fixed bug where units were incorrect in the output table of an +`ctapipe.core.FeatureGenerator` if a table of class `astropy.table.Table` was +passed to the call method. This bug did not affect calls using an +`astropy.table.QTable`. diff --git a/src/ctapipe/core/__init__.py b/src/ctapipe/core/__init__.py index f83471c864f..f5d7dffb4c3 100644 --- a/src/ctapipe/core/__init__.py +++ b/src/ctapipe/core/__init__.py @@ -5,6 +5,7 @@ from .component import Component, non_abstract_children from .container import Container, DeprecatedField, Field, FieldValidationError, Map +from .expression_engine import ExpressionEngine from .feature_generator import FeatureGenerator from .provenance import Provenance, get_module_version from .qualityquery import QualityCriteriaError, QualityQuery @@ -28,4 +29,5 @@ "QualityQuery", "QualityCriteriaError", "FieldValidationError", + "ExpressionEngine", ] diff --git a/src/ctapipe/core/feature_generator.py b/src/ctapipe/core/feature_generator.py index 4d330b8fedc..fbb1f02fe81 100644 --- a/src/ctapipe/core/feature_generator.py +++ b/src/ctapipe/core/feature_generator.py @@ -3,6 +3,9 @@ """ from collections import ChainMap +from copy import deepcopy + +from astropy.table import QTable, Table from .component import Component from .expression_engine import ExpressionEngine @@ -11,19 +14,31 @@ __all__ = [ "FeatureGenerator", "FeatureGeneratorException", + "shallow_copy_table", ] -def _shallow_copy_table(table): +def shallow_copy_table( + table, output_cls: type[Table] | type[QTable] | None = None +) -> Table | QTable: """ Make a shallow copy of the table. - Data of the existing columns will be shared between shallow - copies, but adding / removing columns won't be seen in - the original table. + Data of the existing columns will be shared between shallow copies, but + adding / removing columns won't be seen in the original table. Metadata for + the new table will be a copy (not shallow) of the original metadata, so that + new metadata can be added without affecting the original table. + + Parameters + ---------- + output_cls: type[Table] | type[QTable] | None + type of the output table. If None, use the input table type """ - # automatically return Table or QTable depending on input - return table.__class__({col: table[col] for col in table.colnames}, copy=False) + output_cls = output_cls or table.__class__ + + new_table = output_cls({col: table[col] for col in table.colnames}, copy=False) + new_table.meta = deepcopy(table.meta) + return new_table class FeatureGeneratorException(TypeError): @@ -54,26 +69,43 @@ def __init__(self, config=None, parent=None, **kwargs): self.engine = ExpressionEngine(expressions=self.features) self._feature_names = [name for name, _ in self.features] - def __call__(self, table, **kwargs): + def __call__(self, table: Table | QTable, **kwargs) -> Table: """ Apply feature generation to the input table. This method returns a shallow copy of the input table with the new features added. Existing columns will share the underlying data, however the new columns won't be visible in the input table. + + Parameters + ---------- + table: QTable | Table + Input table. Internally a Table will be converted to a QTable so that + unit propagation works, so expressions should only rely on properties of QTables. + **kwargs: + Other objects that should be available in expressions. For example, + if a you pass ``subarray=subarray``, expressions can use that + object. This can also be special functions like ``f=my_function``, + which would allow an expression like ``"f(col1)"``. + + Returns + ------- + QTable|Table: + A new table with the same columns as the input, but with new columns + for each feature. The returned class depends on what was passed in. """ - table = _shallow_copy_table(table) - lookup = ChainMap(table, kwargs) + table_copy = shallow_copy_table(table, output_cls=QTable) + lookup = ChainMap(table_copy, kwargs) for result, name in zip(self.engine(lookup), self._feature_names): - if name in table.colnames: + if name in table_copy.colnames: raise FeatureGeneratorException(f"{name} is already a column of table.") try: - table[name] = result + table_copy[name] = result except Exception as err: raise err - return table + return table.__class__(table_copy) # ensure the return type is what is expected def __len__(self): return len(self.features) diff --git a/src/ctapipe/core/tests/test_feature_generator.py b/src/ctapipe/core/tests/test_feature_generator.py index 3ba3fb0d153..65f88c0adcd 100644 --- a/src/ctapipe/core/tests/test_feature_generator.py +++ b/src/ctapipe/core/tests/test_feature_generator.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from astropy.table import Table +from astropy.table import QTable, Table from ctapipe.core.expression_engine import ExpressionError from ctapipe.core.feature_generator import FeatureGenerator, FeatureGeneratorException @@ -60,14 +60,15 @@ def test_to_unit(): expressions = [ ("length_meter", "length.to(u.m)"), - ("log_length_meter", "log10(length.quantity.to_value(u.m))"), + ("log_length_meter", "log10(length.to_value(u.m))"), ] generator = FeatureGenerator(features=expressions) table = Table({"length": [1 * u.km]}) table = generator(table) - assert table["length_meter"] == 1000 + assert table["length_meter"] == 1000 * u.m assert table["log_length_meter"] == 3 + assert table["length_meter"].unit == u.m def test_multiplicity(subarray_prod5_paranal): @@ -102,3 +103,45 @@ def test_multiplicity(subarray_prod5_paranal): np.testing.assert_equal(table["n_lsts"], [1, 2]) np.testing.assert_equal(table["n_msts"], [2, 1]) np.testing.assert_equal(table["n_ssts"], [0, 1]) + + +@pytest.mark.parametrize("table_class", [QTable, Table]) +def test_unit_propagation(table_class): + """ + Check that units propagate to features. + + If a column in the input table has a unit, and a feature does math on that + unit, the feature should have the appropriate unit. + """ + + import astropy.units as u + + table = table_class(dict(x=np.arange(11) * u.cm, E=np.linspace(-2, 2, 11) * u.TeV)) + features = [ + ("x2", "x**2"), + ("E_per_area", "E/x**2"), + ] + + feature_gen = FeatureGenerator(features=features) + new_table = feature_gen(table) + + assert new_table["x2"].unit.is_equivalent("cm2") + assert new_table["E_per_area"].unit.is_equivalent("TeV cm-2") + + +@pytest.mark.parametrize("table_class", [QTable, Table]) +def test_input_output_class(table_class): + """Ensure output table class is same as input.""" + + import astropy.units as u + + table = table_class(dict(x=np.arange(11) * u.cm, E=np.linspace(-2, 2, 11) * u.TeV)) + features = [ + ("x2", "x**2"), + ("E_per_area", "E/x**2"), + ] + + feature_gen = FeatureGenerator(features=features) + new_table = feature_gen(table) + + assert new_table.__class__ == table.__class__