diff --git a/pkg/rain/query_common.go b/pkg/rain/query_common.go index 3aa7eb0..9861815 100644 --- a/pkg/rain/query_common.go +++ b/pkg/rain/query_common.go @@ -45,13 +45,6 @@ type tableDefSource struct { table *schema.TableDef } -func tableDefFromSelectSource(source selectTableSource) *schema.TableDef { - if table, ok := source.(tableDefSource); ok { - return table.table - } - return nil -} - func (s tableDefSource) writeSQL(ctx *compileContext) error { ctx.writeTable(s.table) return nil @@ -121,7 +114,7 @@ func writeCTEs(ctx *compileContext, ctes []cteDefinition, label string) error { return nil } -func writeOrderLimit(ctx *compileContext, order []schema.OrderExpr, limit *int, offset *int, featureOrder, featureLimit dialect.Feature) error { +func writeOrderLimit(ctx *compileContext, order []schema.OrderExpr, limit int, hasLimit bool, offset int, hasOffset bool, featureOrder, featureLimit dialect.Feature) error { if len(order) > 0 { if featureOrder != dialect.FeatureUnlimited && !dialect.HasFeature(ctx.dialect.Features(), featureOrder) { return fmt.Errorf("rain: ORDER BY is not supported for this query type in %s dialect", ctx.dialect.Name()) @@ -146,20 +139,20 @@ func writeOrderLimit(ctx *compileContext, order []schema.OrderExpr, limit *int, } } - if limit != nil || (offset != nil && *offset > 0) { + if hasLimit || hasOffset { if featureLimit != dialect.FeatureUnlimited && !dialect.HasFeature(ctx.dialect.Features(), featureLimit) { return fmt.Errorf("rain: LIMIT/OFFSET is not supported for this query type in %s dialect", ctx.dialect.Name()) } l := -1 - if limit != nil { - l = *limit + if hasLimit { + l = limit if l < 0 { return errors.New("rain: LIMIT must be non-negative") } } o := 0 - if offset != nil { - o = *offset + if hasOffset { + o = offset if o < 0 { return errors.New("rain: OFFSET must be non-negative") } diff --git a/pkg/rain/query_common_internal_test.go b/pkg/rain/query_common_internal_test.go index c50ae72..491c5f8 100644 --- a/pkg/rain/query_common_internal_test.go +++ b/pkg/rain/query_common_internal_test.go @@ -19,17 +19,10 @@ func TestQueryCommonHelpers(t *testing.T) { users, _ := defineInternalQueryTables() - if got := tableDefFromSelectSource(tableDefSource{table: users.TableDef()}); got != users.TableDef() { - t.Fatalf("expected tableDefFromSelectSource to return the table, got %#v", got) - } - if got := tableDefFromSelectSource(subqueryTableSource{}); got != nil { - t.Fatalf("expected non-table select source to return nil, got %#v", got) - } - t.Run("SubqueryAliasValidation", func(t *testing.T) { ctx := newCompileContext(dialectForTest(t, "postgres")) defer releaseCompileContext(ctx) - if err := (subqueryTableSource{alias: " ", query: &SelectQuery{dialect: ctx.dialect, table: tableDefSource{table: users.TableDef()}}}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-empty alias") { + if err := (subqueryTableSource{alias: " ", query: &SelectQuery{dialect: ctx.dialect, table: users.TableDef()}}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-empty alias") { t.Fatalf("expected empty alias error, got %v", err) } }) @@ -49,7 +42,7 @@ func TestQueryCommonHelpers(t *testing.T) { alias: "u", query: &SelectQuery{ dialect: ctx.dialect, - table: tableDefSource{table: users.TableDef()}, + table: users.TableDef(), cols: []schema.Expression{users.ID}, }, }).writeSQL(ctx) diff --git a/pkg/rain/query_compile_internal_test.go b/pkg/rain/query_compile_internal_test.go index 54e5a0e..304f396 100644 --- a/pkg/rain/query_compile_internal_test.go +++ b/pkg/rain/query_compile_internal_test.go @@ -24,14 +24,14 @@ func TestQueryBuilderAndHelperErrors(t *testing.T) { if _, _, err := db.Select().ToSQL(); err == nil || !strings.Contains(err.Error(), "requires a table") { t.Fatalf("expected select table error, got %v", err) } - selectNoRunner := &SelectQuery{dialect: db.Dialect(), table: tableDefSource{table: users.TableDef()}} + selectNoRunner := &SelectQuery{dialect: db.Dialect(), table: users.TableDef()} if err := selectNoRunner.Scan(context.Background(), &internalUserRow{}); !errors.Is(err, ErrNoConnection) { t.Fatalf("expected select scan ErrNoConnection, got %v", err) } if _, err := selectNoRunner.Prepare(context.Background()); !errors.Is(err, ErrNoConnection) { t.Fatalf("expected select prepare ErrNoConnection, got %v", err) } - selectUnsupportedPrepare := &SelectQuery{runner: &countingRunner{}, dialect: db.Dialect(), table: tableDefSource{table: users.TableDef()}} + selectUnsupportedPrepare := &SelectQuery{runner: &countingRunner{}, dialect: db.Dialect(), table: users.TableDef()} if _, err := selectUnsupportedPrepare.Prepare(context.Background()); !errors.Is(err, ErrPrepareNotSupported) { t.Fatalf("expected select prepare ErrPrepareNotSupported, got %v", err) } @@ -303,7 +303,7 @@ func TestCompiledQueryBindValidation(t *testing.T) { users, _ := defineInternalQueryTables() compiled, err := (&SelectQuery{ dialect: dialectForTest(t, "postgres"), - table: tableDefSource{table: users.TableDef()}, + table: users.TableDef(), where: []schema.Predicate{ schema.And( users.Email.EqExpr(schema.Placeholder("email")), @@ -468,7 +468,7 @@ func TestNewOperatorsSQL(t *testing.T) { name: "Exists", expr: schema.Exists(&SelectQuery{ dialect: d, - table: tableDefSource{table: users.TableDef()}, + table: users.TableDef(), where: []schema.Predicate{users.ID.Eq(1)}, }), wantSQL: `EXISTS (SELECT * FROM "users" WHERE "users"."id" = $1)`, @@ -478,7 +478,7 @@ func TestNewOperatorsSQL(t *testing.T) { name: "NotExists", expr: schema.NotExists(&SelectQuery{ dialect: d, - table: tableDefSource{table: users.TableDef()}, + table: users.TableDef(), where: []schema.Predicate{users.ID.Eq(1)}, }), wantSQL: `NOT EXISTS (SELECT * FROM "users" WHERE "users"."id" = $1)`, @@ -515,10 +515,10 @@ func TestCompoundQueryInternals(t *testing.T) { t.Run("cacheOptions preserved in non-flattening wrapSetOp", func(t *testing.T) { q1 := &SelectQuery{ dialect: d, - table: tableDefSource{table: users.TableDef()}, + table: users.TableDef(), cacheOptions: &queryCacheOptions{ttl: 5 * time.Minute}, } - q2 := &SelectQuery{dialect: d, table: tableDefSource{table: users.TableDef()}} + q2 := &SelectQuery{dialect: d, table: users.TableDef()} union := q1.Union(q2) if union.cacheOptions == nil || union.cacheOptions.ttl != 5*time.Minute { t.Fatalf("expected cacheOptions to propagate, got %#v", union.cacheOptions) @@ -526,9 +526,9 @@ func TestCompoundQueryInternals(t *testing.T) { }) t.Run("cacheOptions preserved in flattening wrapSetOp", func(t *testing.T) { - q1 := &SelectQuery{dialect: d, table: tableDefSource{table: users.TableDef()}} - q2 := &SelectQuery{dialect: d, table: tableDefSource{table: users.TableDef()}} - q3 := &SelectQuery{dialect: d, table: tableDefSource{table: users.TableDef()}} + q1 := &SelectQuery{dialect: d, table: users.TableDef()} + q2 := &SelectQuery{dialect: d, table: users.TableDef()} + q3 := &SelectQuery{dialect: d, table: users.TableDef()} base := q1.Union(q2) base.cacheOptions = &queryCacheOptions{ttl: 5 * time.Minute} union := base.Union(q3) @@ -540,12 +540,12 @@ func TestCompoundQueryInternals(t *testing.T) { t.Run("compileExists on compound query", func(t *testing.T) { q1 := &SelectQuery{ dialect: d, - table: tableDefSource{table: users.TableDef()}, + table: users.TableDef(), where: []schema.Predicate{users.ID.Eq(1)}, } q2 := &SelectQuery{ dialect: d, - table: tableDefSource{table: users.TableDef()}, + table: users.TableDef(), where: []schema.Predicate{users.ID.Eq(2)}, } union := q1.Union(q2) @@ -560,7 +560,7 @@ func TestCompoundQueryInternals(t *testing.T) { }) t.Run("isBareCompound", func(t *testing.T) { - op := &SelectQuery{dialect: d, table: tableDefSource{table: users.TableDef()}} + op := &SelectQuery{dialect: d, table: users.TableDef()} bare := &SelectQuery{dialect: d, firstOperand: op} if !bare.isBareCompound() { t.Fatalf("expected bare compound") diff --git a/pkg/rain/query_delete.go b/pkg/rain/query_delete.go index 6b7590f..6e805ef 100644 --- a/pkg/rain/query_delete.go +++ b/pkg/rain/query_delete.go @@ -18,7 +18,10 @@ type DeleteQuery struct { where []schema.Predicate using []selectTableSource order []schema.OrderExpr - limit *int + limit int + hasLimit bool + offset int + hasOffset bool ctes []cteDefinition returning []schema.Expression unbounded bool @@ -73,7 +76,8 @@ func (q *DeleteQuery) OrderBy(order ...schema.OrderExpr) *DeleteQuery { // Limit sets the LIMIT clause. // Supported by MySQL and SQLite. func (q *DeleteQuery) Limit(limit int) *DeleteQuery { - q.limit = &limit + q.limit = limit + q.hasLimit = true return q } @@ -175,7 +179,7 @@ func (q *DeleteQuery) writeSQL(ctx *compileContext) error { } } - if err := writeOrderLimit(ctx, q.order, q.limit, nil, dialect.FeatureDeleteOrder, dialect.FeatureDeleteLimit); err != nil { + if err := writeOrderLimit(ctx, q.order, q.limit, q.hasLimit, q.offset, q.hasOffset, dialect.FeatureDeleteOrder, dialect.FeatureDeleteLimit); err != nil { return err } diff --git a/pkg/rain/query_select.go b/pkg/rain/query_select.go index ec695d6..38be5fd 100644 --- a/pkg/rain/query_select.go +++ b/pkg/rain/query_select.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "reflect" + "strings" "github.com/hyperlocalise/rain-orm/pkg/dialect" "github.com/hyperlocalise/rain-orm/pkg/schema" @@ -17,7 +18,9 @@ type SelectQuery struct { runner queryRunner dialect dialect.Dialect cache QueryCache - table selectTableSource + table *schema.TableDef + tableSubquery *SelectQuery + tableAlias string cols []schema.Expression where []schema.Predicate joins []joinClause @@ -29,8 +32,10 @@ type SelectQuery struct { setOps []setOperation distinct bool distinctOn []schema.Expression - limit *int - offset *int + limit int + hasLimit bool + offset int + hasOffset bool relationNames []string relationConfigs map[string]RelationConfig cacheOptions *queryCacheOptions @@ -39,7 +44,9 @@ type SelectQuery struct { // Table sets the table source for the query. func (q *SelectQuery) Table(table schema.TableReference) *SelectQuery { - q.table = tableDefSource{table: table.TableDef()} + q.table = table.TableDef() + q.tableSubquery = nil + q.tableAlias = "" return q } @@ -50,7 +57,9 @@ func (q *SelectQuery) From(table schema.TableReference) *SelectQuery { // TableSubquery sets a subquery source for the query's FROM clause. func (q *SelectQuery) TableSubquery(query *SelectQuery, alias string) *SelectQuery { - q.table = subqueryTableSource{query: query, alias: alias} + q.tableSubquery = query + q.tableAlias = alias + q.table = nil return q } @@ -165,13 +174,15 @@ func (q *SelectQuery) OrderBy(order ...schema.OrderExpr) *SelectQuery { // Limit sets the LIMIT clause. func (q *SelectQuery) Limit(limit int) *SelectQuery { - q.limit = &limit + q.limit = limit + q.hasLimit = true return q } // Offset sets the OFFSET clause. func (q *SelectQuery) Offset(offset int) *SelectQuery { - q.offset = &offset + q.offset = offset + q.hasOffset = true return q } @@ -310,6 +321,9 @@ func (q *SelectQuery) CloneForTable(table *schema.TableDef) any { func (q *SelectQuery) clone() *SelectQuery { newQ := *q + if q.tableSubquery != nil { + newQ.tableSubquery = q.tableSubquery.clone() + } newQ.cols = append([]schema.Expression(nil), q.cols...) newQ.where = append([]schema.Predicate(nil), q.where...) newQ.joins = append([]joinClause(nil), q.joins...) @@ -326,14 +340,6 @@ func (q *SelectQuery) clone() *SelectQuery { newQ.relationConfigs[k] = v } } - if q.limit != nil { - l := *q.limit - newQ.limit = &l - } - if q.offset != nil { - o := *q.offset - newQ.offset = &o - } if q.locking != nil { copyLocking := *q.locking copyLocking.of = append([]schema.TableReference(nil), q.locking.of...) @@ -385,8 +391,9 @@ func (q *SelectQuery) withSQLiteInsertSelectConflictWhereChanged() (*SelectQuery func (q *SelectQuery) isBareCompound() bool { return q.firstOperand != nil && - len(q.order) == 0 && q.limit == nil && q.offset == nil && - !q.distinct && len(q.distinctOn) == 0 && len(q.cols) == 0 && q.table == nil && + len(q.order) == 0 && !q.hasLimit && !q.hasOffset && + !q.distinct && len(q.distinctOn) == 0 && len(q.cols) == 0 && + q.table == nil && q.tableSubquery == nil && len(q.where) == 0 && len(q.joins) == 0 && len(q.groupBy) == 0 && len(q.having) == 0 && len(q.relationNames) == 0 && len(q.ctes) == 0 && @@ -469,13 +476,16 @@ func (q *SelectQuery) writeSQL(ctx *compileContext) error { return err } } - if err := writeOrderLimit(ctx, q.order, q.limit, q.offset, dialect.FeatureUnlimited, dialect.FeatureUnlimited); err != nil { + if err := writeOrderLimit(ctx, q.order, q.limit, q.hasLimit, q.offset, q.hasOffset, dialect.FeatureUnlimited, dialect.FeatureUnlimited); err != nil { return err } return q.writeLocking(ctx) } - if q.table == nil { + if q.table == nil && q.tableSubquery == nil { + if q.tableAlias != "" { + return errors.New("rain: subquery table source requires a non-nil query") + } return errors.New("rain: select query requires a table") } @@ -511,7 +521,7 @@ func (q *SelectQuery) writeSQL(ctx *compileContext) error { } ctx.writeString(" FROM ") - if err := q.table.writeSQL(ctx); err != nil { + if err := q.writeTableSourceSQL(ctx); err != nil { return err } @@ -545,7 +555,7 @@ func (q *SelectQuery) writeSQL(ctx *compileContext) error { } } - if err := writeOrderLimit(ctx, q.order, q.limit, q.offset, dialect.FeatureUnlimited, dialect.FeatureUnlimited); err != nil { + if err := writeOrderLimit(ctx, q.order, q.limit, q.hasLimit, q.offset, q.hasOffset, dialect.FeatureUnlimited, dialect.FeatureUnlimited); err != nil { return err } @@ -605,7 +615,7 @@ func (q *SelectQuery) writeCompoundOperandSQL(ctx *compileContext) error { } // Use parentheses if the operand has its own ORDER BY, LIMIT, locking, or is itself a compound query. // Flattening is handled during builder chaining in wrapSetOp. - useParens := len(q.order) > 0 || q.limit != nil || q.offset != nil || q.locking != nil || q.firstOperand != nil + useParens := len(q.order) > 0 || q.hasLimit || q.hasOffset || q.locking != nil || q.firstOperand != nil if useParens { ctx.writeByte('(') } @@ -837,7 +847,7 @@ func (q *SelectQuery) scanValidationTable() *schema.TableDef { if len(q.joins) > 0 { return nil } - return tableDefFromSelectSource(q.table) + return q.table } func (q *SelectQuery) writeCachedSelectResult(ctx context.Context, key string, options *queryCacheOptions, value *cachedSelectRows) error { @@ -886,7 +896,10 @@ func (q *SelectQuery) toAggregateSQL(selection string) (string, []any, error) { } func (q *SelectQuery) compile() (compiledQuery, error) { - if q.table == nil && q.firstOperand == nil { + if q.table == nil && q.tableSubquery == nil && q.firstOperand == nil { + if q.tableAlias != "" { + return compiledQuery{}, errors.New("rain: subquery table source requires a non-nil query") + } return compiledQuery{}, errors.New("rain: select query requires a table") } @@ -904,7 +917,7 @@ func (q *SelectQuery) compile() (compiledQuery, error) { if len(q.cols) > 0 { return compiledQuery{}, errors.New("rain: compound queries do not support Column()") } - if q.table != nil { + if q.table != nil || q.tableSubquery != nil { return compiledQuery{}, errors.New("rain: compound queries do not support Table() (already has operands)") } if len(q.where) > 0 { @@ -939,7 +952,10 @@ func (q *SelectQuery) compileAggregate(selection string) (compiledQuery, error) if q.firstOperand != nil { return compiledQuery{}, errors.New("rain: aggregate helpers do not support compound queries (UNION, etc.); use TableSubquery as a workaround") } - if q.table == nil { + if q.table == nil && q.tableSubquery == nil { + if q.tableAlias != "" { + return compiledQuery{}, errors.New("rain: subquery table source requires a non-nil query") + } return compiledQuery{}, errors.New("rain: select query requires a table") } if len(q.ctes) > 0 { @@ -957,7 +973,7 @@ func (q *SelectQuery) compileAggregate(selection string) (compiledQuery, error) ctx.writeString("SELECT ") ctx.writeString(selection) ctx.writeString(" FROM ") - if err := q.table.writeSQL(ctx); err != nil { + if err := q.writeTableSourceSQL(ctx); err != nil { return compiledQuery{}, err } @@ -986,6 +1002,26 @@ func (q *SelectQuery) compileExists() (compiledQuery, error) { return wrapExistsCompiled(compiled) } +func (q *SelectQuery) writeTableSourceSQL(ctx *compileContext) error { + if q.table != nil { + ctx.writeTable(q.table) + return nil + } + if strings.TrimSpace(q.tableAlias) == "" { + return errors.New("rain: subquery table source requires a non-empty alias") + } + if q.tableSubquery == nil { + return errors.New("rain: subquery table source requires a non-nil query") + } + ctx.writeByte('(') + if err := q.tableSubquery.writeSQL(ctx); err != nil { + return err + } + ctx.writeString(") AS ") + ctx.writeQuotedIdentifier(q.tableAlias) + return nil +} + func wrapExistsCompiled(compiled compiledQuery) (compiledQuery, error) { // NOTE: This shallow copies the input compiledQuery and wraps the SQL. // The argPlan and args slices are shared with the original. This is safe diff --git a/pkg/rain/query_update.go b/pkg/rain/query_update.go index b74bf25..e084e88 100644 --- a/pkg/rain/query_update.go +++ b/pkg/rain/query_update.go @@ -21,7 +21,10 @@ type UpdateQuery struct { where []schema.Predicate from []selectTableSource order []schema.OrderExpr - limit *int + limit int + hasLimit bool + offset int + hasOffset bool ctes []cteDefinition returning []schema.Expression unbounded bool @@ -96,7 +99,8 @@ func (q *UpdateQuery) OrderBy(order ...schema.OrderExpr) *UpdateQuery { // Limit sets the LIMIT clause. // Supported by MySQL and SQLite. func (q *UpdateQuery) Limit(limit int) *UpdateQuery { - q.limit = &limit + q.limit = limit + q.hasLimit = true return q } @@ -211,7 +215,7 @@ func (q *UpdateQuery) writeSQLInternal(ctx *compileContext, assignments []assign if !dialect.HasFeature(ctx.dialect.Features(), dialect.FeatureUpdateFrom) { return fmt.Errorf("rain: UPDATE ... FROM is not supported by %s dialect", ctx.dialect.Name()) } - if ctx.dialect.Name() == "sqlite" && (len(q.order) > 0 || q.limit != nil) { + if ctx.dialect.Name() == "sqlite" && (len(q.order) > 0 || q.hasLimit) { return errors.New("rain: SQLite does not support combining UPDATE ... FROM with ORDER BY or LIMIT") } ctx.writeString(" FROM ") @@ -232,7 +236,7 @@ func (q *UpdateQuery) writeSQLInternal(ctx *compileContext, assignments []assign } } - if err := writeOrderLimit(ctx, q.order, q.limit, nil, dialect.FeatureUpdateOrder, dialect.FeatureUpdateLimit); err != nil { + if err := writeOrderLimit(ctx, q.order, q.limit, q.hasLimit, q.offset, q.hasOffset, dialect.FeatureUpdateOrder, dialect.FeatureUpdateLimit); err != nil { return err } diff --git a/pkg/rain/relation_loading.go b/pkg/rain/relation_loading.go index 0ab7978..a6a9243 100644 --- a/pkg/rain/relation_loading.go +++ b/pkg/rain/relation_loading.go @@ -37,12 +37,11 @@ func (q *SelectQuery) scanRowsWithRelations(ctx context.Context, rows *sql.Rows, return fmt.Errorf("rain: destination must point to a struct or slice") } - tableSource, ok := q.table.(tableDefSource) - if !ok { + if q.table == nil { return fmt.Errorf("rain: relation loading requires a concrete table source") } - relationTree, err := buildRelationLoadTree(tableSource.table, q.relationNames, q.relationConfigs) + relationTree, err := buildRelationLoadTree(q.table, q.relationNames, q.relationConfigs) if err != nil { return err } @@ -54,7 +53,7 @@ func (q *SelectQuery) scanRowsWithRelations(ctx context.Context, rows *sql.Rows, containerPtr = slicePtr.Interface() } - if err := scanRowsAgainstTable(rows, containerPtr, tableSource.table); err != nil { + if err := scanRowsAgainstTable(rows, containerPtr, q.table); err != nil { return err } @@ -257,7 +256,7 @@ func (q *SelectQuery) loadRelatedManyToManyRows( batchKeys := sourceKeys[start:end] var batchPairs []pair - joinQuery := &SelectQuery{runner: q.runner, dialect: q.dialect, table: tableDefSource{table: relation.JoinTable}} + joinQuery := &SelectQuery{runner: q.runner, dialect: q.dialect, table: relation.JoinTable} if err := joinQuery. Column(schema.Ref(relation.JoinSourceColumn).As("s"), schema.Ref(relation.JoinTargetColumn).As("t")). Where(schema.Ref(relation.JoinSourceColumn).In(batchKeys...)). @@ -294,7 +293,7 @@ func (q *SelectQuery) loadRelatedManyToManyRows( batchKeys := uniqueTargetKeys[start:end] batchDest := reflect.New(reflect.SliceOf(relatedElemType)) - targetQuery := &SelectQuery{runner: q.runner, dialect: q.dialect, table: tableDefSource{table: relation.TargetTable}} + targetQuery := &SelectQuery{runner: q.runner, dialect: q.dialect, table: relation.TargetTable} if config.Where != nil { targetQuery.Where(config.Where) } @@ -375,7 +374,7 @@ func (q *SelectQuery) loadRelatedRows( for start := 0; start < len(sourceKeys); start += relationBatchSize { end := min(start+relationBatchSize, len(sourceKeys)) batchDest := reflect.New(reflect.SliceOf(relatedElemType)) - query := &SelectQuery{runner: q.runner, dialect: q.dialect, table: tableDefSource{table: relation.TargetTable}} + query := &SelectQuery{runner: q.runner, dialect: q.dialect, table: relation.TargetTable} if config.Where != nil { query.Where(config.Where) }