Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions pkg/rain/expressions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package rain

import (
"testing"

"github.com/hyperlocalise/rain-orm/pkg/schema"
)

func TestBinaryAndConcatExpressionsToSQL(t *testing.T) {
type UsersTable struct {
schema.TableModel
ID *schema.Column[int64]
Age *schema.Column[int32]
Name *schema.Column[string]
Score *schema.Column[float64]
}

Users := schema.Define("users", func(t *UsersTable) {
t.ID = t.BigSerial("id").PrimaryKey()
t.Age = t.Integer("age").NotNull()
t.Name = t.Text("name").NotNull()
t.Score = t.Double("score").NotNull()
})

t.Run("Arithmetic", func(t *testing.T) {
db := MustOpenDialect("postgres")

tests := []struct {
name string
expr schema.Expression
wantSQL string
wantArgs []any
}{
{
name: "Add",
expr: Users.Age.Add(int32(10)),
wantSQL: `("users"."age" + $1)`,
wantArgs: []any{int32(10)},
},
{
name: "Sub",
expr: Users.Age.Sub(int32(5)),
wantSQL: `("users"."age" - $1)`,
wantArgs: []any{int32(5)},
},
{
name: "Mul",
expr: Users.Score.Mul(1.5),
wantSQL: `("users"."score" * $1)`,
wantArgs: []any{1.5},
},
{
name: "Div",
expr: Users.Score.Div(2.0),
wantSQL: `("users"."score" / $1)`,
wantArgs: []any{2.0},
},
{
name: "Mod",
expr: Users.Age.Mod(int32(2)),
wantSQL: `("users"."age" % $1)`,
wantArgs: []any{int32(2)},
},
{
name: "NestedArithmetic",
expr: Users.Age.Add(int32(10)).Mul(int32(2)),
wantSQL: `(("users"."age" + $1) * $2)`,
wantArgs: []any{int32(10), int32(2)},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query := db.Select(tt.expr).From(Users)
gotSQL, gotArgs, err := query.ToSQL()
if err != nil {
t.Fatalf("ToSQL() error = %v", err)
}
expectedSQL := "SELECT " + tt.wantSQL + " FROM \"users\""
if gotSQL != expectedSQL {
t.Errorf("got SQL %q, want %q", gotSQL, expectedSQL)
}
if len(gotArgs) != len(tt.wantArgs) {
t.Errorf("got %d args, want %d", len(gotArgs), len(tt.wantArgs))
}
for i := range gotArgs {
if gotArgs[i] != tt.wantArgs[i] {
t.Errorf("arg %d: got %v, want %v", i, gotArgs[i], tt.wantArgs[i])
}
}
})
}
})

t.Run("Concat", func(t *testing.T) {
postgres := MustOpenDialect("postgres")
mysql := MustOpenDialect("mysql")

tests := []struct {
name string
db *DB
expr schema.Expression
wantSQL string
}{
{
name: "PostgresConcat",
db: postgres,
expr: schema.Concat(Users.Name, " (", Users.Age, ")"),
wantSQL: `SELECT ("users"."name" || $1 || "users"."age" || $2) FROM "users"`,
},
{
name: "MySQLConcat",
db: mysql,
expr: schema.Concat(Users.Name, " (", Users.Age, ")"),
wantSQL: "SELECT CONCAT(`users`.`name`, ?, `users`.`age`, ?) FROM `users`",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query := tt.db.Select(tt.expr).From(Users)
gotSQL, _, err := query.ToSQL()
if err != nil {
t.Fatalf("ToSQL() error = %v", err)
}
if gotSQL != tt.wantSQL {
t.Errorf("got SQL %q, want %q", gotSQL, tt.wantSQL)
}
})
}
})
}
54 changes: 54 additions & 0 deletions pkg/rain/query_compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,29 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex
c.args = append(c.args, nil)
c.writeString(c.dialect.Placeholder(index))
case schema.ComparisonExpr:
switch value.Operator {
case "=", "<>", ">", ">=", "<", "<=", "LIKE", "NOT LIKE", "ILIKE", "NOT ILIKE":
// ok
default:
return fmt.Errorf("rain: invalid comparison operator %q", value.Operator)
}
if err := c.writeExpression(value.Left); err != nil {
return err
}
c.writeByte(' ')
c.writeString(value.Operator)
c.writeByte(' ')
if err := c.writeExpression(value.Right); err != nil {
return err
}
case schema.BinaryExpr:
switch value.Operator {
case "+", "-", "*", "/", "%":
// ok
default:
return fmt.Errorf("rain: invalid binary operator %q", value.Operator)
}
c.writeByte('(')
if err := c.writeExpression(value.Left); err != nil {
return err
}
Expand All @@ -379,6 +402,37 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex
if err := c.writeExpression(value.Right); err != nil {
return err
}
c.writeByte(')')
case schema.ConcatExpr:
if len(value.Exprs) < 2 {
return errors.New("rain: CONCAT requires at least two expressions")
}
switch c.dialect.Name() {
case "postgres", "sqlite":
c.writeByte('(')
for idx, expr := range value.Exprs {
if idx > 0 {
c.writeString(" || ")
}
if err := c.writeExpression(expr); err != nil {
return err
}
}
c.writeByte(')')
case "mysql":
c.writeString("CONCAT(")
for idx, expr := range value.Exprs {
if idx > 0 {
c.writeString(", ")
}
if err := c.writeExpression(expr); err != nil {
return err
}
}
c.writeByte(')')
default:
return fmt.Errorf("rain: CONCAT is not implemented for %s dialect", c.dialect.Name())
}
case schema.InExpr:
if len(value.Values) == 0 {
return errors.New("rain: IN predicate requires at least one value")
Expand Down
87 changes: 87 additions & 0 deletions pkg/rain/sqlite_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1955,3 +1955,90 @@ func TestSQLiteIntegrationFirst(t *testing.T) {
}
})
}

