Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/changes/2921.api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
As a consequence of fixing the bug #2921, `ctapipe.core.ExpressionEngine`
converts all input tables to `astropy.table.QTable` internally, which has a
small side effect on what is allowed in expressions: all columns with units are
now of type `astropy.units.Quantity`, instead of `astropy.table.Column`. Before,
an expression like ``"some_column.quantity.to(u.m)"`` would work if a ``Table``
was passed (but would fail for a ``QTable``). Now, that expression should be
``some_column.to(u.m)``
4 changes: 4 additions & 0 deletions docs/changes/2921.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fixed bug where units were incorrect in the output table of an
`ctapipe.core.FeatureGenerator` if a table of class `astropy.table.Table` was
passed to the call method. This bug did not affect calls using an
`astropy.table.QTable`.
2 changes: 2 additions & 0 deletions src/ctapipe/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .component import Component, non_abstract_children
from .container import Container, DeprecatedField, Field, FieldValidationError, Map
from .expression_engine import ExpressionEngine
from .feature_generator import FeatureGenerator
from .provenance import Provenance, get_module_version
from .qualityquery import QualityCriteriaError, QualityQuery
Expand All @@ -28,4 +29,5 @@
"QualityQuery",
"QualityCriteriaError",
"FieldValidationError",
"ExpressionEngine",
]
56 changes: 44 additions & 12 deletions src/ctapipe/core/feature_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
"""

from collections import ChainMap
from copy import deepcopy

from astropy.table import QTable, Table

from .component import Component
from .expression_engine import ExpressionEngine
Expand All @@ -11,19 +14,31 @@
__all__ = [
"FeatureGenerator",
"FeatureGeneratorException",
"shallow_copy_table",
]


def _shallow_copy_table(table):
def shallow_copy_table(
table, output_cls: type[Table] | type[QTable] | None = None
) -> Table | QTable:
"""
Make a shallow copy of the table.

Data of the existing columns will be shared between shallow
copies, but adding / removing columns won't be seen in
the original table.
Data of the existing columns will be shared between shallow copies, but
adding / removing columns won't be seen in the original table. Metadata for
the new table will be a copy (not shallow) of the original metadata, so that
new metadata can be added without affecting the original table.

Parameters
----------
output_cls: type[Table] | type[QTable] | None
type of the output table. If None, use the input table type
"""
# automatically return Table or QTable depending on input
return table.__class__({col: table[col] for col in table.colnames}, copy=False)
output_cls = output_cls or table.__class__

new_table = output_cls({col: table[col] for col in table.colnames}, copy=False)
new_table.meta = deepcopy(table.meta)
return new_table


class FeatureGeneratorException(TypeError):
Expand Down Expand Up @@ -54,26 +69,43 @@ def __init__(self, config=None, parent=None, **kwargs):
self.engine = ExpressionEngine(expressions=self.features)
self._feature_names = [name for name, _ in self.features]

def __call__(self, table, **kwargs):
def __call__(self, table: Table | QTable, **kwargs) -> Table:
"""
Apply feature generation to the input table.

This method returns a shallow copy of the input table with the
new features added. Existing columns will share the underlying data,
however the new columns won't be visible in the input table.

Parameters
----------
table: QTable | Table
Input table. Internally a Table will be converted to a QTable so that
unit propagation works, so expressions should only rely on properties of QTables.
**kwargs:
Other objects that should be available in expressions. For example,
if a you pass ``subarray=subarray``, expressions can use that
object. This can also be special functions like ``f=my_function``,
which would allow an expression like ``"f(col1)"``.

Returns
-------
QTable|Table:
A new table with the same columns as the input, but with new columns
for each feature. The returned class depends on what was passed in.
"""
table = _shallow_copy_table(table)
lookup = ChainMap(table, kwargs)
table_copy = shallow_copy_table(table, output_cls=QTable)
lookup = ChainMap(table_copy, kwargs)

for result, name in zip(self.engine(lookup), self._feature_names):
if name in table.colnames:
if name in table_copy.colnames:
raise FeatureGeneratorException(f"{name} is already a column of table.")
try:
table[name] = result
table_copy[name] = result
except Exception as err:
raise err

return table
return table.__class__(table_copy) # ensure the return type is what is expected

def __len__(self):
return len(self.features)
49 changes: 46 additions & 3 deletions src/ctapipe/core/tests/test_feature_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import pytest
from astropy.table import Table
from astropy.table import QTable, Table

from ctapipe.core.expression_engine import ExpressionError
from ctapipe.core.feature_generator import FeatureGenerator, FeatureGeneratorException
Expand Down Expand Up @@ -60,14 +60,15 @@ def test_to_unit():

expressions = [
("length_meter", "length.to(u.m)"),
("log_length_meter", "log10(length.quantity.to_value(u.m))"),
("log_length_meter", "log10(length.to_value(u.m))"),
]
generator = FeatureGenerator(features=expressions)
table = Table({"length": [1 * u.km]})

table = generator(table)
assert table["length_meter"] == 1000
assert table["length_meter"] == 1000 * u.m
assert table["log_length_meter"] == 3
assert table["length_meter"].unit == u.m


def test_multiplicity(subarray_prod5_paranal):
Expand Down Expand Up @@ -102,3 +103,45 @@ def test_multiplicity(subarray_prod5_paranal):
np.testing.assert_equal(table["n_lsts"], [1, 2])
np.testing.assert_equal(table["n_msts"], [2, 1])
np.testing.assert_equal(table["n_ssts"], [0, 1])


@pytest.mark.parametrize("table_class", [QTable, Table])
def test_unit_propagation(table_class):
"""
Check that units propagate to features.

If a column in the input table has a unit, and a feature does math on that
unit, the feature should have the appropriate unit.
"""

import astropy.units as u

table = table_class(dict(x=np.arange(11) * u.cm, E=np.linspace(-2, 2, 11) * u.TeV))
features = [
("x2", "x**2"),
("E_per_area", "E/x**2"),
]

feature_gen = FeatureGenerator(features=features)
new_table = feature_gen(table)

assert new_table["x2"].unit.is_equivalent("cm2")
assert new_table["E_per_area"].unit.is_equivalent("TeV cm-2")


@pytest.mark.parametrize("table_class", [QTable, Table])
def test_input_output_class(table_class):
"""Ensure output table class is same as input."""

import astropy.units as u

table = table_class(dict(x=np.arange(11) * u.cm, E=np.linspace(-2, 2, 11) * u.TeV))
features = [
("x2", "x**2"),
("E_per_area", "E/x**2"),
]

feature_gen = FeatureGenerator(features=features)
new_table = feature_gen(table)

assert new_table.__class__ == table.__class__