diff --git a/src/query/mod.rs b/src/query/mod.rs index 253b3f04f..5e154e5ad 100644 --- a/src/query/mod.rs +++ b/src/query/mod.rs @@ -80,6 +80,7 @@ type BoxedBatchStream = SendableRecordBatchStream; /// Result type returned by query execution: either collected batches or a streaming adapter, plus field names. type QueryResult = Result<(Either, BoxedBatchStream>, Vec), ExecuteError>; +const DEFAULT_COUNTS_TOP_K: usize = 10; // pub static QUERY_SESSION: Lazy = // Lazy::new(|| Query::create_session_context(PARSEABLE.storage())); @@ -495,6 +496,9 @@ pub struct CountConditions { pub conditions: Option, /// GroupBy columns pub group_by: Option>, + /// Optional number of top group values to return. + #[serde(alias = "topk", alias = "top_k")] + pub top_k: Option, } /// Request for counts, received from API/SQL query. @@ -513,6 +517,10 @@ pub struct CountsRequest { pub conditions: Option, } +fn quote_identifier(value: &str) -> String { + format!("\"{}\"", value.replace('"', "\"\"")) +} + impl CountsRequest { /// This function is supposed to read maninfest files for the given stream, /// get the sum of `num_rows` between the `startTime` and `endTime`, @@ -652,17 +660,19 @@ impl CountsRequest { let time_range = TimeRange::parse_human_time(&self.start_time, &self.end_time)?; let table_name = &self.stream; + let table_ref = quote_identifier(table_name); + let time_column_ref = format!("{}.{}", table_ref, quote_identifier(&time_column)); let start_time_col_name = "_bin_start_time_"; let end_time_col_name = "_bin_end_time_"; let bin_interval = count_api_bin_interval(&time_range.start, &time_range.end); let date_bin = format!( - "CAST(DATE_BIN('{bin_interval}', \"{table_name}\".\"{time_column}\", TIMESTAMP '{DATE_BIN_EPOCH_ANCHOR}') AS TEXT) as {start_time_col_name}, DATE_BIN('{bin_interval}', \"{table_name}\".\"{time_column}\", TIMESTAMP '{DATE_BIN_EPOCH_ANCHOR}') + INTERVAL '{bin_interval}' as {end_time_col_name}" + "CAST(DATE_BIN('{bin_interval}', {time_column_ref}, TIMESTAMP '{DATE_BIN_EPOCH_ANCHOR}') AS TEXT) as {start_time_col_name}, DATE_BIN('{bin_interval}', {time_column_ref}, TIMESTAMP '{DATE_BIN_EPOCH_ANCHOR}') + INTERVAL '{bin_interval}' as {end_time_col_name}" ); let group_by_cols = count_conditions .group_by .as_ref() - .map(|cols| cols.iter().map(|c| format!("\"{c}\"")).collect::>()) + .map(|cols| cols.iter().map(|c| quote_identifier(c)).collect::>()) .unwrap_or_default(); let group_clause = if group_by_cols.is_empty() { @@ -671,17 +681,47 @@ impl CountsRequest { format!(", {}", group_by_cols.join(", ")) }; - let query = if let Some(conditions) = &count_conditions.conditions { + let where_clause = if let Some(conditions) = &count_conditions.conditions { let f = get_filter_string(conditions).map_err(QueryError::CustomError)?; - format!( - "SELECT {date_bin}{group_clause}, COUNT(*) as count FROM \"{table_name}\" WHERE {} GROUP BY {end_time_col_name},{start_time_col_name}{group_clause} ORDER BY {end_time_col_name}{group_clause}", - f - ) + format!(" WHERE {f}") } else { - format!( - "SELECT {date_bin}{group_clause}, COUNT(*) as count FROM \"{table_name}\" GROUP BY {end_time_col_name},{start_time_col_name}{group_clause} ORDER BY {end_time_col_name}{group_clause}", - ) + String::default() }; + + let query = format!( + "SELECT {date_bin}{group_clause}, COUNT(*) as count FROM {table_ref}{where_clause} GROUP BY {end_time_col_name},{start_time_col_name}{group_clause} ORDER BY {end_time_col_name}{group_clause}", + ); + + if group_by_cols.is_empty() { + return Ok(query); + } + + let top_k = count_conditions.top_k.unwrap_or(DEFAULT_COUNTS_TOP_K); + + if top_k == 0 { + return Err(QueryError::CustomError( + "topK must be greater than 0".to_string(), + )); + } + + let top_group_cols = group_by_cols.join(", "); + let top_group_join = group_by_cols + .iter() + .map(|col| format!("(gc.{col} = tg.{col} OR (gc.{col} IS NULL AND tg.{col} IS NULL))")) + .join(" AND "); + let top_group_select = group_by_cols + .iter() + .map(|col| format!(", gc.{col} AS {col}")) + .join(""); + let top_group_order = group_by_cols + .iter() + .map(|col| format!(", gc.{col}")) + .join(""); + + let query = format!( + "WITH grouped_counts AS (SELECT {date_bin}{group_clause}, COUNT(*) as count FROM {table_ref}{where_clause} GROUP BY {end_time_col_name},{start_time_col_name}{group_clause}), top_groups AS (SELECT {top_group_cols} FROM grouped_counts GROUP BY {top_group_cols} ORDER BY SUM(\"count\") DESC LIMIT {top_k}) SELECT gc.{start_time_col_name} AS {start_time_col_name}, gc.{end_time_col_name} AS {end_time_col_name}{top_group_select}, gc.\"count\" AS count FROM grouped_counts gc INNER JOIN top_groups tg ON {top_group_join} ORDER BY gc.{end_time_col_name}{top_group_order}", + ); + Ok(query) } } @@ -1016,7 +1056,113 @@ impl PartitionedMetricMonitor { mod tests { use serde_json::json; - use crate::query::flatten_objects_for_count; + use crate::query::{ + CountConditions, CountsRequest, flatten_objects_for_count, resolve_stream_names, + }; + + #[test] + fn test_count_conditions_accepts_top_k() { + let conditions: CountConditions = serde_json::from_value(json!({ + "groupBy": ["host"], + "topK": 5 + })) + .unwrap(); + assert_eq!(conditions.top_k, Some(5)); + + let conditions: CountConditions = serde_json::from_value(json!({ + "groupBy": ["host"], + "topk": 3 + })) + .unwrap(); + assert_eq!(conditions.top_k, Some(3)); + } + + #[tokio::test] + async fn test_counts_sql_applies_top_k_to_group_by() { + let request = CountsRequest { + stream: "logs".to_string(), + start_time: "2024-01-01T00:00:00Z".to_string(), + end_time: "2024-01-01T01:00:00Z".to_string(), + num_bins: None, + conditions: Some(CountConditions { + conditions: None, + group_by: Some(vec!["host".to_string(), "service.name".to_string()]), + top_k: Some(5), + }), + }; + + let sql = request.get_df_sql("p_timestamp".to_string()).await.unwrap(); + + assert!(sql.starts_with("WITH grouped_counts AS")); + assert!(sql.contains("SELECT \"host\", \"service.name\" FROM grouped_counts")); + assert!(sql.contains("GROUP BY \"host\", \"service.name\"")); + assert!(sql.contains("ORDER BY SUM(\"count\") DESC LIMIT 5")); + assert!(sql.contains( + "(gc.\"host\" = tg.\"host\" OR (gc.\"host\" IS NULL AND tg.\"host\" IS NULL))" + )); + assert!(sql.contains("ORDER BY gc._bin_end_time_, gc.\"host\", gc.\"service.name\"")); + } + + #[tokio::test] + async fn test_counts_sql_defaults_top_k_for_group_by() { + let request = CountsRequest { + stream: "logs".to_string(), + start_time: "2024-01-01T00:00:00Z".to_string(), + end_time: "2024-01-01T01:00:00Z".to_string(), + num_bins: None, + conditions: Some(CountConditions { + conditions: None, + group_by: Some(vec!["host".to_string()]), + top_k: None, + }), + }; + + let sql = request.get_df_sql("p_timestamp".to_string()).await.unwrap(); + + assert!(sql.starts_with("WITH grouped_counts AS")); + assert!(sql.contains("ORDER BY SUM(\"count\") DESC LIMIT 10")); + } + + #[tokio::test] + async fn test_counts_top_k_sql_resolves_source_stream() { + let request = CountsRequest { + stream: "logs".to_string(), + start_time: "2024-01-01T00:00:00Z".to_string(), + end_time: "2024-01-01T01:00:00Z".to_string(), + num_bins: None, + conditions: Some(CountConditions { + conditions: None, + group_by: Some(vec!["host".to_string()]), + top_k: Some(5), + }), + }; + + let sql = request.get_df_sql("p_timestamp".to_string()).await.unwrap(); + let streams = resolve_stream_names(&sql).unwrap(); + + assert_eq!(streams, vec!["logs"]); + } + + #[tokio::test] + async fn test_counts_sql_rejects_zero_top_k_for_group_by() { + let request = CountsRequest { + stream: "logs".to_string(), + start_time: "2024-01-01T00:00:00Z".to_string(), + end_time: "2024-01-01T01:00:00Z".to_string(), + num_bins: None, + conditions: Some(CountConditions { + conditions: None, + group_by: Some(vec!["host".to_string()]), + top_k: Some(0), + }), + }; + + let err = request + .get_df_sql("p_timestamp".to_string()) + .await + .unwrap_err(); + assert!(err.to_string().contains("topK must be greater than 0")); + } #[test] fn test_flat_simple() {