-
Notifications
You must be signed in to change notification settings - Fork 2
TensorIterator Guide
TensorIterator is a helper or abstraction that lets you easily loop over (iterate through) elements of one or more tensors, handling all the complicated stuff for you — like broadcasting, type promotion, and memory strides.
For binary add operator, we only need to construct a TensorIterator and pass it as input to the kernel.
Tensor &add_out(Tensor &out, const Tensor &left, const Tensor &right) {
auto iter = TensorIterator().add_output(out).add_input(left).add_input(right).build_for_loops();
add_kernel(iter);
return out;
}The member variables of TensorIterator can be simply accessed using the following:
class TensorIterator {
Tensor *tensors_[MAX_TENSORS];
TensorProp tensor_props_[MAX_TENSORS];
void *data_ptr_[MAX_TENSORS];
int64_t shape_[MAX_TENSOR_DIMS];
int64_t stride_bytes_[MAX_TENSORS][MAX_TENSOR_DIMS];
int64_t perm_[MAX_TENSOR_DIMS];
int64_t view_offsets_[MAX_TENSOR_DIMS];
int common_device_ = -1;
int num_outputs_ = 0;
int num_inputs_ = 0;
int num_tensors_ = 0;
int ndim_ = 0;
bool resize_outputs_ = true;
bool accumulate_ = false;
bool final_output_ = true;
bool is_reduction_ = false;
bool check_mem_overlap_ = true;
int64_t reduce_dim_ = 0;
ScalarType common_dtype_ = ScalarType::Undefined;
};When using build to construct a TensorIterator, the process will roughly go through the following steps:
TensorIterator &TensorIterator::build() {
// check whether all defined tensors are in the same device, then update common_device_
check_and_compute_common_device();
// check whether all defined tensors have the same dim, then update ndim_
check_and_compute_dim();
// infer common_dtype from input tensors
compute_common_dtype();
// atomatic output allocation for reduction
allocate_reduction_output_if_need();
// set is_output and is_read_write flags on appropriate tensors
mark_outputs();
// check that the defined outputs have no internal overlap
// and do not share memory with inputs
check_mem_overlaps();
// compute and check the broadcasted shape through input tensors
compute_broadcasted_shape();
// mark outputs for resizing if necessary
mark_resize_outputs();
// compute each defined tensor's stride after broadcasting
compute_broadcasted_strides();
// re-order dimensions to improve coalescing
reorder_dimensions();
// allocate the output tensor if it's not provided
allocate_outputs();
// coalesce adjacent dimensions when possible
coalesce_dimensions();
// update data_ptr_
update_data_pointers();
return *this;
}TensorIterator::shape_ is inferred based on all the input tensors, which means that all inputs must be broadcastable to a single shape.
void TensorIterator::compute_broadcasted_shape() {
for (int i = ndim_ - 1; i >= 0; --i) {
bool is_first = true;
int64_t sz;
for (int j = 0; j < num_tensors_; ++j) {
if (!tensors_[j]->defined()) continue;
if (is_first) {
sz = tensors_[j]->shape(i);
is_first = false;
} else {
auto sz_ = tensors_[j]->shape(i);
CHECK_FAIL(sz == sz_ || sz == 1 || sz_ == 1);
sz = sz == 1 ? sz_ : sz;
}
}
shape_[i] = sz;
}
}If a dimension is broadcastable, we only need to set its stride to 0.
void TensorIterator::compute_broadcasted_strides() {
for (int id = 0; id < num_tensors_; ++id) {
auto t = tensors_[id];
if (!t->defined()) continue;
auto element_size_in_bytes = t->element_size_in_bytes();
for (int i = ndim_ - 1; i >= 0; --i) {
if (t->shape(i) == 1 && shape_[i] != 1) {
stride_bytes_[id][i] = 0;
} else {
stride_bytes_[id][i] = t->stride(i) * element_size_in_bytes;
}
}
}
}We move the fast-moving dimension to the front to facilitate subsequent coalescing operations.
void TensorIterator::reorder_dimensions() {
// initialize perm with n-1, n-2, ..., 1, 0
for (int i = ndim_ - 1, ct = 0; i >= 0; --i) {
perm_[ct++] = i;
}
// returns 1 if the dim0 should come after dim1, -1 if dim0 should come
// before dim1, and 0 if the comparison is ambiguous.
auto should_swap = [&](size_t dim0, size_t dim1) {
for (int arg = 0; arg < num_tensors_; ++arg) {
// ignore undefined or incorrectly sized tensors
if (!tensors_[arg]->defined() || tensor_props_[arg].will_resize) {
continue;
}
int64_t stride0 = stride_bytes_[arg][dim0];
int64_t stride1 = stride_bytes_[arg][dim1];
if (stride0 == 0 || stride1 == 0) {
// move on to the next input if one of the dimensions is broadcasted
continue;
} else if (stride0 < stride1) {
return -1;
} else if (stride0 > stride1) {
return 1;
} else {
// for equal strides, the dimension with smaller size goes front
auto t_dim0 = shape_[dim0];
auto t_dim1 = shape_[dim1];
// return only if dimensions should be swapped, otherwise move on to the next tensor
if (t_dim0 > t_dim1) {
return 1;
}
}
}
return 0;
};
// insertion sort with support for ambiguous comparisons
for (int i = 1; i < ndim_; ++i) {
int dim1 = i;
for (int dim0 = i - 1; dim0 >= 0; dim0--) {
int comparison = should_swap(perm_[dim0], perm_[dim1]);
if (comparison > 0) {
std::swap(perm_[dim0], perm_[dim1]);
dim1 = dim0;
} else if (comparison < 0) {
break;
}
}
}
permute_dimensions();
}Besides, another reason for moving the fast-moving dimension to the front is that it allows us to easily write unrolled loops during offset calculation:
class OffsetCalculator {
HOST_DEVICE offset_type get(index_t linear_idx) const {
offset_type offsets;
#pragma unroll
for (int arg = 0; arg < NARGS; arg++) {
offsets[arg] = 0;
}
#pragma unroll
for (int dim = 0; dim < MAX_TENSOR_DIMS; ++dim) {
if (dim == dims) {
break;
}
auto divmod = sizes_[dim].divmod(linear_idx);
linear_idx = divmod.div;
#pragma unroll
for (int arg = 0; arg < NARGS; arg++) {
offsets[arg] += divmod.mod * strides_[dim][arg];
}
}
return offsets;
}
};To reduce the number of loops for offset calculation, we can coalesce dimensions.
void TensorIterator::coalesce_dimensions() {
if (ndim_ <= 1) return;
// We can coalesce two adjacent dimensions if either dim has size 1 or if:
// shape[n] * stride[n] == stride[n + 1].
auto can_coalesce = [&](int dim0, int dim1) {
auto shape0 = shape_[dim0];
auto shape1 = shape_[dim1];
if (shape0 == 1 || shape1 == 1) {
return true;
}
for (int i = 0; i < num_tensors_; ++i) {
auto stride0 = stride_bytes_[i][dim0];
auto stride1 = stride_bytes_[i][dim1];
if (shape0 * stride0 != stride1) {
return false;
}
}
return true;
};
// replace each operands stride at dim0 with its stride at dim1
auto replace_stride = [&](int dim0, int dim1) {
for (int i = 0; i < num_tensors_; ++i) {
stride_bytes_[i][dim0] = stride_bytes_[i][dim1];
}
};
int prev_dim = 0;
for (int dim = 1; dim < ndim_; ++dim) {
if (can_coalesce(prev_dim, dim)) {
if (shape_[prev_dim] == 1) {
replace_stride(prev_dim, dim);
}
shape_[prev_dim] *= shape_[dim];
} else {
prev_dim++;
if (prev_dim != dim) {
replace_stride(prev_dim, dim);
shape_[prev_dim] = shape_[dim];
}
}
}
ndim_ = prev_dim + 1;
}Finally is a brief implementation of the loops kernel:
void add_kernel(TensorIterator &iter) {
DISPATCH_BASIC_TYPES(iter.common_dtype(), "add_kernel", [&]() {
using acc_t = acc_type<scalar_t>; // with acc type
gpu_kernel(iter, [](acc_t a, acc_t b) {return a + b;});
});
}
template <typename func_t>
void gpu_kernel(TensorIterator &iter, const func_t &f) {
if (!iter.can_use_32bit_indexing()) {
for (auto &sub_iter : iter.with_32bit_indexing()) {
gpu_kernel(sub_iter, f);
}
return;
}
gpu_kernel_impl(iter, f);
}
template <typename func_t>
void gpu_kernel_impl(TensorIterator &iter, const func_t &f) {
/* ... */
if (!dynamic_casting) {
if (contiguous) {
int vec_size = memory_access::can_vectorize_up_to<func_t>(data);
auto input_calc = TrivialOffsetCalculator<traits::arity>();
launch_vectorized_kernel(numel, f, data, input_calc, vec_size);
} else {
auto offset_calc = make_offset_calculator<traits::arity + 1>(iter);
constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
auto fn = LegacyKernelNoCastFunctor<decltype(offset_calc), arg0_t, decltype(data), func_t>(
offset_calc, data, f);
launch_legacy_kernel<unroll_factor>(numel, fn);
}
} else {
if (contiguous) {
memory::array<ScalarType, traits::arity> dtypes;
for (int i = 0; i < traits::arity; i++) {
dtypes[i] = iter.dtype(i + 1);
}
auto loader = memory_access::LoadWithCast<traits::arity>(dtypes);
auto storer = memory_access::StoreWithCast(iter.dtype(0));
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
auto output_offset_calculator = TrivialOffsetCalculator<1>();
launch_unrolled_kernel<4>(numel, f, data, input_offset_calculator, output_offset_calculator, loader, storer);
} else {
memory::array<ScalarType, ntensors> dtypes;
for (int i = 0; i < ntensors; i++) {
dtypes[i] = iter.dtype(i);
}
auto offset_calc = make_offset_calculator<traits::arity + 1>(iter);
auto fn = LegacyKernelCastFunctor<decltype(offset_calc), arg0_t, decltype(data), func_t, decltype(dtypes)>(
offset_calc, data, f, dtypes);
launch_legacy_kernel<4>(numel, fn);
}
}
}template <int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt, 4)
__global__ void legacy_elementwise_kernel(int N, func_t f) {
int tid = threadIdx.x;
int nv = nt * vt;
int idx = nv * blockIdx.x + tid;
#pragma unroll
for (int i = 0; i < vt; i++) {
if (idx < N) {
f(idx);
idx += nt;
}
}
}
template <bool reverted_idx = false, typename func_t, typename policy_t>
__device__ inline void unrolled_elementwise_kernel_helper(func_t f, policy_t policy) {
using traits = function_traits<func_t>;
using return_t = typename traits::result_type;
using args_t = typename traits::ArgsTuple;
constexpr int elems_per_thread = policy_t::tws;
int idx = blockIdx.x;
if constexpr (reverted_idx)
idx = gridDim.x - blockIdx.x - 1;
return_t results[elems_per_thread];
args_t args[elems_per_thread];
// load
policy.load(args, idx);
// compute
#pragma unroll
for (int i = 0; i < elems_per_thread; i++) {
if (policy.check_inbounds(i)) {
results[i] = c10::guts::apply(f, args[i]);
}
}
// store
policy.store(results, idx);
}