From f4fc92eee795606e9d4a1f5919f9c0f26bfb9e44 Mon Sep 17 00:00:00 2001 From: xytpai Date: Thu, 8 Jan 2026 03:18:27 +0000 Subject: [PATCH 01/22] add rocjpeg support --- setup.py | 17 +- test/test_image.py | 16 +- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 308 +++++++++++++++++- .../csrc/io/image/cuda/decode_jpegs_cuda.h | 32 ++ 4 files changed, 361 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index 6181007924e..5ad744e5061 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" +USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "1") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) # Note: the GPU video decoding stuff used to be called "video codec", which # isn't an accurate or descriptive name considering there are at least 2 other @@ -52,6 +53,7 @@ print(f"{USE_JPEG = }") print(f"{USE_WEBP = }") print(f"{USE_NVJPEG = }") +print(f"{USE_ROCJPEG = }") print(f"{NVCC_FLAGS = }") print(f"{USE_CPU_VIDEO_DECODER = }") print(f"{USE_GPU_VIDEO_DECODER = }") @@ -350,18 +352,23 @@ def make_image_extension(): else: warnings.warn("Building torchvision without WEBP support") - if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): + if (USE_NVJPEG or USE_ROCJPEG) and (torch.cuda.is_available() or FORCE_CUDA): nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() - + rocjpeg_found = ROCM_HOME is not None and (Path(ROCM_HOME) / "include/rocjpeg/rocjpeg.h").exists() if nvjpeg_found: print("Building torchvision with NVJPEG image support") libraries.append("nvjpeg") define_macros += [("NVJPEG_FOUND", 1)] Extension = CUDAExtension + elif rocjpeg_found: + print("Building torchvision with ROCJPEG image support") + libraries.append("rocjpeg") + define_macros += [("ROCJPEG_FOUND", 1)] + Extension = CUDAExtension else: - warnings.warn("Building torchvision without NVJPEG support") - elif USE_NVJPEG: - warnings.warn("Building torchvision without NVJPEG support") + warnings.warn("Building torchvision without NVJPEG or ROCJPEG support") + elif (USE_NVJPEG or USE_ROCJPEG): + warnings.warn("Building torchvision without NVJPEG or ROCJPEG support") return Extension( name="torchvision.image", diff --git a/test/test_image.py b/test/test_image.py index b11dd67ca12..e30b5695241 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -406,8 +406,10 @@ def test_read_interlaced_png(): @needs_cuda -@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) -@pytest.mark.parametrize("scripted", (False, True)) +# @pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) +@pytest.mark.parametrize("mode", [ImageReadMode.RGB]) +# @pytest.mark.parametrize("scripted", (False, True)) +@pytest.mark.parametrize("scripted", (False, )) def test_decode_jpegs_cuda(mode, scripted): encoded_images = [] for jpeg_path in get_images(IMAGE_ROOT, ".jpg"): @@ -415,15 +417,17 @@ def test_decode_jpegs_cuda(mode, scripted): continue encoded_image = read_file(jpeg_path) encoded_images.append(encoded_image) + encoded_images = encoded_images[:3] + # encoded_images = [encoded_images[0], encoded_images[2], encoded_images[1]] decoded_images_cpu = decode_jpeg(encoded_images, mode=mode) decode_fn = torch.jit.script(decode_jpeg) if scripted else decode_jpeg # test multithreaded decoding # in the current version we prevent this by using a lock but we still want to test it - num_workers = 10 + num_workers = 1 with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)] + futures = [executor.submit(decode_fn, encoded_images, mode, "cuda:0") for _ in range(num_workers)] decoded_images_threaded = [future.result() for future in futures] assert len(decoded_images_threaded) == num_workers for decoded_images in decoded_images_threaded: @@ -431,7 +435,9 @@ def test_decode_jpegs_cuda(mode, scripted): for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu): assert decoded_image_cuda.shape == decoded_image_cpu.shape assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8 - assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2 + print(decoded_image_cuda.contiguous()) + print(decoded_image_cpu.contiguous().cpu()) + assert (decoded_image_cuda.contiguous().cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 5 @needs_cuda diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 85aa6c760c1..9afb18abf32 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -1,5 +1,5 @@ #include "decode_jpegs_cuda.h" -#if !NVJPEG_FOUND +#if !NVJPEG_FOUND && !ROCJPEG_FOUND namespace vision { namespace image { std::vector decode_jpegs_cuda( @@ -11,8 +11,9 @@ std::vector decode_jpegs_cuda( } } // namespace image } // namespace vision +#endif -#else +#if NVJPEG_FOUND #include #include #include @@ -600,3 +601,306 @@ std::vector CUDAJpegDecoder::decode_images( } // namespace vision #endif + +#if ROCJPEG_FOUND + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace vision { +namespace image { + +std::mutex decoderMutex; +std::unique_ptr rocJpegDecoder; + +std::vector decode_jpegs_cuda( + const std::vector& encoded_images, + vision::image::ImageReadMode mode, + torch::Device device) { + C10_LOG_API_USAGE_ONCE( + "torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda"); + + std::lock_guard lock(decoderMutex); + std::vector contig_images; + contig_images.reserve(encoded_images.size()); + + TORCH_CHECK( + device.is_cuda(), "Expected the device parameter to be a cuda device"); + + for (auto& encoded_image : encoded_images) { + TORCH_CHECK( + encoded_image.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + + TORCH_CHECK( + !encoded_image.is_cuda(), + "The input tensor must be on CPU when decoding with nvjpeg") + + TORCH_CHECK( + encoded_image.dim() == 1 && encoded_image.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + // nvjpeg requires images to be contiguous + if (encoded_image.is_contiguous()) { + contig_images.push_back(encoded_image); + } else { + contig_images.push_back(encoded_image.contiguous()); + } + } + + at::cuda::CUDAGuard device_guard(device); + + if (rocJpegDecoder == nullptr || device != rocJpegDecoder->target_device) { + if (rocJpegDecoder != nullptr) { + rocJpegDecoder.reset(new RocJpegDecoder(device)); + } else { + rocJpegDecoder = std::make_unique(device); + std::atexit([]() { rocJpegDecoder.reset(); }); + } + } + + RocJpegOutputFormat output_format; + + switch (mode) { + case vision::image::IMAGE_READ_MODE_UNCHANGED: + output_format = ROCJPEG_OUTPUT_NATIVE; + break; + case vision::image::IMAGE_READ_MODE_GRAY: + output_format = ROCJPEG_OUTPUT_Y; + break; + case vision::image::IMAGE_READ_MODE_RGB: + output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + break; + default: + TORCH_CHECK( + false, "The provided mode is not supported for JPEG decoding on GPU"); + } + + try { + at::cuda::CUDAEvent event; + auto result = rocJpegDecoder->decode_images(contig_images, output_format); + auto current_stream{ + device.has_index() ? at::cuda::getCurrentCUDAStream( + rocJpegDecoder->original_device.index()) + : at::cuda::getCurrentCUDAStream()}; + event.record(rocJpegDecoder->stream); + event.block(current_stream); + return result; + } catch (const std::exception& e) { + if (typeid(e) != typeid(std::runtime_error)) { + TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); + } else { + throw; + } + } +} + +RocJpegDecoder::RocJpegDecoder(const torch::Device& target_device) + : original_device{torch::kCUDA, c10::cuda::current_device()}, + target_device{target_device}, + stream{ + target_device.has_index() + ? at::cuda::getStreamFromPool(false, target_device.index()) + : at::cuda::getStreamFromPool(false)} { + int device_id = target_device.index(); + CHECK_HIP(hipSetDevice(device_id)); + RocJpegStatus status; + RocJpegBackend rocjpeg_backend = ROCJPEG_BACKEND_HARDWARE; + + status = rocJpegCreate(rocjpeg_backend, device_id, &rocjpeg_handle); + TORCH_CHECK( + status == ROCJPEG_STATUS_SUCCESS, + "Failed to initialize rocjpeg with hardware backend"); + + status = rocJpegStreamCreate(&rocjpeg_stream_handles[0]); + TORCH_CHECK( + status == ROCJPEG_STATUS_SUCCESS, "Failed to initialize rocjpeg stream"); + + status = rocJpegStreamCreate(&rocjpeg_stream_handles[1]); + TORCH_CHECK( + status == ROCJPEG_STATUS_SUCCESS, "Failed to initialize rocjpeg stream"); +} + +RocJpegDecoder::~RocJpegDecoder() { + rocJpegDestroy(rocjpeg_handle); + rocJpegStreamDestroy(rocjpeg_stream_handles[0]); + rocJpegStreamDestroy(rocjpeg_stream_handles[1]); +} + +static inline int align(int value, int alignment) { + return (value + alignment - 1) & ~(alignment - 1); +} + +std::vector RocJpegDecoder::decode_images( + const std::vector& encoded_images, + const RocJpegOutputFormat& output_format) { + /* + This function decodes a batch of jpeg bitstreams. + + Args: + - encoded_images (std::vector): a vector of tensors + containing the jpeg bitstreams to be decoded + - output_format (RocJpegOutputFormat): ROCJPEG_OUTPUT_RGB, ROCJPEG_OUTPUT_Y + or ROCJPEG_OUTPUT_NATIVE + - device (torch::Device): The desired CUDA device for the returned Tensors + + Returns: + - output_tensors (std::vector): a vector of Tensors + containing the decoded images + */ + + int num_images = encoded_images.size(); + std::vector output_tensors{num_images}; + RocJpegStatus rocjpeg_status; + cudaError_t cudaStatus; + + // baseline JPEGs can be batch decoded with hardware support + std::vector channels(num_images); + + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK( + cudaStatus == cudaSuccess, + "Failed to synchronize CUDA stream: ", + cudaStatus); + + constexpr int batch_size = 2; + RocJpegUtils rocjpeg_utils; + std::string chroma_sub_sampling = ""; + uint8_t num_components; + RocJpegChromaSubsampling temp_subsampling; + std::vector temp_widths(ROCJPEG_MAX_COMPONENT, 0); + std::vector temp_heights(ROCJPEG_MAX_COMPONENT, 0); + RocJpegDecodeParams decode_params = {}; + decode_params.output_format = output_format; + std::vector decode_params_batch; + decode_params_batch.resize(batch_size, decode_params); + std::vector output_images; + output_images.resize(batch_size); + int current_batch_size = 0; + uint32_t channel_sizes[ROCJPEG_MAX_COMPONENT] = {}; + uint32_t num_channels = 0; + std::vector> prior_channel_sizes; + prior_channel_sizes.resize( + batch_size, std::vector(ROCJPEG_MAX_COMPONENT, 0)); + + for (int i = 0; i < num_images; i += batch_size) { + int batch_end = std::min(i + batch_size, num_images); + for (int j = i; j < batch_end; j++) { + int index = j - i; + rocjpeg_status = rocJpegStreamParse( + (unsigned char*)encoded_images[j].data_ptr(), + encoded_images[j].numel(), + rocjpeg_stream_handles[index]); + if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { + TORCH_CHECK( + false, + "ERROR: Failed to parse the input jpeg stream with ", + rocJpegGetErrorName(rocjpeg_status)); + } + CHECK_ROCJPEG(rocJpegGetImageInfo( + rocjpeg_handle, + rocjpeg_stream_handles[index], + &num_components, + &temp_subsampling, + temp_widths.data(), + temp_heights.data())); + rocjpeg_utils.GetChromaSubsamplingStr( + temp_subsampling, chroma_sub_sampling); + if (temp_widths[0] < 64 || temp_heights[0] < 64) { + TORCH_CHECK( + false, "The image resolution is not supported by VCN Hardware"); + } + if (temp_subsampling == ROCJPEG_CSS_411 || + temp_subsampling == ROCJPEG_CSS_UNKNOWN) { + TORCH_CHECK( + false, "The chroma sub-sampling is not supported by VCN Hardware"); + } + if (rocjpeg_utils.GetChannelPitchAndSizes( + decode_params_batch[index], + temp_subsampling, + temp_widths.data(), + temp_heights.data(), + num_channels, + output_images[index], + channel_sizes)) { + TORCH_CHECK(false, "ERROR: Failed to get the channel pitch and sizes"); + } + + uint32_t roi_width = decode_params_batch[index].crop_rectangle.right - decode_params_batch[index].crop_rectangle.left; + uint32_t roi_height = decode_params_batch[index].crop_rectangle.bottom - decode_params_batch[index].crop_rectangle.top; + bool is_roi_valid = (roi_width > 0 && roi_height > 0 && roi_width <= temp_widths[0] && roi_height <= temp_heights[0]) ? true : false; + std::cout << "is_roi_valid: " << is_roi_valid << "\n"; + uint32_t width = is_roi_valid ? align(roi_width, 16) : align(temp_widths[0], 16); + uint32_t height = is_roi_valid ? align(roi_height, 16) : align(temp_heights[0], 16); + auto output_tensor = torch::zeros( + {int64_t(num_channels), + int64_t(height), + int64_t(width)}, + torch::dtype(torch::kU8).device(target_device)); + channels[j] = num_channels; + + // for (int n = 0; n < (int)num_channels; n++) { + // output_images[current_batch_size].channel[n] = + // output_tensor[n].data_ptr(); + // } + + // allocate memory for each channel and reuse them if the sizes remain + // unchanged for a new image. + for (int c = 0; c < (int)num_channels; c++) { + output_images[index].channel[c] = output_tensor[c].data_ptr(); + } + // for (int c = (int)num_channels; c < ROCJPEG_MAX_COMPONENT; c++) { + // output_images[index].channel[c] = NULL; + // output_images[index].pitch[c] = 0; + // } + // output_tensors[j] = output_tensor; // output_tensor.narrow(1, 0, temp_heights[0]).narrow(2, 0, temp_widths[0]); + current_batch_size++; + output_tensors[j] = output_tensor.narrow(1, 0, temp_heights[0]).narrow(2, 0, temp_widths[0]); + } + + // if (current_batch_size == 2) { + if (current_batch_size > 0) { + CHECK_ROCJPEG(rocJpegDecodeBatched( + rocjpeg_handle, + rocjpeg_stream_handles, + current_batch_size, + decode_params_batch.data(), + output_images.data())); + } + + current_batch_size = 0; + } + + cudaStatus = cudaStreamSynchronize(stream); + TORCH_CHECK( + cudaStatus == cudaSuccess, + "Failed to synchronize CUDA stream: ", + cudaStatus); + + // prune extraneous channels from single channel images + if (output_format == ROCJPEG_OUTPUT_NATIVE) { + for (std::vector::size_type i = 0; i < output_tensors.size(); + ++i) { + if (channels[i] == 1) { + output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); + } + } + } + + cudaDeviceSynchronize(); + return output_tensors; +} + +} // namespace image +} // namespace vision + +#endif diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 6f72d9e35b2..a052e5e354e 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -4,6 +4,7 @@ #include "../common.h" #if NVJPEG_FOUND + #include #include @@ -42,4 +43,35 @@ class CUDAJpegDecoder { }; } // namespace image } // namespace vision + +#endif + +#if ROCJPEG_FOUND + +#include +#include +#include "rocjpeg_samples_utils.h" + +namespace vision { +namespace image { +class RocJpegDecoder { + public: + RocJpegDecoder(const torch::Device& target_device); + ~RocJpegDecoder(); + + std::vector decode_images( + const std::vector& encoded_images, + const RocJpegOutputFormat& output_format); + + const torch::Device original_device; + const torch::Device target_device; + const c10::cuda::CUDAStream stream; + + private: + RocJpegStreamHandle rocjpeg_stream_handles[2]; + RocJpegHandle rocjpeg_handle; +}; +} // namespace image +} // namespace vision + #endif From a371c3e57b4c838f3a707dc08e91c90d6e969f5c Mon Sep 17 00:00:00 2001 From: xytpai Date: Thu, 8 Jan 2026 03:21:08 +0000 Subject: [PATCH 02/22] update rocjpeg utils --- .../io/image/cuda/rocjpeg_samples_utils.h | 567 ++++++++++++++++++ 1 file changed, 567 insertions(+) create mode 100644 torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h diff --git a/torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h b/torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h new file mode 100644 index 00000000000..3a9595dcc4f --- /dev/null +++ b/torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h @@ -0,0 +1,567 @@ +/* +Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ +#ifndef ROC_JPEG_SAMPLES_COMMON +#define ROC_JPEG_SAMPLES_COMMON +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if __cplusplus >= 201703L && __has_include() + #include + namespace fs = std::filesystem; +#else + #include + namespace fs = std::experimental::filesystem; +#endif +#include +#include "rocjpeg/rocjpeg.h" + +#define CHECK_ROCJPEG(call) { \ + RocJpegStatus rocjpeg_status = (call); \ + if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { \ + std::cerr << #call << " returned " << rocJpegGetErrorName(rocjpeg_status) << " at " << __FILE__ << ":" << __LINE__ << std::endl;\ + exit(1); \ + } \ +} + +#define CHECK_HIP(call) { \ + hipError_t hip_status = (call); \ + if (hip_status != hipSuccess) { \ + std::cout << "HIP failure: 'status: " << hipGetErrorName(hip_status) << "' at " << __FILE__ << ":" << __LINE__ << std::endl;\ + exit(1); \ + } \ +} + +/** + * @class RocJpegUtils + * @brief Utility class for rocJPEG samples. + * + * This class provides utility functions for rocJPEG samples, such as parsing command line arguments, + * getting file paths, initializing HIP device, getting chroma subsampling string, getting channel pitch and sizes, + * getting output file extension, and saving images. + */ +class RocJpegUtils { +public: + /** + * @brief Parses the command line arguments. + * + * This function parses the command line arguments and sets the corresponding variables. + * + * @param input_path The input path. + * @param output_file_path The output file path. + * @param save_images Flag indicating whether to save images. + * @param device_id The device ID. + * @param rocjpeg_backend The rocJPEG backend. + * @param decode_params The rocJPEG decode parameters. + * @param num_threads The number of threads. + * @param crop The crop rectangle. + * @param argc The number of command line arguments. + * @param argv The command line arguments. + */ + static void ParseCommandLine(std::string &input_path, std::string &output_file_path, bool &save_images, int &device_id, + RocJpegBackend &rocjpeg_backend, RocJpegDecodeParams &decode_params, int *num_threads, int *batch_size, int argc, char *argv[]) { + if(argc <= 1) { + ShowHelpAndExit("", num_threads != nullptr, batch_size != nullptr); + } + for (int i = 1; i < argc; i++) { + if (!strcmp(argv[i], "-h")) { + ShowHelpAndExit("", num_threads != nullptr, batch_size != nullptr); + } + if (!strcmp(argv[i], "-i")) { + if (++i == argc) { + ShowHelpAndExit("-i", num_threads != nullptr, batch_size != nullptr); + } + input_path = argv[i]; + continue; + } + if (!strcmp(argv[i], "-o")) { + if (++i == argc) { + ShowHelpAndExit("-o", num_threads != nullptr, batch_size != nullptr); + } + output_file_path = argv[i]; + save_images = true; + continue; + } + if (!strcmp(argv[i], "-d")) { + if (++i == argc) { + ShowHelpAndExit("-d", num_threads != nullptr, batch_size != nullptr); + } + device_id = atoi(argv[i]); + continue; + } + if (!strcmp(argv[i], "-be")) { + if (++i == argc) { + ShowHelpAndExit("-be", num_threads != nullptr, batch_size != nullptr); + } + rocjpeg_backend = static_cast(atoi(argv[i])); + continue; + } + if (!strcmp(argv[i], "-fmt")) { + if (++i == argc) { + ShowHelpAndExit("-fmt", num_threads != nullptr, batch_size != nullptr); + } + std::string selected_output_format = argv[i]; + if (selected_output_format == "native") { + decode_params.output_format = ROCJPEG_OUTPUT_NATIVE; + } else if (selected_output_format == "yuv_planar") { + decode_params.output_format = ROCJPEG_OUTPUT_YUV_PLANAR; + } else if (selected_output_format == "y") { + decode_params.output_format = ROCJPEG_OUTPUT_Y; + } else if (selected_output_format == "rgb") { + decode_params.output_format = ROCJPEG_OUTPUT_RGB; + } else if (selected_output_format == "rgb_planar") { + decode_params.output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + } else { + ShowHelpAndExit(argv[i], num_threads != nullptr); + } + continue; + } + if (!strcmp(argv[i], "-t")) { + if (++i == argc) { + ShowHelpAndExit("-t", num_threads != nullptr, batch_size != nullptr); + } + if (num_threads != nullptr) { + *num_threads = atoi(argv[i]); + if (*num_threads <= 0 || *num_threads > 32) { + ShowHelpAndExit(argv[i], num_threads != nullptr, batch_size != nullptr); + } + } + continue; + } + if (!strcmp(argv[i], "-b")) { + if (++i == argc) { + ShowHelpAndExit("-b", num_threads != nullptr, batch_size != nullptr); + } + if (batch_size != nullptr) + *batch_size = atoi(argv[i]); + continue; + } + if (!strcmp(argv[i], "-crop")) { + if (++i == argc || 4 != sscanf(argv[i], "%hd,%hd,%hd,%hd", &decode_params.crop_rectangle.left, &decode_params.crop_rectangle.top, &decode_params.crop_rectangle.right, &decode_params.crop_rectangle.bottom)) { + ShowHelpAndExit("-crop"); + } + if ((&decode_params.crop_rectangle.right - &decode_params.crop_rectangle.left) % 2 == 1 || (&decode_params.crop_rectangle.bottom - &decode_params.crop_rectangle.top) % 2 == 1) { + std::cout << "output crop rectangle must have width and height of even numbers" << std::endl; + exit(1); + } + continue; + } + ShowHelpAndExit(argv[i], num_threads != nullptr, batch_size != nullptr); + } + } + + /** + * Checks if a file is a JPEG file. + * + * @param filePath The path to the file to be checked. + * @return True if the file is a JPEG file, false otherwise. + */ + static bool IsJPEG(const std::string& filePath) { + std::ifstream file(filePath, std::ios::binary); + if (!file.is_open()) { + std::cerr << "Failed to open file: " << filePath << std::endl; + return false; + } + + unsigned char buffer[2]; + file.read(reinterpret_cast(buffer), 2); + file.close(); + + // The first two bytes of every JPEG stream are always 0xFFD8, which represents the Start of Image (SOI) marker. + return buffer[0] == 0xFF && buffer[1] == 0xD8; + } + + /** + * @brief Gets the file paths. + * + * This function gets the file paths based on the input path and sets the corresponding variables. + * + * @param input_path The input path. + * @param file_paths The vector to store the file paths. + * @param is_dir Flag indicating whether the input path is a directory. + * @param is_file Flag indicating whether the input path is a file. + * @return True if successful, false otherwise. + */ + static bool GetFilePaths(std::string &input_path, std::vector &file_paths, bool &is_dir, bool &is_file) { + std::cout << "Reading images from disk, please wait!" << std::endl; + if (!fs::exists(input_path)) { + std::cerr << "ERROR: the input path does not exist!" << std::endl; + return false; + } + is_dir = fs::is_directory(input_path); + is_file = fs::is_regular_file(input_path); + if (is_dir) { + for (const auto &entry : fs::recursive_directory_iterator(input_path)) { + if (fs::is_regular_file(entry) && IsJPEG(entry.path().string())) { + file_paths.push_back(entry.path().string()); + } + } + } else if (is_file && IsJPEG(input_path)) { + file_paths.push_back(input_path); + } else { + std::cerr << "ERROR: the input path does not contain JPEG files!" << std::endl; + return false; + } + return true; + } + + /** + * @brief Initializes the HIP device. + * + * This function initializes the HIP device with the specified device ID. + * + * @param device_id The device ID. + * @return True if successful, false otherwise. + */ + static bool InitHipDevice(int device_id) { + int num_devices; + hipDeviceProp_t hip_dev_prop; + CHECK_HIP(hipGetDeviceCount(&num_devices)); + if (num_devices < 1) { + std::cerr << "ERROR: didn't find any GPU!" << std::endl; + return false; + } + if (device_id >= num_devices) { + std::cerr << "ERROR: the requested device_id is not found!" << std::endl; + return false; + } + CHECK_HIP(hipSetDevice(device_id)); + CHECK_HIP(hipGetDeviceProperties(&hip_dev_prop, device_id)); + + std::cout << "Using GPU device " << device_id << ": " << hip_dev_prop.name << "[" << hip_dev_prop.gcnArchName << "] on PCI bus " << + std::setfill('0') << std::setw(2) << std::right << std::hex << hip_dev_prop.pciBusID << ":" << std::setfill('0') << std::setw(2) << + std::right << std::hex << hip_dev_prop.pciDomainID << "." << hip_dev_prop.pciDeviceID << std::dec << std::endl; + + return true; + } + + /** + * @brief Gets the chroma subsampling string. + * + * This function gets the chroma subsampling string based on the specified subsampling value. + * + * @param subsampling The chroma subsampling value. + * @param chroma_sub_sampling The string to store the chroma subsampling. + */ + void GetChromaSubsamplingStr(RocJpegChromaSubsampling subsampling, std::string &chroma_sub_sampling) { + switch (subsampling) { + case ROCJPEG_CSS_444: + chroma_sub_sampling = "YUV 4:4:4"; + break; + case ROCJPEG_CSS_440: + chroma_sub_sampling = "YUV 4:4:0"; + break; + case ROCJPEG_CSS_422: + chroma_sub_sampling = "YUV 4:2:2"; + break; + case ROCJPEG_CSS_420: + chroma_sub_sampling = "YUV 4:2:0"; + break; + case ROCJPEG_CSS_411: + chroma_sub_sampling = "YUV 4:1:1"; + break; + case ROCJPEG_CSS_400: + chroma_sub_sampling = "YUV 4:0:0"; + break; + case ROCJPEG_CSS_UNKNOWN: + chroma_sub_sampling = "UNKNOWN"; + break; + default: + chroma_sub_sampling = ""; + break; + } + } + + /** + * @brief Gets the channel pitch and sizes. + * + * This function gets the channel pitch and sizes based on the specified output format, chroma subsampling, + * output image, and channel sizes. + * + * @param decode_params The decode parameters that specify the output format and crop rectangle. + * @param subsampling The chroma subsampling. + * @param widths The array to store the channel widths. + * @param heights The array to store the channel heights. + * @param num_channels The number of channels. + * @param output_image The output image. + * @param channel_sizes The array to store the channel sizes. + * @return The channel pitch. + */ + int GetChannelPitchAndSizes(RocJpegDecodeParams decode_params, RocJpegChromaSubsampling subsampling, uint32_t *widths, uint32_t *heights, + uint32_t &num_channels, RocJpegImage &output_image, uint32_t *channel_sizes) { + + bool is_roi_valid = false; + uint32_t roi_width; + uint32_t roi_height; + roi_width = decode_params.crop_rectangle.right - decode_params.crop_rectangle.left; + roi_height = decode_params.crop_rectangle.bottom - decode_params.crop_rectangle.top; + if (roi_width > 0 && roi_height > 0 && roi_width <= widths[0] && roi_height <= heights[0]) { + is_roi_valid = true; + } + switch (decode_params.output_format) { + case ROCJPEG_OUTPUT_NATIVE: + switch (subsampling) { + case ROCJPEG_CSS_444: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + break; + case ROCJPEG_CSS_440: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + channel_sizes[2] = channel_sizes[1] = output_image.pitch[0] * (is_roi_valid ? align(roi_height >> 1, mem_alignment) : align(heights[0] >> 1, mem_alignment)); + break; + case ROCJPEG_CSS_422: + num_channels = 1; + output_image.pitch[0] = (is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment)) * 2; + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + break; + case ROCJPEG_CSS_420: + num_channels = 2; + output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + channel_sizes[1] = output_image.pitch[1] * (is_roi_valid ? align(roi_height >> 1, mem_alignment) : align(heights[0] >> 1, mem_alignment)); + break; + case ROCJPEG_CSS_400: + num_channels = 1; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + break; + default: + std::cout << "Unknown chroma subsampling!" << std::endl; + return EXIT_FAILURE; + } + break; + case ROCJPEG_OUTPUT_YUV_PLANAR: + if (subsampling == ROCJPEG_CSS_400) { + num_channels = 1; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + } else { + num_channels = 3; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + output_image.pitch[1] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[1], mem_alignment); + output_image.pitch[2] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[2], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + channel_sizes[1] = output_image.pitch[1] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[1], mem_alignment)); + channel_sizes[2] = output_image.pitch[2] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[2], mem_alignment)); + } + break; + case ROCJPEG_OUTPUT_Y: + num_channels = 1; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + break; + case ROCJPEG_OUTPUT_RGB: + num_channels = 1; + output_image.pitch[0] = (is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment)) * 3; + channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + break; + case ROCJPEG_OUTPUT_RGB_PLANAR: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); + channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); + break; + default: + std::cout << "Unknown output format!" << std::endl; + return EXIT_FAILURE; + } + return EXIT_SUCCESS; + } + + /** + * @brief Gets the output file extension. + * + * This function gets the output file extension based on the specified output format, base file name, + * image width, image height, and file name for saving. + * + * @param output_format The output format. + * @param base_file_name The base file name. + * @param image_width The image width. + * @param image_height The image height. + * @param file_name_for_saving The string to store the file name for saving. + */ + void GetOutputFileExt(RocJpegOutputFormat output_format, std::string &base_file_name, uint32_t image_width, uint32_t image_height, RocJpegChromaSubsampling subsampling, std::string &file_name_for_saving) { + std::string file_extension; + std::string::size_type const p(base_file_name.find_last_of('.')); + std::string file_name_no_ext = base_file_name.substr(0, p); + std::string format_description = ""; + switch (output_format) { + case ROCJPEG_OUTPUT_NATIVE: + file_extension = "yuv"; + switch (subsampling) { + case ROCJPEG_CSS_444: + format_description = "444"; + break; + case ROCJPEG_CSS_440: + format_description = "440"; + break; + case ROCJPEG_CSS_422: + format_description = "422_yuyv"; + break; + case ROCJPEG_CSS_420: + format_description = "nv12"; + break; + case ROCJPEG_CSS_400: + format_description = "400"; + break; + default: + std::cout << "Unknown chroma subsampling!" << std::endl; + return; + } + break; + case ROCJPEG_OUTPUT_YUV_PLANAR: + file_extension = "yuv"; + format_description = "planar"; + break; + case ROCJPEG_OUTPUT_Y: + file_extension = "yuv"; + format_description = "400"; + break; + case ROCJPEG_OUTPUT_RGB: + file_extension = "rgb"; + format_description = "packed"; + break; + case ROCJPEG_OUTPUT_RGB_PLANAR: + file_extension = "rgb"; + format_description = "planar"; + break; + default: + file_extension = ""; + break; + } + file_name_for_saving += "//" + file_name_no_ext + "_" + std::to_string(image_width) + "x" + + std::to_string(image_height) + "_" + format_description + "." + file_extension; + } + +private: + static const int mem_alignment = 16; + /** + * @brief Shows the help message and exits. + * + * This function shows the help message and exits the program. + * + * @param option The option to display in the help message (optional). + * @param show_threads Flag indicating whether to show the number of threads in the help message. + */ + static void ShowHelpAndExit(const char *option = nullptr, bool show_threads = false, bool show_batch_size = false) { + std::cout << "Options:\n" + "-i [input path] - input path to a single JPEG image or a directory containing JPEG images - [required]\n" + "-be [backend] - select rocJPEG backend (0 for hardware-accelerated JPEG decoding using VCN,\n" + " 1 for hybrid JPEG decoding using CPU and GPU HIP kernels (currently not supported)) [optional - default: 0]\n" + "-fmt [output format] - select rocJPEG output format for decoding, one of the [native, yuv_planar, y, rgb, rgb_planar] - [optional - default: native]\n" + "-o [output path] - path to an output file or a path to an existing directory - write decoded images to a file or an existing directory based on selected output format - [optional]\n" + "-crop [crop rectangle] - crop rectangle for output in a comma-separated format: left,top,right,bottom - [optional]\n" + "-d [device id] - specify the GPU device id for the desired device (use 0 for the first device, 1 for the second device, and so on) [optional - default: 0]\n"; + if (show_threads) { + std::cout << "-t [threads] - number of threads (<= 32) for parallel JPEG decoding - [optional - default: 1]\n"; + } + if (show_batch_size) { + std::cout << "-b [batch_size] - decode images from input by batches of a specified size - [optional - default: 1]\n"; + } + exit(0); + } + /** + * @brief Aligns a value to a specified alignment. + * + * This function takes a value and aligns it to the specified alignment. It returns the aligned value. + * + * @param value The value to be aligned. + * @param alignment The alignment value. + * @return The aligned value. + */ + static inline int align(int value, int alignment) { + return (value + alignment - 1) & ~(alignment - 1); + } +}; + +class ThreadPool { + public: + ThreadPool(int nthreads) : shutdown_(false) { + // Create the specified number of threads + threads_.reserve(nthreads); + for (int i = 0; i < nthreads; ++i) + threads_.emplace_back(std::bind(&ThreadPool::ThreadEntry, this, i)); + } + + ~ThreadPool() {} + + void JoinThreads() { + { + // Unblock any threads and tell them to stop + std::unique_lock lock(mutex_); + shutdown_ = true; + cond_var_.notify_all(); + } + + // Wait for all threads to stop + for (auto& thread : threads_) + thread.join(); + } + + void ExecuteJob(std::function func) { + // Place a job on the queue and unblock a thread + std::unique_lock lock(mutex_); + decode_jobs_queue_.emplace(std::move(func)); + cond_var_.notify_one(); + } + + protected: + void ThreadEntry(int i) { + std::function execute_decode_job; + + while (true) { + { + std::unique_lock lock(mutex_); + cond_var_.wait(lock, [&] {return shutdown_ || !decode_jobs_queue_.empty();}); + if (decode_jobs_queue_.empty()) { + // No jobs to do; shutting down + return; + } + + execute_decode_job = std::move(decode_jobs_queue_.front()); + decode_jobs_queue_.pop(); + } + + // Execute the decode job without holding any locks + execute_decode_job(); + } + } + + std::mutex mutex_; + std::condition_variable cond_var_; + bool shutdown_; + std::queue> decode_jobs_queue_; + std::vector threads_; +}; + +#endif //ROC_JPEG_SAMPLES_COMMON From e4c4fd0f1b9eeb5218042437ea15d7cb537abff8 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 8 Jan 2026 11:29:14 +0800 Subject: [PATCH 03/22] rm cout --- torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 9afb18abf32..d0c2f2cdacf 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -838,7 +838,6 @@ std::vector RocJpegDecoder::decode_images( uint32_t roi_width = decode_params_batch[index].crop_rectangle.right - decode_params_batch[index].crop_rectangle.left; uint32_t roi_height = decode_params_batch[index].crop_rectangle.bottom - decode_params_batch[index].crop_rectangle.top; bool is_roi_valid = (roi_width > 0 && roi_height > 0 && roi_width <= temp_widths[0] && roi_height <= temp_heights[0]) ? true : false; - std::cout << "is_roi_valid: " << is_roi_valid << "\n"; uint32_t width = is_roi_valid ? align(roi_width, 16) : align(temp_widths[0], 16); uint32_t height = is_roi_valid ? align(roi_height, 16) : align(temp_heights[0], 16); auto output_tensor = torch::zeros( From 3d9041c4012f2e265a8da057fc7655b6680c2aed Mon Sep 17 00:00:00 2001 From: xytpai Date: Fri, 16 Jan 2026 09:53:44 +0000 Subject: [PATCH 04/22] refine code --- setup.py | 2 +- test/test_image.py | 16 +- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 233 +++++++++++++++--- .../csrc/io/image/cuda/decode_jpegs_cuda.h | 22 +- 4 files changed, 232 insertions(+), 41 deletions(-) diff --git a/setup.py b/setup.py index 5ad744e5061..4b9559eb630 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" -USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "1") == "1" +USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "0") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) # Note: the GPU video decoding stuff used to be called "video codec", which # isn't an accurate or descriptive name considering there are at least 2 other diff --git a/test/test_image.py b/test/test_image.py index e30b5695241..b11dd67ca12 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -406,10 +406,8 @@ def test_read_interlaced_png(): @needs_cuda -# @pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) -@pytest.mark.parametrize("mode", [ImageReadMode.RGB]) -# @pytest.mark.parametrize("scripted", (False, True)) -@pytest.mark.parametrize("scripted", (False, )) +@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]) +@pytest.mark.parametrize("scripted", (False, True)) def test_decode_jpegs_cuda(mode, scripted): encoded_images = [] for jpeg_path in get_images(IMAGE_ROOT, ".jpg"): @@ -417,17 +415,15 @@ def test_decode_jpegs_cuda(mode, scripted): continue encoded_image = read_file(jpeg_path) encoded_images.append(encoded_image) - encoded_images = encoded_images[:3] - # encoded_images = [encoded_images[0], encoded_images[2], encoded_images[1]] decoded_images_cpu = decode_jpeg(encoded_images, mode=mode) decode_fn = torch.jit.script(decode_jpeg) if scripted else decode_jpeg # test multithreaded decoding # in the current version we prevent this by using a lock but we still want to test it - num_workers = 1 + num_workers = 10 with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(decode_fn, encoded_images, mode, "cuda:0") for _ in range(num_workers)] + futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)] decoded_images_threaded = [future.result() for future in futures] assert len(decoded_images_threaded) == num_workers for decoded_images in decoded_images_threaded: @@ -435,9 +431,7 @@ def test_decode_jpegs_cuda(mode, scripted): for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu): assert decoded_image_cuda.shape == decoded_image_cpu.shape assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8 - print(decoded_image_cuda.contiguous()) - print(decoded_image_cpu.contiguous().cpu()) - assert (decoded_image_cuda.contiguous().cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 5 + assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2 @needs_cuda diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index d0c2f2cdacf..9b974b7dc03 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -671,18 +671,13 @@ std::vector decode_jpegs_cuda( RocJpegOutputFormat output_format; switch (mode) { - case vision::image::IMAGE_READ_MODE_UNCHANGED: - output_format = ROCJPEG_OUTPUT_NATIVE; - break; - case vision::image::IMAGE_READ_MODE_GRAY: - output_format = ROCJPEG_OUTPUT_Y; - break; case vision::image::IMAGE_READ_MODE_RGB: output_format = ROCJPEG_OUTPUT_RGB_PLANAR; break; default: TORCH_CHECK( - false, "The provided mode is not supported for JPEG decoding on GPU"); + false, + "The provided mode is not supported for ROCJPEG decoding on GPU"); } try { @@ -736,8 +731,184 @@ RocJpegDecoder::~RocJpegDecoder() { rocJpegStreamDestroy(rocjpeg_stream_handles[1]); } +static constexpr int mem_alignment = 16; + static inline int align(int value, int alignment) { - return (value + alignment - 1) & ~(alignment - 1); + return (value + alignment - 1) & ~(alignment - 1); +} + +void getChromaSubsamplingStr( + RocJpegChromaSubsampling subsampling, + std::string& chroma_sub_sampling) { + switch (subsampling) { + case ROCJPEG_CSS_444: + chroma_sub_sampling = "YUV 4:4:4"; + break; + case ROCJPEG_CSS_440: + chroma_sub_sampling = "YUV 4:4:0"; + break; + case ROCJPEG_CSS_422: + chroma_sub_sampling = "YUV 4:2:2"; + break; + case ROCJPEG_CSS_420: + chroma_sub_sampling = "YUV 4:2:0"; + break; + case ROCJPEG_CSS_411: + chroma_sub_sampling = "YUV 4:1:1"; + break; + case ROCJPEG_CSS_400: + chroma_sub_sampling = "YUV 4:0:0"; + break; + case ROCJPEG_CSS_UNKNOWN: + chroma_sub_sampling = "UNKNOWN"; + break; + default: + chroma_sub_sampling = ""; + break; + } +} + +int getChannelPitchAndSizes( + RocJpegDecodeParams decode_params, + RocJpegChromaSubsampling subsampling, + uint32_t* widths, + uint32_t* heights, + uint32_t& num_channels, + RocJpegImage& output_image, + uint32_t* channel_sizes) { + bool is_roi_valid = false; + uint32_t roi_width; + uint32_t roi_height; + roi_width = + decode_params.crop_rectangle.right - decode_params.crop_rectangle.left; + roi_height = + decode_params.crop_rectangle.bottom - decode_params.crop_rectangle.top; + if (roi_width > 0 && roi_height > 0 && roi_width <= widths[0] && + roi_height <= heights[0]) { + is_roi_valid = true; + } + switch (decode_params.output_format) { + case ROCJPEG_OUTPUT_NATIVE: + switch (subsampling) { + case ROCJPEG_CSS_444: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = + output_image.pitch[0] = + is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = + output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + case ROCJPEG_CSS_440: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = + output_image.pitch[0] = + is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + channel_sizes[2] = channel_sizes[1] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height >> 1, mem_alignment) + : align(heights[0] >> 1, mem_alignment)); + break; + case ROCJPEG_CSS_422: + num_channels = 1; + output_image.pitch[0] = + (is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment)) * + 2; + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + case ROCJPEG_CSS_420: + num_channels = 2; + output_image.pitch[1] = output_image.pitch[0] = is_roi_valid + ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + channel_sizes[1] = output_image.pitch[1] * + (is_roi_valid ? align(roi_height >> 1, mem_alignment) + : align(heights[0] >> 1, mem_alignment)); + break; + case ROCJPEG_CSS_400: + num_channels = 1; + output_image.pitch[0] = is_roi_valid + ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + default: + std::cout << "Unknown chroma subsampling!" << std::endl; + return EXIT_FAILURE; + } + break; + case ROCJPEG_OUTPUT_YUV_PLANAR: + if (subsampling == ROCJPEG_CSS_400) { + num_channels = 1; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + } else { + num_channels = 3; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + output_image.pitch[1] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[1], mem_alignment); + output_image.pitch[2] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[2], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + channel_sizes[1] = output_image.pitch[1] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[1], mem_alignment)); + channel_sizes[2] = output_image.pitch[2] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[2], mem_alignment)); + } + break; + case ROCJPEG_OUTPUT_Y: + num_channels = 1; + output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + case ROCJPEG_OUTPUT_RGB: + num_channels = 1; + output_image.pitch[0] = (is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment)) * + 3; + channel_sizes[0] = output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + case ROCJPEG_OUTPUT_RGB_PLANAR: + num_channels = 3; + output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = + is_roi_valid ? align(roi_width, mem_alignment) + : align(widths[0], mem_alignment); + channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = + output_image.pitch[0] * + (is_roi_valid ? align(roi_height, mem_alignment) + : align(heights[0], mem_alignment)); + break; + default: + std::cout << "Unknown output format!" << std::endl; + return EXIT_FAILURE; + } + return EXIT_SUCCESS; } std::vector RocJpegDecoder::decode_images( @@ -757,7 +928,7 @@ std::vector RocJpegDecoder::decode_images( - output_tensors (std::vector): a vector of Tensors containing the decoded images */ - + int num_images = encoded_images.size(); std::vector output_tensors{num_images}; RocJpegStatus rocjpeg_status; @@ -773,7 +944,6 @@ std::vector RocJpegDecoder::decode_images( cudaStatus); constexpr int batch_size = 2; - RocJpegUtils rocjpeg_utils; std::string chroma_sub_sampling = ""; uint8_t num_components; RocJpegChromaSubsampling temp_subsampling; @@ -797,9 +967,9 @@ std::vector RocJpegDecoder::decode_images( for (int j = i; j < batch_end; j++) { int index = j - i; rocjpeg_status = rocJpegStreamParse( - (unsigned char*)encoded_images[j].data_ptr(), - encoded_images[j].numel(), - rocjpeg_stream_handles[index]); + (unsigned char*)encoded_images[j].data_ptr(), + encoded_images[j].numel(), + rocjpeg_stream_handles[index]); if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { TORCH_CHECK( false, @@ -813,8 +983,7 @@ std::vector RocJpegDecoder::decode_images( &temp_subsampling, temp_widths.data(), temp_heights.data())); - rocjpeg_utils.GetChromaSubsamplingStr( - temp_subsampling, chroma_sub_sampling); + getChromaSubsamplingStr(temp_subsampling, chroma_sub_sampling); if (temp_widths[0] < 64 || temp_heights[0] < 64) { TORCH_CHECK( false, "The image resolution is not supported by VCN Hardware"); @@ -824,7 +993,7 @@ std::vector RocJpegDecoder::decode_images( TORCH_CHECK( false, "The chroma sub-sampling is not supported by VCN Hardware"); } - if (rocjpeg_utils.GetChannelPitchAndSizes( + if (getChannelPitchAndSizes( decode_params_batch[index], temp_subsampling, temp_widths.data(), @@ -835,15 +1004,21 @@ std::vector RocJpegDecoder::decode_images( TORCH_CHECK(false, "ERROR: Failed to get the channel pitch and sizes"); } - uint32_t roi_width = decode_params_batch[index].crop_rectangle.right - decode_params_batch[index].crop_rectangle.left; - uint32_t roi_height = decode_params_batch[index].crop_rectangle.bottom - decode_params_batch[index].crop_rectangle.top; - bool is_roi_valid = (roi_width > 0 && roi_height > 0 && roi_width <= temp_widths[0] && roi_height <= temp_heights[0]) ? true : false; - uint32_t width = is_roi_valid ? align(roi_width, 16) : align(temp_widths[0], 16); - uint32_t height = is_roi_valid ? align(roi_height, 16) : align(temp_heights[0], 16); + uint32_t roi_width = decode_params_batch[index].crop_rectangle.right - + decode_params_batch[index].crop_rectangle.left; + uint32_t roi_height = decode_params_batch[index].crop_rectangle.bottom - + decode_params_batch[index].crop_rectangle.top; + bool is_roi_valid = + (roi_width > 0 && roi_height > 0 && roi_width <= temp_widths[0] && + roi_height <= temp_heights[0]) + ? true + : false; + uint32_t width = is_roi_valid ? align(roi_width, mem_alignment) + : align(temp_widths[0], mem_alignment); + uint32_t height = is_roi_valid ? align(roi_height, mem_alignment) + : align(temp_heights[0], mem_alignment); auto output_tensor = torch::zeros( - {int64_t(num_channels), - int64_t(height), - int64_t(width)}, + {int64_t(num_channels), int64_t(height), int64_t(width)}, torch::dtype(torch::kU8).device(target_device)); channels[j] = num_channels; @@ -855,15 +1030,17 @@ std::vector RocJpegDecoder::decode_images( // allocate memory for each channel and reuse them if the sizes remain // unchanged for a new image. for (int c = 0; c < (int)num_channels; c++) { - output_images[index].channel[c] = output_tensor[c].data_ptr(); + output_images[index].channel[c] = output_tensor[c].data_ptr(); } // for (int c = (int)num_channels; c < ROCJPEG_MAX_COMPONENT; c++) { // output_images[index].channel[c] = NULL; // output_images[index].pitch[c] = 0; // } - // output_tensors[j] = output_tensor; // output_tensor.narrow(1, 0, temp_heights[0]).narrow(2, 0, temp_widths[0]); + // output_tensors[j] = output_tensor; // output_tensor.narrow(1, 0, + // temp_heights[0]).narrow(2, 0, temp_widths[0]); current_batch_size++; - output_tensors[j] = output_tensor.narrow(1, 0, temp_heights[0]).narrow(2, 0, temp_widths[0]); + output_tensors[j] = output_tensor.narrow(1, 0, temp_heights[0]) + .narrow(2, 0, temp_widths[0]); } // if (current_batch_size == 2) { @@ -874,7 +1051,7 @@ std::vector RocJpegDecoder::decode_images( current_batch_size, decode_params_batch.data(), output_images.data())); - } + } current_batch_size = 0; } diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index a052e5e354e..5c0fa56113b 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -50,7 +50,6 @@ class CUDAJpegDecoder { #include #include -#include "rocjpeg_samples_utils.h" namespace vision { namespace image { @@ -74,4 +73,25 @@ class RocJpegDecoder { } // namespace image } // namespace vision +#define CHECK_ROCJPEG(call) \ + { \ + RocJpegStatus rocjpeg_status = (call); \ + if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { \ + std::cerr << #call << " returned " \ + << rocJpegGetErrorName(rocjpeg_status) << " at " << __FILE__ \ + << ":" << __LINE__ << std::endl; \ + exit(1); \ + } \ + } + +#define CHECK_HIP(call) \ + { \ + hipError_t hip_status = (call); \ + if (hip_status != hipSuccess) { \ + std::cout << "HIP failure: 'status: " << hipGetErrorName(hip_status) \ + << "' at " << __FILE__ << ":" << __LINE__ << std::endl; \ + exit(1); \ + } \ + } + #endif From 1d299860339b60ca1d5fd488e4149e429bcd33f6 Mon Sep 17 00:00:00 2001 From: xytpai Date: Fri, 16 Jan 2026 09:57:41 +0000 Subject: [PATCH 05/22] rm unused file --- .../io/image/cuda/rocjpeg_samples_utils.h | 567 ------------------ 1 file changed, 567 deletions(-) delete mode 100644 torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h diff --git a/torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h b/torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h deleted file mode 100644 index 3a9595dcc4f..00000000000 --- a/torchvision/csrc/io/image/cuda/rocjpeg_samples_utils.h +++ /dev/null @@ -1,567 +0,0 @@ -/* -Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved. - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -*/ -#ifndef ROC_JPEG_SAMPLES_COMMON -#define ROC_JPEG_SAMPLES_COMMON -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#if __cplusplus >= 201703L && __has_include() - #include - namespace fs = std::filesystem; -#else - #include - namespace fs = std::experimental::filesystem; -#endif -#include -#include "rocjpeg/rocjpeg.h" - -#define CHECK_ROCJPEG(call) { \ - RocJpegStatus rocjpeg_status = (call); \ - if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { \ - std::cerr << #call << " returned " << rocJpegGetErrorName(rocjpeg_status) << " at " << __FILE__ << ":" << __LINE__ << std::endl;\ - exit(1); \ - } \ -} - -#define CHECK_HIP(call) { \ - hipError_t hip_status = (call); \ - if (hip_status != hipSuccess) { \ - std::cout << "HIP failure: 'status: " << hipGetErrorName(hip_status) << "' at " << __FILE__ << ":" << __LINE__ << std::endl;\ - exit(1); \ - } \ -} - -/** - * @class RocJpegUtils - * @brief Utility class for rocJPEG samples. - * - * This class provides utility functions for rocJPEG samples, such as parsing command line arguments, - * getting file paths, initializing HIP device, getting chroma subsampling string, getting channel pitch and sizes, - * getting output file extension, and saving images. - */ -class RocJpegUtils { -public: - /** - * @brief Parses the command line arguments. - * - * This function parses the command line arguments and sets the corresponding variables. - * - * @param input_path The input path. - * @param output_file_path The output file path. - * @param save_images Flag indicating whether to save images. - * @param device_id The device ID. - * @param rocjpeg_backend The rocJPEG backend. - * @param decode_params The rocJPEG decode parameters. - * @param num_threads The number of threads. - * @param crop The crop rectangle. - * @param argc The number of command line arguments. - * @param argv The command line arguments. - */ - static void ParseCommandLine(std::string &input_path, std::string &output_file_path, bool &save_images, int &device_id, - RocJpegBackend &rocjpeg_backend, RocJpegDecodeParams &decode_params, int *num_threads, int *batch_size, int argc, char *argv[]) { - if(argc <= 1) { - ShowHelpAndExit("", num_threads != nullptr, batch_size != nullptr); - } - for (int i = 1; i < argc; i++) { - if (!strcmp(argv[i], "-h")) { - ShowHelpAndExit("", num_threads != nullptr, batch_size != nullptr); - } - if (!strcmp(argv[i], "-i")) { - if (++i == argc) { - ShowHelpAndExit("-i", num_threads != nullptr, batch_size != nullptr); - } - input_path = argv[i]; - continue; - } - if (!strcmp(argv[i], "-o")) { - if (++i == argc) { - ShowHelpAndExit("-o", num_threads != nullptr, batch_size != nullptr); - } - output_file_path = argv[i]; - save_images = true; - continue; - } - if (!strcmp(argv[i], "-d")) { - if (++i == argc) { - ShowHelpAndExit("-d", num_threads != nullptr, batch_size != nullptr); - } - device_id = atoi(argv[i]); - continue; - } - if (!strcmp(argv[i], "-be")) { - if (++i == argc) { - ShowHelpAndExit("-be", num_threads != nullptr, batch_size != nullptr); - } - rocjpeg_backend = static_cast(atoi(argv[i])); - continue; - } - if (!strcmp(argv[i], "-fmt")) { - if (++i == argc) { - ShowHelpAndExit("-fmt", num_threads != nullptr, batch_size != nullptr); - } - std::string selected_output_format = argv[i]; - if (selected_output_format == "native") { - decode_params.output_format = ROCJPEG_OUTPUT_NATIVE; - } else if (selected_output_format == "yuv_planar") { - decode_params.output_format = ROCJPEG_OUTPUT_YUV_PLANAR; - } else if (selected_output_format == "y") { - decode_params.output_format = ROCJPEG_OUTPUT_Y; - } else if (selected_output_format == "rgb") { - decode_params.output_format = ROCJPEG_OUTPUT_RGB; - } else if (selected_output_format == "rgb_planar") { - decode_params.output_format = ROCJPEG_OUTPUT_RGB_PLANAR; - } else { - ShowHelpAndExit(argv[i], num_threads != nullptr); - } - continue; - } - if (!strcmp(argv[i], "-t")) { - if (++i == argc) { - ShowHelpAndExit("-t", num_threads != nullptr, batch_size != nullptr); - } - if (num_threads != nullptr) { - *num_threads = atoi(argv[i]); - if (*num_threads <= 0 || *num_threads > 32) { - ShowHelpAndExit(argv[i], num_threads != nullptr, batch_size != nullptr); - } - } - continue; - } - if (!strcmp(argv[i], "-b")) { - if (++i == argc) { - ShowHelpAndExit("-b", num_threads != nullptr, batch_size != nullptr); - } - if (batch_size != nullptr) - *batch_size = atoi(argv[i]); - continue; - } - if (!strcmp(argv[i], "-crop")) { - if (++i == argc || 4 != sscanf(argv[i], "%hd,%hd,%hd,%hd", &decode_params.crop_rectangle.left, &decode_params.crop_rectangle.top, &decode_params.crop_rectangle.right, &decode_params.crop_rectangle.bottom)) { - ShowHelpAndExit("-crop"); - } - if ((&decode_params.crop_rectangle.right - &decode_params.crop_rectangle.left) % 2 == 1 || (&decode_params.crop_rectangle.bottom - &decode_params.crop_rectangle.top) % 2 == 1) { - std::cout << "output crop rectangle must have width and height of even numbers" << std::endl; - exit(1); - } - continue; - } - ShowHelpAndExit(argv[i], num_threads != nullptr, batch_size != nullptr); - } - } - - /** - * Checks if a file is a JPEG file. - * - * @param filePath The path to the file to be checked. - * @return True if the file is a JPEG file, false otherwise. - */ - static bool IsJPEG(const std::string& filePath) { - std::ifstream file(filePath, std::ios::binary); - if (!file.is_open()) { - std::cerr << "Failed to open file: " << filePath << std::endl; - return false; - } - - unsigned char buffer[2]; - file.read(reinterpret_cast(buffer), 2); - file.close(); - - // The first two bytes of every JPEG stream are always 0xFFD8, which represents the Start of Image (SOI) marker. - return buffer[0] == 0xFF && buffer[1] == 0xD8; - } - - /** - * @brief Gets the file paths. - * - * This function gets the file paths based on the input path and sets the corresponding variables. - * - * @param input_path The input path. - * @param file_paths The vector to store the file paths. - * @param is_dir Flag indicating whether the input path is a directory. - * @param is_file Flag indicating whether the input path is a file. - * @return True if successful, false otherwise. - */ - static bool GetFilePaths(std::string &input_path, std::vector &file_paths, bool &is_dir, bool &is_file) { - std::cout << "Reading images from disk, please wait!" << std::endl; - if (!fs::exists(input_path)) { - std::cerr << "ERROR: the input path does not exist!" << std::endl; - return false; - } - is_dir = fs::is_directory(input_path); - is_file = fs::is_regular_file(input_path); - if (is_dir) { - for (const auto &entry : fs::recursive_directory_iterator(input_path)) { - if (fs::is_regular_file(entry) && IsJPEG(entry.path().string())) { - file_paths.push_back(entry.path().string()); - } - } - } else if (is_file && IsJPEG(input_path)) { - file_paths.push_back(input_path); - } else { - std::cerr << "ERROR: the input path does not contain JPEG files!" << std::endl; - return false; - } - return true; - } - - /** - * @brief Initializes the HIP device. - * - * This function initializes the HIP device with the specified device ID. - * - * @param device_id The device ID. - * @return True if successful, false otherwise. - */ - static bool InitHipDevice(int device_id) { - int num_devices; - hipDeviceProp_t hip_dev_prop; - CHECK_HIP(hipGetDeviceCount(&num_devices)); - if (num_devices < 1) { - std::cerr << "ERROR: didn't find any GPU!" << std::endl; - return false; - } - if (device_id >= num_devices) { - std::cerr << "ERROR: the requested device_id is not found!" << std::endl; - return false; - } - CHECK_HIP(hipSetDevice(device_id)); - CHECK_HIP(hipGetDeviceProperties(&hip_dev_prop, device_id)); - - std::cout << "Using GPU device " << device_id << ": " << hip_dev_prop.name << "[" << hip_dev_prop.gcnArchName << "] on PCI bus " << - std::setfill('0') << std::setw(2) << std::right << std::hex << hip_dev_prop.pciBusID << ":" << std::setfill('0') << std::setw(2) << - std::right << std::hex << hip_dev_prop.pciDomainID << "." << hip_dev_prop.pciDeviceID << std::dec << std::endl; - - return true; - } - - /** - * @brief Gets the chroma subsampling string. - * - * This function gets the chroma subsampling string based on the specified subsampling value. - * - * @param subsampling The chroma subsampling value. - * @param chroma_sub_sampling The string to store the chroma subsampling. - */ - void GetChromaSubsamplingStr(RocJpegChromaSubsampling subsampling, std::string &chroma_sub_sampling) { - switch (subsampling) { - case ROCJPEG_CSS_444: - chroma_sub_sampling = "YUV 4:4:4"; - break; - case ROCJPEG_CSS_440: - chroma_sub_sampling = "YUV 4:4:0"; - break; - case ROCJPEG_CSS_422: - chroma_sub_sampling = "YUV 4:2:2"; - break; - case ROCJPEG_CSS_420: - chroma_sub_sampling = "YUV 4:2:0"; - break; - case ROCJPEG_CSS_411: - chroma_sub_sampling = "YUV 4:1:1"; - break; - case ROCJPEG_CSS_400: - chroma_sub_sampling = "YUV 4:0:0"; - break; - case ROCJPEG_CSS_UNKNOWN: - chroma_sub_sampling = "UNKNOWN"; - break; - default: - chroma_sub_sampling = ""; - break; - } - } - - /** - * @brief Gets the channel pitch and sizes. - * - * This function gets the channel pitch and sizes based on the specified output format, chroma subsampling, - * output image, and channel sizes. - * - * @param decode_params The decode parameters that specify the output format and crop rectangle. - * @param subsampling The chroma subsampling. - * @param widths The array to store the channel widths. - * @param heights The array to store the channel heights. - * @param num_channels The number of channels. - * @param output_image The output image. - * @param channel_sizes The array to store the channel sizes. - * @return The channel pitch. - */ - int GetChannelPitchAndSizes(RocJpegDecodeParams decode_params, RocJpegChromaSubsampling subsampling, uint32_t *widths, uint32_t *heights, - uint32_t &num_channels, RocJpegImage &output_image, uint32_t *channel_sizes) { - - bool is_roi_valid = false; - uint32_t roi_width; - uint32_t roi_height; - roi_width = decode_params.crop_rectangle.right - decode_params.crop_rectangle.left; - roi_height = decode_params.crop_rectangle.bottom - decode_params.crop_rectangle.top; - if (roi_width > 0 && roi_height > 0 && roi_width <= widths[0] && roi_height <= heights[0]) { - is_roi_valid = true; - } - switch (decode_params.output_format) { - case ROCJPEG_OUTPUT_NATIVE: - switch (subsampling) { - case ROCJPEG_CSS_444: - num_channels = 3; - output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - break; - case ROCJPEG_CSS_440: - num_channels = 3; - output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - channel_sizes[2] = channel_sizes[1] = output_image.pitch[0] * (is_roi_valid ? align(roi_height >> 1, mem_alignment) : align(heights[0] >> 1, mem_alignment)); - break; - case ROCJPEG_CSS_422: - num_channels = 1; - output_image.pitch[0] = (is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment)) * 2; - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - break; - case ROCJPEG_CSS_420: - num_channels = 2; - output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - channel_sizes[1] = output_image.pitch[1] * (is_roi_valid ? align(roi_height >> 1, mem_alignment) : align(heights[0] >> 1, mem_alignment)); - break; - case ROCJPEG_CSS_400: - num_channels = 1; - output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - break; - default: - std::cout << "Unknown chroma subsampling!" << std::endl; - return EXIT_FAILURE; - } - break; - case ROCJPEG_OUTPUT_YUV_PLANAR: - if (subsampling == ROCJPEG_CSS_400) { - num_channels = 1; - output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - } else { - num_channels = 3; - output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - output_image.pitch[1] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[1], mem_alignment); - output_image.pitch[2] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[2], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - channel_sizes[1] = output_image.pitch[1] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[1], mem_alignment)); - channel_sizes[2] = output_image.pitch[2] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[2], mem_alignment)); - } - break; - case ROCJPEG_OUTPUT_Y: - num_channels = 1; - output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - break; - case ROCJPEG_OUTPUT_RGB: - num_channels = 1; - output_image.pitch[0] = (is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment)) * 3; - channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - break; - case ROCJPEG_OUTPUT_RGB_PLANAR: - num_channels = 3; - output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = output_image.pitch[0] * (is_roi_valid ? align(roi_height, mem_alignment) : align(heights[0], mem_alignment)); - break; - default: - std::cout << "Unknown output format!" << std::endl; - return EXIT_FAILURE; - } - return EXIT_SUCCESS; - } - - /** - * @brief Gets the output file extension. - * - * This function gets the output file extension based on the specified output format, base file name, - * image width, image height, and file name for saving. - * - * @param output_format The output format. - * @param base_file_name The base file name. - * @param image_width The image width. - * @param image_height The image height. - * @param file_name_for_saving The string to store the file name for saving. - */ - void GetOutputFileExt(RocJpegOutputFormat output_format, std::string &base_file_name, uint32_t image_width, uint32_t image_height, RocJpegChromaSubsampling subsampling, std::string &file_name_for_saving) { - std::string file_extension; - std::string::size_type const p(base_file_name.find_last_of('.')); - std::string file_name_no_ext = base_file_name.substr(0, p); - std::string format_description = ""; - switch (output_format) { - case ROCJPEG_OUTPUT_NATIVE: - file_extension = "yuv"; - switch (subsampling) { - case ROCJPEG_CSS_444: - format_description = "444"; - break; - case ROCJPEG_CSS_440: - format_description = "440"; - break; - case ROCJPEG_CSS_422: - format_description = "422_yuyv"; - break; - case ROCJPEG_CSS_420: - format_description = "nv12"; - break; - case ROCJPEG_CSS_400: - format_description = "400"; - break; - default: - std::cout << "Unknown chroma subsampling!" << std::endl; - return; - } - break; - case ROCJPEG_OUTPUT_YUV_PLANAR: - file_extension = "yuv"; - format_description = "planar"; - break; - case ROCJPEG_OUTPUT_Y: - file_extension = "yuv"; - format_description = "400"; - break; - case ROCJPEG_OUTPUT_RGB: - file_extension = "rgb"; - format_description = "packed"; - break; - case ROCJPEG_OUTPUT_RGB_PLANAR: - file_extension = "rgb"; - format_description = "planar"; - break; - default: - file_extension = ""; - break; - } - file_name_for_saving += "//" + file_name_no_ext + "_" + std::to_string(image_width) + "x" - + std::to_string(image_height) + "_" + format_description + "." + file_extension; - } - -private: - static const int mem_alignment = 16; - /** - * @brief Shows the help message and exits. - * - * This function shows the help message and exits the program. - * - * @param option The option to display in the help message (optional). - * @param show_threads Flag indicating whether to show the number of threads in the help message. - */ - static void ShowHelpAndExit(const char *option = nullptr, bool show_threads = false, bool show_batch_size = false) { - std::cout << "Options:\n" - "-i [input path] - input path to a single JPEG image or a directory containing JPEG images - [required]\n" - "-be [backend] - select rocJPEG backend (0 for hardware-accelerated JPEG decoding using VCN,\n" - " 1 for hybrid JPEG decoding using CPU and GPU HIP kernels (currently not supported)) [optional - default: 0]\n" - "-fmt [output format] - select rocJPEG output format for decoding, one of the [native, yuv_planar, y, rgb, rgb_planar] - [optional - default: native]\n" - "-o [output path] - path to an output file or a path to an existing directory - write decoded images to a file or an existing directory based on selected output format - [optional]\n" - "-crop [crop rectangle] - crop rectangle for output in a comma-separated format: left,top,right,bottom - [optional]\n" - "-d [device id] - specify the GPU device id for the desired device (use 0 for the first device, 1 for the second device, and so on) [optional - default: 0]\n"; - if (show_threads) { - std::cout << "-t [threads] - number of threads (<= 32) for parallel JPEG decoding - [optional - default: 1]\n"; - } - if (show_batch_size) { - std::cout << "-b [batch_size] - decode images from input by batches of a specified size - [optional - default: 1]\n"; - } - exit(0); - } - /** - * @brief Aligns a value to a specified alignment. - * - * This function takes a value and aligns it to the specified alignment. It returns the aligned value. - * - * @param value The value to be aligned. - * @param alignment The alignment value. - * @return The aligned value. - */ - static inline int align(int value, int alignment) { - return (value + alignment - 1) & ~(alignment - 1); - } -}; - -class ThreadPool { - public: - ThreadPool(int nthreads) : shutdown_(false) { - // Create the specified number of threads - threads_.reserve(nthreads); - for (int i = 0; i < nthreads; ++i) - threads_.emplace_back(std::bind(&ThreadPool::ThreadEntry, this, i)); - } - - ~ThreadPool() {} - - void JoinThreads() { - { - // Unblock any threads and tell them to stop - std::unique_lock lock(mutex_); - shutdown_ = true; - cond_var_.notify_all(); - } - - // Wait for all threads to stop - for (auto& thread : threads_) - thread.join(); - } - - void ExecuteJob(std::function func) { - // Place a job on the queue and unblock a thread - std::unique_lock lock(mutex_); - decode_jobs_queue_.emplace(std::move(func)); - cond_var_.notify_one(); - } - - protected: - void ThreadEntry(int i) { - std::function execute_decode_job; - - while (true) { - { - std::unique_lock lock(mutex_); - cond_var_.wait(lock, [&] {return shutdown_ || !decode_jobs_queue_.empty();}); - if (decode_jobs_queue_.empty()) { - // No jobs to do; shutting down - return; - } - - execute_decode_job = std::move(decode_jobs_queue_.front()); - decode_jobs_queue_.pop(); - } - - // Execute the decode job without holding any locks - execute_decode_job(); - } - } - - std::mutex mutex_; - std::condition_variable cond_var_; - bool shutdown_; - std::queue> decode_jobs_queue_; - std::vector threads_; -}; - -#endif //ROC_JPEG_SAMPLES_COMMON From 15d8f1113958795c80acb3ff05cd1958cc3cde32 Mon Sep 17 00:00:00 2001 From: xytpai Date: Fri, 16 Jan 2026 10:09:43 +0000 Subject: [PATCH 06/22] refine code 2 --- setup.py | 6 ++--- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 27 ++----------------- 2 files changed, 5 insertions(+), 28 deletions(-) diff --git a/setup.py b/setup.py index 4b9559eb630..24a43c01778 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" -USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "0") == "1" +USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "1") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) # Note: the GPU video decoding stuff used to be called "video codec", which # isn't an accurate or descriptive name considering there are at least 2 other @@ -355,12 +355,12 @@ def make_image_extension(): if (USE_NVJPEG or USE_ROCJPEG) and (torch.cuda.is_available() or FORCE_CUDA): nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() rocjpeg_found = ROCM_HOME is not None and (Path(ROCM_HOME) / "include/rocjpeg/rocjpeg.h").exists() - if nvjpeg_found: + if nvjpeg_found and USE_NVJPEG: print("Building torchvision with NVJPEG image support") libraries.append("nvjpeg") define_macros += [("NVJPEG_FOUND", 1)] Extension = CUDAExtension - elif rocjpeg_found: + elif rocjpeg_found and USE_ROCJPEG: print("Building torchvision with ROCJPEG image support") libraries.append("rocjpeg") define_macros += [("ROCJPEG_FOUND", 1)] diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 9b974b7dc03..e5b49e1067e 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -920,8 +920,7 @@ std::vector RocJpegDecoder::decode_images( Args: - encoded_images (std::vector): a vector of tensors containing the jpeg bitstreams to be decoded - - output_format (RocJpegOutputFormat): ROCJPEG_OUTPUT_RGB, ROCJPEG_OUTPUT_Y - or ROCJPEG_OUTPUT_NATIVE + - output_format (RocJpegOutputFormat): ROCJPEG_OUTPUT_RGB - device (torch::Device): The desired CUDA device for the returned Tensors Returns: @@ -1017,27 +1016,16 @@ std::vector RocJpegDecoder::decode_images( : align(temp_widths[0], mem_alignment); uint32_t height = is_roi_valid ? align(roi_height, mem_alignment) : align(temp_heights[0], mem_alignment); - auto output_tensor = torch::zeros( + auto output_tensor = torch::empty( {int64_t(num_channels), int64_t(height), int64_t(width)}, torch::dtype(torch::kU8).device(target_device)); channels[j] = num_channels; - // for (int n = 0; n < (int)num_channels; n++) { - // output_images[current_batch_size].channel[n] = - // output_tensor[n].data_ptr(); - // } - // allocate memory for each channel and reuse them if the sizes remain // unchanged for a new image. for (int c = 0; c < (int)num_channels; c++) { output_images[index].channel[c] = output_tensor[c].data_ptr(); } - // for (int c = (int)num_channels; c < ROCJPEG_MAX_COMPONENT; c++) { - // output_images[index].channel[c] = NULL; - // output_images[index].pitch[c] = 0; - // } - // output_tensors[j] = output_tensor; // output_tensor.narrow(1, 0, - // temp_heights[0]).narrow(2, 0, temp_widths[0]); current_batch_size++; output_tensors[j] = output_tensor.narrow(1, 0, temp_heights[0]) .narrow(2, 0, temp_widths[0]); @@ -1062,17 +1050,6 @@ std::vector RocJpegDecoder::decode_images( "Failed to synchronize CUDA stream: ", cudaStatus); - // prune extraneous channels from single channel images - if (output_format == ROCJPEG_OUTPUT_NATIVE) { - for (std::vector::size_type i = 0; i < output_tensors.size(); - ++i) { - if (channels[i] == 1) { - output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); - } - } - } - - cudaDeviceSynchronize(); return output_tensors; } From b68f0ef7bda712722ed13c07101c5efa52cc6f35 Mon Sep 17 00:00:00 2001 From: xytpai Date: Sat, 13 Jun 2026 14:57:19 +0000 Subject: [PATCH 07/22] full format support --- test/test_image.py | 4 +-- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 29 ++++++++++++++++--- .../csrc/io/image/cuda/decode_jpegs_cuda.h | 3 +- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index b11dd67ca12..a6e6a367798 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -423,7 +423,7 @@ def test_decode_jpegs_cuda(mode, scripted): num_workers = 10 with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)] + futures = [executor.submit(decode_fn, encoded_images, mode, f"cuda:{torch.cuda.current_device()}") for _ in range(num_workers)] decoded_images_threaded = [future.result() for future in futures] assert len(decoded_images_threaded) == num_workers for decoded_images in decoded_images_threaded: @@ -431,7 +431,7 @@ def test_decode_jpegs_cuda(mode, scripted): for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu): assert decoded_image_cuda.shape == decoded_image_cpu.shape assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8 - assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2 + assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2.5 @needs_cuda diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 5f777f6f200..49c40d30523 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -665,8 +665,16 @@ std::vector decode_jpegs_cuda( } RocJpegOutputFormat output_format; + bool prune_single_channel = false; switch (mode) { + case vision::image::IMAGE_READ_MODE_UNCHANGED: + output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + prune_single_channel = true; + break; + case vision::image::IMAGE_READ_MODE_GRAY: + output_format = ROCJPEG_OUTPUT_Y; + break; case vision::image::IMAGE_READ_MODE_RGB: output_format = ROCJPEG_OUTPUT_RGB_PLANAR; break; @@ -678,7 +686,8 @@ std::vector decode_jpegs_cuda( try { at::cuda::CUDAEvent event; - auto result = rocJpegDecoder->decode_images(contig_images, output_format); + auto result = rocJpegDecoder->decode_images( + contig_images, output_format, prune_single_channel); auto current_stream{ device.has_index() ? at::cuda::getCurrentCUDAStream( rocJpegDecoder->original_device.index()) @@ -702,7 +711,9 @@ RocJpegDecoder::RocJpegDecoder(const torch::Device& target_device) target_device.has_index() ? at::cuda::getStreamFromPool(false, target_device.index()) : at::cuda::getStreamFromPool(false)} { - int device_id = target_device.index(); + int device_id = + target_device.has_index() ? target_device.index() + : c10::cuda::current_device(); CHECK_HIP(hipSetDevice(device_id)); RocJpegStatus status; RocJpegBackend rocjpeg_backend = ROCJPEG_BACKEND_HARDWARE; @@ -909,7 +920,8 @@ int getChannelPitchAndSizes( std::vector RocJpegDecoder::decode_images( const std::vector& encoded_images, - const RocJpegOutputFormat& output_format) { + const RocJpegOutputFormat& output_format, + bool prune_single_channel) { /* This function decodes a batch of jpeg bitstreams. @@ -1015,7 +1027,7 @@ std::vector RocJpegDecoder::decode_images( auto output_tensor = torch::empty( {int64_t(num_channels), int64_t(height), int64_t(width)}, torch::dtype(torch::kU8).device(target_device)); - channels[j] = num_channels; + channels[j] = num_components; // allocate memory for each channel and reuse them if the sizes remain // unchanged for a new image. @@ -1046,6 +1058,15 @@ std::vector RocJpegDecoder::decode_images( "Failed to synchronize CUDA stream: ", cudaStatus); + if (prune_single_channel) { + for (std::vector::size_type i = 0; i < output_tensors.size(); + ++i) { + if (channels[i] == 1) { + output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); + } + } + } + return output_tensors; } diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 5c0fa56113b..fe132dd5fc6 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -60,7 +60,8 @@ class RocJpegDecoder { std::vector decode_images( const std::vector& encoded_images, - const RocJpegOutputFormat& output_format); + const RocJpegOutputFormat& output_format, + bool prune_single_channel); const torch::Device original_device; const torch::Device target_device; From e113fcc16165812b3adb63ffdbb1131014965101 Mon Sep 17 00:00:00 2001 From: xytpai Date: Sat, 13 Jun 2026 15:09:20 +0000 Subject: [PATCH 08/22] remove stream dependency --- test/test_image.py | 2 +- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 38 +++---------------- .../csrc/io/image/cuda/decode_jpegs_cuda.h | 7 +--- 3 files changed, 7 insertions(+), 40 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index a6e6a367798..91b86dc20fd 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -423,7 +423,7 @@ def test_decode_jpegs_cuda(mode, scripted): num_workers = 10 with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(decode_fn, encoded_images, mode, f"cuda:{torch.cuda.current_device()}") for _ in range(num_workers)] + futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)] decoded_images_threaded = [future.result() for future in futures] assert len(decoded_images_threaded) == num_workers for decoded_images in decoded_images_threaded: diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 49c40d30523..2c893cdd400 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -596,14 +596,10 @@ std::vector CUDAJpegDecoder::decode_images( } // namespace image } // namespace vision -#endif - -#if ROCJPEG_FOUND +#elif ROCJPEG_FOUND #include -#include #include -#include #include #include #include @@ -685,16 +681,10 @@ std::vector decode_jpegs_cuda( } try { - at::cuda::CUDAEvent event; - auto result = rocJpegDecoder->decode_images( + // rocJPEG owns and synchronizes its internal HIP stream; there is no + // caller-visible CUDA stream here to record an event on. + return rocJpegDecoder->decode_images( contig_images, output_format, prune_single_channel); - auto current_stream{ - device.has_index() ? at::cuda::getCurrentCUDAStream( - rocJpegDecoder->original_device.index()) - : at::cuda::getCurrentCUDAStream()}; - event.record(rocJpegDecoder->stream); - event.block(current_stream); - return result; } catch (const std::exception& e) { if (typeid(e) != typeid(std::runtime_error)) { TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); @@ -705,12 +695,7 @@ std::vector decode_jpegs_cuda( } RocJpegDecoder::RocJpegDecoder(const torch::Device& target_device) - : original_device{torch::kCUDA, c10::cuda::current_device()}, - target_device{target_device}, - stream{ - target_device.has_index() - ? at::cuda::getStreamFromPool(false, target_device.index()) - : at::cuda::getStreamFromPool(false)} { + : target_device{target_device} { int device_id = target_device.has_index() ? target_device.index() : c10::cuda::current_device(); @@ -939,17 +924,10 @@ std::vector RocJpegDecoder::decode_images( int num_images = encoded_images.size(); std::vector output_tensors{num_images}; RocJpegStatus rocjpeg_status; - cudaError_t cudaStatus; // baseline JPEGs can be batch decoded with hardware support std::vector channels(num_images); - cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK( - cudaStatus == cudaSuccess, - "Failed to synchronize CUDA stream: ", - cudaStatus); - constexpr int batch_size = 2; std::string chroma_sub_sampling = ""; uint8_t num_components; @@ -1052,12 +1030,6 @@ std::vector RocJpegDecoder::decode_images( current_batch_size = 0; } - cudaStatus = cudaStreamSynchronize(stream); - TORCH_CHECK( - cudaStatus == cudaSuccess, - "Failed to synchronize CUDA stream: ", - cudaStatus); - if (prune_single_channel) { for (std::vector::size_type i = 0; i < output_tensors.size(); ++i) { diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index fe132dd5fc6..7dcf7fd6d97 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -44,11 +44,8 @@ class CUDAJpegDecoder { } // namespace image } // namespace vision -#endif - -#if ROCJPEG_FOUND +#elif ROCJPEG_FOUND -#include #include namespace vision { @@ -63,9 +60,7 @@ class RocJpegDecoder { const RocJpegOutputFormat& output_format, bool prune_single_channel); - const torch::Device original_device; const torch::Device target_device; - const c10::cuda::CUDAStream stream; private: RocJpegStreamHandle rocjpeg_stream_handles[2]; From 85b55f1dcaf4e6fdc6696ef65154629ef4d15909 Mon Sep 17 00:00:00 2001 From: xytpai Date: Sat, 13 Jun 2026 15:19:29 +0000 Subject: [PATCH 09/22] make batch-size dynamic --- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 33 ++++++++++--------- .../csrc/io/image/cuda/decode_jpegs_cuda.h | 5 ++- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 2c893cdd400..e492f30b39c 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -707,20 +707,24 @@ RocJpegDecoder::RocJpegDecoder(const torch::Device& target_device) TORCH_CHECK( status == ROCJPEG_STATUS_SUCCESS, "Failed to initialize rocjpeg with hardware backend"); - - status = rocJpegStreamCreate(&rocjpeg_stream_handles[0]); - TORCH_CHECK( - status == ROCJPEG_STATUS_SUCCESS, "Failed to initialize rocjpeg stream"); - - status = rocJpegStreamCreate(&rocjpeg_stream_handles[1]); - TORCH_CHECK( - status == ROCJPEG_STATUS_SUCCESS, "Failed to initialize rocjpeg stream"); } RocJpegDecoder::~RocJpegDecoder() { rocJpegDestroy(rocjpeg_handle); - rocJpegStreamDestroy(rocjpeg_stream_handles[0]); - rocJpegStreamDestroy(rocjpeg_stream_handles[1]); + for (auto stream_handle : rocjpeg_stream_handles) { + rocJpegStreamDestroy(stream_handle); + } +} + +void RocJpegDecoder::ensure_stream_handles(std::size_t num_handles) { + while (rocjpeg_stream_handles.size() < num_handles) { + RocJpegStreamHandle stream_handle; + RocJpegStatus status = rocJpegStreamCreate(&stream_handle); + TORCH_CHECK( + status == ROCJPEG_STATUS_SUCCESS, + "Failed to initialize rocjpeg stream"); + rocjpeg_stream_handles.push_back(stream_handle); + } } static constexpr int mem_alignment = 16; @@ -928,7 +932,8 @@ std::vector RocJpegDecoder::decode_images( // baseline JPEGs can be batch decoded with hardware support std::vector channels(num_images); - constexpr int batch_size = 2; + const int batch_size = num_images; + ensure_stream_handles(static_cast(batch_size)); std::string chroma_sub_sampling = ""; uint8_t num_components; RocJpegChromaSubsampling temp_subsampling; @@ -943,9 +948,6 @@ std::vector RocJpegDecoder::decode_images( int current_batch_size = 0; uint32_t channel_sizes[ROCJPEG_MAX_COMPONENT] = {}; uint32_t num_channels = 0; - std::vector> prior_channel_sizes; - prior_channel_sizes.resize( - batch_size, std::vector(ROCJPEG_MAX_COMPONENT, 0)); for (int i = 0; i < num_images; i += batch_size) { int batch_end = std::min(i + batch_size, num_images); @@ -1017,11 +1019,10 @@ std::vector RocJpegDecoder::decode_images( .narrow(2, 0, temp_widths[0]); } - // if (current_batch_size == 2) { if (current_batch_size > 0) { CHECK_ROCJPEG(rocJpegDecodeBatched( rocjpeg_handle, - rocjpeg_stream_handles, + rocjpeg_stream_handles.data(), current_batch_size, decode_params_batch.data(), output_images.data())); diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 7dcf7fd6d97..133eeccc979 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include "../common.h" @@ -63,7 +64,9 @@ class RocJpegDecoder { const torch::Device target_device; private: - RocJpegStreamHandle rocjpeg_stream_handles[2]; + void ensure_stream_handles(std::size_t num_handles); + + std::vector rocjpeg_stream_handles; RocJpegHandle rocjpeg_handle; }; } // namespace image From dd23f0e816893969f3b7c398cfd52d03a2fe43af Mon Sep 17 00:00:00 2001 From: xytpai Date: Sat, 13 Jun 2026 15:34:41 +0000 Subject: [PATCH 10/22] resolve remaining comments --- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 190 ++---------------- .../csrc/io/image/cuda/decode_jpegs_cuda.h | 35 ++-- 2 files changed, 38 insertions(+), 187 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index e492f30b39c..bc1a9420ebe 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -635,13 +635,13 @@ std::vector decode_jpegs_cuda( TORCH_CHECK( !encoded_image.is_cuda(), - "The input tensor must be on CPU when decoding with nvjpeg") + "The input tensor must be on CPU when decoding with rocJPEG") TORCH_CHECK( encoded_image.dim() == 1 && encoded_image.numel() > 0, "Expected a non empty 1-dimensional tensor"); - // nvjpeg requires images to be contiguous + // rocJPEG requires images to be contiguous if (encoded_image.is_contiguous()) { contig_images.push_back(encoded_image); } else { @@ -696,9 +696,8 @@ std::vector decode_jpegs_cuda( RocJpegDecoder::RocJpegDecoder(const torch::Device& target_device) : target_device{target_device} { - int device_id = - target_device.has_index() ? target_device.index() - : c10::cuda::current_device(); + int device_id = target_device.has_index() ? target_device.index() + : c10::cuda::current_device(); CHECK_HIP(hipSetDevice(device_id)); RocJpegStatus status; RocJpegBackend rocjpeg_backend = ROCJPEG_BACKEND_HARDWARE; @@ -733,45 +732,12 @@ static inline int align(int value, int alignment) { return (value + alignment - 1) & ~(alignment - 1); } -void getChromaSubsamplingStr( - RocJpegChromaSubsampling subsampling, - std::string& chroma_sub_sampling) { - switch (subsampling) { - case ROCJPEG_CSS_444: - chroma_sub_sampling = "YUV 4:4:4"; - break; - case ROCJPEG_CSS_440: - chroma_sub_sampling = "YUV 4:4:0"; - break; - case ROCJPEG_CSS_422: - chroma_sub_sampling = "YUV 4:2:2"; - break; - case ROCJPEG_CSS_420: - chroma_sub_sampling = "YUV 4:2:0"; - break; - case ROCJPEG_CSS_411: - chroma_sub_sampling = "YUV 4:1:1"; - break; - case ROCJPEG_CSS_400: - chroma_sub_sampling = "YUV 4:0:0"; - break; - case ROCJPEG_CSS_UNKNOWN: - chroma_sub_sampling = "UNKNOWN"; - break; - default: - chroma_sub_sampling = ""; - break; - } -} - -int getChannelPitchAndSizes( +void setOutputChannelPitches( RocJpegDecodeParams decode_params, - RocJpegChromaSubsampling subsampling, uint32_t* widths, uint32_t* heights, uint32_t& num_channels, - RocJpegImage& output_image, - uint32_t* channel_sizes) { + RocJpegImage& output_image) { bool is_roi_valid = false; uint32_t roi_width; uint32_t roi_height; @@ -784,127 +750,20 @@ int getChannelPitchAndSizes( is_roi_valid = true; } switch (decode_params.output_format) { - case ROCJPEG_OUTPUT_NATIVE: - switch (subsampling) { - case ROCJPEG_CSS_444: - num_channels = 3; - output_image.pitch[2] = output_image.pitch[1] = - output_image.pitch[0] = - is_roi_valid ? align(roi_width, mem_alignment) - : align(widths[0], mem_alignment); - channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = - output_image.pitch[0] * - (is_roi_valid ? align(roi_height, mem_alignment) - : align(heights[0], mem_alignment)); - break; - case ROCJPEG_CSS_440: - num_channels = 3; - output_image.pitch[2] = output_image.pitch[1] = - output_image.pitch[0] = - is_roi_valid ? align(roi_width, mem_alignment) - : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * - (is_roi_valid ? align(roi_height, mem_alignment) - : align(heights[0], mem_alignment)); - channel_sizes[2] = channel_sizes[1] = output_image.pitch[0] * - (is_roi_valid ? align(roi_height >> 1, mem_alignment) - : align(heights[0] >> 1, mem_alignment)); - break; - case ROCJPEG_CSS_422: - num_channels = 1; - output_image.pitch[0] = - (is_roi_valid ? align(roi_width, mem_alignment) - : align(widths[0], mem_alignment)) * - 2; - channel_sizes[0] = output_image.pitch[0] * - (is_roi_valid ? align(roi_height, mem_alignment) - : align(heights[0], mem_alignment)); - break; - case ROCJPEG_CSS_420: - num_channels = 2; - output_image.pitch[1] = output_image.pitch[0] = is_roi_valid - ? align(roi_width, mem_alignment) - : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * - (is_roi_valid ? align(roi_height, mem_alignment) - : align(heights[0], mem_alignment)); - channel_sizes[1] = output_image.pitch[1] * - (is_roi_valid ? align(roi_height >> 1, mem_alignment) - : align(heights[0] >> 1, mem_alignment)); - break; - case ROCJPEG_CSS_400: - num_channels = 1; - output_image.pitch[0] = is_roi_valid - ? align(roi_width, mem_alignment) - : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * - (is_roi_valid ? align(roi_height, mem_alignment) - : align(heights[0], mem_alignment)); - break; - default: - std::cout << "Unknown chroma subsampling!" << std::endl; - return EXIT_FAILURE; - } - break; - case ROCJPEG_OUTPUT_YUV_PLANAR: - if (subsampling == ROCJPEG_CSS_400) { - num_channels = 1; - output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) - : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * - (is_roi_valid ? align(roi_height, mem_alignment) - : align(heights[0], mem_alignment)); - } else { - num_channels = 3; - output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) - : align(widths[0], mem_alignment); - output_image.pitch[1] = is_roi_valid ? align(roi_width, mem_alignment) - : align(widths[1], mem_alignment); - output_image.pitch[2] = is_roi_valid ? align(roi_width, mem_alignment) - : align(widths[2], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * - (is_roi_valid ? align(roi_height, mem_alignment) - : align(heights[0], mem_alignment)); - channel_sizes[1] = output_image.pitch[1] * - (is_roi_valid ? align(roi_height, mem_alignment) - : align(heights[1], mem_alignment)); - channel_sizes[2] = output_image.pitch[2] * - (is_roi_valid ? align(roi_height, mem_alignment) - : align(heights[2], mem_alignment)); - } - break; case ROCJPEG_OUTPUT_Y: num_channels = 1; output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[0] = output_image.pitch[0] * - (is_roi_valid ? align(roi_height, mem_alignment) - : align(heights[0], mem_alignment)); - break; - case ROCJPEG_OUTPUT_RGB: - num_channels = 1; - output_image.pitch[0] = (is_roi_valid ? align(roi_width, mem_alignment) - : align(widths[0], mem_alignment)) * - 3; - channel_sizes[0] = output_image.pitch[0] * - (is_roi_valid ? align(roi_height, mem_alignment) - : align(heights[0], mem_alignment)); break; case ROCJPEG_OUTPUT_RGB_PLANAR: num_channels = 3; output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) : align(widths[0], mem_alignment); - channel_sizes[2] = channel_sizes[1] = channel_sizes[0] = - output_image.pitch[0] * - (is_roi_valid ? align(roi_height, mem_alignment) - : align(heights[0], mem_alignment)); break; default: - std::cout << "Unknown output format!" << std::endl; - return EXIT_FAILURE; + TORCH_CHECK(false, "Unsupported rocJPEG output format"); } - return EXIT_SUCCESS; } std::vector RocJpegDecoder::decode_images( @@ -917,8 +776,8 @@ std::vector RocJpegDecoder::decode_images( Args: - encoded_images (std::vector): a vector of tensors containing the jpeg bitstreams to be decoded - - output_format (RocJpegOutputFormat): ROCJPEG_OUTPUT_RGB - - device (torch::Device): The desired CUDA device for the returned Tensors + - output_format (RocJpegOutputFormat): ROCJPEG_OUTPUT_Y or + ROCJPEG_OUTPUT_RGB_PLANAR Returns: - output_tensors (std::vector): a vector of Tensors @@ -934,7 +793,6 @@ std::vector RocJpegDecoder::decode_images( const int batch_size = num_images; ensure_stream_handles(static_cast(batch_size)); - std::string chroma_sub_sampling = ""; uint8_t num_components; RocJpegChromaSubsampling temp_subsampling; std::vector temp_widths(ROCJPEG_MAX_COMPONENT, 0); @@ -946,7 +804,6 @@ std::vector RocJpegDecoder::decode_images( std::vector output_images; output_images.resize(batch_size); int current_batch_size = 0; - uint32_t channel_sizes[ROCJPEG_MAX_COMPONENT] = {}; uint32_t num_channels = 0; for (int i = 0; i < num_images; i += batch_size) { @@ -970,7 +827,6 @@ std::vector RocJpegDecoder::decode_images( &temp_subsampling, temp_widths.data(), temp_heights.data())); - getChromaSubsamplingStr(temp_subsampling, chroma_sub_sampling); if (temp_widths[0] < 64 || temp_heights[0] < 64) { TORCH_CHECK( false, "The image resolution is not supported by VCN Hardware"); @@ -980,16 +836,12 @@ std::vector RocJpegDecoder::decode_images( TORCH_CHECK( false, "The chroma sub-sampling is not supported by VCN Hardware"); } - if (getChannelPitchAndSizes( - decode_params_batch[index], - temp_subsampling, - temp_widths.data(), - temp_heights.data(), - num_channels, - output_images[index], - channel_sizes)) { - TORCH_CHECK(false, "ERROR: Failed to get the channel pitch and sizes"); - } + setOutputChannelPitches( + decode_params_batch[index], + temp_widths.data(), + temp_heights.data(), + num_channels, + output_images[index]); uint32_t roi_width = decode_params_batch[index].crop_rectangle.right - decode_params_batch[index].crop_rectangle.left; @@ -1031,12 +883,12 @@ std::vector RocJpegDecoder::decode_images( current_batch_size = 0; } - if (prune_single_channel) { - for (std::vector::size_type i = 0; i < output_tensors.size(); - ++i) { - if (channels[i] == 1) { - output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); - } + for (std::vector::size_type i = 0; i < output_tensors.size(); + ++i) { + if (prune_single_channel && channels[i] == 1) { + output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); + } else { + output_tensors[i] = output_tensors[i].contiguous(); } } diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 133eeccc979..2db2a21bccc 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -1,6 +1,6 @@ #pragma once -#include #include +#include #include #include "../common.h" @@ -72,25 +72,24 @@ class RocJpegDecoder { } // namespace image } // namespace vision -#define CHECK_ROCJPEG(call) \ - { \ - RocJpegStatus rocjpeg_status = (call); \ - if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { \ - std::cerr << #call << " returned " \ - << rocJpegGetErrorName(rocjpeg_status) << " at " << __FILE__ \ - << ":" << __LINE__ << std::endl; \ - exit(1); \ - } \ +#define CHECK_ROCJPEG(call) \ + { \ + RocJpegStatus rocjpeg_status = (call); \ + TORCH_CHECK( \ + rocjpeg_status == ROCJPEG_STATUS_SUCCESS, \ + #call, \ + " returned ", \ + rocJpegGetErrorName(rocjpeg_status)); \ } -#define CHECK_HIP(call) \ - { \ - hipError_t hip_status = (call); \ - if (hip_status != hipSuccess) { \ - std::cout << "HIP failure: 'status: " << hipGetErrorName(hip_status) \ - << "' at " << __FILE__ << ":" << __LINE__ << std::endl; \ - exit(1); \ - } \ +#define CHECK_HIP(call) \ + { \ + hipError_t hip_status = (call); \ + TORCH_CHECK( \ + hip_status == hipSuccess, \ + #call, \ + " failed with status: ", \ + hipGetErrorName(hip_status)); \ } #endif From 722a4afa6b847e9448773bc6c73b7cee0ac9cf5f Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 17 Jun 2026 23:54:37 -0700 Subject: [PATCH 11/22] [ROCm] Clean up rocJPEG decode and share GPU JPEG scaffolding (#2) The rocJPEG path duplicated the entire decode_jpegs_cuda() entry point (input validation, decoder lifecycle, error handling) from the nvJPEG path, and carried a large amount of rocJPEG-sample boilerplate that was dead or misleading in this context. This reworks the file so the backend-agnostic orchestration lives once and is shared by both backends, and each backend only implements what is genuinely backend-specific. Review in this order: 1. decode_jpegs_cuda.h: both decoders expose a uniform decode_images(images, mode), and a GpuJpegDecoder type alias selects the compiled-in backend. 2. decode_jpegs_cuda.cpp shared region (NVJPEG_FOUND || ROCJPEG_FOUND): the single decode_jpegs_cuda() entry point plus a validate_and_make_contiguous() helper. The input validation, the device guard, the decoder singleton lifecycle, and the error wrapper are no longer duplicated between backends. 3. nvJPEG block: the mode-to-format mapping, the version-property warning, and the event-based stream synchronization move into CUDAJpegDecoder::decode_images(mode); the existing nvJPEG internals are otherwise untouched. 4. rocJPEG block: rewritten to drop the dead ROI/crop handling (the decode params were always zero-initialized, so it never ran), the vestigial single-pass batch loop, the misleading memory-reuse comment, the unused iostream/fstream/sstream includes, and the bespoke typeid-based catch. STD_TORCH_CHECK is used throughout to match the surrounding code. The CUDA JPEG decode test tolerance bump is now gated on ROCm so it does not weaken the nvJPEG assertion. The nvJPEG path cannot be built on a ROCm host; its changes are mechanical relocations of existing lines and rely on CUDA CI for confirmation. Test Plan: Built against ROCm 7.2.1 and rocJPEG 1.4.0 on a gfx90a GPU: ``` USE_ROCJPEG=1 PYTORCH_ROCM_ARCH=gfx90a pip install -e . --no-build-isolation ``` Ran the GPU JPEG decode tests: ``` python -m pytest test/test_image.py \ -k "test_decode_jpegs_cuda or test_decode_jpeg_cuda_errors or test_decode_jpeg_cuda_device_param" ``` All 8 selected tests pass (UNCHANGED/GRAY/RGB, scripted and eager, plus the error and device-parameter cases). The unrelated test_encode_*_cuda failures are pre-existing: GPU JPEG encode is nvJPEG-only and is not part of this change. Authored with assistance from Claude (Anthropic). --- test/test_image.py | 5 +- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 457 ++++++------------ .../csrc/io/image/cuda/decode_jpegs_cuda.h | 15 +- 3 files changed, 172 insertions(+), 305 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 91b86dc20fd..8f105dc4c5c 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -426,12 +426,15 @@ def test_decode_jpegs_cuda(mode, scripted): futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)] decoded_images_threaded = [future.result() for future in futures] assert len(decoded_images_threaded) == num_workers + # rocJPEG's color conversion differs slightly from nvJPEG, so it needs a + # looser tolerance against the CPU reference. + tol = 2.5 if torch.version.hip is not None else 2 for decoded_images in decoded_images_threaded: assert len(decoded_images) == len(encoded_images) for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu): assert decoded_image_cuda.shape == decoded_image_cpu.shape assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8 - assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2.5 + assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < tol @needs_cuda diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index bc1a9420ebe..4b657b0b2d6 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -13,24 +13,37 @@ std::vector decode_jpegs_cuda( } // namespace vision #endif -#if NVJPEG_FOUND -#include -#include +#if NVJPEG_FOUND || ROCJPEG_FOUND #include -#include -#include -#include -#include +#include #include #include -#include -#include -#include + namespace vision { namespace image { - +namespace { std::mutex decoderMutex; -std::unique_ptr cudaJpegDecoder; +std::unique_ptr gpuJpegDecoder; + +std::vector validate_and_make_contiguous( + const std::vector& encoded_images) { + std::vector contig_images; + contig_images.reserve(encoded_images.size()); + for (auto& encoded_image : encoded_images) { + STD_TORCH_CHECK( + encoded_image.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + STD_TORCH_CHECK( + !encoded_image.is_cuda(), "The input tensor must be on CPU"); + STD_TORCH_CHECK( + encoded_image.dim() == 1 && encoded_image.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + // The decoder backends require contiguous input; contiguous() is a no-op + // when the tensor already is. + contig_images.push_back(encoded_image.contiguous()); + } + return contig_images; +} +} // namespace std::vector decode_jpegs_cuda( const std::vector& encoded_images, @@ -40,32 +53,54 @@ std::vector decode_jpegs_cuda( "torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda"); std::lock_guard lock(decoderMutex); - std::vector contig_images; - contig_images.reserve(encoded_images.size()); STD_TORCH_CHECK( device.is_cuda(), "Expected the device parameter to be a cuda device"); - for (auto& encoded_image : encoded_images) { - STD_TORCH_CHECK( - encoded_image.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - - STD_TORCH_CHECK( - !encoded_image.is_cuda(), - "The input tensor must be on CPU when decoding with nvjpeg") + std::vector contig_images = + validate_and_make_contiguous(encoded_images); - STD_TORCH_CHECK( - encoded_image.dim() == 1 && encoded_image.numel() > 0, - "Expected a non empty 1-dimensional tensor"); + at::cuda::CUDAGuard device_guard(device); - // nvjpeg requires images to be contiguous - if (encoded_image.is_contiguous()) { - contig_images.push_back(encoded_image); + if (gpuJpegDecoder == nullptr || device != gpuJpegDecoder->target_device) { + if (gpuJpegDecoder != nullptr) { + gpuJpegDecoder.reset(new GpuJpegDecoder(device)); } else { - contig_images.push_back(encoded_image.contiguous()); + gpuJpegDecoder = std::make_unique(device); + std::atexit([]() { gpuJpegDecoder.reset(); }); } } + try { + return gpuJpegDecoder->decode_images(contig_images, mode); + } catch (const std::exception& e) { + STD_TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); + } +} + +} // namespace image +} // namespace vision +#endif + +#if NVJPEG_FOUND +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace vision { +namespace image { + +std::vector CUDAJpegDecoder::decode_images( + const std::vector& encoded_images, + vision::image::ImageReadMode mode) { int major_version; int minor_version; nvjpegStatus_t get_major_property_status = @@ -87,17 +122,6 @@ std::vector decode_jpegs_cuda( "Make sure to rely on CUDA 11.6 or above before using decode_jpeg(..., device='cuda')."); } - at::cuda::CUDAGuard device_guard(device); - - if (cudaJpegDecoder == nullptr || device != cudaJpegDecoder->target_device) { - if (cudaJpegDecoder != nullptr) { - cudaJpegDecoder.reset(new CUDAJpegDecoder(device)); - } else { - cudaJpegDecoder = std::make_unique(device); - std::atexit([]() { cudaJpegDecoder.reset(); }); - } - } - nvjpegOutputFormat_t output_format; switch (mode) { @@ -119,19 +143,15 @@ std::vector decode_jpegs_cuda( false, "The provided mode is not supported for JPEG decoding on GPU"); } - try { - at::cuda::CUDAEvent event; - auto result = cudaJpegDecoder->decode_images(contig_images, output_format); - auto current_stream{ - device.has_index() ? at::cuda::getCurrentCUDAStream( - cudaJpegDecoder->original_device.index()) - : at::cuda::getCurrentCUDAStream()}; - event.record(cudaJpegDecoder->stream); - event.block(current_stream); - return result; - } catch (const std::exception& e) { - STD_TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); - } + at::cuda::CUDAEvent event; + auto result = decode_images(encoded_images, output_format); + auto current_stream{ + target_device.has_index() + ? at::cuda::getCurrentCUDAStream(original_device.index()) + : at::cuda::getCurrentCUDAStream()}; + event.record(stream); + event.block(current_stream); + return result; } CUDAJpegDecoder::CUDAJpegDecoder(const torch::Device& target_device) @@ -599,113 +619,25 @@ std::vector CUDAJpegDecoder::decode_images( #elif ROCJPEG_FOUND #include -#include -#include -#include -#include -#include -#include -#include -#include -#include namespace vision { namespace image { -std::mutex decoderMutex; -std::unique_ptr rocJpegDecoder; - -std::vector decode_jpegs_cuda( - const std::vector& encoded_images, - vision::image::ImageReadMode mode, - torch::Device device) { - C10_LOG_API_USAGE_ONCE( - "torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda"); - - std::lock_guard lock(decoderMutex); - std::vector contig_images; - contig_images.reserve(encoded_images.size()); - - TORCH_CHECK( - device.is_cuda(), "Expected the device parameter to be a cuda device"); - - for (auto& encoded_image : encoded_images) { - TORCH_CHECK( - encoded_image.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - - TORCH_CHECK( - !encoded_image.is_cuda(), - "The input tensor must be on CPU when decoding with rocJPEG") - - TORCH_CHECK( - encoded_image.dim() == 1 && encoded_image.numel() > 0, - "Expected a non empty 1-dimensional tensor"); - - // rocJPEG requires images to be contiguous - if (encoded_image.is_contiguous()) { - contig_images.push_back(encoded_image); - } else { - contig_images.push_back(encoded_image.contiguous()); - } - } - - at::cuda::CUDAGuard device_guard(device); - - if (rocJpegDecoder == nullptr || device != rocJpegDecoder->target_device) { - if (rocJpegDecoder != nullptr) { - rocJpegDecoder.reset(new RocJpegDecoder(device)); - } else { - rocJpegDecoder = std::make_unique(device); - std::atexit([]() { rocJpegDecoder.reset(); }); - } - } - - RocJpegOutputFormat output_format; - bool prune_single_channel = false; +namespace { +constexpr uint32_t kRocJpegPitchAlignment = 16; - switch (mode) { - case vision::image::IMAGE_READ_MODE_UNCHANGED: - output_format = ROCJPEG_OUTPUT_RGB_PLANAR; - prune_single_channel = true; - break; - case vision::image::IMAGE_READ_MODE_GRAY: - output_format = ROCJPEG_OUTPUT_Y; - break; - case vision::image::IMAGE_READ_MODE_RGB: - output_format = ROCJPEG_OUTPUT_RGB_PLANAR; - break; - default: - TORCH_CHECK( - false, - "The provided mode is not supported for ROCJPEG decoding on GPU"); - } - - try { - // rocJPEG owns and synchronizes its internal HIP stream; there is no - // caller-visible CUDA stream here to record an event on. - return rocJpegDecoder->decode_images( - contig_images, output_format, prune_single_channel); - } catch (const std::exception& e) { - if (typeid(e) != typeid(std::runtime_error)) { - TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); - } else { - throw; - } - } +uint32_t align_up(uint32_t value, uint32_t alignment) { + return (value + alignment - 1) & ~(alignment - 1); } +} // namespace RocJpegDecoder::RocJpegDecoder(const torch::Device& target_device) : target_device{target_device} { int device_id = target_device.has_index() ? target_device.index() : c10::cuda::current_device(); CHECK_HIP(hipSetDevice(device_id)); - RocJpegStatus status; - RocJpegBackend rocjpeg_backend = ROCJPEG_BACKEND_HARDWARE; - - status = rocJpegCreate(rocjpeg_backend, device_id, &rocjpeg_handle); - TORCH_CHECK( - status == ROCJPEG_STATUS_SUCCESS, - "Failed to initialize rocjpeg with hardware backend"); + CHECK_ROCJPEG( + rocJpegCreate(ROCJPEG_BACKEND_HARDWARE, device_id, &rocjpeg_handle)); } RocJpegDecoder::~RocJpegDecoder() { @@ -718,174 +650,101 @@ RocJpegDecoder::~RocJpegDecoder() { void RocJpegDecoder::ensure_stream_handles(std::size_t num_handles) { while (rocjpeg_stream_handles.size() < num_handles) { RocJpegStreamHandle stream_handle; - RocJpegStatus status = rocJpegStreamCreate(&stream_handle); - TORCH_CHECK( - status == ROCJPEG_STATUS_SUCCESS, - "Failed to initialize rocjpeg stream"); + CHECK_ROCJPEG(rocJpegStreamCreate(&stream_handle)); rocjpeg_stream_handles.push_back(stream_handle); } } -static constexpr int mem_alignment = 16; - -static inline int align(int value, int alignment) { - return (value + alignment - 1) & ~(alignment - 1); -} - -void setOutputChannelPitches( - RocJpegDecodeParams decode_params, - uint32_t* widths, - uint32_t* heights, - uint32_t& num_channels, - RocJpegImage& output_image) { - bool is_roi_valid = false; - uint32_t roi_width; - uint32_t roi_height; - roi_width = - decode_params.crop_rectangle.right - decode_params.crop_rectangle.left; - roi_height = - decode_params.crop_rectangle.bottom - decode_params.crop_rectangle.top; - if (roi_width > 0 && roi_height > 0 && roi_width <= widths[0] && - roi_height <= heights[0]) { - is_roi_valid = true; - } - switch (decode_params.output_format) { - case ROCJPEG_OUTPUT_Y: - num_channels = 1; - output_image.pitch[0] = is_roi_valid ? align(roi_width, mem_alignment) - : align(widths[0], mem_alignment); +std::vector RocJpegDecoder::decode_images( + const std::vector& encoded_images, + vision::image::ImageReadMode mode) { + RocJpegOutputFormat output_format; + bool prune_single_channel = false; + switch (mode) { + case vision::image::IMAGE_READ_MODE_UNCHANGED: + // rocJPEG has no "unchanged" output; decode to RGB and drop the extra + // channels from grayscale images afterwards, matching the nvJPEG path. + output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + prune_single_channel = true; + break; + case vision::image::IMAGE_READ_MODE_GRAY: + output_format = ROCJPEG_OUTPUT_Y; break; - case ROCJPEG_OUTPUT_RGB_PLANAR: - num_channels = 3; - output_image.pitch[2] = output_image.pitch[1] = output_image.pitch[0] = - is_roi_valid ? align(roi_width, mem_alignment) - : align(widths[0], mem_alignment); + case vision::image::IMAGE_READ_MODE_RGB: + output_format = ROCJPEG_OUTPUT_RGB_PLANAR; break; default: - TORCH_CHECK(false, "Unsupported rocJPEG output format"); + STD_TORCH_CHECK( + false, "The provided mode is not supported for JPEG decoding on GPU"); } -} + const uint32_t num_channels = output_format == ROCJPEG_OUTPUT_Y ? 1 : 3; -std::vector RocJpegDecoder::decode_images( - const std::vector& encoded_images, - const RocJpegOutputFormat& output_format, - bool prune_single_channel) { - /* - This function decodes a batch of jpeg bitstreams. + const std::size_t num_images = encoded_images.size(); + ensure_stream_handles(num_images); - Args: - - encoded_images (std::vector): a vector of tensors - containing the jpeg bitstreams to be decoded - - output_format (RocJpegOutputFormat): ROCJPEG_OUTPUT_Y or - ROCJPEG_OUTPUT_RGB_PLANAR + std::vector decode_params(num_images); + std::vector output_images(num_images); + std::vector output_tensors(num_images); + std::vector source_channels(num_images); - Returns: - - output_tensors (std::vector): a vector of Tensors - containing the decoded images - */ + for (std::size_t i = 0; i < num_images; ++i) { + CHECK_ROCJPEG(rocJpegStreamParse( + static_cast(encoded_images[i].data_ptr()), + encoded_images[i].numel(), + rocjpeg_stream_handles[i])); + + uint8_t num_components = 0; + RocJpegChromaSubsampling subsampling = ROCJPEG_CSS_UNKNOWN; + uint32_t widths[ROCJPEG_MAX_COMPONENT] = {}; + uint32_t heights[ROCJPEG_MAX_COMPONENT] = {}; + CHECK_ROCJPEG(rocJpegGetImageInfo( + rocjpeg_handle, + rocjpeg_stream_handles[i], + &num_components, + &subsampling, + widths, + heights)); - int num_images = encoded_images.size(); - std::vector output_tensors{num_images}; - RocJpegStatus rocjpeg_status; - - // baseline JPEGs can be batch decoded with hardware support - std::vector channels(num_images); - - const int batch_size = num_images; - ensure_stream_handles(static_cast(batch_size)); - uint8_t num_components; - RocJpegChromaSubsampling temp_subsampling; - std::vector temp_widths(ROCJPEG_MAX_COMPONENT, 0); - std::vector temp_heights(ROCJPEG_MAX_COMPONENT, 0); - RocJpegDecodeParams decode_params = {}; - decode_params.output_format = output_format; - std::vector decode_params_batch; - decode_params_batch.resize(batch_size, decode_params); - std::vector output_images; - output_images.resize(batch_size); - int current_batch_size = 0; - uint32_t num_channels = 0; - - for (int i = 0; i < num_images; i += batch_size) { - int batch_end = std::min(i + batch_size, num_images); - for (int j = i; j < batch_end; j++) { - int index = j - i; - rocjpeg_status = rocJpegStreamParse( - (unsigned char*)encoded_images[j].data_ptr(), - encoded_images[j].numel(), - rocjpeg_stream_handles[index]); - if (rocjpeg_status != ROCJPEG_STATUS_SUCCESS) { - TORCH_CHECK( - false, - "ERROR: Failed to parse the input jpeg stream with ", - rocJpegGetErrorName(rocjpeg_status)); - } - CHECK_ROCJPEG(rocJpegGetImageInfo( - rocjpeg_handle, - rocjpeg_stream_handles[index], - &num_components, - &temp_subsampling, - temp_widths.data(), - temp_heights.data())); - if (temp_widths[0] < 64 || temp_heights[0] < 64) { - TORCH_CHECK( - false, "The image resolution is not supported by VCN Hardware"); - } - if (temp_subsampling == ROCJPEG_CSS_411 || - temp_subsampling == ROCJPEG_CSS_UNKNOWN) { - TORCH_CHECK( - false, "The chroma sub-sampling is not supported by VCN Hardware"); - } - setOutputChannelPitches( - decode_params_batch[index], - temp_widths.data(), - temp_heights.data(), - num_channels, - output_images[index]); - - uint32_t roi_width = decode_params_batch[index].crop_rectangle.right - - decode_params_batch[index].crop_rectangle.left; - uint32_t roi_height = decode_params_batch[index].crop_rectangle.bottom - - decode_params_batch[index].crop_rectangle.top; - bool is_roi_valid = - (roi_width > 0 && roi_height > 0 && roi_width <= temp_widths[0] && - roi_height <= temp_heights[0]) - ? true - : false; - uint32_t width = is_roi_valid ? align(roi_width, mem_alignment) - : align(temp_widths[0], mem_alignment); - uint32_t height = is_roi_valid ? align(roi_height, mem_alignment) - : align(temp_heights[0], mem_alignment); - auto output_tensor = torch::empty( - {int64_t(num_channels), int64_t(height), int64_t(width)}, - torch::dtype(torch::kU8).device(target_device)); - channels[j] = num_components; - - // allocate memory for each channel and reuse them if the sizes remain - // unchanged for a new image. - for (int c = 0; c < (int)num_channels; c++) { - output_images[index].channel[c] = output_tensor[c].data_ptr(); - } - current_batch_size++; - output_tensors[j] = output_tensor.narrow(1, 0, temp_heights[0]) - .narrow(2, 0, temp_widths[0]); - } + const uint32_t width = widths[0]; + const uint32_t height = heights[0]; + STD_TORCH_CHECK( + width >= 64 && height >= 64, + "Image resolution ", + width, + "x", + height, + " is below the VCN hardware JPEG decoder minimum of 64x64"); + STD_TORCH_CHECK( + subsampling != ROCJPEG_CSS_411 && subsampling != ROCJPEG_CSS_UNKNOWN, + "The image chroma subsampling is not supported by the VCN hardware JPEG decoder"); + + // VCN writes rows at a 16-byte-aligned pitch, so allocate a buffer padded + // to that alignment and return a view of the valid region. + const uint32_t pitch = align_up(width, kRocJpegPitchAlignment); + auto buffer = torch::empty( + {int64_t(num_channels), + int64_t(align_up(height, kRocJpegPitchAlignment)), + int64_t(pitch)}, + torch::dtype(torch::kU8).device(target_device)); - if (current_batch_size > 0) { - CHECK_ROCJPEG(rocJpegDecodeBatched( - rocjpeg_handle, - rocjpeg_stream_handles.data(), - current_batch_size, - decode_params_batch.data(), - output_images.data())); + decode_params[i].output_format = output_format; + for (uint32_t c = 0; c < num_channels; ++c) { + output_images[i].channel[c] = buffer[c].data_ptr(); + output_images[i].pitch[c] = pitch; } - - current_batch_size = 0; + source_channels[i] = num_components; + output_tensors[i] = buffer.narrow(1, 0, height).narrow(2, 0, width); } - for (std::vector::size_type i = 0; i < output_tensors.size(); - ++i) { - if (prune_single_channel && channels[i] == 1) { + CHECK_ROCJPEG(rocJpegDecodeBatched( + rocjpeg_handle, + rocjpeg_stream_handles.data(), + static_cast(num_images), + decode_params.data(), + output_images.data())); + + for (std::size_t i = 0; i < num_images; ++i) { + if (prune_single_channel && source_channels[i] == 1) { output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); } else { output_tensors[i] = output_tensors[i].contiguous(); diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 2db2a21bccc..4dbc7e654f4 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -18,13 +18,16 @@ class CUDAJpegDecoder { std::vector decode_images( const std::vector& encoded_images, - const nvjpegOutputFormat_t& output_format); + vision::image::ImageReadMode mode); const torch::Device original_device; const torch::Device target_device; const c10::cuda::CUDAStream stream; private: + std::vector decode_images( + const std::vector& encoded_images, + const nvjpegOutputFormat_t& output_format); std::tuple< std::vector, std::vector, @@ -42,11 +45,13 @@ class CUDAJpegDecoder { bool hw_decode_available{false}; nvjpegHandle_t nvjpeg_handle; }; +using GpuJpegDecoder = CUDAJpegDecoder; } // namespace image } // namespace vision #elif ROCJPEG_FOUND +#include #include namespace vision { @@ -58,8 +63,7 @@ class RocJpegDecoder { std::vector decode_images( const std::vector& encoded_images, - const RocJpegOutputFormat& output_format, - bool prune_single_channel); + vision::image::ImageReadMode mode); const torch::Device target_device; @@ -69,13 +73,14 @@ class RocJpegDecoder { std::vector rocjpeg_stream_handles; RocJpegHandle rocjpeg_handle; }; +using GpuJpegDecoder = RocJpegDecoder; } // namespace image } // namespace vision #define CHECK_ROCJPEG(call) \ { \ RocJpegStatus rocjpeg_status = (call); \ - TORCH_CHECK( \ + STD_TORCH_CHECK( \ rocjpeg_status == ROCJPEG_STATUS_SUCCESS, \ #call, \ " returned ", \ @@ -85,7 +90,7 @@ class RocJpegDecoder { #define CHECK_HIP(call) \ { \ hipError_t hip_status = (call); \ - TORCH_CHECK( \ + STD_TORCH_CHECK( \ hip_status == hipSuccess, \ #call, \ " failed with status: ", \ From a319739e3cf9ae5965c4f937a3ba1549d69f4450 Mon Sep 17 00:00:00 2001 From: xytpai Date: Thu, 18 Jun 2026 08:27:34 +0000 Subject: [PATCH 12/22] refine IMAGE_READ_MODE_UNCHANGED --- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 57 ++++++++++++++----- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 4b657b0b2d6..3017ca67374 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -659,13 +659,9 @@ std::vector RocJpegDecoder::decode_images( const std::vector& encoded_images, vision::image::ImageReadMode mode) { RocJpegOutputFormat output_format; - bool prune_single_channel = false; switch (mode) { case vision::image::IMAGE_READ_MODE_UNCHANGED: - // rocJPEG has no "unchanged" output; decode to RGB and drop the extra - // channels from grayscale images afterwards, matching the nvJPEG path. - output_format = ROCJPEG_OUTPUT_RGB_PLANAR; - prune_single_channel = true; + output_format = ROCJPEG_OUTPUT_NATIVE; break; case vision::image::IMAGE_READ_MODE_GRAY: output_format = ROCJPEG_OUTPUT_Y; @@ -677,7 +673,6 @@ std::vector RocJpegDecoder::decode_images( STD_TORCH_CHECK( false, "The provided mode is not supported for JPEG decoding on GPU"); } - const uint32_t num_channels = output_format == ROCJPEG_OUTPUT_Y ? 1 : 3; const std::size_t num_images = encoded_images.size(); ensure_stream_handles(num_images); @@ -720,14 +715,52 @@ std::vector RocJpegDecoder::decode_images( // VCN writes rows at a 16-byte-aligned pitch, so allocate a buffer padded // to that alignment and return a view of the valid region. - const uint32_t pitch = align_up(width, kRocJpegPitchAlignment); + uint32_t pitch = align_up(width, kRocJpegPitchAlignment); + uint32_t num_channels; + switch (output_format) { + case ROCJPEG_OUTPUT_NATIVE: + switch (subsampling) { + case ROCJPEG_CSS_444: + case ROCJPEG_CSS_440: + case ROCJPEG_CSS_420: + num_channels = 3; + break; + case ROCJPEG_CSS_422: + num_channels = 1; + pitch = align_up(width * 2, kRocJpegPitchAlignment); + break; + case ROCJPEG_CSS_400: + num_channels = 1; + break; + default: + TORCH_CHECK(false, "Unsupported rocJPEG native chroma subsampling"); + } + break; + case ROCJPEG_OUTPUT_Y: + num_channels = 1; + break; + case ROCJPEG_OUTPUT_RGB_PLANAR: + num_channels = 3; + break; + default: + TORCH_CHECK(false, "Unsupported rocJPEG output format"); + } + auto buffer = torch::empty( {int64_t(num_channels), int64_t(align_up(height, kRocJpegPitchAlignment)), int64_t(pitch)}, torch::dtype(torch::kU8).device(target_device)); - - decode_params[i].output_format = output_format; + + auto image_output_format = output_format; + if (output_format == ROCJPEG_OUTPUT_NATIVE) { + // ROCJPEG_OUTPUT_NATIVE returns YUV/native layouts whose channel count and + // plane sizes depend on chroma subsampling. torchvision's UNCHANGED mode is + // expected to match the CPU/nvJPEG behavior: grayscale JPEGs return one + // channel, while color JPEGs return RGB. Decode to that compatible layout. + image_output_format = num_components == 1 ? ROCJPEG_OUTPUT_Y : ROCJPEG_OUTPUT_RGB_PLANAR; + } + decode_params[i].output_format = image_output_format; for (uint32_t c = 0; c < num_channels; ++c) { output_images[i].channel[c] = buffer[c].data_ptr(); output_images[i].pitch[c] = pitch; @@ -744,11 +777,7 @@ std::vector RocJpegDecoder::decode_images( output_images.data())); for (std::size_t i = 0; i < num_images; ++i) { - if (prune_single_channel && source_channels[i] == 1) { - output_tensors[i] = output_tensors[i][0].unsqueeze(0).clone(); - } else { - output_tensors[i] = output_tensors[i].contiguous(); - } + output_tensors[i] = output_tensors[i].contiguous(); } return output_tensors; From 4b719087c557d71077af7413e749add3507feffd Mon Sep 17 00:00:00 2001 From: xytpai Date: Thu, 18 Jun 2026 08:39:35 +0000 Subject: [PATCH 13/22] rm dead code & refine comment --- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 3017ca67374..7094e414bd0 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -88,11 +88,7 @@ std::vector decode_jpegs_cuda( #include #include #include -#include -#include #include -#include -#include #include #include namespace vision { @@ -751,14 +747,16 @@ std::vector RocJpegDecoder::decode_images( int64_t(align_up(height, kRocJpegPitchAlignment)), int64_t(pitch)}, torch::dtype(torch::kU8).device(target_device)); - + auto image_output_format = output_format; if (output_format == ROCJPEG_OUTPUT_NATIVE) { - // ROCJPEG_OUTPUT_NATIVE returns YUV/native layouts whose channel count and - // plane sizes depend on chroma subsampling. torchvision's UNCHANGED mode is - // expected to match the CPU/nvJPEG behavior: grayscale JPEGs return one - // channel, while color JPEGs return RGB. Decode to that compatible layout. - image_output_format = num_components == 1 ? ROCJPEG_OUTPUT_Y : ROCJPEG_OUTPUT_RGB_PLANAR; + // ROCJPEG_OUTPUT_NATIVE returns YUV/native layouts whose channel count + // and plane sizes depend on chroma subsampling. torchvision's UNCHANGED + // mode is expected to match the CPU/nvJPEG behavior: grayscale JPEGs + // return one channel, while color JPEGs return RGB. Decode to that + // compatible layout. + image_output_format = + num_components == 1 ? ROCJPEG_OUTPUT_Y : ROCJPEG_OUTPUT_RGB_PLANAR; } decode_params[i].output_format = image_output_format; for (uint32_t c = 0; c < num_channels; ++c) { @@ -769,6 +767,8 @@ std::vector RocJpegDecoder::decode_images( output_tensors[i] = buffer.narrow(1, 0, height).narrow(2, 0, width); } + // Choosing a batch size that is a multiple of the available JPEG cores is + // recommended. CHECK_ROCJPEG(rocJpegDecodeBatched( rocjpeg_handle, rocjpeg_stream_handles.data(), From 7ce968fffcffe502a0708ecf5d61e95f7a350332 Mon Sep 17 00:00:00 2001 From: xytpai Date: Thu, 18 Jun 2026 08:43:41 +0000 Subject: [PATCH 14/22] recover nv path --- torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 7094e414bd0..44cb48afdbb 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -88,7 +88,11 @@ std::vector decode_jpegs_cuda( #include #include #include +#include +#include #include +#include +#include #include #include namespace vision { From 248894cc74a86a73bdca9740a4d9966e5126cfa9 Mon Sep 17 00:00:00 2001 From: xytpai Date: Thu, 18 Jun 2026 17:08:49 +0000 Subject: [PATCH 15/22] resolve comments --- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 56 +++++++------------ 1 file changed, 21 insertions(+), 35 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 44cb48afdbb..9c045296dbf 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -680,7 +680,6 @@ std::vector RocJpegDecoder::decode_images( std::vector decode_params(num_images); std::vector output_images(num_images); std::vector output_tensors(num_images); - std::vector source_channels(num_images); for (std::size_t i = 0; i < num_images; ++i) { CHECK_ROCJPEG(rocJpegStreamParse( @@ -708,34 +707,27 @@ std::vector RocJpegDecoder::decode_images( width, "x", height, - " is below the VCN hardware JPEG decoder minimum of 64x64"); + " is below the rocJPEG hardware JPEG decoder minimum of 64x64"); STD_TORCH_CHECK( subsampling != ROCJPEG_CSS_411 && subsampling != ROCJPEG_CSS_UNKNOWN, - "The image chroma subsampling is not supported by the VCN hardware JPEG decoder"); + "The image chroma subsampling is not supported by the rocJPEG hardware JPEG decoder"); - // VCN writes rows at a 16-byte-aligned pitch, so allocate a buffer padded + auto image_output_format = output_format; + if (output_format == ROCJPEG_OUTPUT_NATIVE) { + // ROCJPEG_OUTPUT_NATIVE returns YUV/native layouts whose channel count + // and plane sizes depend on chroma subsampling. torchvision's UNCHANGED + // mode is expected to match the CPU/nvJPEG behavior: grayscale JPEGs + // return one channel, while color JPEGs return RGB. Decode to that + // compatible layout. + image_output_format = + num_components == 1 ? ROCJPEG_OUTPUT_Y : ROCJPEG_OUTPUT_RGB_PLANAR; + } + + // rocJPEG writes rows at a 16-byte-aligned pitch, so allocate a buffer padded // to that alignment and return a view of the valid region. uint32_t pitch = align_up(width, kRocJpegPitchAlignment); uint32_t num_channels; - switch (output_format) { - case ROCJPEG_OUTPUT_NATIVE: - switch (subsampling) { - case ROCJPEG_CSS_444: - case ROCJPEG_CSS_440: - case ROCJPEG_CSS_420: - num_channels = 3; - break; - case ROCJPEG_CSS_422: - num_channels = 1; - pitch = align_up(width * 2, kRocJpegPitchAlignment); - break; - case ROCJPEG_CSS_400: - num_channels = 1; - break; - default: - TORCH_CHECK(false, "Unsupported rocJPEG native chroma subsampling"); - } - break; + switch (image_output_format) { case ROCJPEG_OUTPUT_Y: num_channels = 1; break; @@ -743,7 +735,7 @@ std::vector RocJpegDecoder::decode_images( num_channels = 3; break; default: - TORCH_CHECK(false, "Unsupported rocJPEG output format"); + STD_TORCH_CHECK(false, "Unsupported rocJPEG output format"); } auto buffer = torch::empty( @@ -752,22 +744,11 @@ std::vector RocJpegDecoder::decode_images( int64_t(pitch)}, torch::dtype(torch::kU8).device(target_device)); - auto image_output_format = output_format; - if (output_format == ROCJPEG_OUTPUT_NATIVE) { - // ROCJPEG_OUTPUT_NATIVE returns YUV/native layouts whose channel count - // and plane sizes depend on chroma subsampling. torchvision's UNCHANGED - // mode is expected to match the CPU/nvJPEG behavior: grayscale JPEGs - // return one channel, while color JPEGs return RGB. Decode to that - // compatible layout. - image_output_format = - num_components == 1 ? ROCJPEG_OUTPUT_Y : ROCJPEG_OUTPUT_RGB_PLANAR; - } decode_params[i].output_format = image_output_format; for (uint32_t c = 0; c < num_channels; ++c) { output_images[i].channel[c] = buffer[c].data_ptr(); output_images[i].pitch[c] = pitch; } - source_channels[i] = num_components; output_tensors[i] = buffer.narrow(1, 0, height).narrow(2, 0, width); } @@ -780,6 +761,11 @@ std::vector RocJpegDecoder::decode_images( decode_params.data(), output_images.data())); + // rocJPEG owns its internal HIP stream and does not expose it to callers. + // Synchronize before copying the padded views below so the copies cannot race + // with device writes from rocJpegDecodeBatched(). + CHECK_HIP(hipDeviceSynchronize()); + for (std::size_t i = 0; i < num_images; ++i) { output_tensors[i] = output_tensors[i].contiguous(); } From 802cac27f950c2a9e7ee1583e716614e538f35fd Mon Sep 17 00:00:00 2001 From: xytpai Date: Thu, 18 Jun 2026 17:33:19 +0000 Subject: [PATCH 16/22] apply clang-format --- torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 9c045296dbf..2d95703d59a 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -723,8 +723,8 @@ std::vector RocJpegDecoder::decode_images( num_components == 1 ? ROCJPEG_OUTPUT_Y : ROCJPEG_OUTPUT_RGB_PLANAR; } - // rocJPEG writes rows at a 16-byte-aligned pitch, so allocate a buffer padded - // to that alignment and return a view of the valid region. + // rocJPEG writes rows at a 16-byte-aligned pitch, so allocate a buffer + // padded to that alignment and return a view of the valid region. uint32_t pitch = align_up(width, kRocJpegPitchAlignment); uint32_t num_channels; switch (image_output_format) { From d942228df719534c87f194b01588333eef4d4c03 Mon Sep 17 00:00:00 2001 From: xytpai Date: Fri, 19 Jun 2026 03:12:49 +0000 Subject: [PATCH 17/22] Separate rocJPEG and nvJPEG setup blocks --- setup.py | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/setup.py b/setup.py index 4bb174f2144..859eed3b13b 100644 --- a/setup.py +++ b/setup.py @@ -346,23 +346,31 @@ def make_image_extension(): else: warnings.warn("Building torchvision without WEBP support") - if (USE_NVJPEG or USE_ROCJPEG) and (torch.cuda.is_available() or FORCE_CUDA): - nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() - rocjpeg_found = ROCM_HOME is not None and (Path(ROCM_HOME) / "include/rocjpeg/rocjpeg.h").exists() - if nvjpeg_found and USE_NVJPEG: - print("Building torchvision with NVJPEG image support") - libraries.append("nvjpeg") - define_macros += [("NVJPEG_FOUND", 1)] - Extension = CUDAExtension - elif rocjpeg_found and USE_ROCJPEG: - print("Building torchvision with ROCJPEG image support") - libraries.append("rocjpeg") - define_macros += [("ROCJPEG_FOUND", 1)] - Extension = CUDAExtension + if IS_ROCM: + if USE_ROCJPEG and (torch.cuda.is_available() or FORCE_CUDA): + rocjpeg_found = ROCM_HOME is not None and (Path(ROCM_HOME) / "include/rocjpeg/rocjpeg.h").exists() + if rocjpeg_found: + print("Building torchvision with ROCJPEG image support") + libraries.append("rocjpeg") + define_macros += [("ROCJPEG_FOUND", 1)] + Extension = CUDAExtension + else: + warnings.warn("Building torchvision without ROCJPEG support") else: - warnings.warn("Building torchvision without NVJPEG or ROCJPEG support") - elif (USE_NVJPEG or USE_ROCJPEG): - warnings.warn("Building torchvision without NVJPEG or ROCJPEG support") + warnings.warn("Building torchvision without ROCJPEG support") + else: + if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): + nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() + + if nvjpeg_found: + print("Building torchvision with NVJPEG image support") + libraries.append("nvjpeg") + define_macros += [("NVJPEG_FOUND", 1)] + Extension = CUDAExtension + else: + warnings.warn("Building torchvision without NVJPEG support") + elif USE_NVJPEG: + warnings.warn("Building torchvision without NVJPEG support") return Extension( name="torchvision.image", From 758139341a434308abe667b02ad2aafba6385f37 Mon Sep 17 00:00:00 2001 From: xytpai Date: Fri, 19 Jun 2026 03:23:34 +0000 Subject: [PATCH 18/22] add _ suffix for private class members --- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 20 +++++++++---------- .../csrc/io/image/cuda/decode_jpegs_cuda.h | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 2d95703d59a..79a48d26222 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -637,21 +637,21 @@ RocJpegDecoder::RocJpegDecoder(const torch::Device& target_device) : c10::cuda::current_device(); CHECK_HIP(hipSetDevice(device_id)); CHECK_ROCJPEG( - rocJpegCreate(ROCJPEG_BACKEND_HARDWARE, device_id, &rocjpeg_handle)); + rocJpegCreate(ROCJPEG_BACKEND_HARDWARE, device_id, &rocjpeg_handle_)); } RocJpegDecoder::~RocJpegDecoder() { - rocJpegDestroy(rocjpeg_handle); - for (auto stream_handle : rocjpeg_stream_handles) { + rocJpegDestroy(rocjpeg_handle_); + for (auto stream_handle : rocjpeg_stream_handles_) { rocJpegStreamDestroy(stream_handle); } } void RocJpegDecoder::ensure_stream_handles(std::size_t num_handles) { - while (rocjpeg_stream_handles.size() < num_handles) { + while (rocjpeg_stream_handles_.size() < num_handles) { RocJpegStreamHandle stream_handle; CHECK_ROCJPEG(rocJpegStreamCreate(&stream_handle)); - rocjpeg_stream_handles.push_back(stream_handle); + rocjpeg_stream_handles_.push_back(stream_handle); } } @@ -685,15 +685,15 @@ std::vector RocJpegDecoder::decode_images( CHECK_ROCJPEG(rocJpegStreamParse( static_cast(encoded_images[i].data_ptr()), encoded_images[i].numel(), - rocjpeg_stream_handles[i])); + rocjpeg_stream_handles_[i])); uint8_t num_components = 0; RocJpegChromaSubsampling subsampling = ROCJPEG_CSS_UNKNOWN; uint32_t widths[ROCJPEG_MAX_COMPONENT] = {}; uint32_t heights[ROCJPEG_MAX_COMPONENT] = {}; CHECK_ROCJPEG(rocJpegGetImageInfo( - rocjpeg_handle, - rocjpeg_stream_handles[i], + rocjpeg_handle_, + rocjpeg_stream_handles_[i], &num_components, &subsampling, widths, @@ -755,8 +755,8 @@ std::vector RocJpegDecoder::decode_images( // Choosing a batch size that is a multiple of the available JPEG cores is // recommended. CHECK_ROCJPEG(rocJpegDecodeBatched( - rocjpeg_handle, - rocjpeg_stream_handles.data(), + rocjpeg_handle_, + rocjpeg_stream_handles_.data(), static_cast(num_images), decode_params.data(), output_images.data())); diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 4dbc7e654f4..3c5bfb734ca 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -70,8 +70,8 @@ class RocJpegDecoder { private: void ensure_stream_handles(std::size_t num_handles); - std::vector rocjpeg_stream_handles; - RocJpegHandle rocjpeg_handle; + std::vector rocjpeg_stream_handles_; + RocJpegHandle rocjpeg_handle_; }; using GpuJpegDecoder = RocJpegDecoder; } // namespace image From a4073b0a01276192b536df588ef73e804cbb252b Mon Sep 17 00:00:00 2001 From: xytpai Date: Fri, 19 Jun 2026 03:37:01 +0000 Subject: [PATCH 19/22] just return padded tensor in its original layout --- torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 79a48d26222..5d11a0237ca 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -762,14 +762,10 @@ std::vector RocJpegDecoder::decode_images( output_images.data())); // rocJPEG owns its internal HIP stream and does not expose it to callers. - // Synchronize before copying the padded views below so the copies cannot race - // with device writes from rocJpegDecodeBatched(). + // Synchronize before returning the padded CHW views so subsequent PyTorch + // operations cannot race with device writes from rocJpegDecodeBatched(). CHECK_HIP(hipDeviceSynchronize()); - for (std::size_t i = 0; i < num_images; ++i) { - output_tensors[i] = output_tensors[i].contiguous(); - } - return output_tensors; } From a2572c8106420928ef29605947b3462c929cd31d Mon Sep 17 00:00:00 2001 From: xytpai Date: Mon, 22 Jun 2026 05:30:53 +0000 Subject: [PATCH 20/22] rm unnecessary sync --- torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 5d11a0237ca..e7e547a204b 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -760,11 +760,8 @@ std::vector RocJpegDecoder::decode_images( static_cast(num_images), decode_params.data(), output_images.data())); - - // rocJPEG owns its internal HIP stream and does not expose it to callers. - // Synchronize before returning the padded CHW views so subsequent PyTorch - // operations cannot race with device writes from rocJpegDecodeBatched(). - CHECK_HIP(hipDeviceSynchronize()); + // rocJpegDecodeBatched synchronizes rocJPEG's internal HIP stream before + // returning, so the decoded output tensors are ready for PyTorch streams. return output_tensors; } From b413e54b923865296211458124644d757334818e Mon Sep 17 00:00:00 2001 From: xytpai Date: Mon, 22 Jun 2026 05:50:12 +0000 Subject: [PATCH 21/22] refine code --- .../csrc/io/image/cuda/decode_jpegs_cuda.cpp | 63 ++++++++----------- .../csrc/io/image/cuda/decode_jpegs_cuda.h | 2 +- 2 files changed, 28 insertions(+), 37 deletions(-) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index e7e547a204b..f2f03a08687 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -647,7 +647,8 @@ RocJpegDecoder::~RocJpegDecoder() { } } -void RocJpegDecoder::ensure_stream_handles(std::size_t num_handles) { +// Reuse existing rocJPEG stream handles and create only the missing ones. +void RocJpegDecoder::ensure_stream_handle_count(std::size_t num_handles) { while (rocjpeg_stream_handles_.size() < num_handles) { RocJpegStreamHandle stream_handle; CHECK_ROCJPEG(rocJpegStreamCreate(&stream_handle)); @@ -658,24 +659,8 @@ void RocJpegDecoder::ensure_stream_handles(std::size_t num_handles) { std::vector RocJpegDecoder::decode_images( const std::vector& encoded_images, vision::image::ImageReadMode mode) { - RocJpegOutputFormat output_format; - switch (mode) { - case vision::image::IMAGE_READ_MODE_UNCHANGED: - output_format = ROCJPEG_OUTPUT_NATIVE; - break; - case vision::image::IMAGE_READ_MODE_GRAY: - output_format = ROCJPEG_OUTPUT_Y; - break; - case vision::image::IMAGE_READ_MODE_RGB: - output_format = ROCJPEG_OUTPUT_RGB_PLANAR; - break; - default: - STD_TORCH_CHECK( - false, "The provided mode is not supported for JPEG decoding on GPU"); - } - const std::size_t num_images = encoded_images.size(); - ensure_stream_handles(num_images); + ensure_stream_handle_count(num_images); std::vector decode_params(num_images); std::vector output_images(num_images); @@ -712,32 +697,38 @@ std::vector RocJpegDecoder::decode_images( subsampling != ROCJPEG_CSS_411 && subsampling != ROCJPEG_CSS_UNKNOWN, "The image chroma subsampling is not supported by the rocJPEG hardware JPEG decoder"); - auto image_output_format = output_format; - if (output_format == ROCJPEG_OUTPUT_NATIVE) { - // ROCJPEG_OUTPUT_NATIVE returns YUV/native layouts whose channel count - // and plane sizes depend on chroma subsampling. torchvision's UNCHANGED - // mode is expected to match the CPU/nvJPEG behavior: grayscale JPEGs - // return one channel, while color JPEGs return RGB. Decode to that - // compatible layout. - image_output_format = - num_components == 1 ? ROCJPEG_OUTPUT_Y : ROCJPEG_OUTPUT_RGB_PLANAR; - } - - // rocJPEG writes rows at a 16-byte-aligned pitch, so allocate a buffer - // padded to that alignment and return a view of the valid region. - uint32_t pitch = align_up(width, kRocJpegPitchAlignment); + RocJpegOutputFormat image_output_format; uint32_t num_channels; - switch (image_output_format) { - case ROCJPEG_OUTPUT_Y: + switch (mode) { + case vision::image::IMAGE_READ_MODE_UNCHANGED: + // torchvision's UNCHANGED mode is expected to match the CPU/nvJPEG + // behavior: grayscale JPEGs return one channel, while color JPEGs + // return RGB. + if (num_components == 1) { + image_output_format = ROCJPEG_OUTPUT_Y; + num_channels = 1; + } else { + image_output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + num_channels = 3; + } + break; + case vision::image::IMAGE_READ_MODE_GRAY: + image_output_format = ROCJPEG_OUTPUT_Y; num_channels = 1; break; - case ROCJPEG_OUTPUT_RGB_PLANAR: + case vision::image::IMAGE_READ_MODE_RGB: + image_output_format = ROCJPEG_OUTPUT_RGB_PLANAR; num_channels = 3; break; default: - STD_TORCH_CHECK(false, "Unsupported rocJPEG output format"); + STD_TORCH_CHECK( + false, + "The provided mode is not supported for JPEG decoding on GPU"); } + // rocJPEG writes rows at a 16-byte-aligned pitch, so allocate a buffer + // padded to that alignment and return a view of the valid region. + uint32_t pitch = align_up(width, kRocJpegPitchAlignment); auto buffer = torch::empty( {int64_t(num_channels), int64_t(align_up(height, kRocJpegPitchAlignment)), diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 3c5bfb734ca..ebda5a5a6db 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -68,7 +68,7 @@ class RocJpegDecoder { const torch::Device target_device; private: - void ensure_stream_handles(std::size_t num_handles); + void ensure_stream_handle_count(std::size_t num_handles); std::vector rocjpeg_stream_handles_; RocJpegHandle rocjpeg_handle_; From 0fe060abc348c0ed69a13fdd67a7ef6e0e34d578 Mon Sep 17 00:00:00 2001 From: xytpai Date: Mon, 22 Jun 2026 05:59:05 +0000 Subject: [PATCH 22/22] add rocjpeg doc link --- torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index ebda5a5a6db..4b4a84a7def 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -54,6 +54,9 @@ using GpuJpegDecoder = CUDAJpegDecoder; #include #include +// rocJPEG decode API documentation: +// https://rocm.docs.amd.com/projects/rocJPEG/en/latest/how-to/rocjpeg-decoding-a-jpeg-stream.html + namespace vision { namespace image { class RocJpegDecoder {