From aacfe91a309ff35302c11f8255dcf31c14f7d058 Mon Sep 17 00:00:00 2001 From: Nikhil Sinha Date: Wed, 24 Jun 2026 15:35:41 +0700 Subject: [PATCH] feat: topk in group by in counts api current: counts api returns the date binned result for each distinct value for the field(s) present in group by if the field used is high cardinal, the response will be bloated Prism uses this response to build a chart that might crash in render change: add optional topk param to group by default to 10 api returns topk results for the group by field(s) --- src/query/mod.rs | 168 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 157 insertions(+), 11 deletions(-) diff --git a/src/query/mod.rs b/src/query/mod.rs index 253b3f04f..5e154e5ad 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -80,6 +80,7 @@ type BoxedBatchStream = SendableRecordBatchStream; /// Result type returned by query execution: either collected batches or a streaming adapter, plus field names. type QueryResult = Result<(Either, BoxedBatchStream>, Vec), ExecuteError>; +const DEFAULT_COUNTS_TOP_K: usize = 10; // pub static QUERY_SESSION: Lazy = // Lazy::new(|| Query::create_session_context(PARSEABLE.storage())); @@ -495,6 +496,9 @@ pub struct CountConditions { pub conditions: Option, /// GroupBy columns pub group_by: Option>, + /// Optional number of top group values to return. + #[serde(alias = "topk", alias = "top_k")] + pub top_k: Option, } /// Request for counts, received from API/SQL query. @@ -513,6 +517,10 @@ pub struct CountsRequest { pub conditions: Option, } +fn quote_identifier(value: &str) -> String { + format!("\"{}\"", value.replace('"', "\"\"")) +} + impl CountsRequest { /// This function is supposed to read maninfest files for the given stream, /// get the sum of `num_rows` between the `startTime` and `endTime`, @@ -652,17 +660,19 @@ impl CountsRequest { let time_range = TimeRange::parse_human_time(&self.start_time, &self.end_time)?; let table_name = &self.stream; + let table_ref = quote_identifier(table_name); + let time_column_ref = format!("{}.{}", table_ref, quote_identifier(&time_column)); let start_time_col_name = "_bin_start_time_"; let end_time_col_name = "_bin_end_time_"; let bin_interval = count_api_bin_interval(&time_range.start, &time_range.end); let date_bin = format!( - "CAST(DATE_BIN('{bin_interval}', \"{table_name}\".\"{time_column}\", TIMESTAMP '{DATE_BIN_EPOCH_ANCHOR}') AS TEXT) as {start_time_col_name}, DATE_BIN('{bin_interval}', \"{table_name}\".\"{time_column}\", TIMESTAMP '{DATE_BIN_EPOCH_ANCHOR}') + INTERVAL '{bin_interval}' as {end_time_col_name}" + "CAST(DATE_BIN('{bin_interval}', {time_column_ref}, TIMESTAMP '{DATE_BIN_EPOCH_ANCHOR}') AS TEXT) as {start_time_col_name}, DATE_BIN('{bin_interval}', {time_column_ref}, TIMESTAMP '{DATE_BIN_EPOCH_ANCHOR}') + INTERVAL '{bin_interval}' as {end_time_col_name}" ); let group_by_cols = count_conditions .group_by .as_ref() - .map(|cols| cols.iter().map(|c| format!("\"{c}\"")).collect::>()) + .map(|cols| cols.iter().map(|c| quote_identifier(c)).collect::>()) .unwrap_or_default(); let group_clause = if group_by_cols.is_empty() { @@ -671,17 +681,47 @@ impl CountsRequest { format!(", {}", group_by_cols.join(", ")) }; - let query = if let Some(conditions) = &count_conditions.conditions { + let where_clause = if let Some(conditions) = &count_conditions.conditions { let f = get_filter_string(conditions).map_err(QueryError::CustomError)?; - format!( - "SELECT {date_bin}{group_clause}, COUNT(*) as count FROM \"{table_name}\" WHERE {} GROUP BY {end_time_col_name},{start_time_col_name}{group_clause} ORDER BY {end_time_col_name}{group_clause}", - f - ) + format!(" WHERE {f}") } else { - format!( - "SELECT {date_bin}{group_clause}, COUNT(*) as count FROM \"{table_name}\" GROUP BY {end_time_col_name},{start_time_col_name}{group_clause} ORDER BY {end_time_col_name}{group_clause}", - ) + String::default() }; + + let query = format!( + "SELECT {date_bin}{group_clause}, COUNT(*) as count FROM {table_ref}{where_clause} GROUP BY {end_time_col_name},{start_time_col_name}{group_clause} ORDER BY {end_time_col_name}{group_clause}", + ); + + if group_by_cols.is_empty() { + return Ok(query); + } + + let top_k = count_conditions.top_k.unwrap_or(DEFAULT_COUNTS_TOP_K); + + if top_k == 0 { + return Err(QueryError::CustomError( + "topK must be greater than 0".to_string(), + )); + } + + let top_group_cols = group_by_cols.join(", "); + let top_group_join = group_by_cols + .iter() + .map(|col| format!("(gc.{col} = tg.{col} OR (gc.{col} IS NULL AND tg.{col} IS NULL))")) + .join(" AND "); + let top_group_select = group_by_cols + .iter() + .map(|col| format!(", gc.{col} AS {col}")) + .join(""); + let top_group_order = group_by_cols + .iter() + .map(|col| format!(", gc.{col}")) + .join(""); + + let query = format!( + "WITH grouped_counts AS (SELECT {date_bin}{group_clause}, COUNT(*) as count FROM {table_ref}{where_clause} GROUP BY {end_time_col_name},{start_time_col_name}{group_clause}), top_groups AS (SELECT {top_group_cols} FROM grouped_counts GROUP BY {top_group_cols} ORDER BY SUM(\"count\") DESC LIMIT {top_k}) SELECT gc.{start_time_col_name} AS {start_time_col_name}, gc.{end_time_col_name} AS {end_time_col_name}{top_group_select}, gc.\"count\" AS count FROM grouped_counts gc INNER JOIN top_groups tg ON {top_group_join} ORDER BY gc.{end_time_col_name}{top_group_order}", + ); + Ok(query) } } @@ -1016,7 +1056,113 @@ impl PartitionedMetricMonitor { mod tests { use serde_json::json; - use crate::query::flatten_objects_for_count; + use crate::query::{ + CountConditions, CountsRequest, flatten_objects_for_count, resolve_stream_names, + }; + + #[test] + fn test_count_conditions_accepts_top_k() { + let conditions: CountConditions = serde_json::from_value(json!({ + "groupBy": ["host"], + "topK": 5 + })) + .unwrap(); + assert_eq!(conditions.top_k, Some(5)); + + let conditions: CountConditions = serde_json::from_value(json!({ + "groupBy": ["host"], + "topk": 3 + })) + .unwrap(); + assert_eq!(conditions.top_k, Some(3)); + } + + #[tokio::test] + async fn test_counts_sql_applies_top_k_to_group_by() { + let request = CountsRequest { + stream: "logs".to_string(), + start_time: "2024-01-01T00:00:00Z".to_string(), + end_time: "2024-01-01T01:00:00Z".to_string(), + num_bins: None, + conditions: Some(CountConditions { + conditions: None, + group_by: Some(vec!["host".to_string(), "service.name".to_string()]), + top_k: Some(5), + }), + }; + + let sql = request.get_df_sql("p_timestamp".to_string()).await.unwrap(); + + assert!(sql.starts_with("WITH grouped_counts AS")); + assert!(sql.contains("SELECT \"host\", \"service.name\" FROM grouped_counts")); + assert!(sql.contains("GROUP BY \"host\", \"service.name\"")); + assert!(sql.contains("ORDER BY SUM(\"count\") DESC LIMIT 5")); + assert!(sql.contains( + "(gc.\"host\" = tg.\"host\" OR (gc.\"host\" IS NULL AND tg.\"host\" IS NULL))" + )); + assert!(sql.contains("ORDER BY gc._bin_end_time_, gc.\"host\", gc.\"service.name\"")); + } + + #[tokio::test] + async fn test_counts_sql_defaults_top_k_for_group_by() { + let request = CountsRequest { + stream: "logs".to_string(), + start_time: "2024-01-01T00:00:00Z".to_string(), + end_time: "2024-01-01T01:00:00Z".to_string(), + num_bins: None, + conditions: Some(CountConditions { + conditions: None, + group_by: Some(vec!["host".to_string()]), + top_k: None, + }), + }; + + let sql = request.get_df_sql("p_timestamp".to_string()).await.unwrap(); + + assert!(sql.starts_with("WITH grouped_counts AS")); + assert!(sql.contains("ORDER BY SUM(\"count\") DESC LIMIT 10")); + } + + #[tokio::test] + async fn test_counts_top_k_sql_resolves_source_stream() { + let request = CountsRequest { + stream: "logs".to_string(), + start_time: "2024-01-01T00:00:00Z".to_string(), + end_time: "2024-01-01T01:00:00Z".to_string(), + num_bins: None, + conditions: Some(CountConditions { + conditions: None, + group_by: Some(vec!["host".to_string()]), + top_k: Some(5), + }), + }; + + let sql = request.get_df_sql("p_timestamp".to_string()).await.unwrap(); + let streams = resolve_stream_names(&sql).unwrap(); + + assert_eq!(streams, vec!["logs"]); + } + + #[tokio::test] + async fn test_counts_sql_rejects_zero_top_k_for_group_by() { + let request = CountsRequest { + stream: "logs".to_string(), + start_time: "2024-01-01T00:00:00Z".to_string(), + end_time: "2024-01-01T01:00:00Z".to_string(), + num_bins: None, + conditions: Some(CountConditions { + conditions: None, + group_by: Some(vec!["host".to_string()]), + top_k: Some(0), + }), + }; + + let err = request + .get_df_sql("p_timestamp".to_string()) + .await + .unwrap_err(); + assert!(err.to_string().contains("topK must be greater than 0")); + } #[test] fn test_flat_simple() {