Skip to content

TensorIterator Guide

Yutao Xu edited this page Aug 15, 2025 · 3 revisions

TensorIterator is a helper or abstraction that lets you easily loop over (iterate through) elements of one or more tensors, handling all the complicated stuff for you — like broadcasting, type promotion, and memory strides.

Binary Add Operator

For binary add operator, we only need to construct a TensorIterator and pass it as input to the kernel.

Tensor &add_out(Tensor &out, const Tensor &left, const Tensor &right) {
    auto iter = TensorIterator().add_output(out).add_input(left).add_input(right).build_for_loops();
    add_kernel(iter);
    return out;
}

Build a TensorIterator

The member variables of TensorIterator can be simply accessed using the following:

class TensorIterator {
    Tensor *tensors_[MAX_TENSORS];
    TensorProp tensor_props_[MAX_TENSORS];
    void *data_ptr_[MAX_TENSORS];

    int64_t shape_[MAX_TENSOR_DIMS];
    int64_t stride_bytes_[MAX_TENSORS][MAX_TENSOR_DIMS];
    int64_t perm_[MAX_TENSOR_DIMS];
    int64_t view_offsets_[MAX_TENSOR_DIMS];
    
    int common_device_ = -1;
    int num_outputs_ = 0;
    int num_inputs_ = 0;
    int num_tensors_ = 0;
    int ndim_ = 0;
    bool resize_outputs_ = true;
    bool accumulate_ = false;
    bool final_output_ = true;
    bool is_reduction_ = false;
    bool check_mem_overlap_ = true;
    int64_t reduce_dim_ = 0;
    ScalarType common_dtype_ = ScalarType::Undefined;
};

When using build to construct a TensorIterator, the process will roughly go through the following steps:

TensorIterator &TensorIterator::build() {
    // check whether all defined tensors are in the same device, then update common_device_
    check_and_compute_common_device();
    // check whether all defined tensors have the same dim, then update ndim_
    check_and_compute_dim();
    // infer common_dtype from input tensors
    compute_common_dtype();
    // atomatic output allocation for reduction
    allocate_reduction_output_if_need();
    // set is_output and is_read_write flags on appropriate tensors
    mark_outputs();
    // check that the defined outputs have no internal overlap
    // and do not share memory with inputs
    check_mem_overlaps();
    // compute and check the broadcasted shape through input tensors
    compute_broadcasted_shape();
    // mark outputs for resizing if necessary
    mark_resize_outputs();
    // compute each defined tensor's stride after broadcasting
    compute_broadcasted_strides();
    // re-order dimensions to improve coalescing
    reorder_dimensions();
    // allocate the output tensor if it's not provided
    allocate_outputs();
    // coalesce adjacent dimensions when possible
    coalesce_dimensions();
    // update data_ptr_
    update_data_pointers();
    return *this;
}

Compute Broadcasted Shape

TensorIterator::shape_ is inferred based on all the input tensors, which means that all inputs must be broadcastable to a single shape.

void TensorIterator::compute_broadcasted_shape() {
    for (int i = ndim_ - 1; i >= 0; --i) {
        bool is_first = true;
        int64_t sz;
        for (int j = 0; j < num_tensors_; ++j) {
            if (!tensors_[j]->defined()) continue;
            if (is_first) {
                sz = tensors_[j]->shape(i);
                is_first = false;
            } else {
                auto sz_ = tensors_[j]->shape(i);
                CHECK_FAIL(sz == sz_ || sz == 1 || sz_ == 1);
                sz = sz == 1 ? sz_ : sz;
            }
        }
        shape_[i] = sz;
    }
}

Compute Broadcasted Stride

If a dimension is broadcastable, we only need to set its stride to 0.

void TensorIterator::compute_broadcasted_strides() {
    for (int id = 0; id < num_tensors_; ++id) {
        auto t = tensors_[id];
        if (!t->defined()) continue;
        auto element_size_in_bytes = t->element_size_in_bytes();
        for (int i = ndim_ - 1; i >= 0; --i) {
            if (t->shape(i) == 1 && shape_[i] != 1) {
                stride_bytes_[id][i] = 0;
            } else {
                stride_bytes_[id][i] = t->stride(i) * element_size_in_bytes;
            }
        }
    }
}

