diff --git a/docs/source/api_ref_torchcodec.rst b/docs/source/api_ref_torchcodec.rst index f6d3fef36..c1ab12b51 100644 --- a/docs/source/api_ref_torchcodec.rst +++ b/docs/source/api_ref_torchcodec.rst @@ -14,4 +14,5 @@ torchcodec Frame FrameBatch + MotionVectorBatch AudioSamples diff --git a/examples/decoding/motion_vectors.py b/examples/decoding/motion_vectors.py new file mode 100644 index 000000000..aac05d4b8 --- /dev/null +++ b/examples/decoding/motion_vectors.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +================================ +Extracting motion vectors (CPU) +================================ + +This example shows how to export compressed-domain motion vectors using +``VideoDecoder``. Motion vectors are returned in a padded tensor along with +per-frame metadata and counts of valid vectors. +""" + +# %% +# Download a sample video (same source as the basic decoding example). +import tempfile + +import requests +import torch + +from torchcodec.decoders import VideoDecoder +from torchcodec.encoders import VideoEncoder + +url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4" +response = requests.get(url, headers={"User-Agent": ""}) +if response.status_code != 200: + raise RuntimeError(f"Failed to download video. {response.status_code = }.") + +raw_video_bytes = response.content + +# %% +# Create a decoder with motion vector export enabled (CPU only). +decoder = VideoDecoder(raw_video_bytes, device="cpu", export_mvs=True) +mvs = decoder.get_motion_vectors_at([0, 1, 2]) + +print(mvs) +print(f"{mvs.data.shape = }") +print(f"{mvs.counts = }") +print(f"{mvs.frame_types = }") + +# %% +# Motion vector fields in each 10-element row. +MV_FIELDS = [ + "source", + "w", + "h", + "src_x", + "src_y", + "dst_x", + "dst_y", + "motion_x", + "motion_y", + "motion_scale", +] + +# %% +# Use counts to slice valid vectors per frame. +frame_index = 0 +count = int(mvs.counts[frame_index]) +valid = mvs.data[frame_index, :count] +print(f"{count = }") +print(f"{valid.shape = }") + +if count > 0: + first_mv = valid[0].tolist() + print(dict(zip(MV_FIELDS, first_mv))) + +# %% +# Frame types are ASCII codes (e.g., 'I', 'P', 'B'). +frame_type_chars = [chr(int(x)) for x in mvs.frame_types] +print(f"{frame_type_chars = }") + +# %% +# Optional: visualize motion vectors over a frame. +# Note: this uses integer rounding for coordinates. For sub-pixel precision, +# scale coordinates or use a rendering backend that supports fixed-point shifts. +try: + import matplotlib.pyplot as plt + from torchvision.transforms.v2.functional import to_pil_image +except ImportError: + print("Cannot plot, please run `pip install torchvision matplotlib`") +else: + plot_index = int(torch.argmax(mvs.counts).item()) + if int(mvs.counts[plot_index]) == 0: + print("No motion vectors available to plot.") + else: + frame = decoder.get_frame_at(plot_index).data + fig, ax = plt.subplots() + ax.imshow(to_pil_image(frame)) + + valid = mvs.data[plot_index, : int(mvs.counts[plot_index])] + for mv in valid: + dst_x, dst_y = int(mv[5]), int(mv[6]) + motion_scale = int(mv[9]) + if motion_scale == 0: + continue + src_x = int(dst_x + mv[7].item() / motion_scale) + src_y = int(dst_y + mv[8].item() / motion_scale) + ax.arrow( + src_x, + src_y, + dst_x - src_x, + dst_y - src_y, + color="red", + width=0.5, + head_width=2.0, + length_includes_head=True, + ) + ax.scatter([dst_x], [dst_y], s=5, c="blue") + + ax.set(xticks=[], yticks=[], title=f"Motion vectors (frame {plot_index})") + plt.tight_layout() + +# %% +# Optional: encode a short video with MV overlays using VideoEncoder. +# This overlay is a simple visualization (integer coordinates, no arrowheads). +def _draw_line(image: torch.Tensor, x0: int, y0: int, x1: int, y1: int): + h, w = image.shape[1], image.shape[2] + x0 = max(0, min(w - 1, x0)) + x1 = max(0, min(w - 1, x1)) + y0 = max(0, min(h - 1, y0)) + y1 = max(0, min(h - 1, y1)) + + dx = abs(x1 - x0) + dy = -abs(y1 - y0) + sx = 1 if x0 < x1 else -1 + sy = 1 if y0 < y1 else -1 + err = dx + dy + + color = torch.tensor([255, 0, 0], dtype=image.dtype) + while True: + image[:, y0, x0] = color + if x0 == x1 and y0 == y1: + break + e2 = 2 * err + if e2 >= dy: + err += dy + x0 += sx + if e2 <= dx: + err += dx + y0 += sy + + +num_overlay_frames = 10 +overlay_frames = decoder.get_frames_in_range(0, num_overlay_frames).data.clone() +overlay_mvs = decoder.get_motion_vectors_at(list(range(num_overlay_frames))) + +max_draw_per_frame = None +for i in range(num_overlay_frames): + count = int(overlay_mvs.counts[i]) + if count == 0: + continue + if max_draw_per_frame is None or count <= max_draw_per_frame: + sample_indices = torch.arange(count) + else: + sample_indices = torch.linspace( + 0, count - 1, steps=max_draw_per_frame + ).round().to(torch.int64) + valid = overlay_mvs.data[i, sample_indices] + for mv in valid: + dst_x, dst_y = int(mv[5]), int(mv[6]) + motion_scale = int(mv[9]) + if motion_scale == 0: + continue + src_x = float(dst_x) + float(mv[7].item()) / motion_scale + src_y = float(dst_y) + float(mv[8].item()) / motion_scale + src_x = int(round(src_x)) + src_y = int(round(src_y)) + _draw_line(overlay_frames[i], src_x, src_y, dst_x, dst_y) + +encoder = VideoEncoder(frames=overlay_frames, frame_rate=decoder.metadata.average_fps) +overlay_path = tempfile.NamedTemporaryFile( + suffix=".mp4", prefix="motion_vectors_overlay_", delete=False +).name +encoder.to_file(overlay_path) +print(f"Wrote {overlay_path}") diff --git a/src/torchcodec/__init__.py b/src/torchcodec/__init__.py index 144d3a67f..539d88e21 100644 --- a/src/torchcodec/__init__.py +++ b/src/torchcodec/__init__.py @@ -8,7 +8,12 @@ # Note: usort wants to put Frame and FrameBatch after decoders and samplers, # but that results in circular import. -from ._frame import AudioSamples, Frame, FrameBatch # usort:skip # noqa +from ._frame import ( + AudioSamples, + Frame, + FrameBatch, + MotionVectorBatch, +) # usort:skip # noqa from . import decoders, encoders, samplers, transforms # noqa try: diff --git a/src/torchcodec/_core/Frame.h b/src/torchcodec/_core/Frame.h index 6a034676f..b1b8eb290 100644 --- a/src/torchcodec/_core/Frame.h +++ b/src/torchcodec/_core/Frame.h @@ -49,6 +49,16 @@ struct FrameBatchOutput { const torch::Device& device); }; +constexpr int kMotionVectorNumFields = 10; + +struct MotionVectorsBatchOutput { + torch::Tensor data; // 3D: of shape (N, max_mvs, kMotionVectorNumFields) + torch::Tensor counts; // 1D of shape (N,) + torch::Tensor ptsSeconds; // 1D of shape (N,) + torch::Tensor durationSeconds; // 1D of shape (N,) + torch::Tensor frameTypes; // 1D of shape (N,) +}; + struct AudioFramesOutput { torch::Tensor data; // shape is (numChannels, numSamples) double ptsSeconds; diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index fa5803d34..e9f2d41cf 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -5,6 +5,8 @@ // LICENSE file in the root directory of this source tree. #include "SingleStreamDecoder.h" +#include +#include #include #include #include @@ -29,6 +31,41 @@ int64_t getPtsOrDts(const UniqueAVFrame& avFrame) { return avFrame->pts == INT64_MIN ? avFrame->pkt_dts : avFrame->pts; } +std::vector extractMotionVectors(const UniqueAVFrame& avFrame) { + AVFrameSideData* sd = + av_frame_get_side_data(avFrame.get(), AV_FRAME_DATA_MOTION_VECTORS); + if (!sd || sd->size == 0) { + return {}; + } + + TORCH_CHECK( + sd->size % sizeof(AVMotionVector) == 0, + "Unexpected motion vectors side data size. Expected a multiple of ", + sizeof(AVMotionVector), + " bytes but got ", + sd->size, + "."); + + const AVMotionVector* mvs = reinterpret_cast(sd->data); + const int32_t numMvs = static_cast(sd->size / sizeof(*mvs)); + std::vector output; + output.reserve(static_cast(numMvs) * kMotionVectorNumFields); + + for (int32_t i = 0; i < numMvs; ++i) { + output.push_back(static_cast(mvs[i].source)); + output.push_back(static_cast(mvs[i].w)); + output.push_back(static_cast(mvs[i].h)); + output.push_back(static_cast(mvs[i].src_x)); + output.push_back(static_cast(mvs[i].src_y)); + output.push_back(static_cast(mvs[i].dst_x)); + output.push_back(static_cast(mvs[i].dst_y)); + output.push_back(static_cast(mvs[i].motion_x)); + output.push_back(static_cast(mvs[i].motion_y)); + output.push_back(static_cast(mvs[i].motion_scale)); + } + return output; +} + } // namespace // -------------------------------------------------------------------------- @@ -415,7 +452,8 @@ void SingleStreamDecoder::addStream( AVMediaType mediaType, const torch::Device& device, const std::string_view deviceVariant, - std::optional ffmpegThreadCount) { + std::optional ffmpegThreadCount, + bool exportMotionVectors) { TORCH_CHECK( activeStreamIndex_ == NO_ACTIVE_STREAM, "Can only add one single stream."); @@ -474,6 +512,10 @@ void SingleStreamDecoder::addStream( streamInfo.codecContext->thread_count = ffmpegThreadCount.value_or(0); streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base; + if (mediaType == AVMEDIA_TYPE_VIDEO && exportMotionVectors) { + streamInfo.codecContext->flags2 |= AV_CODEC_FLAG2_EXPORT_MVS; + } + // Note that we must make sure to register the harware device context // with the codec context before calling avcodec_open2(). Otherwise, decoding // will happen on the CPU and not the hardware device. @@ -516,7 +558,8 @@ void SingleStreamDecoder::addVideoStream( AVMEDIA_TYPE_VIDEO, videoStreamOptions.device, videoStreamOptions.deviceVariant, - videoStreamOptions.ffmpegThreadCount); + videoStreamOptions.ffmpegThreadCount, + videoStreamOptions.exportMotionVectors); auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; @@ -644,6 +687,32 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal( return result; } +UniqueAVFrame SingleStreamDecoder::getAVFrameAtIndexInternal( + int64_t frameIndex) { + validateActiveStream(AVMEDIA_TYPE_VIDEO); + + const auto& streamInfo = streamInfos_[activeStreamIndex_]; + const auto& streamMetadata = + containerMetadata_.allStreamMetadata[activeStreamIndex_]; + + std::optional numFrames = streamMetadata.getNumFrames(seekMode_); + if (numFrames.has_value()) { + frameIndex = frameIndex >= 0 ? frameIndex : frameIndex + numFrames.value(); + } + validateFrameIndex(streamMetadata, frameIndex); + + if (frameIndex != lastDecodedFrameIndex_ + 1) { + int64_t pts = getPts(frameIndex); + setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase)); + } + + UniqueAVFrame avFrame = decodeAVFrame([this](const UniqueAVFrame& avFrame) { + return getPtsOrDts(avFrame) >= cursor_; + }); + lastDecodedFrameIndex_ = frameIndex; + return avFrame; +} + FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( const torch::Tensor& frameIndices) { validateActiveStream(AVMEDIA_TYPE_VIDEO); @@ -710,6 +779,108 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( return frameBatchOutput; } +MotionVectorsBatchOutput SingleStreamDecoder::getMotionVectorsAtIndices( + const torch::Tensor& frameIndices) { + validateActiveStream(AVMEDIA_TYPE_VIDEO); + + const auto& streamInfo = streamInfos_[activeStreamIndex_]; + TORCH_CHECK( + streamInfo.videoStreamOptions.exportMotionVectors, + "Motion vector extraction requires export_mvs=True when creating the decoder."); + + auto frameIndicesAccessor = frameIndices.accessor(); + + bool indicesAreSorted = true; + for (int64_t i = 1; i < frameIndices.numel(); ++i) { + if (frameIndicesAccessor[i] < frameIndicesAccessor[i - 1]) { + indicesAreSorted = false; + break; + } + } + + std::vector argsort; + if (!indicesAreSorted) { + argsort.resize(frameIndices.numel()); + for (size_t i = 0; i < argsort.size(); ++i) { + argsort[i] = i; + } + std::sort( + argsort.begin(), + argsort.end(), + [&frameIndicesAccessor](size_t a, size_t b) { + return frameIndicesAccessor[a] < frameIndicesAccessor[b]; + }); + } + + const int64_t numFrames = frameIndices.numel(); + + MotionVectorsBatchOutput output; + output.counts = torch::zeros( + {numFrames}, torch::dtype(torch::kInt32).device(torch::kCPU)); + output.ptsSeconds = torch::empty( + {numFrames}, torch::dtype(torch::kFloat64).device(torch::kCPU)); + output.durationSeconds = torch::empty( + {numFrames}, torch::dtype(torch::kFloat64).device(torch::kCPU)); + output.frameTypes = torch::empty( + {numFrames}, torch::dtype(torch::kInt32).device(torch::kCPU)); + + auto countsAccessor = output.counts.accessor(); + auto ptsAccessor = output.ptsSeconds.accessor(); + auto durationAccessor = output.durationSeconds.accessor(); + auto frameTypeAccessor = output.frameTypes.accessor(); + + std::vector> motionVectorsPerFrame(numFrames); + int32_t maxMvs = 0; + auto previousIndexInVideo = -1; + for (int64_t f = 0; f < numFrames; ++f) { + auto indexInOutput = indicesAreSorted ? f : argsort[f]; + auto indexInVideo = frameIndicesAccessor[indexInOutput]; + + if ((f > 0) && (indexInVideo == previousIndexInVideo)) { + auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1]; + motionVectorsPerFrame[indexInOutput] = + motionVectorsPerFrame[previousIndexInOutput]; + countsAccessor[indexInOutput] = countsAccessor[previousIndexInOutput]; + ptsAccessor[indexInOutput] = ptsAccessor[previousIndexInOutput]; + durationAccessor[indexInOutput] = durationAccessor[previousIndexInOutput]; + frameTypeAccessor[indexInOutput] = + frameTypeAccessor[previousIndexInOutput]; + maxMvs = std::max(maxMvs, countsAccessor[indexInOutput]); + } else { + UniqueAVFrame avFrame = getAVFrameAtIndexInternal(indexInVideo); + motionVectorsPerFrame[indexInOutput] = extractMotionVectors(avFrame); + int32_t count = static_cast( + motionVectorsPerFrame[indexInOutput].size() / kMotionVectorNumFields); + countsAccessor[indexInOutput] = count; + maxMvs = std::max(maxMvs, count); + ptsAccessor[indexInOutput] = + ptsToSeconds(getPtsOrDts(avFrame), streamInfo.timeBase); + durationAccessor[indexInOutput] = + ptsToSeconds(getDuration(avFrame), streamInfo.timeBase); + frameTypeAccessor[indexInOutput] = + static_cast(av_get_picture_type_char(avFrame->pict_type)); + } + previousIndexInVideo = indexInVideo; + } + + output.data = torch::zeros( + {numFrames, maxMvs, kMotionVectorNumFields}, + torch::dtype(torch::kInt32).device(torch::kCPU)); + auto dataAccessor = output.data.accessor(); + for (int64_t i = 0; i < numFrames; ++i) { + const auto& flat = motionVectorsPerFrame[i]; + const int32_t count = countsAccessor[i]; + for (int32_t mv = 0; mv < count; ++mv) { + int base = mv * kMotionVectorNumFields; + for (int j = 0; j < kMotionVectorNumFields; ++j) { + dataAccessor[i][mv][j] = flat[static_cast(base + j)]; + } + } + } + + return output; +} + FrameBatchOutput SingleStreamDecoder::getFramesInRange( int64_t start, int64_t stop, diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 8457b4d21..23c09bd28 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -115,6 +115,10 @@ class SingleStreamDecoder { // Tensor. FrameBatchOutput getFramesAtIndices(const torch::Tensor& frameIndices); + // Returns motion vectors and metadata for frames at the given indices. + MotionVectorsBatchOutput getMotionVectorsAtIndices( + const torch::Tensor& frameIndices); + // Returns frames within a given range. The range is defined by [start, stop). // The values retrieved from the range are: [start, start+step, // start+(2*step), start+(3*step), ..., stop). The default for step is 1. @@ -177,6 +181,8 @@ class SingleStreamDecoder { int64_t frameIndex, std::optional preAllocatedOutputTensor = std::nullopt); + UniqueAVFrame getAVFrameAtIndexInternal(int64_t frameIndex); + // Exposed for _test_frame_pts_equality, which is used to test non-regression // of pts resolution (64 to 32 bit floats) double getPtsSecondsForFrame(int64_t frameIndex); @@ -305,7 +311,8 @@ class SingleStreamDecoder { AVMediaType mediaType, const torch::Device& device = torch::kCPU, const std::string_view deviceVariant = "ffmpeg", - std::optional ffmpegThreadCount = std::nullopt); + std::optional ffmpegThreadCount = std::nullopt, + bool exportMotionVectors = false); // Returns the "best" stream index for a given media type. The "best" is // determined by various heuristics in FFMPEG. diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index 5ed7bd6c5..cb3f06699 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -47,6 +47,9 @@ struct VideoStreamOptions { // Device variant (e.g., "ffmpeg", "beta", etc.) std::string_view deviceVariant = "ffmpeg"; + // If true, request FFmpeg to export motion vectors as side data. + bool exportMotionVectors = false; + // Encoding options std::optional codec; // Optional pixel format for video encoding (e.g., "yuv420p", "yuv444p") diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index add9efa90..548ff9a4b 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -40,6 +40,7 @@ get_frames_by_pts_in_range_audio, get_frames_in_range, get_json_metadata, + get_motion_vectors_at_indices, get_next_frame, scan_all_streams_to_update_metadata, seek_to_pts, diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index e35f62388..8f94e95d9 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -48,9 +48,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor"); m.def( - "_add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()"); + "_add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", bool? export_mvs=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()"); m.def( - "add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()"); + "add_video_stream(Tensor(a!) decoder, *, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str device=\"cpu\", str device_variant=\"ffmpeg\", str transform_specs=\"\", bool? export_mvs=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None) -> ()"); m.def( "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()"); m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); @@ -61,6 +61,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { "get_frame_at_index(Tensor(a!) decoder, *, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_at_indices(Tensor(a!) decoder, *, Tensor frame_indices) -> (Tensor, Tensor, Tensor)"); + m.def( + "get_motion_vectors_at_indices(Tensor(a!) decoder, *, Tensor frame_indices) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( @@ -145,6 +147,23 @@ OpsFrameBatchOutput makeOpsFrameBatchOutput(FrameBatchOutput& batch) { return std::make_tuple(batch.data, batch.ptsSeconds, batch.durationSeconds); } +using OpsMotionVectorsBatchOutput = std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor>; + +OpsMotionVectorsBatchOutput makeOpsMotionVectorsBatchOutput( + MotionVectorsBatchOutput& batch) { + return std::make_tuple( + batch.data, + batch.counts, + batch.ptsSeconds, + batch.durationSeconds, + batch.frameTypes); +} + // The elements of this tuple are all tensors that represent the concatenation // of multiple audio frames: // 1. The frames data (concatenated) @@ -421,6 +440,7 @@ void _add_video_stream( std::string_view device = "cpu", std::string_view device_variant = "ffmpeg", std::string_view transform_specs = "", + std::optional export_mvs = std::nullopt, std::optional> custom_frame_mappings = std::nullopt, std::optional color_conversion_library = std::nullopt) { @@ -453,6 +473,7 @@ void _add_video_stream( videoStreamOptions.device = torch::Device(std::string(device)); videoStreamOptions.deviceVariant = device_variant; + videoStreamOptions.exportMotionVectors = export_mvs.value_or(false); std::vector transforms = makeTransforms(std::string(transform_specs)); @@ -478,6 +499,7 @@ void add_video_stream( std::string_view device = "cpu", std::string_view device_variant = "ffmpeg", std::string_view transform_specs = "", + std::optional export_mvs = std::nullopt, const std::optional< std::tuple>& custom_frame_mappings = std::nullopt) { @@ -489,6 +511,7 @@ void add_video_stream( device, device_variant, transform_specs, + export_mvs, custom_frame_mappings); } @@ -555,6 +578,14 @@ OpsFrameBatchOutput get_frames_at_indices( return makeOpsFrameBatchOutput(result); } +OpsMotionVectorsBatchOutput get_motion_vectors_at_indices( + torch::Tensor& decoder, + const torch::Tensor& frame_indices) { + auto videoDecoder = unwrapTensorToGetDecoder(decoder); + auto result = videoDecoder->getMotionVectorsAtIndices(frame_indices); + return makeOpsMotionVectorsBatchOutput(result); +} + // Return the frames inside a range as a single stacked Tensor. The range is // defined as [start, stop). OpsFrameBatchOutput get_frames_in_range( @@ -1082,6 +1113,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("get_frame_at_pts", &get_frame_at_pts); m.impl("get_frame_at_index", &get_frame_at_index); m.impl("get_frames_at_indices", &get_frames_at_indices); + m.impl("get_motion_vectors_at_indices", &get_motion_vectors_at_indices); m.impl("get_frames_in_range", &get_frames_in_range); m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range); m.impl("get_frames_by_pts_in_range_audio", &get_frames_by_pts_in_range_audio); diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 3188dfc7b..d33496daa 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -148,6 +148,9 @@ def expose_ffmpeg_dlls(): # noqa: F811 _get_frames_at_indices_tensor_input = ( torch.ops.torchcodec_ns.get_frames_at_indices.default ) +_get_motion_vectors_at_indices_tensor_input = ( + torch.ops.torchcodec_ns.get_motion_vectors_at_indices.default +) _get_frames_by_pts_tensor_input = torch.ops.torchcodec_ns.get_frames_by_pts.default get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default @@ -277,10 +280,42 @@ def get_frames_at_indices( frame_indices = frame_indices.to(torch.int64) else: # Convert list to tensor for dispatch - frame_indices = torch.tensor(frame_indices) + if isinstance(frame_indices, (list, tuple)) and len(frame_indices) == 0: + frame_indices = torch.empty((0,), dtype=torch.int64) + else: + frame_indices = torch.tensor(frame_indices) + if ( + frame_indices.dtype != torch.int64 + and frame_indices.dtype != torch.bool + and not frame_indices.dtype.is_floating_point + and not frame_indices.dtype.is_complex + ): + frame_indices = frame_indices.to(torch.int64) return _get_frames_at_indices_tensor_input(decoder, frame_indices=frame_indices) +def get_motion_vectors_at_indices( + decoder: torch.Tensor, *, frame_indices: torch.Tensor | list[int] +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if isinstance(frame_indices, torch.Tensor): + frame_indices = frame_indices.to(torch.int64) + else: + if isinstance(frame_indices, (list, tuple)) and len(frame_indices) == 0: + frame_indices = torch.empty((0,), dtype=torch.int64) + else: + frame_indices = torch.tensor(frame_indices) + if ( + frame_indices.dtype != torch.int64 + and frame_indices.dtype != torch.bool + and not frame_indices.dtype.is_floating_point + and not frame_indices.dtype.is_complex + ): + frame_indices = frame_indices.to(torch.int64) + return _get_motion_vectors_at_indices_tensor_input( + decoder, frame_indices=frame_indices + ) + + def get_frames_by_pts( decoder: torch.Tensor, *, timestamps: torch.Tensor | list[float] ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -408,6 +443,7 @@ def _add_video_stream_abstract( device: str = "cpu", device_variant: str = "ffmpeg", transform_specs: str = "", + export_mvs: bool | None = None, custom_frame_mappings: ( tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None ) = None, @@ -426,6 +462,7 @@ def add_video_stream_abstract( device: str = "cpu", device_variant: str = "ffmpeg", transform_specs: str = "", + export_mvs: bool | None = None, custom_frame_mappings: ( tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None ) = None, @@ -513,6 +550,21 @@ def get_frames_at_indices_abstract( ) +@register_fake("torchcodec_ns::get_motion_vectors_at_indices") +def get_motion_vectors_at_indices_abstract( + decoder: torch.Tensor, *, frame_indices: torch.Tensor | list[int] +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + mv_size = [get_ctx().new_dynamic_size() for _ in range(3)] + batch_size = [get_ctx().new_dynamic_size()] + return ( + torch.empty(mv_size, dtype=torch.int32), + torch.empty(batch_size, dtype=torch.int32), + torch.empty(batch_size, dtype=torch.float64), + torch.empty(batch_size, dtype=torch.float64), + torch.empty(batch_size, dtype=torch.int32), + ) + + @register_fake("torchcodec_ns::get_frames_in_range") def get_frames_in_range_abstract( decoder: torch.Tensor, diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index 2ceb890b7..2130b82d8 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -117,6 +117,101 @@ def __repr__(self): return _frame_repr(self) +@dataclass +class MotionVectorBatch(Iterable): + """Motion vectors and metadata for multiple video frames. + + The ``data`` tensor is padded to shape ``(N, max_mvs, 10)``. Each motion + vector has 10 int32 fields in this order: + 1) source + 2) w + 3) h + 4) src_x + 5) src_y + 6) dst_x + 7) dst_y + 8) motion_x + 9) motion_y + 10) motion_scale + + The motion components follow the FFmpeg convention. In particular, + ``dst_x`` and ``dst_y`` refer to the block position in the current frame, + and ``src_x`` and ``src_y`` refer to the corresponding block position in + the reference frame identified by ``source``. Coordinates can be + sub-pixel, represented via the fixed-point fields ``motion_x``, + ``motion_y``, and ``motion_scale``: + ``src_x = dst_x + motion_x / motion_scale`` and + ``src_y = dst_y + motion_y / motion_scale``. + + All tensors are on CPU. ``data``, ``counts``, and ``frame_types`` are int32. + ``pts_seconds`` and ``duration_seconds`` are float64. + + .. note:: + + If none of the requested frames contain motion vectors, ``max_mvs`` can + be 0 and ``data`` will have shape ``(N, 0, 10)``. + """ + + data: Tensor + """Motion vectors as a 3D tensor (N, max_mvs, 10) of int32.""" + counts: Tensor + """Number of motion vectors per frame (1D tensor of int32).""" + pts_seconds: Tensor + """PTS for each frame in seconds (1D tensor of float64).""" + duration_seconds: Tensor + """Duration for each frame in seconds (1D tensor of float64).""" + frame_types: Tensor + """Frame types as ASCII codes (1D tensor of int32).""" + + def __post_init__(self): + if self.data.ndim != 3: + raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }") + leading_dim = self.data.shape[0] + for name, tensor in ( + ("counts", self.counts), + ("pts_seconds", self.pts_seconds), + ("duration_seconds", self.duration_seconds), + ("frame_types", self.frame_types), + ): + if tensor.ndim != 1 or tensor.shape[0] != leading_dim: + raise ValueError( + f"{name} must be 1D with length {leading_dim}, got {tensor.shape = }." + ) + + def __iter__(self) -> Iterator["MotionVectorBatch"]: + for i in range(len(self)): + yield self[i] + + def __getitem__(self, key) -> "MotionVectorBatch": + data = self.data[key] + counts = self.counts[key] + pts_seconds = self.pts_seconds[key] + duration_seconds = self.duration_seconds[key] + frame_types = self.frame_types[key] + + if data.ndim == 2: + data = data.unsqueeze(0) + if counts.ndim == 0: + counts = counts.unsqueeze(0) + pts_seconds = pts_seconds.unsqueeze(0) + duration_seconds = duration_seconds.unsqueeze(0) + frame_types = frame_types.unsqueeze(0) + + return MotionVectorBatch( + data=data, + counts=counts, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + frame_types=frame_types, + ) + + def __len__(self): + return len(self.data) + + def __repr__(self): + return _frame_repr(self) + + @dataclass class AudioSamples(Iterable): """Audio samples with associated metadata.""" diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index c1bf72cbc..81473a4e3 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -15,7 +15,7 @@ import torch from torch import device as torch_device, nn, Tensor -from torchcodec import _core as core, Frame, FrameBatch +from torchcodec import _core as core, Frame, FrameBatch, MotionVectorBatch from torchcodec.decoders._decoder_utils import ( _get_cuda_backend, create_decoder, @@ -127,6 +127,10 @@ class VideoDecoder: :class:`~torchcodec.transforms.DecoderTransform` and :class:`~torchvision.transforms.v2.Transform` objects. Read more about this parameter in: TODO_DECODER_TRANSFORMS_TUTORIAL. + export_mvs (bool, optional): If True, request FFmpeg to export motion + vectors as side data. Required for + :meth:`~torchcodec.decoders.VideoDecoder.get_motion_vectors_at`. + CPU-only. Default: False. custom_frame_mappings (str, bytes, or file-like object, optional): Mapping of frames to their metadata, typically generated via ffprobe. This enables accurate frame seeking without requiring a full video scan. @@ -170,6 +174,7 @@ def __init__( device: str | torch_device | None = None, seek_mode: Literal["exact", "approximate"] = "exact", transforms: Sequence[DecoderTransform | nn.Module] | None = None, + export_mvs: bool = False, custom_frame_mappings: ( str | bytes | io.RawIOBase | io.BufferedReader | None ) = None, @@ -224,6 +229,9 @@ def __init__( elif isinstance(device, torch_device): device = str(device) + if export_mvs and not device.startswith("cpu"): + raise ValueError("export_mvs is only supported for CPU decoding.") + device_variant = _get_cuda_backend() transform_specs = _make_transform_specs( transforms, @@ -238,8 +246,10 @@ def __init__( device=device, device_variant=device_variant, transform_specs=transform_specs, + export_mvs=export_mvs, custom_frame_mappings=custom_frame_mappings_data, ) + self._export_mvs = export_mvs self._cpu_fallback = CpuFallbackStatus() if device.startswith("cuda"): @@ -375,6 +385,53 @@ def get_frames_at(self, indices: torch.Tensor | list[int]) -> FrameBatch: duration_seconds=duration_seconds, ) + def get_motion_vectors_at( + self, indices: torch.Tensor | list[int] + ) -> MotionVectorBatch: + """Return motion vectors at the given indices. + + .. note:: + + This requires ``export_mvs=True`` when creating the decoder. + Motion vectors are returned as a padded tensor of shape + ``(N, max_mvs, 10)``, along with a ``counts`` vector that indicates + how many vectors are valid per frame. Each motion vector has 10 + int32 fields in this order: source, w, h, src_x, src_y, dst_x, + dst_y, motion_x, motion_y, motion_scale. All outputs are on CPU. + ``data``, ``counts``, and ``frame_types`` are int32, and + ``pts_seconds`` and ``duration_seconds`` are float64. The motion + components follow the FFmpeg convention: + ``src_x = dst_x + motion_x / motion_scale`` and + ``src_y = dst_y + motion_y / motion_scale``. + This API is CPU-only. + + If none of the requested frames contain motion vectors, ``max_mvs`` + can be 0 and ``data`` will have shape ``(N, 0, 10)``. + + Args: + indices (torch.Tensor or list of int): The indices of the frames to retrieve. + + Returns: + MotionVectorBatch: Motion vectors and per-frame metadata. + """ + + if not self._export_mvs: + raise RuntimeError( + "Motion vector extraction requires export_mvs=True when creating the decoder." + ) + + data, counts, pts_seconds, duration_seconds, frame_types = ( + core.get_motion_vectors_at_indices(self._decoder, frame_indices=indices) + ) + + return MotionVectorBatch( + data=data, + counts=counts, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + frame_types=frame_types, + ) + def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatch: """Return multiple frames at the given index range. diff --git a/test/test_decoders.py b/test/test_decoders.py index de7ec9d6c..2fe4fed35 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -11,7 +11,7 @@ import numpy import pytest import torch -from torchcodec import _core, ffmpeg_major_version, FrameBatch +from torchcodec import _core, ffmpeg_major_version, FrameBatch, MotionVectorBatch from torchcodec.decoders import ( AudioDecoder, AudioStreamMetadata, @@ -601,6 +601,92 @@ def test_get_frames_at(self, device, seek_mode): frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0 ) + def test_get_frames_at_empty_indices(self): + decoder = VideoDecoder(NASA_VIDEO.path, device="cpu") + frames = decoder.get_frames_at([]) + + assert isinstance(frames, FrameBatch) + assert frames.data.shape[0] == 0 + assert frames.pts_seconds.shape == (0,) + assert frames.duration_seconds.shape == (0,) + + def test_get_motion_vectors_at(self): + decoder = VideoDecoder(NASA_VIDEO.path, device="cpu", export_mvs=True) + motion_vectors = decoder.get_motion_vectors_at([0, 1, 2]) + + assert isinstance(motion_vectors, MotionVectorBatch) + assert motion_vectors.data.shape[0] == 3 + assert motion_vectors.data.shape[2] == 10 + assert motion_vectors.counts.shape == (3,) + assert motion_vectors.pts_seconds.shape == (3,) + assert motion_vectors.duration_seconds.shape == (3,) + assert motion_vectors.frame_types.shape == (3,) + assert motion_vectors.data.dtype == torch.int32 + assert motion_vectors.counts.dtype == torch.int32 + assert motion_vectors.frame_types.dtype == torch.int32 + assert motion_vectors.pts_seconds.dtype == torch.float64 + assert motion_vectors.duration_seconds.dtype == torch.float64 + assert int(motion_vectors.counts.max()) <= motion_vectors.data.shape[1] + + def test_get_motion_vectors_at_requires_export(self): + decoder = VideoDecoder(NASA_VIDEO.path) + with pytest.raises(RuntimeError, match="export_mvs=True"): + decoder.get_motion_vectors_at([0]) + + def test_get_motion_vectors_at_empty_indices(self): + decoder = VideoDecoder(NASA_VIDEO.path, device="cpu", export_mvs=True) + motion_vectors = decoder.get_motion_vectors_at([]) + + assert isinstance(motion_vectors, MotionVectorBatch) + assert motion_vectors.data.shape == (0, 0, 10) + assert motion_vectors.counts.shape == (0,) + assert motion_vectors.pts_seconds.shape == (0,) + assert motion_vectors.duration_seconds.shape == (0,) + assert motion_vectors.frame_types.shape == (0,) + assert motion_vectors.data.dtype == torch.int32 + assert motion_vectors.counts.dtype == torch.int32 + assert motion_vectors.frame_types.dtype == torch.int32 + assert motion_vectors.pts_seconds.dtype == torch.float64 + assert motion_vectors.duration_seconds.dtype == torch.float64 + + def test_get_motion_vectors_at_unsorted_and_duplicate_indices(self): + decoder = VideoDecoder(NASA_VIDEO.path, device="cpu", export_mvs=True) + indices = [2, 0, 2, 1] + batch = decoder.get_motion_vectors_at(indices) + + assert isinstance(batch, MotionVectorBatch) + assert batch.data.shape[0] == len(indices) + assert batch.data.shape[2] == 10 + + for pos, idx in enumerate(indices): + single = decoder.get_motion_vectors_at([idx]) + assert int(batch.counts[pos]) == int(single.counts[0]) + assert int(batch.frame_types[pos]) == int(single.frame_types[0]) + torch.testing.assert_close( + batch.pts_seconds[pos], single.pts_seconds[0], atol=1e-6, rtol=0 + ) + torch.testing.assert_close( + batch.duration_seconds[pos], + single.duration_seconds[0], + atol=1e-6, + rtol=0, + ) + + count = int(batch.counts[pos]) + assert count <= batch.data.shape[1] + if count > 0: + torch.testing.assert_close( + batch.data[pos, :count], single.data[0, :count] + ) + if batch.data.shape[1] > count: + assert torch.all(batch.data[pos, count:] == 0) + + torch.testing.assert_close(batch.data[0], batch.data[2]) + assert int(batch.counts[0]) == int(batch.counts[2]) + assert int(batch.frame_types[0]) == int(batch.frame_types[2]) + torch.testing.assert_close(batch.pts_seconds[0], batch.pts_seconds[2]) + torch.testing.assert_close(batch.duration_seconds[0], batch.duration_seconds[2]) + @pytest.mark.parametrize("device", all_supported_devices()) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_get_frames_at_fails(self, device, seek_mode): diff --git a/test/test_frame_dataclasses.py b/test/test_frame_dataclasses.py index 003e49530..2fe7955bf 100644 --- a/test/test_frame_dataclasses.py +++ b/test/test_frame_dataclasses.py @@ -1,6 +1,6 @@ import pytest import torch -from torchcodec import AudioSamples, Frame, FrameBatch +from torchcodec import AudioSamples, Frame, FrameBatch, MotionVectorBatch def test_unpacking(): @@ -144,6 +144,44 @@ def test_framebatch_indexing(): assert fb_fancy.data.shape == (1, C, H, W) +def test_motion_vector_batch_error(): + with pytest.raises(ValueError, match="data must be 3-dimensional"): + MotionVectorBatch( + data=torch.rand(2, 3), + counts=torch.zeros(2, dtype=torch.int32), + pts_seconds=torch.zeros(2), + duration_seconds=torch.zeros(2), + frame_types=torch.zeros(2, dtype=torch.int32), + ) + + with pytest.raises(ValueError, match="counts must be 1D"): + MotionVectorBatch( + data=torch.zeros(2, 1, 10), + counts=torch.zeros(2, 1, dtype=torch.int32), + pts_seconds=torch.zeros(2), + duration_seconds=torch.zeros(2), + frame_types=torch.zeros(2, dtype=torch.int32), + ) + + with pytest.raises(ValueError, match="frame_types must be 1D"): + MotionVectorBatch( + data=torch.zeros(2, 1, 10), + counts=torch.zeros(2, dtype=torch.int32), + pts_seconds=torch.zeros(2), + duration_seconds=torch.zeros(2), + frame_types=torch.zeros(2, 1, dtype=torch.int32), + ) + + with pytest.raises(ValueError, match="counts must be 1D with length 2"): + MotionVectorBatch( + data=torch.zeros(2, 1, 10), + counts=torch.zeros(3, dtype=torch.int32), + pts_seconds=torch.zeros(2), + duration_seconds=torch.zeros(2), + frame_types=torch.zeros(2, dtype=torch.int32), + ) + + def test_audio_samples_error(): with pytest.raises(ValueError, match="data must be 2-dimensional"): AudioSamples(