func TestSQLiteIntegrationArithmeticAndConcat(t *testing.T) {
t.Parallel()

ctx := context.Background()
db := openSQLiteTestDB(t)
users, _, _ := defineSQLiteTables()
createSQLiteSchema(t, ctx, db)

// Seed data
if _, err := db.Insert().Table(users).Values(
map[schema.ColumnReference]any{users.ID: 10, users.Email: "alice@example.com", users.Name: "Alice", users.Active: true},
).Exec(ctx); err != nil {
t.Fatalf("seed data failed: %v", err)
}

t.Run("Arithmetic", func(t *testing.T) {
var results []struct {
Val int64 `db:"val"`
}

// Add: 10 + 5 = 15
if err := db.Select(users.ID.Add(int64(5)).As("val")).From(users).Scan(ctx, &results); err != nil {
t.Fatalf("Add failed: %v", err)
}
if len(results) == 0 || results[0].Val != 15 {
t.Fatalf("expected 15, got %+v", results)
}

// Sub: 10 - 3 = 5
if err := db.Select(users.ID.Sub(int64(3)).As("val")).From(users).Scan(ctx, &results); err != nil {
t.Fatalf("Sub failed: %v", err)
}
if len(results) == 0 || results[0].Val != 7 {
t.Fatalf("expected 7, got %+v", results)
}

// Mul: 10 * 2 = 20
if err := db.Select(users.ID.Mul(int64(2)).As("val")).From(users).Scan(ctx, &results); err != nil {
t.Fatalf("Mul failed: %v", err)
}
if len(results) == 0 || results[0].Val != 20 {
t.Fatalf("expected 20, got %+v", results)
}

// Div: 10 / 3 = 3 (integer division)
if err := db.Select(users.ID.Div(int64(3)).As("val")).From(users).Scan(ctx, &results); err != nil {
t.Fatalf("Div failed: %v", err)
}
if len(results) == 0 || results[0].Val != 3 {
t.Fatalf("expected 3, got %+v", results)
}

// Mod: 10 % 3 = 1
if err := db.Select(users.ID.Mod(int64(3)).As("val")).From(users).Scan(ctx, &results); err != nil {
t.Fatalf("Mod failed: %v", err)
}
if len(results) == 0 || results[0].Val != 1 {
t.Fatalf("expected 1, got %+v", results)
}

// Nested: (10 + 10) * 2 = 40
if err := db.Select(users.ID.Add(int64(10)).Mul(int64(2)).As("val")).From(users).Scan(ctx, &results); err != nil {
t.Fatalf("Nested arithmetic failed: %v", err)
}
if len(results) == 0 || results[0].Val != 40 {
t.Fatalf("expected 40, got %+v", results)
}
})

t.Run("Concat", func(t *testing.T) {
var results []struct {
Val string `db:"val"`
}
if err := db.Select(schema.Concat(users.Name, " (", users.Email, ")").As("val")).From(users).Scan(ctx, &results); err != nil {
t.Fatalf("Concat failed: %v", err)
}
if len(results) == 0 {
t.Fatalf("no results for concat")
}
result := results[0].Val
expected := "Alice (alice@example.com)"
if result != expected {
t.Fatalf("expected %q, got %q", expected, result)
}
})
}
Loading