Reorder Dimensions

We move the fast-moving dimension to the front to facilitate subsequent coalescing operations.

void TensorIterator::reorder_dimensions() {
    // initialize perm with n-1, n-2, ..., 1, 0
    for (int i = ndim_ - 1, ct = 0; i >= 0; --i) {
        perm_[ct++] = i;
    }

    // returns 1 if the dim0 should come after dim1, -1 if dim0 should come
    // before dim1, and 0 if the comparison is ambiguous.
    auto should_swap = [&](size_t dim0, size_t dim1) {
        for (int arg = 0; arg < num_tensors_; ++arg) {
            // ignore undefined or incorrectly sized tensors
            if (!tensors_[arg]->defined() || tensor_props_[arg].will_resize) {
                continue;
            }
            int64_t stride0 = stride_bytes_[arg][dim0];
            int64_t stride1 = stride_bytes_[arg][dim1];
            if (stride0 == 0 || stride1 == 0) {
                // move on to the next input if one of the dimensions is broadcasted
                continue;
            } else if (stride0 < stride1) {
                return -1;
            } else if (stride0 > stride1) {
                return 1;
            } else {
                // for equal strides, the dimension with smaller size goes front
                auto t_dim0 = shape_[dim0];
                auto t_dim1 = shape_[dim1];
                // return only if dimensions should be swapped, otherwise move on to the next tensor
                if (t_dim0 > t_dim1) {
                    return 1;
                }
            }
        }
        return 0;
    };

    // insertion sort with support for ambiguous comparisons
    for (int i = 1; i < ndim_; ++i) {
        int dim1 = i;
        for (int dim0 = i - 1; dim0 >= 0; dim0--) {
            int comparison = should_swap(perm_[dim0], perm_[dim1]);
            if (comparison > 0) {
                std::swap(perm_[dim0], perm_[dim1]);
                dim1 = dim0;
            } else if (comparison < 0) {
                break;
            }
        }
    }

    permute_dimensions();
}

Besides, another reason for moving the fast-moving dimension to the front is that it allows us to easily write unrolled loops during offset calculation:

class OffsetCalculator {
    HOST_DEVICE offset_type get(index_t linear_idx) const {
        offset_type offsets;
#pragma unroll
        for (int arg = 0; arg < NARGS; arg++) {
            offsets[arg] = 0;
        }

#pragma unroll
        for (int dim = 0; dim < MAX_TENSOR_DIMS; ++dim) {
            if (dim == dims) {
                break;
            }
            auto divmod = sizes_[dim].divmod(linear_idx);
            linear_idx = divmod.div;

#pragma unroll
            for (int arg = 0; arg < NARGS; arg++) {
                offsets[arg] += divmod.mod * strides_[dim][arg];
            }
        }
        return offsets;
    }
};

Coalesce Dimensions

To reduce the number of loops for offset calculation, we can coalesce dimensions.

void TensorIterator::coalesce_dimensions() {
    if (ndim_ <= 1) return;
    // We can coalesce two adjacent dimensions if either dim has size 1 or if:
    // shape[n] * stride[n] == stride[n + 1].
    auto can_coalesce = [&](int dim0, int dim1) {
        auto shape0 = shape_[dim0];
        auto shape1 = shape_[dim1];
        if (shape0 == 1 || shape1 == 1) {
            return true;
        }
        for (int i = 0; i < num_tensors_; ++i) {
            auto stride0 = stride_bytes_[i][dim0];
            auto stride1 = stride_bytes_[i][dim1];
            if (shape0 * stride0 != stride1) {
                return false;
            }
        }
        return true;
    };

    // replace each operands stride at dim0 with its stride at dim1
    auto replace_stride = [&](int dim0, int dim1) {
        for (int i = 0; i < num_tensors_; ++i) {
            stride_bytes_[i][dim0] = stride_bytes_[i][dim1];
        }
    };

    int prev_dim = 0;
    for (int dim = 1; dim < ndim_; ++dim) {
        if (can_coalesce(prev_dim, dim)) {
            if (shape_[prev_dim] == 1) {
                replace_stride(prev_dim, dim);
            }
            shape_[prev_dim] *= shape_[dim];
        } else {
            prev_dim++;
            if (prev_dim != dim) {
                replace_stride(prev_dim, dim);
                shape_[prev_dim] = shape_[dim];
            }
        }
    }

    ndim_ = prev_dim + 1;
}

