diff --git a/setup.py b/setup.py index 1f4a1a7d17f..859eed3b13b 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1" USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1" USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1" +USE_ROCJPEG = os.getenv("TORCHVISION_USE_ROCJPEG", "1") == "1" NVCC_FLAGS = os.getenv("NVCC_FLAGS", None) TORCHVISION_INCLUDE = os.environ.get("TORCHVISION_INCLUDE", "") @@ -45,6 +46,7 @@ print(f"{USE_JPEG = }") print(f"{USE_WEBP = }") print(f"{USE_NVJPEG = }") +print(f"{USE_ROCJPEG = }") print(f"{NVCC_FLAGS = }") print(f"{TORCHVISION_INCLUDE = }") print(f"{TORCHVISION_LIBRARY = }") @@ -344,18 +346,31 @@ def make_image_extension(): else: warnings.warn("Building torchvision without WEBP support") - if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): - nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() - - if nvjpeg_found: - print("Building torchvision with NVJPEG image support") - libraries.append("nvjpeg") - define_macros += [("NVJPEG_FOUND", 1)] - Extension = CUDAExtension + if IS_ROCM: + if USE_ROCJPEG and (torch.cuda.is_available() or FORCE_CUDA): + rocjpeg_found = ROCM_HOME is not None and (Path(ROCM_HOME) / "include/rocjpeg/rocjpeg.h").exists() + if rocjpeg_found: + print("Building torchvision with ROCJPEG image support") + libraries.append("rocjpeg") + define_macros += [("ROCJPEG_FOUND", 1)] + Extension = CUDAExtension + else: + warnings.warn("Building torchvision without ROCJPEG support") else: + warnings.warn("Building torchvision without ROCJPEG support") + else: + if USE_NVJPEG and (torch.cuda.is_available() or FORCE_CUDA): + nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists() + + if nvjpeg_found: + print("Building torchvision with NVJPEG image support") + libraries.append("nvjpeg") + define_macros += [("NVJPEG_FOUND", 1)] + Extension = CUDAExtension + else: + warnings.warn("Building torchvision without NVJPEG support") + elif USE_NVJPEG: warnings.warn("Building torchvision without NVJPEG support") - elif USE_NVJPEG: - warnings.warn("Building torchvision without NVJPEG support") return Extension( name="torchvision.image", diff --git a/test/test_image.py b/test/test_image.py index b11dd67ca12..8f105dc4c5c 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -426,12 +426,15 @@ def test_decode_jpegs_cuda(mode, scripted): futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)] decoded_images_threaded = [future.result() for future in futures] assert len(decoded_images_threaded) == num_workers + # rocJPEG's color conversion differs slightly from nvJPEG, so it needs a + # looser tolerance against the CPU reference. + tol = 2.5 if torch.version.hip is not None else 2 for decoded_images in decoded_images_threaded: assert len(decoded_images) == len(encoded_images) for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu): assert decoded_image_cuda.shape == decoded_image_cpu.shape assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8 - assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2 + assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < tol @needs_cuda diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp index 0f5cf01548d..f2f03a08687 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.cpp @@ -1,5 +1,5 @@ #include "decode_jpegs_cuda.h" -#if !NVJPEG_FOUND +#if !NVJPEG_FOUND && !ROCJPEG_FOUND namespace vision { namespace image { std::vector decode_jpegs_cuda( @@ -11,25 +11,39 @@ std::vector decode_jpegs_cuda( } } // namespace image } // namespace vision +#endif -#else -#include -#include +#if NVJPEG_FOUND || ROCJPEG_FOUND #include -#include -#include -#include -#include +#include #include #include -#include -#include -#include + namespace vision { namespace image { - +namespace { std::mutex decoderMutex; -std::unique_ptr cudaJpegDecoder; +std::unique_ptr gpuJpegDecoder; + +std::vector validate_and_make_contiguous( + const std::vector& encoded_images) { + std::vector contig_images; + contig_images.reserve(encoded_images.size()); + for (auto& encoded_image : encoded_images) { + STD_TORCH_CHECK( + encoded_image.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + STD_TORCH_CHECK( + !encoded_image.is_cuda(), "The input tensor must be on CPU"); + STD_TORCH_CHECK( + encoded_image.dim() == 1 && encoded_image.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + // The decoder backends require contiguous input; contiguous() is a no-op + // when the tensor already is. + contig_images.push_back(encoded_image.contiguous()); + } + return contig_images; +} +} // namespace std::vector decode_jpegs_cuda( const std::vector& encoded_images, @@ -39,32 +53,54 @@ std::vector decode_jpegs_cuda( "torchvision.csrc.io.image.cuda.decode_jpegs_cuda.decode_jpegs_cuda"); std::lock_guard lock(decoderMutex); - std::vector contig_images; - contig_images.reserve(encoded_images.size()); STD_TORCH_CHECK( device.is_cuda(), "Expected the device parameter to be a cuda device"); - for (auto& encoded_image : encoded_images) { - STD_TORCH_CHECK( - encoded_image.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); - - STD_TORCH_CHECK( - !encoded_image.is_cuda(), - "The input tensor must be on CPU when decoding with nvjpeg") + std::vector contig_images = + validate_and_make_contiguous(encoded_images); - STD_TORCH_CHECK( - encoded_image.dim() == 1 && encoded_image.numel() > 0, - "Expected a non empty 1-dimensional tensor"); + at::cuda::CUDAGuard device_guard(device); - // nvjpeg requires images to be contiguous - if (encoded_image.is_contiguous()) { - contig_images.push_back(encoded_image); + if (gpuJpegDecoder == nullptr || device != gpuJpegDecoder->target_device) { + if (gpuJpegDecoder != nullptr) { + gpuJpegDecoder.reset(new GpuJpegDecoder(device)); } else { - contig_images.push_back(encoded_image.contiguous()); + gpuJpegDecoder = std::make_unique(device); + std::atexit([]() { gpuJpegDecoder.reset(); }); } } + try { + return gpuJpegDecoder->decode_images(contig_images, mode); + } catch (const std::exception& e) { + STD_TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); + } +} + +} // namespace image +} // namespace vision +#endif + +#if NVJPEG_FOUND +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +namespace vision { +namespace image { + +std::vector CUDAJpegDecoder::decode_images( + const std::vector& encoded_images, + vision::image::ImageReadMode mode) { int major_version; int minor_version; nvjpegStatus_t get_major_property_status = @@ -86,17 +122,6 @@ std::vector decode_jpegs_cuda( "Make sure to rely on CUDA 11.6 or above before using decode_jpeg(..., device='cuda')."); } - at::cuda::CUDAGuard device_guard(device); - - if (cudaJpegDecoder == nullptr || device != cudaJpegDecoder->target_device) { - if (cudaJpegDecoder != nullptr) { - cudaJpegDecoder.reset(new CUDAJpegDecoder(device)); - } else { - cudaJpegDecoder = std::make_unique(device); - std::atexit([]() { cudaJpegDecoder.reset(); }); - } - } - nvjpegOutputFormat_t output_format; switch (mode) { @@ -118,19 +143,15 @@ std::vector decode_jpegs_cuda( false, "The provided mode is not supported for JPEG decoding on GPU"); } - try { - at::cuda::CUDAEvent event; - auto result = cudaJpegDecoder->decode_images(contig_images, output_format); - auto current_stream{ - device.has_index() ? at::cuda::getCurrentCUDAStream( - cudaJpegDecoder->original_device.index()) - : at::cuda::getCurrentCUDAStream()}; - event.record(cudaJpegDecoder->stream); - event.block(current_stream); - return result; - } catch (const std::exception& e) { - STD_TORCH_CHECK(false, "Error while decoding JPEG images: ", e.what()); - } + at::cuda::CUDAEvent event; + auto result = decode_images(encoded_images, output_format); + auto current_stream{ + target_device.has_index() + ? at::cuda::getCurrentCUDAStream(original_device.index()) + : at::cuda::getCurrentCUDAStream()}; + event.record(stream); + event.block(current_stream); + return result; } CUDAJpegDecoder::CUDAJpegDecoder(const torch::Device& target_device) @@ -595,4 +616,148 @@ std::vector CUDAJpegDecoder::decode_images( } // namespace image } // namespace vision +#elif ROCJPEG_FOUND + +#include + +namespace vision { +namespace image { + +namespace { +constexpr uint32_t kRocJpegPitchAlignment = 16; + +uint32_t align_up(uint32_t value, uint32_t alignment) { + return (value + alignment - 1) & ~(alignment - 1); +} +} // namespace + +RocJpegDecoder::RocJpegDecoder(const torch::Device& target_device) + : target_device{target_device} { + int device_id = target_device.has_index() ? target_device.index() + : c10::cuda::current_device(); + CHECK_HIP(hipSetDevice(device_id)); + CHECK_ROCJPEG( + rocJpegCreate(ROCJPEG_BACKEND_HARDWARE, device_id, &rocjpeg_handle_)); +} + +RocJpegDecoder::~RocJpegDecoder() { + rocJpegDestroy(rocjpeg_handle_); + for (auto stream_handle : rocjpeg_stream_handles_) { + rocJpegStreamDestroy(stream_handle); + } +} + +// Reuse existing rocJPEG stream handles and create only the missing ones. +void RocJpegDecoder::ensure_stream_handle_count(std::size_t num_handles) { + while (rocjpeg_stream_handles_.size() < num_handles) { + RocJpegStreamHandle stream_handle; + CHECK_ROCJPEG(rocJpegStreamCreate(&stream_handle)); + rocjpeg_stream_handles_.push_back(stream_handle); + } +} + +std::vector RocJpegDecoder::decode_images( + const std::vector& encoded_images, + vision::image::ImageReadMode mode) { + const std::size_t num_images = encoded_images.size(); + ensure_stream_handle_count(num_images); + + std::vector decode_params(num_images); + std::vector output_images(num_images); + std::vector output_tensors(num_images); + + for (std::size_t i = 0; i < num_images; ++i) { + CHECK_ROCJPEG(rocJpegStreamParse( + static_cast(encoded_images[i].data_ptr()), + encoded_images[i].numel(), + rocjpeg_stream_handles_[i])); + + uint8_t num_components = 0; + RocJpegChromaSubsampling subsampling = ROCJPEG_CSS_UNKNOWN; + uint32_t widths[ROCJPEG_MAX_COMPONENT] = {}; + uint32_t heights[ROCJPEG_MAX_COMPONENT] = {}; + CHECK_ROCJPEG(rocJpegGetImageInfo( + rocjpeg_handle_, + rocjpeg_stream_handles_[i], + &num_components, + &subsampling, + widths, + heights)); + + const uint32_t width = widths[0]; + const uint32_t height = heights[0]; + STD_TORCH_CHECK( + width >= 64 && height >= 64, + "Image resolution ", + width, + "x", + height, + " is below the rocJPEG hardware JPEG decoder minimum of 64x64"); + STD_TORCH_CHECK( + subsampling != ROCJPEG_CSS_411 && subsampling != ROCJPEG_CSS_UNKNOWN, + "The image chroma subsampling is not supported by the rocJPEG hardware JPEG decoder"); + + RocJpegOutputFormat image_output_format; + uint32_t num_channels; + switch (mode) { + case vision::image::IMAGE_READ_MODE_UNCHANGED: + // torchvision's UNCHANGED mode is expected to match the CPU/nvJPEG + // behavior: grayscale JPEGs return one channel, while color JPEGs + // return RGB. + if (num_components == 1) { + image_output_format = ROCJPEG_OUTPUT_Y; + num_channels = 1; + } else { + image_output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + num_channels = 3; + } + break; + case vision::image::IMAGE_READ_MODE_GRAY: + image_output_format = ROCJPEG_OUTPUT_Y; + num_channels = 1; + break; + case vision::image::IMAGE_READ_MODE_RGB: + image_output_format = ROCJPEG_OUTPUT_RGB_PLANAR; + num_channels = 3; + break; + default: + STD_TORCH_CHECK( + false, + "The provided mode is not supported for JPEG decoding on GPU"); + } + + // rocJPEG writes rows at a 16-byte-aligned pitch, so allocate a buffer + // padded to that alignment and return a view of the valid region. + uint32_t pitch = align_up(width, kRocJpegPitchAlignment); + auto buffer = torch::empty( + {int64_t(num_channels), + int64_t(align_up(height, kRocJpegPitchAlignment)), + int64_t(pitch)}, + torch::dtype(torch::kU8).device(target_device)); + + decode_params[i].output_format = image_output_format; + for (uint32_t c = 0; c < num_channels; ++c) { + output_images[i].channel[c] = buffer[c].data_ptr(); + output_images[i].pitch[c] = pitch; + } + output_tensors[i] = buffer.narrow(1, 0, height).narrow(2, 0, width); + } + + // Choosing a batch size that is a multiple of the available JPEG cores is + // recommended. + CHECK_ROCJPEG(rocJpegDecodeBatched( + rocjpeg_handle_, + rocjpeg_stream_handles_.data(), + static_cast(num_images), + decode_params.data(), + output_images.data())); + // rocJpegDecodeBatched synchronizes rocJPEG's internal HIP stream before + // returning, so the decoded output tensors are ready for PyTorch streams. + + return output_tensors; +} + +} // namespace image +} // namespace vision + #endif diff --git a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h index 6f72d9e35b2..4b4a84a7def 100644 --- a/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h +++ b/torchvision/csrc/io/image/cuda/decode_jpegs_cuda.h @@ -1,9 +1,11 @@ #pragma once #include +#include #include #include "../common.h" #if NVJPEG_FOUND + #include #include @@ -16,13 +18,16 @@ class CUDAJpegDecoder { std::vector decode_images( const std::vector& encoded_images, - const nvjpegOutputFormat_t& output_format); + vision::image::ImageReadMode mode); const torch::Device original_device; const torch::Device target_device; const c10::cuda::CUDAStream stream; private: + std::vector decode_images( + const std::vector& encoded_images, + const nvjpegOutputFormat_t& output_format); std::tuple< std::vector, std::vector, @@ -40,6 +45,59 @@ class CUDAJpegDecoder { bool hw_decode_available{false}; nvjpegHandle_t nvjpeg_handle; }; +using GpuJpegDecoder = CUDAJpegDecoder; +} // namespace image +} // namespace vision + +#elif ROCJPEG_FOUND + +#include +#include + +// rocJPEG decode API documentation: +// https://rocm.docs.amd.com/projects/rocJPEG/en/latest/how-to/rocjpeg-decoding-a-jpeg-stream.html + +namespace vision { +namespace image { +class RocJpegDecoder { + public: + RocJpegDecoder(const torch::Device& target_device); + ~RocJpegDecoder(); + + std::vector decode_images( + const std::vector& encoded_images, + vision::image::ImageReadMode mode); + + const torch::Device target_device; + + private: + void ensure_stream_handle_count(std::size_t num_handles); + + std::vector rocjpeg_stream_handles_; + RocJpegHandle rocjpeg_handle_; +}; +using GpuJpegDecoder = RocJpegDecoder; } // namespace image } // namespace vision + +#define CHECK_ROCJPEG(call) \ + { \ + RocJpegStatus rocjpeg_status = (call); \ + STD_TORCH_CHECK( \ + rocjpeg_status == ROCJPEG_STATUS_SUCCESS, \ + #call, \ + " returned ", \ + rocJpegGetErrorName(rocjpeg_status)); \ + } + +#define CHECK_HIP(call) \ + { \ + hipError_t hip_status = (call); \ + STD_TORCH_CHECK( \ + hip_status == hipSuccess, \ + #call, \ + " failed with status: ", \ + hipGetErrorName(hip_status)); \ + } + #endif