From f55b99c5641fe6785379e40c007f41af851a2ff3 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 17 Jun 2026 22:51:31 +0000 Subject: [PATCH 1/2] feat(rain): add support for arithmetic and concat expressions Implement arithmetic operators (+, -, *, /, %) and string concatenation in the Rain ORM query builder. - Add BinaryExpr and ConcatExpr to pkg/schema - Add Add, Sub, Mul, Div, Mod methods to Column[T] and AnyColumn - Support expression chaining (e.g., col.Add(1).Mul(2)) - Implement dialect-aware concatenation (|| for Postgres/SQLite, CONCAT() for MySQL) - Support aliasing via .As() on new expression types - Update query compiler to render expressions with correct precedence - Add unit and SQLite integration tests Co-authored-by: cungminh2710 <8063319+cungminh2710@users.noreply.github.com> --- pkg/rain/expressions_test.go | 132 +++++++++++++++++++++++ pkg/rain/query_compile.go | 42 ++++++++ pkg/rain/sqlite_integration_test.go | 53 ++++++++++ pkg/schema/schema.go | 159 ++++++++++++++++++++++++++++ 4 files changed, 386 insertions(+) create mode 100644 pkg/rain/expressions_test.go diff --git a/pkg/rain/expressions_test.go b/pkg/rain/expressions_test.go new file mode 100644 index 0000000..f7f378b --- /dev/null +++ b/pkg/rain/expressions_test.go @@ -0,0 +1,132 @@ +package rain + +import ( + "testing" + + "github.com/hyperlocalise/rain-orm/pkg/schema" +) + +func TestBinaryAndConcatExpressionsToSQL(t *testing.T) { + type UsersTable struct { + schema.TableModel + ID *schema.Column[int64] + Age *schema.Column[int32] + Name *schema.Column[string] + Score *schema.Column[float64] + } + + Users := schema.Define("users", func(t *UsersTable) { + t.ID = t.BigSerial("id").PrimaryKey() + t.Age = t.Integer("age").NotNull() + t.Name = t.Text("name").NotNull() + t.Score = t.Double("score").NotNull() + }) + + t.Run("Arithmetic", func(t *testing.T) { + db := MustOpenDialect("postgres") + + tests := []struct { + name string + expr schema.Expression + wantSQL string + wantArgs []any + }{ + { + name: "Add", + expr: Users.Age.Add(int32(10)), + wantSQL: `("users"."age" + $1)`, + wantArgs: []any{int32(10)}, + }, + { + name: "Sub", + expr: Users.Age.Sub(int32(5)), + wantSQL: `("users"."age" - $1)`, + wantArgs: []any{int32(5)}, + }, + { + name: "Mul", + expr: Users.Score.Mul(1.5), + wantSQL: `("users"."score" * $1)`, + wantArgs: []any{1.5}, + }, + { + name: "Div", + expr: Users.Score.Div(2.0), + wantSQL: `("users"."score" / $1)`, + wantArgs: []any{2.0}, + }, + { + name: "Mod", + expr: Users.Age.Mod(int32(2)), + wantSQL: `("users"."age" % $1)`, + wantArgs: []any{int32(2)}, + }, + { + name: "NestedArithmetic", + expr: Users.Age.Add(int32(10)).Mul(int32(2)), + wantSQL: `(("users"."age" + $1) * $2)`, + wantArgs: []any{int32(10), int32(2)}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query := db.Select(tt.expr).From(Users) + gotSQL, gotArgs, err := query.ToSQL() + if err != nil { + t.Fatalf("ToSQL() error = %v", err) + } + expectedSQL := "SELECT " + tt.wantSQL + " FROM \"users\"" + if gotSQL != expectedSQL { + t.Errorf("got SQL %q, want %q", gotSQL, expectedSQL) + } + if len(gotArgs) != len(tt.wantArgs) { + t.Errorf("got %d args, want %d", len(gotArgs), len(tt.wantArgs)) + } + for i := range gotArgs { + if gotArgs[i] != tt.wantArgs[i] { + t.Errorf("arg %d: got %v, want %v", i, gotArgs[i], tt.wantArgs[i]) + } + } + }) + } + }) + + t.Run("Concat", func(t *testing.T) { + postgres := MustOpenDialect("postgres") + mysql := MustOpenDialect("mysql") + + tests := []struct { + name string + db *DB + expr schema.Expression + wantSQL string + }{ + { + name: "PostgresConcat", + db: postgres, + expr: schema.Concat(Users.Name, " (", Users.Age, ")"), + wantSQL: `SELECT ("users"."name" || $1 || "users"."age" || $2) FROM "users"`, + }, + { + name: "MySQLConcat", + db: mysql, + expr: schema.Concat(Users.Name, " (", Users.Age, ")"), + wantSQL: "SELECT CONCAT(`users`.`name`, ?, `users`.`age`, ?) FROM `users`", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query := tt.db.Select(tt.expr).From(Users) + gotSQL, _, err := query.ToSQL() + if err != nil { + t.Fatalf("ToSQL() error = %v", err) + } + if gotSQL != tt.wantSQL { + t.Errorf("got SQL %q, want %q", gotSQL, tt.wantSQL) + } + }) + } + }) +} diff --git a/pkg/rain/query_compile.go b/pkg/rain/query_compile.go index 9dca1ec..69a901a 100644 --- a/pkg/rain/query_compile.go +++ b/pkg/rain/query_compile.go @@ -379,6 +379,48 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex if err := c.writeExpression(value.Right); err != nil { return err } + case schema.BinaryExpr: + c.writeByte('(') + if err := c.writeExpression(value.Left); err != nil { + return err + } + c.writeByte(' ') + c.writeString(value.Operator) + c.writeByte(' ') + if err := c.writeExpression(value.Right); err != nil { + return err + } + c.writeByte(')') + case schema.ConcatExpr: + if len(value.Exprs) < 2 { + return errors.New("rain: CONCAT requires at least two expressions") + } + switch c.dialect.Name() { + case "postgres", "sqlite": + c.writeByte('(') + for idx, expr := range value.Exprs { + if idx > 0 { + c.writeString(" || ") + } + if err := c.writeExpression(expr); err != nil { + return err + } + } + c.writeByte(')') + case "mysql": + c.writeString("CONCAT(") + for idx, expr := range value.Exprs { + if idx > 0 { + c.writeString(", ") + } + if err := c.writeExpression(expr); err != nil { + return err + } + } + c.writeByte(')') + default: + return fmt.Errorf("rain: CONCAT is not implemented for %s dialect", c.dialect.Name()) + } case schema.InExpr: if len(value.Values) == 0 { return errors.New("rain: IN predicate requires at least one value") diff --git a/pkg/rain/sqlite_integration_test.go b/pkg/rain/sqlite_integration_test.go index 286ca54..1a3dc73 100644 --- a/pkg/rain/sqlite_integration_test.go +++ b/pkg/rain/sqlite_integration_test.go @@ -1931,6 +1931,59 @@ func TestSQLiteIntegrationFirst(t *testing.T) { t.Fatalf("expected bob@example.com, got %q", row.Email) } }) +} + +func TestSQLiteIntegrationArithmeticAndConcat(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openSQLiteTestDB(t) + users, _, _ := defineSQLiteTables() + createSQLiteSchema(t, ctx, db) + + // Seed data + if _, err := db.Insert().Table(users).Values( + map[schema.ColumnReference]any{users.ID: 1, users.Email: "alice@example.com", users.Name: "Alice", users.Active: true}, + ).Exec(ctx); err != nil { + t.Fatalf("seed data failed: %v", err) + } + + t.Run("Arithmetic", func(t *testing.T) { + var results []struct { + Val int64 `db:"val"` + } + if err := db.Select(users.ID.Add(int64(10)).As("val")).From(users).Scan(ctx, &results); err != nil { + t.Fatalf("Add failed: %v", err) + } + if len(results) == 0 || results[0].Val != 11 { + t.Fatalf("expected 11, got %+v", results) + } + + results = nil + if err := db.Select(users.ID.Add(int64(10)).Mul(int64(2)).As("val")).From(users).Scan(ctx, &results); err != nil { + t.Fatalf("Nested arithmetic failed: %v", err) + } + if len(results) == 0 || results[0].Val != 22 { + t.Fatalf("expected 22, got %+v", results) + } + }) + + t.Run("Concat", func(t *testing.T) { + var results []struct { + Val string `db:"val"` + } + if err := db.Select(schema.Concat(users.Name, " (", users.Email, ")").As("val")).From(users).Scan(ctx, &results); err != nil { + t.Fatalf("Concat failed: %v", err) + } + if len(results) == 0 { + t.Fatalf("no results for concat") + } + result := results[0].Val + expected := "Alice (alice@example.com)" + if result != expected { + t.Fatalf("expected %q, got %q", expected, result) + } + }) t.Run("ErrNoRows", func(t *testing.T) { var row sqliteUserRow diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index 61dd81a..5f5e676 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -609,6 +609,31 @@ func (c *AnyColumn) InSubquery(subquery Expression) InExpr { return InExpr{Left: c, Values: []Expression{subquery}} } +// Add adds a value or expression to this column. +func (c *AnyColumn) Add(value any) BinaryExpr { + return BinaryExpr{Left: c, Operator: "+", Right: wrapValue(value)} +} + +// Sub subtracts a value or expression from this column. +func (c *AnyColumn) Sub(value any) BinaryExpr { + return BinaryExpr{Left: c, Operator: "-", Right: wrapValue(value)} +} + +// Mul multiplies this column by a value or expression. +func (c *AnyColumn) Mul(value any) BinaryExpr { + return BinaryExpr{Left: c, Operator: "*", Right: wrapValue(value)} +} + +// Div divides this column by a value or expression. +func (c *AnyColumn) Div(value any) BinaryExpr { + return BinaryExpr{Left: c, Operator: "/", Right: wrapValue(value)} +} + +// Mod calculates the remainder of this column divided by a value or expression. +func (c *AnyColumn) Mod(value any) BinaryExpr { + return BinaryExpr{Left: c, Operator: "%", Right: wrapValue(value)} +} + // NotInSubquery compares this column to the result of a subquery. func (c *AnyColumn) NotInSubquery(subquery Expression) InExpr { return InExpr{Left: c, Values: []Expression{subquery}, Negated: true} @@ -861,6 +886,56 @@ func (c *Column[T]) InSubquery(subquery Expression) InExpr { return InExpr{Left: c, Values: []Expression{subquery}} } +// Add adds a value to this column. +func (c *Column[T]) Add(value T) BinaryExpr { + return BinaryExpr{Left: c, Operator: "+", Right: ValueExpr{Value: value}} +} + +// AddExpr adds a SQL expression to this column. +func (c *Column[T]) AddExpr(expr Expression) BinaryExpr { + return BinaryExpr{Left: c, Operator: "+", Right: expr} +} + +// Sub subtracts a value from this column. +func (c *Column[T]) Sub(value T) BinaryExpr { + return BinaryExpr{Left: c, Operator: "-", Right: ValueExpr{Value: value}} +} + +// SubExpr subtracts a SQL expression from this column. +func (c *Column[T]) SubExpr(expr Expression) BinaryExpr { + return BinaryExpr{Left: c, Operator: "-", Right: expr} +} + +// Mul multiplies this column by a value. +func (c *Column[T]) Mul(value T) BinaryExpr { + return BinaryExpr{Left: c, Operator: "*", Right: ValueExpr{Value: value}} +} + +// MulExpr multiplies this column by a SQL expression. +func (c *Column[T]) MulExpr(expr Expression) BinaryExpr { + return BinaryExpr{Left: c, Operator: "*", Right: expr} +} + +// Div divides this column by a value. +func (c *Column[T]) Div(value T) BinaryExpr { + return BinaryExpr{Left: c, Operator: "/", Right: ValueExpr{Value: value}} +} + +// DivExpr divides this column by a SQL expression. +func (c *Column[T]) DivExpr(expr Expression) BinaryExpr { + return BinaryExpr{Left: c, Operator: "/", Right: expr} +} + +// Mod calculates the remainder of this column divided by a value. +func (c *Column[T]) Mod(value T) BinaryExpr { + return BinaryExpr{Left: c, Operator: "%", Right: ValueExpr{Value: value}} +} + +// ModExpr calculates the remainder of this column divided by a SQL expression. +func (c *Column[T]) ModExpr(expr Expression) BinaryExpr { + return BinaryExpr{Left: c, Operator: "%", Right: expr} +} + // NotInSubquery compares this column to the result of a subquery. func (c *Column[T]) NotInSubquery(subquery Expression) InExpr { return InExpr{Left: c, Values: []Expression{subquery}, Negated: true} @@ -974,6 +1049,57 @@ type NullCheckExpr struct { func (NullCheckExpr) isExpression() {} func (NullCheckExpr) isPredicate() {} +// BinaryExpr represents an arithmetic operation. +type BinaryExpr struct { + Left Expression + Operator string + Right Expression +} + +func (BinaryExpr) isExpression() {} + +// As aliases this binary expression in a SELECT list. +func (b BinaryExpr) As(alias string) AliasExpr { + return As(b, alias) +} + +// Add adds a value or expression to this binary expression. +func (b BinaryExpr) Add(value any) BinaryExpr { + return BinaryExpr{Left: b, Operator: "+", Right: wrapValue(value)} +} + +// Sub subtracts a value or expression from this binary expression. +func (b BinaryExpr) Sub(value any) BinaryExpr { + return BinaryExpr{Left: b, Operator: "-", Right: wrapValue(value)} +} + +// Mul multiplies this binary expression by a value or expression. +func (b BinaryExpr) Mul(value any) BinaryExpr { + return BinaryExpr{Left: b, Operator: "*", Right: wrapValue(value)} +} + +// Div divides this binary expression by a value or expression. +func (b BinaryExpr) Div(value any) BinaryExpr { + return BinaryExpr{Left: b, Operator: "/", Right: wrapValue(value)} +} + +// Mod calculates the remainder of this binary expression divided by a value or expression. +func (b BinaryExpr) Mod(value any) BinaryExpr { + return BinaryExpr{Left: b, Operator: "%", Right: wrapValue(value)} +} + +// ConcatExpr represents a SQL concatenation. +type ConcatExpr struct { + Exprs []Expression +} + +func (ConcatExpr) isExpression() {} + +// As aliases this concatenation expression in a SELECT list. +func (c ConcatExpr) As(alias string) AliasExpr { + return As(c, alias) +} + // LogicalExpr groups predicates with AND or OR. type LogicalExpr struct { Operator string @@ -1304,6 +1430,25 @@ func NotExists(subquery Expression) ExistsExpr { return ExistsExpr{Subquery: subquery, Negated: true} } +// Concat renders a SQL concatenation of multiple expressions or values. +func Concat(values ...any) ConcatExpr { + if len(values) < 2 { + panic("schema: Concat requires at least two values") + } + exprs := make([]Expression, 0, len(values)) + for _, v := range values { + exprs = append(exprs, wrapValue(v)) + } + return ConcatExpr{Exprs: exprs} +} + +func wrapValue(v any) Expression { + if expr, ok := v.(Expression); ok { + return expr + } + return ValueExpr{Value: v} +} + // IndexBuilder configures a table index. type IndexBuilder struct { table *TableDef @@ -1654,6 +1799,20 @@ func cloneExpressionForTable(expr Expression, table *TableDef) Expression { return cloned case AliasExpr: return AliasExpr{Expr: cloneExpressionForTable(value.Expr, table), Alias: value.Alias} + case BinaryExpr: + return BinaryExpr{ + Left: cloneExpressionForTable(value.Left, table), + Operator: value.Operator, + Right: cloneExpressionForTable(value.Right, table), + } + case ConcatExpr: + cloned := ConcatExpr{ + Exprs: make([]Expression, 0, len(value.Exprs)), + } + for _, expr := range value.Exprs { + cloned.Exprs = append(cloned.Exprs, cloneExpressionForTable(expr, table)) + } + return cloned case RawExpr: cloned := RawExpr{SQL: value.SQL, Args: make([]any, 0, len(value.Args))} for _, arg := range value.Args { From 208cf1bfc103bd712cce871a06abfb9cb2a70f15 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Wed, 17 Jun 2026 23:54:34 +0000 Subject: [PATCH 2/2] feat(rain): add support for arithmetic and concat expressions Implement arithmetic operators (+, -, *, /, %) and string concatenation in the Rain ORM query builder. - Add BinaryExpr and ConcatExpr to pkg/schema - Add Add, Sub, Mul, Div, Mod methods to Column[T] and AnyColumn - Support expression chaining (e.g., col.Add(1).Mul(2)) - Implement dialect-aware concatenation (|| for Postgres/SQLite, CONCAT() for MySQL) - Support aliasing via .As() on new expression types - Update query compiler to render expressions with correct precedence - Add operator validation allowlist to prevent SQL injection - Add unit and SQLite integration tests covering all operators Co-authored-by: cungminh2710 <8063319+cungminh2710@users.noreply.github.com> --- pkg/rain/query_compile.go | 12 ++++ pkg/rain/sqlite_integration_test.go | 94 ++++++++++++++++++++--------- 2 files changed, 76 insertions(+), 30 deletions(-) diff --git a/pkg/rain/query_compile.go b/pkg/rain/query_compile.go index 69a901a..de5896f 100644 --- a/pkg/rain/query_compile.go +++ b/pkg/rain/query_compile.go @@ -370,6 +370,12 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex c.args = append(c.args, nil) c.writeString(c.dialect.Placeholder(index)) case schema.ComparisonExpr: + switch value.Operator { + case "=", "<>", ">", ">=", "<", "<=", "LIKE", "NOT LIKE", "ILIKE", "NOT ILIKE": + // ok + default: + return fmt.Errorf("rain: invalid comparison operator %q", value.Operator) + } if err := c.writeExpression(value.Left); err != nil { return err } @@ -380,6 +386,12 @@ func (c *compileContext) writeExpressionInContext(expr schema.Expression, contex return err } case schema.BinaryExpr: + switch value.Operator { + case "+", "-", "*", "/", "%": + // ok + default: + return fmt.Errorf("rain: invalid binary operator %q", value.Operator) + } c.writeByte('(') if err := c.writeExpression(value.Left); err != nil { return err diff --git a/pkg/rain/sqlite_integration_test.go b/pkg/rain/sqlite_integration_test.go index 1a3dc73..595f911 100644 --- a/pkg/rain/sqlite_integration_test.go +++ b/pkg/rain/sqlite_integration_test.go @@ -1931,6 +1931,29 @@ func TestSQLiteIntegrationFirst(t *testing.T) { t.Fatalf("expected bob@example.com, got %q", row.Email) } }) + + t.Run("ErrNoRows", func(t *testing.T) { + var row sqliteUserRow + err := db.Select(). + Table(users). + Where(users.Email.Eq("missing@example.com")). + First(ctx, &row) + + if !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("expected sql.ErrNoRows, got %v", err) + } + }) + + t.Run("RejectSlice", func(t *testing.T) { + var rows []sqliteUserRow + err := db.Select(). + Table(users). + First(ctx, &rows) + + if err == nil || !strings.Contains(err.Error(), "First destination must be a non-nil pointer to a struct") { + t.Fatalf("expected error rejecting slice, got %v", err) + } + }) } func TestSQLiteIntegrationArithmeticAndConcat(t *testing.T) { @@ -1943,7 +1966,7 @@ func TestSQLiteIntegrationArithmeticAndConcat(t *testing.T) { // Seed data if _, err := db.Insert().Table(users).Values( - map[schema.ColumnReference]any{users.ID: 1, users.Email: "alice@example.com", users.Name: "Alice", users.Active: true}, + map[schema.ColumnReference]any{users.ID: 10, users.Email: "alice@example.com", users.Name: "Alice", users.Active: true}, ).Exec(ctx); err != nil { t.Fatalf("seed data failed: %v", err) } @@ -1952,19 +1975,53 @@ func TestSQLiteIntegrationArithmeticAndConcat(t *testing.T) { var results []struct { Val int64 `db:"val"` } - if err := db.Select(users.ID.Add(int64(10)).As("val")).From(users).Scan(ctx, &results); err != nil { + + // Add: 10 + 5 = 15 + if err := db.Select(users.ID.Add(int64(5)).As("val")).From(users).Scan(ctx, &results); err != nil { t.Fatalf("Add failed: %v", err) } - if len(results) == 0 || results[0].Val != 11 { - t.Fatalf("expected 11, got %+v", results) + if len(results) == 0 || results[0].Val != 15 { + t.Fatalf("expected 15, got %+v", results) + } + + // Sub: 10 - 3 = 5 + if err := db.Select(users.ID.Sub(int64(3)).As("val")).From(users).Scan(ctx, &results); err != nil { + t.Fatalf("Sub failed: %v", err) + } + if len(results) == 0 || results[0].Val != 7 { + t.Fatalf("expected 7, got %+v", results) + } + + // Mul: 10 * 2 = 20 + if err := db.Select(users.ID.Mul(int64(2)).As("val")).From(users).Scan(ctx, &results); err != nil { + t.Fatalf("Mul failed: %v", err) + } + if len(results) == 0 || results[0].Val != 20 { + t.Fatalf("expected 20, got %+v", results) + } + + // Div: 10 / 3 = 3 (integer division) + if err := db.Select(users.ID.Div(int64(3)).As("val")).From(users).Scan(ctx, &results); err != nil { + t.Fatalf("Div failed: %v", err) + } + if len(results) == 0 || results[0].Val != 3 { + t.Fatalf("expected 3, got %+v", results) } - results = nil + // Mod: 10 % 3 = 1 + if err := db.Select(users.ID.Mod(int64(3)).As("val")).From(users).Scan(ctx, &results); err != nil { + t.Fatalf("Mod failed: %v", err) + } + if len(results) == 0 || results[0].Val != 1 { + t.Fatalf("expected 1, got %+v", results) + } + + // Nested: (10 + 10) * 2 = 40 if err := db.Select(users.ID.Add(int64(10)).Mul(int64(2)).As("val")).From(users).Scan(ctx, &results); err != nil { t.Fatalf("Nested arithmetic failed: %v", err) } - if len(results) == 0 || results[0].Val != 22 { - t.Fatalf("expected 22, got %+v", results) + if len(results) == 0 || results[0].Val != 40 { + t.Fatalf("expected 40, got %+v", results) } }) @@ -1984,27 +2041,4 @@ func TestSQLiteIntegrationArithmeticAndConcat(t *testing.T) { t.Fatalf("expected %q, got %q", expected, result) } }) - - t.Run("ErrNoRows", func(t *testing.T) { - var row sqliteUserRow - err := db.Select(). - Table(users). - Where(users.Email.Eq("missing@example.com")). - First(ctx, &row) - - if !errors.Is(err, sql.ErrNoRows) { - t.Fatalf("expected sql.ErrNoRows, got %v", err) - } - }) - - t.Run("RejectSlice", func(t *testing.T) { - var rows []sqliteUserRow - err := db.Select(). - Table(users). - First(ctx, &rows) - - if err == nil || !strings.Contains(err.Error(), "First destination must be a non-nil pointer to a struct") { - t.Fatalf("expected error rejecting slice, got %v", err) - } - }) }