Loops Kernel

Finally is a brief implementation of the loops kernel:

void add_kernel(TensorIterator &iter) {
    DISPATCH_BASIC_TYPES(iter.common_dtype(), "add_kernel", [&]() {
        using acc_t = acc_type<scalar_t>; // with acc type
        gpu_kernel(iter, [](acc_t a, acc_t b) {return a + b;});
    });
}

template <typename func_t>
void gpu_kernel(TensorIterator &iter, const func_t &f) {
    if (!iter.can_use_32bit_indexing()) {
        for (auto &sub_iter : iter.with_32bit_indexing()) {
            gpu_kernel(sub_iter, f);
        }
        return;
    }
    gpu_kernel_impl(iter, f);
}

template <typename func_t>
void gpu_kernel_impl(TensorIterator &iter, const func_t &f) {
    /* ... */
    if (!dynamic_casting) {
        if (contiguous) {
            int vec_size = memory_access::can_vectorize_up_to<func_t>(data);
            auto input_calc = TrivialOffsetCalculator<traits::arity>();
            launch_vectorized_kernel(numel, f, data, input_calc, vec_size);
        } else {
            auto offset_calc = make_offset_calculator<traits::arity + 1>(iter);
            constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
            auto fn = LegacyKernelNoCastFunctor<decltype(offset_calc), arg0_t, decltype(data), func_t>(
                offset_calc, data, f);
            launch_legacy_kernel<unroll_factor>(numel, fn);
        }
    } else {
        if (contiguous) {
            memory::array<ScalarType, traits::arity> dtypes;
            for (int i = 0; i < traits::arity; i++) {
                dtypes[i] = iter.dtype(i + 1);
            }
            auto loader = memory_access::LoadWithCast<traits::arity>(dtypes);
            auto storer = memory_access::StoreWithCast(iter.dtype(0));
            auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
            auto output_offset_calculator = TrivialOffsetCalculator<1>();
            launch_unrolled_kernel<4>(numel, f, data, input_offset_calculator, output_offset_calculator, loader, storer);
        } else {
            memory::array<ScalarType, ntensors> dtypes;
            for (int i = 0; i < ntensors; i++) {
                dtypes[i] = iter.dtype(i);
            }
            auto offset_calc = make_offset_calculator<traits::arity + 1>(iter);
            auto fn = LegacyKernelCastFunctor<decltype(offset_calc), arg0_t, decltype(data), func_t, decltype(dtypes)>(
                offset_calc, data, f, dtypes);
            launch_legacy_kernel<4>(numel, fn);
        }
    }
}

Legacy-Loops-Kernel vs. Unrolled-Loops-Kernel

template <int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt, 4)
__global__ void legacy_elementwise_kernel(int N, func_t f) {
  int tid = threadIdx.x;
  int nv = nt * vt;
  int idx = nv * blockIdx.x + tid;
#pragma unroll
  for (int i = 0; i < vt; i++) {
    if (idx < N) {
      f(idx);
      idx += nt;
    }
  }
}

template <bool reverted_idx = false, typename func_t, typename policy_t>
__device__ inline void unrolled_elementwise_kernel_helper(func_t f, policy_t policy) {
  using traits = function_traits<func_t>;
  using return_t = typename traits::result_type;
  using args_t = typename traits::ArgsTuple;
  constexpr int elems_per_thread = policy_t::tws;

  int idx = blockIdx.x;
  if constexpr (reverted_idx)
    idx = gridDim.x - blockIdx.x - 1;

  return_t results[elems_per_thread];
  args_t args[elems_per_thread];

  // load
  policy.load(args, idx);

  // compute
  #pragma unroll
  for (int i = 0; i < elems_per_thread; i++) {
    if (policy.check_inbounds(i)) {
      results[i] = c10::guts::apply(f, args[i]);
    }
  }

  // store
  policy.store(results, idx);
}