diff --git a/pkg/rain/query_common.go b/pkg/rain/query_common.go index 9861815..c195fec 100644 --- a/pkg/rain/query_common.go +++ b/pkg/rain/query_common.go @@ -22,8 +22,10 @@ type preparingQueryRunner interface { } type joinClause struct { - kind string - table selectTableSource + kind string + // OPTIMIZATION: table is a concrete struct instead of an interface to avoid + // interface boxing allocations during query construction and join operations. + table tableSource on schema.Predicate } @@ -37,33 +39,28 @@ type returningClause struct { label string } -type selectTableSource interface { - writeSQL(*compileContext) error +// tableSource represents a source for a query (either a table or a subquery). +// OPTIMIZATION: This is a concrete struct instead of an interface to avoid +// interface boxing allocations when used in slices (e.g., joins, FROM, USING). +type tableSource struct { + table *schema.TableDef + subquery *SelectQuery + alias string } -type tableDefSource struct { - table *schema.TableDef -} - -func (s tableDefSource) writeSQL(ctx *compileContext) error { - ctx.writeTable(s.table) - return nil -} - -type subqueryTableSource struct { - query *SelectQuery - alias string -} - -func (s subqueryTableSource) writeSQL(ctx *compileContext) error { +func (s tableSource) writeSQL(ctx *compileContext) error { + if s.table != nil { + ctx.writeTable(s.table) + return nil + } if strings.TrimSpace(s.alias) == "" { return errors.New("rain: subquery table source requires a non-empty alias") } - if s.query == nil { + if s.subquery == nil { return fmt.Errorf("rain: subquery table source %q requires a non-nil query", s.alias) } ctx.writeByte('(') - if err := s.query.writeSQL(ctx); err != nil { + if err := s.subquery.writeSQL(ctx); err != nil { return err } ctx.writeString(") AS ") diff --git a/pkg/rain/query_common_internal_test.go b/pkg/rain/query_common_internal_test.go index 491c5f8..1294a11 100644 --- a/pkg/rain/query_common_internal_test.go +++ b/pkg/rain/query_common_internal_test.go @@ -22,7 +22,7 @@ func TestQueryCommonHelpers(t *testing.T) { t.Run("SubqueryAliasValidation", func(t *testing.T) { ctx := newCompileContext(dialectForTest(t, "postgres")) defer releaseCompileContext(ctx) - if err := (subqueryTableSource{alias: " ", query: &SelectQuery{dialect: ctx.dialect, table: users.TableDef()}}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-empty alias") { + if err := (tableSource{alias: " ", subquery: &SelectQuery{dialect: ctx.dialect, table: users.TableDef()}}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-empty alias") { t.Fatalf("expected empty alias error, got %v", err) } }) @@ -30,7 +30,7 @@ func TestQueryCommonHelpers(t *testing.T) { t.Run("SubqueryNilQueryValidation", func(t *testing.T) { ctx := newCompileContext(dialectForTest(t, "postgres")) defer releaseCompileContext(ctx) - if err := (subqueryTableSource{alias: "u", query: nil}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-nil query") { + if err := (tableSource{alias: "u", subquery: nil}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-nil query") { t.Fatalf("expected nil query error, got %v", err) } }) @@ -38,16 +38,16 @@ func TestQueryCommonHelpers(t *testing.T) { t.Run("SubqueryWriteSQL", func(t *testing.T) { ctx := newCompileContext(dialectForTest(t, "postgres")) defer releaseCompileContext(ctx) - err := (subqueryTableSource{ + err := (tableSource{ alias: "u", - query: &SelectQuery{ + subquery: &SelectQuery{ dialect: ctx.dialect, table: users.TableDef(), cols: []schema.Expression{users.ID}, }, }).writeSQL(ctx) if err != nil { - t.Fatalf("subqueryTableSource.writeSQL returned error: %v", err) + t.Fatalf("tableSource.writeSQL returned error: %v", err) } if !strings.Contains(ctx.String(), `AS "u"`) { t.Fatalf("expected compiled subquery alias, got %q", ctx.String()) @@ -57,9 +57,9 @@ func TestQueryCommonHelpers(t *testing.T) { t.Run("NestedQueryError", func(t *testing.T) { ctx := newCompileContext(dialectForTest(t, "postgres")) defer releaseCompileContext(ctx) - if err := (subqueryTableSource{ - alias: "broken", - query: &SelectQuery{dialect: ctx.dialect}, + if err := (tableSource{ + alias: "broken", + subquery: &SelectQuery{dialect: ctx.dialect}, }).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "requires a table") { t.Fatalf("expected nested query error, got %v", err) } diff --git a/pkg/rain/query_delete.go b/pkg/rain/query_delete.go index 6e805ef..35f66a1 100644 --- a/pkg/rain/query_delete.go +++ b/pkg/rain/query_delete.go @@ -16,7 +16,7 @@ type DeleteQuery struct { dialect dialect.Dialect table *schema.TableDef where []schema.Predicate - using []selectTableSource + using []tableSource order []schema.OrderExpr limit int hasLimit bool @@ -49,14 +49,14 @@ func (q *DeleteQuery) Returning(exprs ...schema.Expression) *DeleteQuery { // Supported by PostgreSQL. func (q *DeleteQuery) Using(tables ...schema.TableReference) *DeleteQuery { for _, table := range tables { - q.using = append(q.using, tableDefSource{table: table.TableDef()}) + q.using = append(q.using, tableSource{table: table.TableDef()}) } return q } // UsingSubquery appends a subquery source for the DELETE ... USING clause. func (q *DeleteQuery) UsingSubquery(query *SelectQuery, alias string) *DeleteQuery { - q.using = append(q.using, subqueryTableSource{query: query, alias: alias}) + q.using = append(q.using, tableSource{subquery: query, alias: alias}) return q } diff --git a/pkg/rain/query_select.go b/pkg/rain/query_select.go index 38be5fd..7c9f50b 100644 --- a/pkg/rain/query_select.go +++ b/pkg/rain/query_select.go @@ -77,61 +77,61 @@ func (q *SelectQuery) Where(predicate schema.Predicate) *SelectQuery { // Join appends an INNER JOIN clause. func (q *SelectQuery) Join(table schema.TableReference, on schema.Predicate) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "INNER JOIN", table: tableDefSource{table: table.TableDef()}, on: on}) + q.joins = append(q.joins, joinClause{kind: "INNER JOIN", table: tableSource{table: table.TableDef()}, on: on}) return q } // LeftJoin appends a LEFT JOIN clause. func (q *SelectQuery) LeftJoin(table schema.TableReference, on schema.Predicate) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "LEFT JOIN", table: tableDefSource{table: table.TableDef()}, on: on}) + q.joins = append(q.joins, joinClause{kind: "LEFT JOIN", table: tableSource{table: table.TableDef()}, on: on}) return q } // RightJoin appends a RIGHT JOIN clause. func (q *SelectQuery) RightJoin(table schema.TableReference, on schema.Predicate) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "RIGHT JOIN", table: tableDefSource{table: table.TableDef()}, on: on}) + q.joins = append(q.joins, joinClause{kind: "RIGHT JOIN", table: tableSource{table: table.TableDef()}, on: on}) return q } // FullJoin appends a FULL JOIN clause. func (q *SelectQuery) FullJoin(table schema.TableReference, on schema.Predicate) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "FULL JOIN", table: tableDefSource{table: table.TableDef()}, on: on}) + q.joins = append(q.joins, joinClause{kind: "FULL JOIN", table: tableSource{table: table.TableDef()}, on: on}) return q } // CrossJoin appends a CROSS JOIN clause. func (q *SelectQuery) CrossJoin(table schema.TableReference) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "CROSS JOIN", table: tableDefSource{table: table.TableDef()}}) + q.joins = append(q.joins, joinClause{kind: "CROSS JOIN", table: tableSource{table: table.TableDef()}}) return q } // JoinSubquery appends an INNER JOIN against a subquery source. func (q *SelectQuery) JoinSubquery(query *SelectQuery, alias string, on schema.Predicate) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "INNER JOIN", table: subqueryTableSource{query: query, alias: alias}, on: on}) + q.joins = append(q.joins, joinClause{kind: "INNER JOIN", table: tableSource{subquery: query, alias: alias}, on: on}) return q } // LeftJoinSubquery appends a LEFT JOIN against a subquery source. func (q *SelectQuery) LeftJoinSubquery(query *SelectQuery, alias string, on schema.Predicate) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "LEFT JOIN", table: subqueryTableSource{query: query, alias: alias}, on: on}) + q.joins = append(q.joins, joinClause{kind: "LEFT JOIN", table: tableSource{subquery: query, alias: alias}, on: on}) return q } // RightJoinSubquery appends a RIGHT JOIN against a subquery source. func (q *SelectQuery) RightJoinSubquery(query *SelectQuery, alias string, on schema.Predicate) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "RIGHT JOIN", table: subqueryTableSource{query: query, alias: alias}, on: on}) + q.joins = append(q.joins, joinClause{kind: "RIGHT JOIN", table: tableSource{subquery: query, alias: alias}, on: on}) return q } // FullJoinSubquery appends a FULL JOIN against a subquery source. func (q *SelectQuery) FullJoinSubquery(query *SelectQuery, alias string, on schema.Predicate) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "FULL JOIN", table: subqueryTableSource{query: query, alias: alias}, on: on}) + q.joins = append(q.joins, joinClause{kind: "FULL JOIN", table: tableSource{subquery: query, alias: alias}, on: on}) return q } // CrossJoinSubquery appends a CROSS JOIN against a subquery source. func (q *SelectQuery) CrossJoinSubquery(query *SelectQuery, alias string) *SelectQuery { - q.joins = append(q.joins, joinClause{kind: "CROSS JOIN", table: subqueryTableSource{query: query, alias: alias}}) + q.joins = append(q.joins, joinClause{kind: "CROSS JOIN", table: tableSource{subquery: query, alias: alias}}) return q } diff --git a/pkg/rain/query_update.go b/pkg/rain/query_update.go index e084e88..b0041ca 100644 --- a/pkg/rain/query_update.go +++ b/pkg/rain/query_update.go @@ -19,7 +19,7 @@ type UpdateQuery struct { values []assignment rows []map[schema.ColumnReference]any where []schema.Predicate - from []selectTableSource + from []tableSource order []schema.OrderExpr limit int hasLimit bool @@ -72,14 +72,14 @@ func (q *UpdateQuery) Returning(exprs ...schema.Expression) *UpdateQuery { // Supported by PostgreSQL and SQLite (3.33.0+). func (q *UpdateQuery) From(tables ...schema.TableReference) *UpdateQuery { for _, table := range tables { - q.from = append(q.from, tableDefSource{table: table.TableDef()}) + q.from = append(q.from, tableSource{table: table.TableDef()}) } return q } // FromSubquery appends a subquery source for the UPDATE ... FROM clause. func (q *UpdateQuery) FromSubquery(query *SelectQuery, alias string) *UpdateQuery { - q.from = append(q.from, subqueryTableSource{query: query, alias: alias}) + q.from = append(q.from, tableSource{subquery: query, alias: alias}) return q } @@ -159,25 +159,17 @@ func (q *UpdateQuery) compile() (compiledQuery, error) { return compiledQuery{}, fmt.Errorf("rain: cannot update view %q", q.table.Name) } - assignments, err := q.updateAssignments() - if err != nil { + ctx := newCompileContext(q.dialect) + defer releaseCompileContext(ctx) + + if err := q.writeSQL(ctx); err != nil { return compiledQuery{}, err } - if len(assignments) == 0 { - return compiledQuery{}, errors.New("rain: update query requires at least one assignment") - } if len(q.where) == 0 && !q.unbounded { return compiledQuery{}, errors.New("rain: update query requires at least one WHERE predicate; call Unbounded() to allow all rows") } - ctx := newCompileContext(q.dialect) - defer releaseCompileContext(ctx) - - if err := q.writeSQLInternal(ctx, assignments); err != nil { - return compiledQuery{}, err - } - return ctx.compiledQuery(), ctx.err } @@ -190,6 +182,10 @@ func (q *UpdateQuery) writeSQL(ctx *compileContext) error { } func (q *UpdateQuery) writeSQLInternal(ctx *compileContext, assignments []assignment) error { + if len(assignments) == 0 { + return errors.New("rain: update query requires at least one assignment") + } + if err := writeCTEs(ctx, q.ctes, "update"); err != nil { return err }