Feat: Add Model RF-DETR#333
Conversation
|
Can we implement a vignette for this perhaps? @cregouby |
cregouby
left a comment
There was a problem hiding this comment.
praise This is massive, thanks for it
todo see inline.
There was a problem hiding this comment.
improvement Could you move the added line into the ## New models section ?
There was a problem hiding this comment.
todo missing Please add a representative example to the model documentation.
todo missing Please mention the attribution in a code comment # This code is modified from ...
todo Please fix merge conflicts
There was a problem hiding this comment.
improvement Could we be more specific on each tests through getting a little deeper than the output shape ?
suggestion You should use the expect_coco_model_detects_cat(model) for each and every pretrained model. See code in
torchvision/tests/testthat/helper-torchvision.R
Lines 69 to 88 in 536a80a
question Could we also test for
expect_bbox_to_be_xyxy() from torchvision/tests/testthat/helper-torchvision.R
Lines 36 to 65 in 536a80a
| batch_size <- value$size(1) | ||
| n_heads <- value$size(2) | ||
| head_dim <- value$size(3) | ||
| len_query <- sampling_locations$size(2) | ||
| n_levels <- sampling_locations$size(4) | ||
| n_points <- sampling_locations$size(5) |
There was a problem hiding this comment.
suggestion you could make use of zeallot %<-% for readability
| } | ||
| ) | ||
|
|
||
| ms_deform_attn_core_pytorch <- function(value, spatial_shapes, sampling_locations, attention_weights, |
There was a problem hiding this comment.
thought very strange name knowing that there is zero pytorch in the game... OK, this comes forme the code in roboflow repo. Then please mention the attribution as code coment.
| ms_deform_attn <- nn_module( | ||
| "ms_deform_attn", | ||
| initialize = function(d_model = 256, n_levels = 4, n_heads = 8, n_points = 4) { | ||
| self$d_model <- d_model | ||
| self$n_levels <- n_levels | ||
| self$n_heads <- n_heads | ||
| self$n_points <- n_points | ||
| self$sampling_offsets <- nn_linear(d_model, n_heads * n_levels * n_points * 2) | ||
| self$attention_weights <- nn_linear(d_model, n_heads * n_levels * n_points) | ||
| self$value_proj <- nn_linear(d_model, d_model) | ||
| self$output_proj <- nn_linear(d_model, d_model) | ||
| self$reset_parameters() | ||
| }, | ||
| reset_parameters = function() { | ||
| nn_init_constant_(self$sampling_offsets$weight, 0) | ||
| thetas <- torch_arange(0, self$n_heads - 1, dtype = torch_float32()) * (2 * pi / self$n_heads) | ||
| grid_init <- torch_stack(list(thetas$cos(), thetas$sin()), dim = -1) | ||
| grid_init <- grid_init / grid_init$abs()$max(dim = -1, keepdim = TRUE)[[1]] | ||
| grid_init <- grid_init$view(c(self$n_heads, 1, 1, 2))$'repeat'(c(1, self$n_levels, self$n_points, 1)) | ||
| for (i in seq_len(self$n_points)) { | ||
| grid_init[, , i, ] <- grid_init[, , i, ] * i | ||
| } | ||
| self$sampling_offsets$bias <- nn_parameter(grid_init$view(-1)) | ||
| nn_init_constant_(self$attention_weights$weight, 0) | ||
| nn_init_constant_(self$attention_weights$bias, 0) | ||
| nn_init_xavier_uniform_(self$value_proj$weight) | ||
| nn_init_constant_(self$value_proj$bias, 0) | ||
| nn_init_xavier_uniform_(self$output_proj$weight) | ||
| nn_init_constant_(self$output_proj$bias, 0) | ||
| }, | ||
| forward = function(query, reference_points, input_flatten, input_spatial_shapes, | ||
| input_level_start_index, input_padding_mask = NULL, | ||
| input_spatial_shapes_hw = NULL) { | ||
| batch_size <- query$size(1) | ||
| len_query <- query$size(2) | ||
| value <- self$value_proj(input_flatten) | ||
| if (!is.null(input_padding_mask)) { | ||
| value <- value$masked_fill(input_padding_mask$unsqueeze(3), 0) | ||
| } | ||
| sampling_offsets <- self$sampling_offsets(query)$view(c( | ||
| batch_size, len_query, self$n_heads, self$n_levels, self$n_points, 2 | ||
| )) | ||
| attention_weights <- self$attention_weights(query)$view(c( | ||
| batch_size, len_query, self$n_heads, self$n_levels * self$n_points | ||
| )) | ||
| if (reference_points$size(-1) == 2) { | ||
| offset_normalizer <- torch_stack(list( | ||
| input_spatial_shapes[, 2], input_spatial_shapes[, 1] | ||
| ), dim = -1) | ||
| sampling_locations <- reference_points$unsqueeze(3)$unsqueeze(5) + | ||
| sampling_offsets / offset_normalizer$unsqueeze(1)$unsqueeze(1)$unsqueeze(4) | ||
| } else { | ||
| sampling_locations <- reference_points[, , NULL, , NULL, 1:2] + | ||
| sampling_offsets / self$n_points * reference_points[, , NULL, , NULL, 3:4] * 0.5 | ||
| } | ||
| attention_weights <- nnf_softmax(attention_weights, dim = -1) | ||
| value <- value$transpose(2, 3)$contiguous()$view(c( | ||
| batch_size, self$n_heads, self$d_model %/% self$n_heads, -1 | ||
| )) | ||
| output <- ms_deform_attn_core_pytorch( | ||
| value, input_spatial_shapes, sampling_locations, attention_weights, | ||
| input_spatial_shapes_hw | ||
| ) | ||
| self$output_proj(output) | ||
| } | ||
| ) |
There was a problem hiding this comment.
todo please factorize this code with
torchvision/R/models-lw_detr.R
Lines 315 to 404 in 8d19bac
as the initialization is the same and forward here only additionnaly manages 2D reference_points.
suggestion rename it
detr_ms_deform_attn
| out$pred_boxes <- ref_enc | ||
| } | ||
| } | ||
| out |
There was a problem hiding this comment.
todo Please make the output data model identical to all other objects detection models so that it helps visualization: we expect out to have a $detections with each detection item having names c("boxes", "labels", "scores")
| mlp_module <- nn_module( | ||
| "mlp_module", | ||
| initialize = function(input_dim, hidden_dim, output_dim, num_layers) { | ||
| self$num_layers <- num_layers | ||
| h <- rep(hidden_dim, num_layers - 1) | ||
| dims <- c(input_dim, h, output_dim) | ||
| self$layers <- nn_module_list(lapply(seq_len(num_layers), function(i) { | ||
| nn_linear(dims[i], dims[i + 1]) | ||
| })) | ||
| }, | ||
| forward = function(x) { | ||
| for (i in seq_len(self$num_layers)) { | ||
| x <- self$layers[[i]](x) | ||
| if (i < self$num_layers) x <- nnf_relu(x) | ||
| } | ||
| x | ||
| } | ||
| ) | ||
|
|
There was a problem hiding this comment.
suggestion looks like this should be factorized with / reused from
torchvision/R/models-lw_detr.R
Lines 409 to 429 in 8d19bac

This PR adds :
Closes #327