Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/torchcodec/_core/DeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ DeviceInterfaceMap& getDeviceMap() {
return deviceMap;
}

std::string getDeviceTypeString(const std::string& device) {
std::string_view getDeviceTypeString(std::string_view device) {
size_t pos = device.find(':');
if (pos == std::string::npos) {
if (pos == std::string_view::npos) {
return device;
}
return device.substr(0, pos);
Expand All @@ -34,7 +34,7 @@ std::string getDeviceTypeString(const std::string& device) {
// TODO_STABLE_ABI: we might need to support more device types, i.e. those from
// https://github.com/pytorch/pytorch/blob/main/torch/headeronly/core/DeviceType.h
// Ideally we'd remove this helper?
StableDeviceType parseDeviceType(const std::string& deviceType) {
StableDeviceType parseDeviceType(std::string_view deviceType) {
if (deviceType == "cpu") {
return kStableCPU;
} else if (deviceType == "cuda") {
Expand Down Expand Up @@ -67,10 +67,10 @@ bool registerDeviceInterface(
}

void validateDeviceInterface(
const std::string& device,
const std::string& variant) {
std::string_view device,
std::string_view variant) {
std::scoped_lock lock(g_interface_mutex);
std::string deviceType = getDeviceTypeString(device);
std::string_view deviceType = getDeviceTypeString(device);

DeviceInterfaceMap& deviceMap = getDeviceMap();

Expand Down Expand Up @@ -98,7 +98,7 @@ void validateDeviceInterface(

std::unique_ptr<DeviceInterface> createDeviceInterface(
const StableDevice& device,
const std::string_view variant) {
std::string_view variant) {
DeviceInterfaceKey key(device.type(), variant);
std::scoped_lock lock(g_interface_mutex);
DeviceInterfaceMap& deviceMap = getDeviceMap();
Expand Down
13 changes: 8 additions & 5 deletions src/torchcodec/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <memory>
#include <stdexcept>
#include <string>
#include <string_view>
#include "FFMPEGCommon.h"
#include "Frame.h"
#include "StableABICompat.h"
Expand All @@ -21,7 +22,9 @@ namespace facebook::torchcodec {
// Key for device interface registration with device type + variant support
struct DeviceInterfaceKey {
StableDeviceType deviceType;
std::string_view variant = "ffmpeg"; // e.g., "ffmpeg", "beta", etc.
// This key is stored in the global device-interface registry, so it must own
// its variant string.
std::string variant = "ffmpeg"; // e.g., "ffmpeg", "beta", etc.

bool operator<(const DeviceInterfaceKey& other) const {
if (deviceType != other.deviceType) {
Expand All @@ -32,7 +35,7 @@ struct DeviceInterfaceKey {

explicit DeviceInterfaceKey(StableDeviceType type) : deviceType(type) {}

DeviceInterfaceKey(StableDeviceType type, const std::string_view& variant)
DeviceInterfaceKey(StableDeviceType type, std::string_view variant)
: deviceType(type), variant(variant) {}
};

Expand Down Expand Up @@ -177,12 +180,12 @@ TORCHCODEC_THIRD_PARTY_API bool registerDeviceInterface(
const CreateDeviceInterfaceFn createInterface);

FORCE_PUBLIC_VISIBILITY void validateDeviceInterface(
const std::string& device,
const std::string& variant);
std::string_view device,
std::string_view variant);

std::unique_ptr<DeviceInterface> createDeviceInterface(
const StableDevice& device,
const std::string_view variant = "ffmpeg");
std::string_view variant = "ffmpeg");

torch::stable::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame);

Expand Down
23 changes: 12 additions & 11 deletions src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,13 @@ AudioEncoder::~AudioEncoder() {
AudioEncoder::AudioEncoder(
const torch::stable::Tensor& samples,
int sampleRate,
std::string_view fileName,
const std::string& fileName,
const AudioStreamOptions& audioStreamOptions)
: samples_(validateSamples(samples)), inSampleRate_(sampleRate) {
setFFmpegLogLevel();
AVFormatContext* avFormatContext = nullptr;
int status = avformat_alloc_output_context2(
&avFormatContext, nullptr, nullptr, fileName.data());
&avFormatContext, nullptr, nullptr, fileName.c_str());

STD_TORCH_CHECK(
avFormatContext != nullptr,
Expand All @@ -152,7 +152,7 @@ AudioEncoder::AudioEncoder(
getFFMPEGErrorStringFromErrorCode(status));
avFormatContext_.reset(avFormatContext);

status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
status = avio_open(&avFormatContext_->pb, fileName.c_str(), AVIO_FLAG_WRITE);
STD_TORCH_CHECK(
status >= 0,
"avio_open failed. The destination file is ",
Expand All @@ -166,7 +166,7 @@ AudioEncoder::AudioEncoder(
AudioEncoder::AudioEncoder(
const torch::stable::Tensor& samples,
int sampleRate,
std::string_view formatName,
const std::string& formatName,
std::unique_ptr<AVIOContextHolder> avioContextHolder,
const AudioStreamOptions& audioStreamOptions)
: samples_(validateSamples(samples)),
Expand All @@ -175,7 +175,7 @@ AudioEncoder::AudioEncoder(
setFFmpegLogLevel();
AVFormatContext* avFormatContext = nullptr;
int status = avformat_alloc_output_context2(
&avFormatContext, nullptr, formatName.data(), nullptr);
&avFormatContext, nullptr, formatName.c_str(), nullptr);

STD_TORCH_CHECK(
avFormatContext != nullptr,
Expand Down Expand Up @@ -684,15 +684,15 @@ VideoEncoder::~VideoEncoder() {
VideoEncoder::VideoEncoder(
const torch::stable::Tensor& frames,
double frameRate,
std::string_view fileName,
const std::string& fileName,
const VideoStreamOptions& videoStreamOptions)
: frames_(validateFrames(frames)), inFrameRate_(frameRate) {
setFFmpegLogLevel();

// Allocate output format context
AVFormatContext* avFormatContext = nullptr;
int status = avformat_alloc_output_context2(
&avFormatContext, nullptr, nullptr, fileName.data());
&avFormatContext, nullptr, nullptr, fileName.c_str());

STD_TORCH_CHECK(
avFormatContext != nullptr,
Expand All @@ -703,7 +703,7 @@ VideoEncoder::VideoEncoder(
getFFMPEGErrorStringFromErrorCode(status));
avFormatContext_.reset(avFormatContext);

status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
status = avio_open(&avFormatContext_->pb, fileName.c_str(), AVIO_FLAG_WRITE);
STD_TORCH_CHECK(
status >= 0,
"avio_open failed. The destination file is ",
Expand All @@ -716,18 +716,19 @@ VideoEncoder::VideoEncoder(
VideoEncoder::VideoEncoder(
const torch::stable::Tensor& frames,
double frameRate,
std::string_view formatName,
const std::string& formatName,
std::unique_ptr<AVIOContextHolder> avioContextHolder,
const VideoStreamOptions& videoStreamOptions)
: frames_(validateFrames(frames)),
inFrameRate_(frameRate),
avioContextHolder_(std::move(avioContextHolder)) {
setFFmpegLogLevel();
// Map mkv -> matroska when used as format name
formatName = (formatName == "mkv") ? "matroska" : formatName;
const std::string normalizedFormatName =
(formatName == "mkv") ? "matroska" : formatName;
AVFormatContext* avFormatContext = nullptr;
int status = avformat_alloc_output_context2(
&avFormatContext, nullptr, formatName.data(), nullptr);
&avFormatContext, nullptr, normalizedFormatName.c_str(), nullptr);

STD_TORCH_CHECK(
avFormatContext != nullptr,
Expand Down
8 changes: 4 additions & 4 deletions src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ class FORCE_PUBLIC_VISIBILITY AudioEncoder {
AudioEncoder(
const torch::stable::Tensor& samples,
int sampleRate,
std::string_view fileName,
const std::string& fileName,
const AudioStreamOptions& audioStreamOptions);

AudioEncoder(
const torch::stable::Tensor& samples,
int sampleRate,
std::string_view formatName,
const std::string& formatName,
std::unique_ptr<AVIOContextHolder> avioContextHolder,
const AudioStreamOptions& audioStreamOptions);

Expand Down Expand Up @@ -144,13 +144,13 @@ class FORCE_PUBLIC_VISIBILITY VideoEncoder {
VideoEncoder(
const torch::stable::Tensor& frames,
double frameRate,
std::string_view fileName,
const std::string& fileName,
const VideoStreamOptions& videoStreamOptions);

VideoEncoder(
const torch::stable::Tensor& frames,
double frameRate,
std::string_view formatName,
const std::string& formatName,
std::unique_ptr<AVIOContextHolder> avioContextHolder,
const VideoStreamOptions& videoStreamOptions);

Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ void SingleStreamDecoder::addStream(
int streamIndex,
AVMediaType mediaType,
const StableDevice& device,
const std::string_view deviceVariant,
std::string_view deviceVariant,
std::optional<int> ffmpegThreadCount) {
STD_TORCH_CHECK(
activeStreamIndex_ == NO_ACTIVE_STREAM,
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ class FORCE_PUBLIC_VISIBILITY SingleStreamDecoder {
int streamIndex,
AVMediaType mediaType,
const StableDevice& device = StableDevice(kStableCPU),
const std::string_view deviceVariant = "ffmpeg",
std::string_view deviceVariant = "ffmpeg",
std::optional<int> ffmpegThreadCount = std::nullopt);

// Returns the "best" stream index for a given media type. The "best" is
Expand Down
6 changes: 3 additions & 3 deletions src/torchcodec/_core/StreamOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <map>
#include <optional>
#include <string>
#include <string_view>
#include "StableABICompat.h"

namespace facebook::torchcodec {
Expand Down Expand Up @@ -44,8 +43,9 @@ struct VideoStreamOptions {
// Note: This is not used for video encoding, because device is determined by
// the device of the input frame tensor.
StableDevice device = StableDevice(kStableCPU);
// Device variant (e.g., "ffmpeg", "beta", etc.)
std::string_view deviceVariant = "ffmpeg";
// Device variant is stored in StreamInfo and reused after the op call
// returns, so it must own its contents.
std::string deviceVariant = "ffmpeg";

// Encoding options
std::optional<std::string> codec;
Expand Down
7 changes: 7 additions & 0 deletions test/VideoDecoderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <filesystem>
#include <fstream>
#include <iostream>
#include <type_traits>

#ifdef FBCODE_BUILD
#include "tools/cxx/Resources.h"
Expand All @@ -27,6 +28,12 @@ C10_DEFINE_bool(

namespace facebook::torchcodec {

static_assert(
std::is_same_v<decltype(VideoStreamOptions{}.deviceVariant), std::string>);
static_assert(std::is_same_v<
decltype(DeviceInterfaceKey(kStableCPU).variant),
std::string>);

inline torch::stable::Tensor toStableTensor(const torch::Tensor& tensor) {
torch::Tensor* p = new torch::Tensor(tensor);
return torch::stable::Tensor(reinterpret_cast<AtenTensorHandle>(p));
Expand Down