Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 157 additions & 11 deletions src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ type BoxedBatchStream = SendableRecordBatchStream;

/// Result type returned by query execution: either collected batches or a streaming adapter, plus field names.
type QueryResult = Result<(Either<Vec<RecordBatch>, BoxedBatchStream>, Vec<String>), ExecuteError>;
const DEFAULT_COUNTS_TOP_K: usize = 10;

// pub static QUERY_SESSION: Lazy<SessionContext> =
// Lazy::new(|| Query::create_session_context(PARSEABLE.storage()));
Expand Down Expand Up @@ -495,6 +496,9 @@ pub struct CountConditions {
pub conditions: Option<Conditions>,
/// GroupBy columns
pub group_by: Option<Vec<String>>,
/// Optional number of top group values to return.
#[serde(alias = "topk", alias = "top_k")]
pub top_k: Option<usize>,
}

/// Request for counts, received from API/SQL query.
Expand All @@ -513,6 +517,10 @@ pub struct CountsRequest {
pub conditions: Option<CountConditions>,
}

fn quote_identifier(value: &str) -> String {
format!("\"{}\"", value.replace('"', "\"\""))
}

impl CountsRequest {
/// This function is supposed to read maninfest files for the given stream,
/// get the sum of `num_rows` between the `startTime` and `endTime`,
Expand Down Expand Up @@ -652,17 +660,19 @@ impl CountsRequest {
let time_range = TimeRange::parse_human_time(&self.start_time, &self.end_time)?;

let table_name = &self.stream;
let table_ref = quote_identifier(table_name);
let time_column_ref = format!("{}.{}", table_ref, quote_identifier(&time_column));
let start_time_col_name = "_bin_start_time_";
let end_time_col_name = "_bin_end_time_";
let bin_interval = count_api_bin_interval(&time_range.start, &time_range.end);
let date_bin = format!(
"CAST(DATE_BIN('{bin_interval}', \"{table_name}\".\"{time_column}\", TIMESTAMP '{DATE_BIN_EPOCH_ANCHOR}') AS TEXT) as {start_time_col_name}, DATE_BIN('{bin_interval}', \"{table_name}\".\"{time_column}\", TIMESTAMP '{DATE_BIN_EPOCH_ANCHOR}') + INTERVAL '{bin_interval}' as {end_time_col_name}"
"CAST(DATE_BIN('{bin_interval}', {time_column_ref}, TIMESTAMP '{DATE_BIN_EPOCH_ANCHOR}') AS TEXT) as {start_time_col_name}, DATE_BIN('{bin_interval}', {time_column_ref}, TIMESTAMP '{DATE_BIN_EPOCH_ANCHOR}') + INTERVAL '{bin_interval}' as {end_time_col_name}"
);

let group_by_cols = count_conditions
.group_by
.as_ref()
.map(|cols| cols.iter().map(|c| format!("\"{c}\"")).collect::<Vec<_>>())
.map(|cols| cols.iter().map(|c| quote_identifier(c)).collect::<Vec<_>>())
.unwrap_or_default();

let group_clause = if group_by_cols.is_empty() {
Expand All @@ -671,17 +681,47 @@ impl CountsRequest {
format!(", {}", group_by_cols.join(", "))
};

let query = if let Some(conditions) = &count_conditions.conditions {
let where_clause = if let Some(conditions) = &count_conditions.conditions {
let f = get_filter_string(conditions).map_err(QueryError::CustomError)?;
format!(
"SELECT {date_bin}{group_clause}, COUNT(*) as count FROM \"{table_name}\" WHERE {} GROUP BY {end_time_col_name},{start_time_col_name}{group_clause} ORDER BY {end_time_col_name}{group_clause}",
f
)
format!(" WHERE {f}")
} else {
format!(
"SELECT {date_bin}{group_clause}, COUNT(*) as count FROM \"{table_name}\" GROUP BY {end_time_col_name},{start_time_col_name}{group_clause} ORDER BY {end_time_col_name}{group_clause}",
)
String::default()
};

let query = format!(
"SELECT {date_bin}{group_clause}, COUNT(*) as count FROM {table_ref}{where_clause} GROUP BY {end_time_col_name},{start_time_col_name}{group_clause} ORDER BY {end_time_col_name}{group_clause}",
);

if group_by_cols.is_empty() {
return Ok(query);
}

let top_k = count_conditions.top_k.unwrap_or(DEFAULT_COUNTS_TOP_K);

if top_k == 0 {
return Err(QueryError::CustomError(
"topK must be greater than 0".to_string(),
));
}

let top_group_cols = group_by_cols.join(", ");
let top_group_join = group_by_cols
.iter()
.map(|col| format!("(gc.{col} = tg.{col} OR (gc.{col} IS NULL AND tg.{col} IS NULL))"))
.join(" AND ");
let top_group_select = group_by_cols
.iter()
.map(|col| format!(", gc.{col} AS {col}"))
.join("");
let top_group_order = group_by_cols
.iter()
.map(|col| format!(", gc.{col}"))
.join("");

let query = format!(
"WITH grouped_counts AS (SELECT {date_bin}{group_clause}, COUNT(*) as count FROM {table_ref}{where_clause} GROUP BY {end_time_col_name},{start_time_col_name}{group_clause}), top_groups AS (SELECT {top_group_cols} FROM grouped_counts GROUP BY {top_group_cols} ORDER BY SUM(\"count\") DESC LIMIT {top_k}) SELECT gc.{start_time_col_name} AS {start_time_col_name}, gc.{end_time_col_name} AS {end_time_col_name}{top_group_select}, gc.\"count\" AS count FROM grouped_counts gc INNER JOIN top_groups tg ON {top_group_join} ORDER BY gc.{end_time_col_name}{top_group_order}",
Comment thread
coderabbitai[bot] marked this conversation as resolved.
);

Ok(query)
}
}
Expand Down Expand Up @@ -1016,7 +1056,113 @@ impl PartitionedMetricMonitor {
mod tests {
use serde_json::json;

use crate::query::flatten_objects_for_count;
use crate::query::{
CountConditions, CountsRequest, flatten_objects_for_count, resolve_stream_names,
};

#[test]
fn test_count_conditions_accepts_top_k() {
let conditions: CountConditions = serde_json::from_value(json!({
"groupBy": ["host"],
"topK": 5
}))
.unwrap();
assert_eq!(conditions.top_k, Some(5));

let conditions: CountConditions = serde_json::from_value(json!({
"groupBy": ["host"],
"topk": 3
}))
.unwrap();
assert_eq!(conditions.top_k, Some(3));
}

#[tokio::test]
async fn test_counts_sql_applies_top_k_to_group_by() {
let request = CountsRequest {
stream: "logs".to_string(),
start_time: "2024-01-01T00:00:00Z".to_string(),
end_time: "2024-01-01T01:00:00Z".to_string(),
num_bins: None,
conditions: Some(CountConditions {
conditions: None,
group_by: Some(vec!["host".to_string(), "service.name".to_string()]),
top_k: Some(5),
}),
};

let sql = request.get_df_sql("p_timestamp".to_string()).await.unwrap();

assert!(sql.starts_with("WITH grouped_counts AS"));
assert!(sql.contains("SELECT \"host\", \"service.name\" FROM grouped_counts"));
assert!(sql.contains("GROUP BY \"host\", \"service.name\""));
assert!(sql.contains("ORDER BY SUM(\"count\") DESC LIMIT 5"));
assert!(sql.contains(
"(gc.\"host\" = tg.\"host\" OR (gc.\"host\" IS NULL AND tg.\"host\" IS NULL))"
));
assert!(sql.contains("ORDER BY gc._bin_end_time_, gc.\"host\", gc.\"service.name\""));
}

#[tokio::test]
async fn test_counts_sql_defaults_top_k_for_group_by() {
let request = CountsRequest {
stream: "logs".to_string(),
start_time: "2024-01-01T00:00:00Z".to_string(),
end_time: "2024-01-01T01:00:00Z".to_string(),
num_bins: None,
conditions: Some(CountConditions {
conditions: None,
group_by: Some(vec!["host".to_string()]),
top_k: None,
}),
};

let sql = request.get_df_sql("p_timestamp".to_string()).await.unwrap();

assert!(sql.starts_with("WITH grouped_counts AS"));
assert!(sql.contains("ORDER BY SUM(\"count\") DESC LIMIT 10"));
}

#[tokio::test]
async fn test_counts_top_k_sql_resolves_source_stream() {
let request = CountsRequest {
stream: "logs".to_string(),
start_time: "2024-01-01T00:00:00Z".to_string(),
end_time: "2024-01-01T01:00:00Z".to_string(),
num_bins: None,
conditions: Some(CountConditions {
conditions: None,
group_by: Some(vec!["host".to_string()]),
top_k: Some(5),
}),
};

let sql = request.get_df_sql("p_timestamp".to_string()).await.unwrap();
let streams = resolve_stream_names(&sql).unwrap();

assert_eq!(streams, vec!["logs"]);
}

#[tokio::test]
async fn test_counts_sql_rejects_zero_top_k_for_group_by() {
let request = CountsRequest {
stream: "logs".to_string(),
start_time: "2024-01-01T00:00:00Z".to_string(),
end_time: "2024-01-01T01:00:00Z".to_string(),
num_bins: None,
conditions: Some(CountConditions {
conditions: None,
group_by: Some(vec!["host".to_string()]),
top_k: Some(0),
}),
};

let err = request
.get_df_sql("p_timestamp".to_string())
.await
.unwrap_err();
assert!(err.to_string().contains("topK must be greater than 0"));
}

#[test]
fn test_flat_simple() {
Expand Down
Loading