Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 18 additions & 21 deletions pkg/rain/query_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ type preparingQueryRunner interface {
}

type joinClause struct {
kind string
table selectTableSource
kind string
// OPTIMIZATION: table is a concrete struct instead of an interface to avoid
// interface boxing allocations during query construction and join operations.
table tableSource
on schema.Predicate
}

Expand All @@ -37,33 +39,28 @@ type returningClause struct {
label string
}

type selectTableSource interface {
writeSQL(*compileContext) error
// tableSource represents a source for a query (either a table or a subquery).
// OPTIMIZATION: This is a concrete struct instead of an interface to avoid
// interface boxing allocations when used in slices (e.g., joins, FROM, USING).
type tableSource struct {
table *schema.TableDef
subquery *SelectQuery
alias string
}
Comment on lines +45 to 49

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Silent alias discard for table-backed sources

The unified tableSource struct exposes an alias field that is silently ignored whenever table != nil. Any future caller within the package that writes tableSource{table: t, alias: "x"} expecting the alias to be rendered will get no error and no alias in the SQL output. Adding a comment on the field (or a validation in writeSQL) would make this invariant explicit.

Prompt To Fix With AI
This is a comment left during a code review.
Path: pkg/rain/query_common.go
Line: 45-49

Comment:
**Silent `alias` discard for table-backed sources**

The unified `tableSource` struct exposes an `alias` field that is silently ignored whenever `table != nil`. Any future caller within the package that writes `tableSource{table: t, alias: "x"}` expecting the alias to be rendered will get no error and no alias in the SQL output. Adding a comment on the field (or a validation in `writeSQL`) would make this invariant explicit.

How can I resolve this? If you propose a fix, please make it concise.

Fix in Codex


type tableDefSource struct {
table *schema.TableDef
}

func (s tableDefSource) writeSQL(ctx *compileContext) error {
ctx.writeTable(s.table)
return nil
}

type subqueryTableSource struct {
query *SelectQuery
alias string
}

func (s subqueryTableSource) writeSQL(ctx *compileContext) error {
func (s tableSource) writeSQL(ctx *compileContext) error {
if s.table != nil {
ctx.writeTable(s.table)
return nil
}
if strings.TrimSpace(s.alias) == "" {
return errors.New("rain: subquery table source requires a non-empty alias")
}
if s.query == nil {
if s.subquery == nil {
return fmt.Errorf("rain: subquery table source %q requires a non-nil query", s.alias)
}
ctx.writeByte('(')
if err := s.query.writeSQL(ctx); err != nil {
if err := s.subquery.writeSQL(ctx); err != nil {
return err
}
ctx.writeString(") AS ")
Expand Down
16 changes: 8 additions & 8 deletions pkg/rain/query_common_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,32 @@ func TestQueryCommonHelpers(t *testing.T) {
t.Run("SubqueryAliasValidation", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := (subqueryTableSource{alias: " ", query: &SelectQuery{dialect: ctx.dialect, table: users.TableDef()}}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-empty alias") {
if err := (tableSource{alias: " ", subquery: &SelectQuery{dialect: ctx.dialect, table: users.TableDef()}}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-empty alias") {
t.Fatalf("expected empty alias error, got %v", err)
}
})

t.Run("SubqueryNilQueryValidation", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := (subqueryTableSource{alias: "u", query: nil}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-nil query") {
if err := (tableSource{alias: "u", subquery: nil}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-nil query") {
t.Fatalf("expected nil query error, got %v", err)
}
})

t.Run("SubqueryWriteSQL", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
err := (subqueryTableSource{
err := (tableSource{
alias: "u",
query: &SelectQuery{
subquery: &SelectQuery{
dialect: ctx.dialect,
table: users.TableDef(),
cols: []schema.Expression{users.ID},
},
}).writeSQL(ctx)
if err != nil {
t.Fatalf("subqueryTableSource.writeSQL returned error: %v", err)
t.Fatalf("tableSource.writeSQL returned error: %v", err)
}
if !strings.Contains(ctx.String(), `AS "u"`) {
t.Fatalf("expected compiled subquery alias, got %q", ctx.String())
Expand All @@ -57,9 +57,9 @@ func TestQueryCommonHelpers(t *testing.T) {
t.Run("NestedQueryError", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := (subqueryTableSource{
alias: "broken",
query: &SelectQuery{dialect: ctx.dialect},
if err := (tableSource{
alias: "broken",
subquery: &SelectQuery{dialect: ctx.dialect},
}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "requires a table") {
t.Fatalf("expected nested query error, got %v", err)
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/rain/query_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type DeleteQuery struct {
dialect dialect.Dialect
table *schema.TableDef
where []schema.Predicate
using []selectTableSource
using []tableSource
order []schema.OrderExpr
limit int
hasLimit bool
Expand Down Expand Up @@ -49,14 +49,14 @@ func (q *DeleteQuery) Returning(exprs ...schema.Expression) *DeleteQuery {
// Supported by PostgreSQL.
func (q *DeleteQuery) Using(tables ...schema.TableReference) *DeleteQuery {
for _, table := range tables {
q.using = append(q.using, tableDefSource{table: table.TableDef()})
q.using = append(q.using, tableSource{table: table.TableDef()})
}
return q
}

// UsingSubquery appends a subquery source for the DELETE ... USING clause.
func (q *DeleteQuery) UsingSubquery(query *SelectQuery, alias string) *DeleteQuery {
q.using = append(q.using, subqueryTableSource{query: query, alias: alias})
q.using = append(q.using, tableSource{subquery: query, alias: alias})
return q
}

Expand Down
20 changes: 10 additions & 10 deletions pkg/rain/query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,61 +77,61 @@ func (q *SelectQuery) Where(predicate schema.Predicate) *SelectQuery {

// Join appends an INNER JOIN clause.
func (q *SelectQuery) Join(table schema.TableReference, on schema.Predicate) *SelectQuery {
q.joins = append(q.joins, joinClause{kind: "INNER JOIN", table: tableDefSource{table: table.TableDef()}, on: on})
q.joins = append(q.joins, joinClause{kind: "INNER JOIN", table: tableSource{table: table.TableDef()}, on: on})
return q
}

// LeftJoin appends a LEFT JOIN clause.
func (q *SelectQuery) LeftJoin(table schema.TableReference, on schema.Predicate) *SelectQuery {
q.joins = append(q.joins, joinClause{kind: "LEFT JOIN", table: tableDefSource{table: table.TableDef()}, on: on})
q.joins = append(q.joins, joinClause{kind: "LEFT JOIN", table: tableSource{table: table.TableDef()}, on: on})
return q
}

// RightJoin appends a RIGHT JOIN clause.
func (q *SelectQuery) RightJoin(table schema.TableReference, on schema.Predicate) *SelectQuery {
q.joins = append(q.joins, joinClause{kind: "RIGHT JOIN", table: tableDefSource{table: table.TableDef()}, on: on})
q.joins = append(q.joins, joinClause{kind: "RIGHT JOIN", table: tableSource{table: table.TableDef()}, on: on})
return q
}

// FullJoin appends a FULL JOIN clause.
func (q *SelectQuery) FullJoin(table schema.TableReference, on schema.Predicate) *SelectQuery {
q.joins = append(q.joins, joinClause{kind: "FULL JOIN", table: tableDefSource{table: table.TableDef()}, on: on})
q.joins = append(q.joins, joinClause{kind: "FULL JOIN", table: tableSource{table: table.TableDef()}, on: on})
return q
}

// CrossJoin appends a CROSS JOIN clause.
func (q *SelectQuery) CrossJoin(table schema.TableReference) *SelectQuery {
q.joins = append(q.joins, joinClause{kind: "CROSS JOIN", table: tableDefSource{table: table.TableDef()}})
q.joins = append(q.joins, joinClause{kind: "CROSS JOIN", table: tableSource{table: table.TableDef()}})
return q
}

// JoinSubquery appends an INNER JOIN against a subquery source.
func (q *SelectQuery) JoinSubquery(query *SelectQuery, alias string, on schema.Predicate) *SelectQuery {
q.joins = append(q.joins, joinClause{kind: "INNER JOIN", table: subqueryTableSource{query: query, alias: alias}, on: on})
q.joins = append(q.joins, joinClause{kind: "INNER JOIN", table: tableSource{subquery: query, alias: alias}, on: on})
return q
}

// LeftJoinSubquery appends a LEFT JOIN against a subquery source.
func (q *SelectQuery) LeftJoinSubquery(query *SelectQuery, alias string, on schema.Predicate) *SelectQuery {
q.joins = append(q.joins, joinClause{kind: "LEFT JOIN", table: subqueryTableSource{query: query, alias: alias}, on: on})
q.joins = append(q.joins, joinClause{kind: "LEFT JOIN", table: tableSource{subquery: query, alias: alias}, on: on})
return q
}

// RightJoinSubquery appends a RIGHT JOIN against a subquery source.
func (q *SelectQuery) RightJoinSubquery(query *SelectQuery, alias string, on schema.Predicate) *SelectQuery {
q.joins = append(q.joins, joinClause{kind: "RIGHT JOIN", table: subqueryTableSource{query: query, alias: alias}, on: on})
q.joins = append(q.joins, joinClause{kind: "RIGHT JOIN", table: tableSource{subquery: query, alias: alias}, on: on})
return q
}

// FullJoinSubquery appends a FULL JOIN against a subquery source.
func (q *SelectQuery) FullJoinSubquery(query *SelectQuery, alias string, on schema.Predicate) *SelectQuery {
q.joins = append(q.joins, joinClause{kind: "FULL JOIN", table: subqueryTableSource{query: query, alias: alias}, on: on})
q.joins = append(q.joins, joinClause{kind: "FULL JOIN", table: tableSource{subquery: query, alias: alias}, on: on})
return q
}

// CrossJoinSubquery appends a CROSS JOIN against a subquery source.
func (q *SelectQuery) CrossJoinSubquery(query *SelectQuery, alias string) *SelectQuery {
q.joins = append(q.joins, joinClause{kind: "CROSS JOIN", table: subqueryTableSource{query: query, alias: alias}})
q.joins = append(q.joins, joinClause{kind: "CROSS JOIN", table: tableSource{subquery: query, alias: alias}})
return q
}

Expand Down
26 changes: 11 additions & 15 deletions pkg/rain/query_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type UpdateQuery struct {
values []assignment
rows []map[schema.ColumnReference]any
where []schema.Predicate
from []selectTableSource
from []tableSource
order []schema.OrderExpr
limit int
hasLimit bool
Expand Down Expand Up @@ -72,14 +72,14 @@ func (q *UpdateQuery) Returning(exprs ...schema.Expression) *UpdateQuery {
// Supported by PostgreSQL and SQLite (3.33.0+).
func (q *UpdateQuery) From(tables ...schema.TableReference) *UpdateQuery {
for _, table := range tables {
q.from = append(q.from, tableDefSource{table: table.TableDef()})
q.from = append(q.from, tableSource{table: table.TableDef()})
}
return q
}

// FromSubquery appends a subquery source for the UPDATE ... FROM clause.
func (q *UpdateQuery) FromSubquery(query *SelectQuery, alias string) *UpdateQuery {
q.from = append(q.from, subqueryTableSource{query: query, alias: alias})
q.from = append(q.from, tableSource{subquery: query, alias: alias})
return q
}

Expand Down Expand Up @@ -159,25 +159,17 @@ func (q *UpdateQuery) compile() (compiledQuery, error) {
return compiledQuery{}, fmt.Errorf("rain: cannot update view %q", q.table.Name)
}

assignments, err := q.updateAssignments()
if err != nil {
ctx := newCompileContext(q.dialect)
defer releaseCompileContext(ctx)

if err := q.writeSQL(ctx); err != nil {
return compiledQuery{}, err
}
if len(assignments) == 0 {
return compiledQuery{}, errors.New("rain: update query requires at least one assignment")
}

if len(q.where) == 0 && !q.unbounded {
return compiledQuery{}, errors.New("rain: update query requires at least one WHERE predicate; call Unbounded() to allow all rows")
}

ctx := newCompileContext(q.dialect)
defer releaseCompileContext(ctx)

if err := q.writeSQLInternal(ctx, assignments); err != nil {
return compiledQuery{}, err
}

return ctx.compiledQuery(), ctx.err
}

Expand All @@ -190,6 +182,10 @@ func (q *UpdateQuery) writeSQL(ctx *compileContext) error {
}

func (q *UpdateQuery) writeSQLInternal(ctx *compileContext, assignments []assignment) error {
if len(assignments) == 0 {
return errors.New("rain: update query requires at least one assignment")
}

if err := writeCTEs(ctx, q.ctes, "update"); err != nil {
return err
}
Expand Down