Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions pkg/rain/query_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@ type tableDefSource struct {
table *schema.TableDef
}

func tableDefFromSelectSource(source selectTableSource) *schema.TableDef {
if table, ok := source.(tableDefSource); ok {
return table.table
}
return nil
}

func (s tableDefSource) writeSQL(ctx *compileContext) error {
ctx.writeTable(s.table)
return nil
Expand Down Expand Up @@ -121,7 +114,7 @@ func writeCTEs(ctx *compileContext, ctes []cteDefinition, label string) error {
return nil
}

func writeOrderLimit(ctx *compileContext, order []schema.OrderExpr, limit *int, offset *int, featureOrder, featureLimit dialect.Feature) error {
func writeOrderLimit(ctx *compileContext, order []schema.OrderExpr, limit int, hasLimit bool, offset int, hasOffset bool, featureOrder, featureLimit dialect.Feature) error {
if len(order) > 0 {
if featureOrder != dialect.FeatureUnlimited && !dialect.HasFeature(ctx.dialect.Features(), featureOrder) {
return fmt.Errorf("rain: ORDER BY is not supported for this query type in %s dialect", ctx.dialect.Name())
Expand All @@ -146,20 +139,20 @@ func writeOrderLimit(ctx *compileContext, order []schema.OrderExpr, limit *int,
}
}

if limit != nil || (offset != nil && *offset > 0) {
if hasLimit || hasOffset {
if featureLimit != dialect.FeatureUnlimited && !dialect.HasFeature(ctx.dialect.Features(), featureLimit) {
return fmt.Errorf("rain: LIMIT/OFFSET is not supported for this query type in %s dialect", ctx.dialect.Name())
}
l := -1
if limit != nil {
l = *limit
if hasLimit {
l = limit
if l < 0 {
return errors.New("rain: LIMIT must be non-negative")
}
}
o := 0
if offset != nil {
o = *offset
if hasOffset {
o = offset
if o < 0 {
return errors.New("rain: OFFSET must be non-negative")
}
Expand Down
11 changes: 2 additions & 9 deletions pkg/rain/query_common_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,10 @@ func TestQueryCommonHelpers(t *testing.T) {

users, _ := defineInternalQueryTables()

if got := tableDefFromSelectSource(tableDefSource{table: users.TableDef()}); got != users.TableDef() {
t.Fatalf("expected tableDefFromSelectSource to return the table, got %#v", got)
}
if got := tableDefFromSelectSource(subqueryTableSource{}); got != nil {
t.Fatalf("expected non-table select source to return nil, got %#v", got)
}

t.Run("SubqueryAliasValidation", func(t *testing.T) {
ctx := newCompileContext(dialectForTest(t, "postgres"))
defer releaseCompileContext(ctx)
if err := (subqueryTableSource{alias: " ", query: &SelectQuery{dialect: ctx.dialect, table: tableDefSource{table: users.TableDef()}}}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-empty alias") {
if err := (subqueryTableSource{alias: " ", query: &SelectQuery{dialect: ctx.dialect, table: users.TableDef()}}).writeSQL(ctx); err == nil || !strings.Contains(err.Error(), "non-empty alias") {
t.Fatalf("expected empty alias error, got %v", err)
}
})
Expand All @@ -49,7 +42,7 @@ func TestQueryCommonHelpers(t *testing.T) {
alias: "u",
query: &SelectQuery{
dialect: ctx.dialect,
table: tableDefSource{table: users.TableDef()},
table: users.TableDef(),
cols: []schema.Expression{users.ID},
},
}).writeSQL(ctx)
Expand Down
26 changes: 13 additions & 13 deletions pkg/rain/query_compile_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ func TestQueryBuilderAndHelperErrors(t *testing.T) {
if _, _, err := db.Select().ToSQL(); err == nil || !strings.Contains(err.Error(), "requires a table") {
t.Fatalf("expected select table error, got %v", err)
}
selectNoRunner := &SelectQuery{dialect: db.Dialect(), table: tableDefSource{table: users.TableDef()}}
selectNoRunner := &SelectQuery{dialect: db.Dialect(), table: users.TableDef()}
if err := selectNoRunner.Scan(context.Background(), &internalUserRow{}); !errors.Is(err, ErrNoConnection) {
t.Fatalf("expected select scan ErrNoConnection, got %v", err)
}
if _, err := selectNoRunner.Prepare(context.Background()); !errors.Is(err, ErrNoConnection) {
t.Fatalf("expected select prepare ErrNoConnection, got %v", err)
}
selectUnsupportedPrepare := &SelectQuery{runner: &countingRunner{}, dialect: db.Dialect(), table: tableDefSource{table: users.TableDef()}}
selectUnsupportedPrepare := &SelectQuery{runner: &countingRunner{}, dialect: db.Dialect(), table: users.TableDef()}
if _, err := selectUnsupportedPrepare.Prepare(context.Background()); !errors.Is(err, ErrPrepareNotSupported) {
t.Fatalf("expected select prepare ErrPrepareNotSupported, got %v", err)
}
Expand Down Expand Up @@ -303,7 +303,7 @@ func TestCompiledQueryBindValidation(t *testing.T) {
users, _ := defineInternalQueryTables()
compiled, err := (&SelectQuery{
dialect: dialectForTest(t, "postgres"),
table: tableDefSource{table: users.TableDef()},
table: users.TableDef(),
where: []schema.Predicate{
schema.And(
users.Email.EqExpr(schema.Placeholder("email")),
Expand Down Expand Up @@ -468,7 +468,7 @@ func TestNewOperatorsSQL(t *testing.T) {
name: "Exists",
expr: schema.Exists(&SelectQuery{
dialect: d,
table: tableDefSource{table: users.TableDef()},
table: users.TableDef(),
where: []schema.Predicate{users.ID.Eq(1)},
}),
wantSQL: `EXISTS (SELECT * FROM "users" WHERE "users"."id" = $1)`,
Expand All @@ -478,7 +478,7 @@ func TestNewOperatorsSQL(t *testing.T) {
name: "NotExists",
expr: schema.NotExists(&SelectQuery{
dialect: d,
table: tableDefSource{table: users.TableDef()},
table: users.TableDef(),
where: []schema.Predicate{users.ID.Eq(1)},
}),
wantSQL: `NOT EXISTS (SELECT * FROM "users" WHERE "users"."id" = $1)`,
Expand Down Expand Up @@ -515,20 +515,20 @@ func TestCompoundQueryInternals(t *testing.T) {
t.Run("cacheOptions preserved in non-flattening wrapSetOp", func(t *testing.T) {
q1 := &SelectQuery{
dialect: d,
table: tableDefSource{table: users.TableDef()},
table: users.TableDef(),
cacheOptions: &queryCacheOptions{ttl: 5 * time.Minute},
}
q2 := &SelectQuery{dialect: d, table: tableDefSource{table: users.TableDef()}}
q2 := &SelectQuery{dialect: d, table: users.TableDef()}
union := q1.Union(q2)
if union.cacheOptions == nil || union.cacheOptions.ttl != 5*time.Minute {
t.Fatalf("expected cacheOptions to propagate, got %#v", union.cacheOptions)
}
})

t.Run("cacheOptions preserved in flattening wrapSetOp", func(t *testing.T) {
q1 := &SelectQuery{dialect: d, table: tableDefSource{table: users.TableDef()}}
q2 := &SelectQuery{dialect: d, table: tableDefSource{table: users.TableDef()}}
q3 := &SelectQuery{dialect: d, table: tableDefSource{table: users.TableDef()}}
q1 := &SelectQuery{dialect: d, table: users.TableDef()}
q2 := &SelectQuery{dialect: d, table: users.TableDef()}
q3 := &SelectQuery{dialect: d, table: users.TableDef()}
base := q1.Union(q2)
base.cacheOptions = &queryCacheOptions{ttl: 5 * time.Minute}
union := base.Union(q3)
Expand All @@ -540,12 +540,12 @@ func TestCompoundQueryInternals(t *testing.T) {
t.Run("compileExists on compound query", func(t *testing.T) {
q1 := &SelectQuery{
dialect: d,
table: tableDefSource{table: users.TableDef()},
table: users.TableDef(),
where: []schema.Predicate{users.ID.Eq(1)},
}
q2 := &SelectQuery{
dialect: d,
table: tableDefSource{table: users.TableDef()},
table: users.TableDef(),
where: []schema.Predicate{users.ID.Eq(2)},
}
union := q1.Union(q2)
Expand All @@ -560,7 +560,7 @@ func TestCompoundQueryInternals(t *testing.T) {
})

t.Run("isBareCompound", func(t *testing.T) {
op := &SelectQuery{dialect: d, table: tableDefSource{table: users.TableDef()}}
op := &SelectQuery{dialect: d, table: users.TableDef()}
bare := &SelectQuery{dialect: d, firstOperand: op}
if !bare.isBareCompound() {
t.Fatalf("expected bare compound")
Expand Down
10 changes: 7 additions & 3 deletions pkg/rain/query_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ type DeleteQuery struct {
where []schema.Predicate
using []selectTableSource
order []schema.OrderExpr
limit *int
limit int
hasLimit bool
offset int
hasOffset bool
ctes []cteDefinition
returning []schema.Expression
unbounded bool
Expand Down Expand Up @@ -73,7 +76,8 @@ func (q *DeleteQuery) OrderBy(order ...schema.OrderExpr) *DeleteQuery {
// Limit sets the LIMIT clause.
// Supported by MySQL and SQLite.
func (q *DeleteQuery) Limit(limit int) *DeleteQuery {
q.limit = &limit
q.limit = limit
q.hasLimit = true
return q
}

Expand Down Expand Up @@ -175,7 +179,7 @@ func (q *DeleteQuery) writeSQL(ctx *compileContext) error {
}
}

if err := writeOrderLimit(ctx, q.order, q.limit, nil, dialect.FeatureDeleteOrder, dialect.FeatureDeleteLimit); err != nil {
if err := writeOrderLimit(ctx, q.order, q.limit, q.hasLimit, q.offset, q.hasOffset, dialect.FeatureDeleteOrder, dialect.FeatureDeleteLimit); err != nil {
return err
}

Expand Down
Loading