diff --git a/src/torchcodec/_core/DeviceInterface.cpp b/src/torchcodec/_core/DeviceInterface.cpp index d26380180..4b6e6f89b 100644 --- a/src/torchcodec/_core/DeviceInterface.cpp +++ b/src/torchcodec/_core/DeviceInterface.cpp @@ -22,9 +22,9 @@ DeviceInterfaceMap& getDeviceMap() { return deviceMap; } -std::string getDeviceTypeString(const std::string& device) { +std::string_view getDeviceTypeString(std::string_view device) { size_t pos = device.find(':'); - if (pos == std::string::npos) { + if (pos == std::string_view::npos) { return device; } return device.substr(0, pos); @@ -34,7 +34,7 @@ std::string getDeviceTypeString(const std::string& device) { // TODO_STABLE_ABI: we might need to support more device types, i.e. those from // https://github.com/pytorch/pytorch/blob/main/torch/headeronly/core/DeviceType.h // Ideally we'd remove this helper? -StableDeviceType parseDeviceType(const std::string& deviceType) { +StableDeviceType parseDeviceType(std::string_view deviceType) { if (deviceType == "cpu") { return kStableCPU; } else if (deviceType == "cuda") { @@ -67,10 +67,10 @@ bool registerDeviceInterface( } void validateDeviceInterface( - const std::string& device, - const std::string& variant) { + std::string_view device, + std::string_view variant) { std::scoped_lock lock(g_interface_mutex); - std::string deviceType = getDeviceTypeString(device); + std::string_view deviceType = getDeviceTypeString(device); DeviceInterfaceMap& deviceMap = getDeviceMap(); @@ -98,7 +98,7 @@ void validateDeviceInterface( std::unique_ptr createDeviceInterface( const StableDevice& device, - const std::string_view variant) { + std::string_view variant) { DeviceInterfaceKey key(device.type(), variant); std::scoped_lock lock(g_interface_mutex); DeviceInterfaceMap& deviceMap = getDeviceMap(); diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 6b5388d26..9c2d4aea0 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -10,6 +10,7 @@ #include #include #include +#include #include "FFMPEGCommon.h" #include "Frame.h" #include "StableABICompat.h" @@ -21,7 +22,9 @@ namespace facebook::torchcodec { // Key for device interface registration with device type + variant support struct DeviceInterfaceKey { StableDeviceType deviceType; - std::string_view variant = "ffmpeg"; // e.g., "ffmpeg", "beta", etc. + // This key is stored in the global device-interface registry, so it must own + // its variant string. + std::string variant = "ffmpeg"; // e.g., "ffmpeg", "beta", etc. bool operator<(const DeviceInterfaceKey& other) const { if (deviceType != other.deviceType) { @@ -32,7 +35,7 @@ struct DeviceInterfaceKey { explicit DeviceInterfaceKey(StableDeviceType type) : deviceType(type) {} - DeviceInterfaceKey(StableDeviceType type, const std::string_view& variant) + DeviceInterfaceKey(StableDeviceType type, std::string_view variant) : deviceType(type), variant(variant) {} }; @@ -177,12 +180,12 @@ TORCHCODEC_THIRD_PARTY_API bool registerDeviceInterface( const CreateDeviceInterfaceFn createInterface); FORCE_PUBLIC_VISIBILITY void validateDeviceInterface( - const std::string& device, - const std::string& variant); + std::string_view device, + std::string_view variant); std::unique_ptr createDeviceInterface( const StableDevice& device, - const std::string_view variant = "ffmpeg"); + std::string_view variant = "ffmpeg"); torch::stable::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame); diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 2d2a512d4..e979d9cfb 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -135,13 +135,13 @@ AudioEncoder::~AudioEncoder() { AudioEncoder::AudioEncoder( const torch::stable::Tensor& samples, int sampleRate, - std::string_view fileName, + const std::string& fileName, const AudioStreamOptions& audioStreamOptions) : samples_(validateSamples(samples)), inSampleRate_(sampleRate) { setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; int status = avformat_alloc_output_context2( - &avFormatContext, nullptr, nullptr, fileName.data()); + &avFormatContext, nullptr, nullptr, fileName.c_str()); STD_TORCH_CHECK( avFormatContext != nullptr, @@ -152,7 +152,7 @@ AudioEncoder::AudioEncoder( getFFMPEGErrorStringFromErrorCode(status)); avFormatContext_.reset(avFormatContext); - status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); + status = avio_open(&avFormatContext_->pb, fileName.c_str(), AVIO_FLAG_WRITE); STD_TORCH_CHECK( status >= 0, "avio_open failed. The destination file is ", @@ -166,7 +166,7 @@ AudioEncoder::AudioEncoder( AudioEncoder::AudioEncoder( const torch::stable::Tensor& samples, int sampleRate, - std::string_view formatName, + const std::string& formatName, std::unique_ptr avioContextHolder, const AudioStreamOptions& audioStreamOptions) : samples_(validateSamples(samples)), @@ -175,7 +175,7 @@ AudioEncoder::AudioEncoder( setFFmpegLogLevel(); AVFormatContext* avFormatContext = nullptr; int status = avformat_alloc_output_context2( - &avFormatContext, nullptr, formatName.data(), nullptr); + &avFormatContext, nullptr, formatName.c_str(), nullptr); STD_TORCH_CHECK( avFormatContext != nullptr, @@ -684,7 +684,7 @@ VideoEncoder::~VideoEncoder() { VideoEncoder::VideoEncoder( const torch::stable::Tensor& frames, double frameRate, - std::string_view fileName, + const std::string& fileName, const VideoStreamOptions& videoStreamOptions) : frames_(validateFrames(frames)), inFrameRate_(frameRate) { setFFmpegLogLevel(); @@ -692,7 +692,7 @@ VideoEncoder::VideoEncoder( // Allocate output format context AVFormatContext* avFormatContext = nullptr; int status = avformat_alloc_output_context2( - &avFormatContext, nullptr, nullptr, fileName.data()); + &avFormatContext, nullptr, nullptr, fileName.c_str()); STD_TORCH_CHECK( avFormatContext != nullptr, @@ -703,7 +703,7 @@ VideoEncoder::VideoEncoder( getFFMPEGErrorStringFromErrorCode(status)); avFormatContext_.reset(avFormatContext); - status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE); + status = avio_open(&avFormatContext_->pb, fileName.c_str(), AVIO_FLAG_WRITE); STD_TORCH_CHECK( status >= 0, "avio_open failed. The destination file is ", @@ -716,7 +716,7 @@ VideoEncoder::VideoEncoder( VideoEncoder::VideoEncoder( const torch::stable::Tensor& frames, double frameRate, - std::string_view formatName, + const std::string& formatName, std::unique_ptr avioContextHolder, const VideoStreamOptions& videoStreamOptions) : frames_(validateFrames(frames)), @@ -724,10 +724,11 @@ VideoEncoder::VideoEncoder( avioContextHolder_(std::move(avioContextHolder)) { setFFmpegLogLevel(); // Map mkv -> matroska when used as format name - formatName = (formatName == "mkv") ? "matroska" : formatName; + const std::string normalizedFormatName = + (formatName == "mkv") ? "matroska" : formatName; AVFormatContext* avFormatContext = nullptr; int status = avformat_alloc_output_context2( - &avFormatContext, nullptr, formatName.data(), nullptr); + &avFormatContext, nullptr, normalizedFormatName.c_str(), nullptr); STD_TORCH_CHECK( avFormatContext != nullptr, diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 5969f8753..793eaa10e 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -19,13 +19,13 @@ class FORCE_PUBLIC_VISIBILITY AudioEncoder { AudioEncoder( const torch::stable::Tensor& samples, int sampleRate, - std::string_view fileName, + const std::string& fileName, const AudioStreamOptions& audioStreamOptions); AudioEncoder( const torch::stable::Tensor& samples, int sampleRate, - std::string_view formatName, + const std::string& formatName, std::unique_ptr avioContextHolder, const AudioStreamOptions& audioStreamOptions); @@ -144,13 +144,13 @@ class FORCE_PUBLIC_VISIBILITY VideoEncoder { VideoEncoder( const torch::stable::Tensor& frames, double frameRate, - std::string_view fileName, + const std::string& fileName, const VideoStreamOptions& videoStreamOptions); VideoEncoder( const torch::stable::Tensor& frames, double frameRate, - std::string_view formatName, + const std::string& formatName, std::unique_ptr avioContextHolder, const VideoStreamOptions& videoStreamOptions); diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index fedd97f61..a5f83b14a 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -440,7 +440,7 @@ void SingleStreamDecoder::addStream( int streamIndex, AVMediaType mediaType, const StableDevice& device, - const std::string_view deviceVariant, + std::string_view deviceVariant, std::optional ffmpegThreadCount) { STD_TORCH_CHECK( activeStreamIndex_ == NO_ACTIVE_STREAM, diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 33264e545..c0e26f35c 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -318,7 +318,7 @@ class FORCE_PUBLIC_VISIBILITY SingleStreamDecoder { int streamIndex, AVMediaType mediaType, const StableDevice& device = StableDevice(kStableCPU), - const std::string_view deviceVariant = "ffmpeg", + std::string_view deviceVariant = "ffmpeg", std::optional ffmpegThreadCount = std::nullopt); // Returns the "best" stream index for a given media type. The "best" is diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index 6cab3c8e8..9780dbe05 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -9,7 +9,6 @@ #include #include #include -#include #include "StableABICompat.h" namespace facebook::torchcodec { @@ -44,8 +43,9 @@ struct VideoStreamOptions { // Note: This is not used for video encoding, because device is determined by // the device of the input frame tensor. StableDevice device = StableDevice(kStableCPU); - // Device variant (e.g., "ffmpeg", "beta", etc.) - std::string_view deviceVariant = "ffmpeg"; + // Device variant is stored in StreamInfo and reused after the op call + // returns, so it must own its contents. + std::string deviceVariant = "ffmpeg"; // Encoding options std::optional codec; diff --git a/test/VideoDecoderTest.cpp b/test/VideoDecoderTest.cpp index 346679f75..5aa3e4233 100644 --- a/test/VideoDecoderTest.cpp +++ b/test/VideoDecoderTest.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #ifdef FBCODE_BUILD #include "tools/cxx/Resources.h" @@ -27,6 +28,12 @@ C10_DEFINE_bool( namespace facebook::torchcodec { +static_assert( + std::is_same_v); +static_assert(std::is_same_v< + decltype(DeviceInterfaceKey(kStableCPU).variant), + std::string>); + inline torch::stable::Tensor toStableTensor(const torch::Tensor& tensor) { torch::Tensor* p = new torch::Tensor(tensor); return torch::stable::Tensor(reinterpret_cast(p));