From 7ebc6b87359962c853fc10791aea326431e6c11c Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Fri, 13 Mar 2026 06:36:13 -0700 Subject: [PATCH] poc --- src/torchcodec/_core/CMakeLists.txt | 1 + src/torchcodec/_core/CpuDeviceInterface.cpp | 7 + src/torchcodec/_core/StreamOptions.h | 5 + src/torchcodec/_core/ToneMap.cpp | 420 ++++++++++++++++++++ src/torchcodec/_core/ToneMap.h | 25 ++ src/torchcodec/_core/_decoder_utils.py | 2 + src/torchcodec/_core/custom_ops.cpp | 24 +- src/torchcodec/_core/ops.py | 2 + src/torchcodec/decoders/_video_decoder.py | 2 + 9 files changed, 483 insertions(+), 5 deletions(-) create mode 100644 src/torchcodec/_core/ToneMap.cpp create mode 100644 src/torchcodec/_core/ToneMap.h diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 8bf4f21f1..23190c0ec 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -137,6 +137,7 @@ function(make_torchcodec_libraries Transform.cpp Metadata.cpp SwScale.cpp + ToneMap.cpp NVDECCacheConfig.cpp ) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 1049cd500..193de7460 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -5,6 +5,7 @@ // LICENSE file in the root directory of this source tree. #include "CpuDeviceInterface.h" +#include "ToneMap.h" namespace facebook::torchcodec { namespace { @@ -176,6 +177,12 @@ void CpuDeviceInterface::convertVideoAVFrameToFrameOutput( // Both cases cause problems for our batch APIs, as we allocate // FrameBatchOutputs based on the the stream metadata. But single-frame APIs // can still work in such situations, so they should. + // If tone mapping is enabled and the frame is HDR, convert to SDR RGB24. + if (videoStreamOptions_.toneMapping.has_value() && + isHDRFrame(avFrame.get())) { + avFrame = toneMapHDRFrame(avFrame); + } + auto outputDims = resizedOutputDims_.value_or(FrameDims(avFrame->height, avFrame->width)); diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index 6cab3c8e8..4ff8fcb80 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -47,6 +47,11 @@ struct VideoStreamOptions { // Device variant (e.g., "ffmpeg", "beta", etc.) std::string_view deviceVariant = "ffmpeg"; + // Tone mapping algorithm for HDR→SDR conversion. + // If set, HDR frames (PQ/HLG) will be tone-mapped to SDR. + // Supported values: "hable". + std::optional toneMapping; + // Encoding options std::optional codec; // Optional pixel format for video encoding (e.g., "yuv420p", "yuv444p") diff --git a/src/torchcodec/_core/ToneMap.cpp b/src/torchcodec/_core/ToneMap.cpp new file mode 100644 index 000000000..08236c917 --- /dev/null +++ b/src/torchcodec/_core/ToneMap.cpp @@ -0,0 +1,420 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "ToneMap.h" + +#include + +#include +#include +#include +#include + +extern "C" { +#include +#include +#include +#include +#include +#include +} + +namespace facebook::torchcodec { + +namespace { + +// --------------------------------------------------------------------------- +// PQ (SMPTE ST 2084) constants and EOTF +// --------------------------------------------------------------------------- +constexpr double PQ_M1 = 0.1593017578125; // = 2610/16384 +constexpr double PQ_M2 = 78.84375; // = 2523*128/4096 +constexpr double PQ_C1 = 0.8359375; // = 3424/4096 +constexpr double PQ_C2 = 18.8515625; // = 2413*32/4096 +constexpr double PQ_C3 = 18.6875; // = 2392*32/4096 + +// PQ EOTF: signal E in [0,1] → linear luminance in nits [0, 10000] +inline double pqEOTF(double E) { + double Em = std::pow(E, 1.0 / PQ_M2); + double num = std::max(Em - PQ_C1, 0.0); + double den = PQ_C2 - PQ_C3 * Em; + if (den <= 0.0) { + return 0.0; + } + return 10000.0 * std::pow(num / den, 1.0 / PQ_M1); +} + +// --------------------------------------------------------------------------- +// HLG (ARIB STD-B67) inverse OETF + OOTF +// --------------------------------------------------------------------------- +constexpr double HLG_A = 0.17883277; +constexpr double HLG_B = 0.28466892; +constexpr double HLG_C = 0.55991073; + +// HLG inverse OETF: signal E in [0,1] → scene-linear [0,1] +inline double hlgInverseOETF(double E) { + if (E <= 0.0) { + return 0.0; + } + if (E <= 0.5) { + return E * E / 3.0; + } + return (std::exp((E - HLG_C) / HLG_A) + HLG_B) / 12.0; +} + +// --------------------------------------------------------------------------- +// BT.2020 NCL YCbCr → R'G'B' (non-linear signal, [0, 1]) +// --------------------------------------------------------------------------- +// BT.2020 coefficients: Kr=0.2627, Kb=0.0593 +constexpr double BT2020_KR = 0.2627; +constexpr double BT2020_KB = 0.0593; +constexpr double BT2020_KG = 1.0 - BT2020_KR - BT2020_KB; + +// --------------------------------------------------------------------------- +// BT.2020 → BT.709 gamut mapping matrix (3x3, on linear RGB) +// +// Derived from the chromaticity coordinates of BT.2020 and BT.709 primaries +// with the D65 white point. M = M_XYZ_to_709 * M_2020_to_XYZ. +// --------------------------------------------------------------------------- +// clang-format off +constexpr double GAMUT_MAP[3][3] = { + { 1.6605, -0.5877, -0.0728}, + {-0.1246, 1.1330, -0.0084}, + {-0.0182, -0.1006, 1.1187} +}; +// clang-format on + +// --------------------------------------------------------------------------- +// Hable (Uncharted 2) tone mapping curve +// --------------------------------------------------------------------------- +constexpr double HABLE_A = 0.15; +constexpr double HABLE_B = 0.50; +constexpr double HABLE_C = 0.10; +constexpr double HABLE_D = 0.20; +constexpr double HABLE_E = 0.02; +constexpr double HABLE_F = 0.30; + +inline double hable(double x) { + return ((x * (HABLE_A * x + HABLE_C * HABLE_B) + HABLE_D * HABLE_E) / + (x * (HABLE_A * x + HABLE_B) + HABLE_D * HABLE_F)) - + HABLE_E / HABLE_F; +} + +// --------------------------------------------------------------------------- +// BT.709 OETF: linear L [0,1] → non-linear V [0,1] +// --------------------------------------------------------------------------- +inline double bt709OETF(double L) { + if (L < 0.018) { + return 4.5 * L; + } + return 1.099 * std::pow(L, 0.45) - 0.099; +} + +// --------------------------------------------------------------------------- +// Signal peak detection from frame side data +// --------------------------------------------------------------------------- +[[maybe_unused]] double getSignalPeakNits(const AVFrame* frame) { + // Try to get MaxCLL from content light level metadata. + const AVFrameSideData* cllSD = + av_frame_get_side_data(frame, AV_FRAME_DATA_CONTENT_LIGHT_LEVEL); + if (cllSD) { + auto* lightLevel = + reinterpret_cast(cllSD->data); + if (lightLevel->MaxCLL > 0) { + return static_cast(lightLevel->MaxCLL); + } + } + + // Try mastering display metadata. + const AVFrameSideData* mdSD = + av_frame_get_side_data(frame, AV_FRAME_DATA_MASTERING_DISPLAY_METADATA); + if (mdSD) { + auto* mastering = + reinterpret_cast(mdSD->data); + if (mastering->has_luminance && av_q2d(mastering->max_luminance) > 0) { + return av_q2d(mastering->max_luminance); + } + } + + // Default: PQ system peak is 10000 nits, HLG nominal peak is 1000 nits. + if (frame->color_trc == AVCOL_TRC_SMPTE2084) { + return 10000.0; + } + return 1000.0; +} + +// --------------------------------------------------------------------------- +// Read a 10-bit sample from a YUV plane. +// Handles both yuv420p10le (planar, 2 bytes per sample little-endian) and +// p010le (semi-planar, 2 bytes per sample with upper 10 bits). +// --------------------------------------------------------------------------- +inline uint16_t read10BitSample(const uint8_t* data, bool isP010) { + uint16_t val = data[0] | (static_cast(data[1]) << 8); + if (isP010) { + // P010 stores 10-bit values in the upper 10 bits of a 16-bit word. + val >>= 6; + } + return val; +} + +// --------------------------------------------------------------------------- +// PQ linearization LUT for 10-bit input (1024 entries). +// Maps normalized signal [0/1023 .. 1023/1023] → linear nits [0, 10000]. +// Built once, lazily. +// --------------------------------------------------------------------------- +class PQ_LUT { + public: + static const PQ_LUT& instance() { + static PQ_LUT lut; + return lut; + } + + double operator[](int idx) const { + return table_[idx]; + } + + private: + PQ_LUT() { + for (int i = 0; i < 1024; ++i) { + double signal = static_cast(i) / 1023.0; + table_[i] = pqEOTF(signal); + } + } + + double table_[1024]; +}; + +// --------------------------------------------------------------------------- +// HLG inverse OETF LUT for 10-bit input. +// Maps signal → scene-linear. +// --------------------------------------------------------------------------- +class HLG_LUT { + public: + static const HLG_LUT& instance() { + static HLG_LUT lut; + return lut; + } + + double operator[](int idx) const { + return table_[idx]; + } + + private: + HLG_LUT() { + for (int i = 0; i < 1024; ++i) { + double signal = static_cast(i) / 1023.0; + table_[i] = hlgInverseOETF(signal); + } + } + + double table_[1024]; +}; + +} // namespace + +bool isHDRFrame(const AVFrame* frame) { + return frame->color_trc == AVCOL_TRC_SMPTE2084 || + frame->color_trc == AVCOL_TRC_ARIB_STD_B67; +} + +UniqueAVFrame toneMapHDRFrame(const UniqueAVFrame& src) { + const AVFrame* frame = src.get(); + + const bool isPQ = (frame->color_trc == AVCOL_TRC_SMPTE2084); + const bool isHLG = (frame->color_trc == AVCOL_TRC_ARIB_STD_B67); + STD_TORCH_CHECK( + isPQ || isHLG, + "toneMapHDRFrame: unsupported transfer characteristic: ", + static_cast(frame->color_trc)); + + const AVPixelFormat pixFmt = static_cast(frame->format); + const bool isP010 = (pixFmt == AV_PIX_FMT_P010LE); + const bool isYUV420P10 = (pixFmt == AV_PIX_FMT_YUV420P10LE); + STD_TORCH_CHECK( + isP010 || isYUV420P10, + "toneMapHDRFrame: unsupported pixel format: ", + av_get_pix_fmt_name(pixFmt), + ". Expected yuv420p10le or p010le."); + + const int width = frame->width; + const int height = frame->height; + const bool isLimitedRange = (frame->color_range != AVCOL_RANGE_JPEG); + + // Nominal peak luminance for normalization (matches npl=300 convention). + constexpr double SDR_WHITE = 300.0; + // Hardcoded peak in SDR-relative units (matches peak=4 convention). + // 4.0 × 300 = 1200 nits effective content peak. + constexpr double peakSDR = 4.0; + const double hablePeakInv = 1.0 / hable(peakSDR); + + // For HLG, compute system gamma and OOTF parameters. + // Assume a 1000 nit display by default. + constexpr double HLG_DISPLAY_LW = 1000.0; + const double hlgGamma = + std::max(1.0, 1.2 + 0.42 * std::log10(HLG_DISPLAY_LW / 1000.0)); + + // Allocate output frame in RGB24. + UniqueAVFrame dst(av_frame_alloc()); + STD_TORCH_CHECK(dst != nullptr, "Failed to allocate output AVFrame"); + dst->format = AV_PIX_FMT_RGB24; + dst->width = width; + dst->height = height; + int ret = av_frame_get_buffer(dst.get(), 0); + STD_TORCH_CHECK(ret >= 0, "Failed to allocate output frame buffer"); + + // Tag the output as BT.709 SDR. + dst->colorspace = AVCOL_SPC_BT709; + dst->color_primaries = AVCOL_PRI_BT709; + dst->color_trc = AVCOL_TRC_BT709; + dst->color_range = AVCOL_RANGE_JPEG; // full range RGB + + // Pointers to source planes. + const uint8_t* srcY = frame->data[0]; + const int srcYStride = frame->linesize[0]; + + // For planar (yuv420p10le): U is data[1], V is data[2] + // For semi-planar (p010le): UV interleaved in data[1] + const uint8_t* srcU = frame->data[1]; + const int srcUStride = frame->linesize[1]; + const uint8_t* srcV = isP010 ? nullptr : frame->data[2]; + const int srcVStride = isP010 ? 0 : frame->linesize[2]; + + uint8_t* dstData = dst->data[0]; + const int dstStride = dst->linesize[0]; + + // Limited range 10-bit: Y [64, 940], UV [64, 960] + // Full range 10-bit: Y [0, 1023], UV [0, 1023] + const double yMin = isLimitedRange ? 64.0 : 0.0; + const double yRange = isLimitedRange ? (940.0 - 64.0) : 1023.0; + const double uvMin = isLimitedRange ? 64.0 : 0.0; + const double uvRange = isLimitedRange ? (960.0 - 64.0) : 1023.0; + + // Pre-derive YCbCr → RGB coefficients from BT.2020 NCL + // R' = Y' + (2 - 2*Kr) * Cr + // B' = Y' + (2 - 2*Kb) * Cb + // G' = (Y' - Kr*R' - Kb*B') / Kg + const double crToR = 2.0 * (1.0 - BT2020_KR); + const double cbToB = 2.0 * (1.0 - BT2020_KB); + const double crToG = -2.0 * BT2020_KR * (1.0 - BT2020_KR) / BT2020_KG; + const double cbToG = -2.0 * BT2020_KB * (1.0 - BT2020_KB) / BT2020_KG; + + const PQ_LUT& pqLut = PQ_LUT::instance(); + const HLG_LUT& hlgLut = HLG_LUT::instance(); + + for (int y = 0; y < height; ++y) { + const uint8_t* yRow = srcY + y * srcYStride; + // Chroma is subsampled 2x vertically and horizontally (420). + const int chromaY = y / 2; + const uint8_t* uRow = srcU + chromaY * srcUStride; + const uint8_t* vRow = isP010 ? nullptr : (srcV + chromaY * srcVStride); + + uint8_t* outRow = dstData + y * dstStride; + + for (int x = 0; x < width; ++x) { + // Read 10-bit Y sample. + uint16_t yVal = read10BitSample(yRow + x * 2, isP010); + + // Read 10-bit chroma samples (subsampled). + int chromaX = x / 2; + uint16_t uVal, vVal; + if (isP010) { + // P010: UV interleaved as U0 V0 U1 V1 ... + uVal = read10BitSample(uRow + chromaX * 4, true); + vVal = read10BitSample(uRow + chromaX * 4 + 2, true); + } else { + // Planar: separate U and V planes. + uVal = read10BitSample(uRow + chromaX * 2, false); + vVal = read10BitSample(vRow + chromaX * 2, false); + } + + // Normalize to [0, 1] (Y') and [-0.5, 0.5] (Cb, Cr). + double yNorm = + std::clamp((static_cast(yVal) - yMin) / yRange, 0.0, 1.0); + double cb = (static_cast(uVal) - uvMin) / uvRange - 0.5; + double cr = (static_cast(vVal) - uvMin) / uvRange - 0.5; + + // YCbCr → R'G'B' (non-linear signal, [0, 1]) + double rSignal = std::clamp(yNorm + crToR * cr, 0.0, 1.0); + double gSignal = std::clamp(yNorm + crToG * cr + cbToG * cb, 0.0, 1.0); + double bSignal = std::clamp(yNorm + cbToB * cb, 0.0, 1.0); + + // Linearize using EOTF. + double rLin, gLin, bLin; + if (isPQ) { + // Use LUT: clamp the 10-bit signal to valid index. + int rIdx = + std::clamp(static_cast(std::round(rSignal * 1023.0)), 0, 1023); + int gIdx = + std::clamp(static_cast(std::round(gSignal * 1023.0)), 0, 1023); + int bIdx = + std::clamp(static_cast(std::round(bSignal * 1023.0)), 0, 1023); + // PQ EOTF returns nits; normalize to SDR-relative. + rLin = pqLut[rIdx] / SDR_WHITE; + gLin = pqLut[gIdx] / SDR_WHITE; + bLin = pqLut[bIdx] / SDR_WHITE; + } else { + // HLG: inverse OETF gives scene-linear [0, 1] + int rIdx = + std::clamp(static_cast(std::round(rSignal * 1023.0)), 0, 1023); + int gIdx = + std::clamp(static_cast(std::round(gSignal * 1023.0)), 0, 1023); + int bIdx = + std::clamp(static_cast(std::round(bSignal * 1023.0)), 0, 1023); + rLin = hlgLut[rIdx]; + gLin = hlgLut[gIdx]; + bLin = hlgLut[bIdx]; + + // Apply HLG OOTF: display_linear = Lw * scene_linear * Y^(gamma-1) + double luma = BT2020_KR * rLin + BT2020_KG * gLin + BT2020_KB * bLin; + double ootfScale = + HLG_DISPLAY_LW * std::pow(std::max(luma, 0.0), hlgGamma - 1.0); + rLin = rLin * ootfScale / SDR_WHITE; + gLin = gLin * ootfScale / SDR_WHITE; + bLin = bLin * ootfScale / SDR_WHITE; + } + + // BT.2020 → BT.709 gamut mapping (3x3 matrix on linear RGB). + double r709 = GAMUT_MAP[0][0] * rLin + GAMUT_MAP[0][1] * gLin + + GAMUT_MAP[0][2] * bLin; + double g709 = GAMUT_MAP[1][0] * rLin + GAMUT_MAP[1][1] * gLin + + GAMUT_MAP[1][2] * bLin; + double b709 = GAMUT_MAP[2][0] * rLin + GAMUT_MAP[2][1] * gLin + + GAMUT_MAP[2][2] * bLin; + + // Clamp negatives (out-of-gamut colors). + r709 = std::max(r709, 0.0); + g709 = std::max(g709, 0.0); + b709 = std::max(b709, 0.0); + + // Hable tone mapping. + // We tonemap per-channel using max-RGB to preserve hue, similar to + // FFmpeg's vf_tonemap. + double sig = std::max({r709, g709, b709}); + if (sig > 0.0) { + double mappedSig = hable(sig) * hablePeakInv; + double scale = mappedSig / sig; + r709 *= scale; + g709 *= scale; + b709 *= scale; + } + + // BT.709 OETF (gamma encode) + quantize to uint8. + int rOut = static_cast(std::clamp( + bt709OETF(std::clamp(r709, 0.0, 1.0)) * 255.0 + 0.5, 0.0, 255.0)); + int gOut = static_cast(std::clamp( + bt709OETF(std::clamp(g709, 0.0, 1.0)) * 255.0 + 0.5, 0.0, 255.0)); + int bOut = static_cast(std::clamp( + bt709OETF(std::clamp(b709, 0.0, 1.0)) * 255.0 + 0.5, 0.0, 255.0)); + + outRow[x * 3 + 0] = static_cast(rOut); + outRow[x * 3 + 1] = static_cast(gOut); + outRow[x * 3 + 2] = static_cast(bOut); + } + } + + return dst; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/ToneMap.h b/src/torchcodec/_core/ToneMap.h new file mode 100644 index 000000000..a0b9c9250 --- /dev/null +++ b/src/torchcodec/_core/ToneMap.h @@ -0,0 +1,25 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "FFMPEGCommon.h" + +namespace facebook::torchcodec { + +// Returns true if the AVFrame has HDR transfer characteristics (PQ or HLG). +bool isHDRFrame(const AVFrame* frame); + +// Converts an HDR AVFrame (PQ or HLG, BT.2020) to an SDR AVFrame in RGB24 +// (BT.709). The full pipeline is: +// 1. YUV → RGB using BT.2020 NCL matrix +// 2. PQ EOTF or HLG EOTF (linearization) +// 3. BT.2020 → BT.709 gamut mapping +// 4. Hable tone mapping +// 5. BT.709 OETF + quantize to uint8 RGB24 +UniqueAVFrame toneMapHDRFrame(const UniqueAVFrame& src); + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/_decoder_utils.py b/src/torchcodec/_core/_decoder_utils.py index 4204cf735..5639c4818 100644 --- a/src/torchcodec/_core/_decoder_utils.py +++ b/src/torchcodec/_core/_decoder_utils.py @@ -153,6 +153,7 @@ def create_video_decoder( device_variant: str = "ffmpeg", transforms: Sequence[DecoderTransform | nn.Module] | None = None, custom_frame_mappings: tuple[Tensor, Tensor, Tensor] | None = None, + tone_mapping: str | None = None, ) -> tuple[Tensor, int, VideoStreamMetadata]: decoder = create_decoder(source=source, seek_mode=seek_mode) @@ -178,6 +179,7 @@ def create_video_decoder( device_variant=device_variant, transform_specs=transform_specs, custom_frame_mappings=custom_frame_mappings, + tone_mapping=tone_mapping, ) return (decoder, stream_index, metadata) diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index e521f5372..cea045c7b 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -51,9 +51,9 @@ STABLE_TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor"); m.def( - "_add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", Tensor? custom_frame_mappings_pts=None, Tensor? custom_frame_mappings_duration=None, Tensor? custom_frame_mappings_keyframe_indices=None, str? color_conversion_library=None) -> ()"); + "_add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", Tensor? custom_frame_mappings_pts=None, Tensor? custom_frame_mappings_duration=None, Tensor? custom_frame_mappings_keyframe_indices=None, str? color_conversion_library=None, str? tone_mapping=None) -> ()"); m.def( - "add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", Tensor? custom_frame_mappings_pts=None, Tensor? custom_frame_mappings_duration=None, Tensor? custom_frame_mappings_keyframe_indices=None) -> ()"); + "add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", Tensor? custom_frame_mappings_pts=None, Tensor? custom_frame_mappings_duration=None, Tensor? custom_frame_mappings_keyframe_indices=None, str? tone_mapping=None) -> ()"); m.def( "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()"); m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); @@ -437,7 +437,8 @@ void _add_video_stream( std::nullopt, std::optional custom_frame_mappings_keyframe_indices = std::nullopt, - std::optional color_conversion_library = std::nullopt) { + std::optional color_conversion_library = std::nullopt, + std::optional tone_mapping = std::nullopt) { VideoStreamOptions videoStreamOptions; videoStreamOptions.ffmpegThreadCount = num_threads; @@ -465,6 +466,16 @@ void _add_video_stream( } } + if (tone_mapping.has_value()) { + const std::string& tm = tone_mapping.value(); + STD_TORCH_CHECK( + tm == "hable", + "Invalid tone_mapping=", + tm, + ". Supported values: \"hable\"."); + videoStreamOptions.toneMapping = tm; + } + validateDeviceInterface(device, device_variant); videoStreamOptions.device = StableDevice(std::move(device)); @@ -510,7 +521,8 @@ void add_video_stream( std::optional custom_frame_mappings_duration = std::nullopt, std::optional - custom_frame_mappings_keyframe_indices = std::nullopt) { + custom_frame_mappings_keyframe_indices = std::nullopt, + std::optional tone_mapping = std::nullopt) { _add_video_stream( decoder, num_threads, @@ -521,7 +533,9 @@ void add_video_stream( std::move(transform_specs), std::move(custom_frame_mappings_pts), std::move(custom_frame_mappings_duration), - std::move(custom_frame_mappings_keyframe_indices)); + std::move(custom_frame_mappings_keyframe_indices), + /*color_conversion_library=*/std::nullopt, + std::move(tone_mapping)); } void add_audio_stream( diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index e8b0be548..f20938b07 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -84,6 +84,7 @@ def add_video_stream( custom_frame_mappings: ( tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None ) = None, + tone_mapping: str | None = None, ) -> None: custom_frame_mappings_pts: torch.Tensor | None = None custom_frame_mappings_keyframe_indices: torch.Tensor | None = None @@ -105,6 +106,7 @@ def add_video_stream( custom_frame_mappings_pts=custom_frame_mappings_pts, custom_frame_mappings_duration=custom_frame_mappings_duration, custom_frame_mappings_keyframe_indices=custom_frame_mappings_keyframe_indices, + tone_mapping=tone_mapping, ) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 19b415c2c..09124b16c 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -169,6 +169,7 @@ def __init__( custom_frame_mappings: ( str | bytes | io.RawIOBase | io.BufferedReader | None ) = None, + tone_mapping: str | None = None, ): torch._C._log_api_usage_once("torchcodec.decoders.VideoDecoder") allowed_seek_modes = ("exact", "approximate") @@ -222,6 +223,7 @@ def __init__( device_variant=device_variant, transforms=transforms, custom_frame_mappings=custom_frame_mappings_data, + tone_mapping=tone_mapping, ) assert self.metadata.begin_stream_seconds is not None # mypy.