diff --git a/pkg/rain/query_select.go b/pkg/rain/query_select.go index 0318ccf..b64193c 100644 --- a/pkg/rain/query_select.go +++ b/pkg/rain/query_select.go @@ -184,6 +184,7 @@ func (q *SelectQuery) WithRelations(names ...string) *SelectQuery { type RelationConfig struct { Where schema.Predicate OrderBy []schema.OrderExpr + Columns []schema.Expression } // Relation configures filters and ordering for a named relation. @@ -629,6 +630,12 @@ func (q *SelectQuery) writeJoins(ctx *compileContext) error { return nil } +// First executes the SELECT query with an implicit LIMIT 1 and scans the result into dest. +// If no rows are found, it returns sql.ErrNoRows. +func (q *SelectQuery) First(ctx context.Context, dest any) error { + return q.clone().Limit(1).Scan(ctx, dest) +} + // Scan executes the SELECT query and scans results into dest. func (q *SelectQuery) Scan(ctx context.Context, dest any) error { if q.runner == nil { diff --git a/pkg/rain/query_select_test.go b/pkg/rain/query_select_test.go index 69e613a..55380ba 100644 --- a/pkg/rain/query_select_test.go +++ b/pkg/rain/query_select_test.go @@ -117,6 +117,27 @@ func TestSelectErgonomicsToSQL(t *testing.T) { } } +func TestSelectFirst(t *testing.T) { + t.Parallel() + + db, err := rain.OpenDialect("postgres") + if err != nil { + t.Fatalf("OpenDialect returned error: %v", err) + } + users, _ := defineTables() + + // First uses a clone with LIMIT 1, verify SQL rendering + sqlText, _, err := db.Select().Table(users).Column(users.ID, users.Email).Where(users.ID.Eq(int64(1))).Limit(1).ToSQL() + if err != nil { + t.Fatalf("ToSQL returned error: %v", err) + } + + wantSQL := `SELECT "users"."id", "users"."email" FROM "users" WHERE "users"."id" = $1 LIMIT 1` + if sqlText != wantSQL { + t.Fatalf("unexpected SQL:\nwant: %s\ngot: %s", wantSQL, sqlText) + } +} + func TestSelectJoinsToSQL(t *testing.T) { t.Parallel() diff --git a/pkg/rain/relation_loading.go b/pkg/rain/relation_loading.go index 0ab7978..4574082 100644 --- a/pkg/rain/relation_loading.go +++ b/pkg/rain/relation_loading.go @@ -295,6 +295,10 @@ func (q *SelectQuery) loadRelatedManyToManyRows( batchDest := reflect.New(reflect.SliceOf(relatedElemType)) targetQuery := &SelectQuery{runner: q.runner, dialect: q.dialect, table: tableDefSource{table: relation.TargetTable}} + if len(config.Columns) > 0 { + targetQuery.Column(config.Columns...) + ensureTargetColumnSelected(targetQuery, relation.TargetColumn) + } if config.Where != nil { targetQuery.Where(config.Where) } @@ -376,6 +380,10 @@ func (q *SelectQuery) loadRelatedRows( end := min(start+relationBatchSize, len(sourceKeys)) batchDest := reflect.New(reflect.SliceOf(relatedElemType)) query := &SelectQuery{runner: q.runner, dialect: q.dialect, table: tableDefSource{table: relation.TargetTable}} + if len(config.Columns) > 0 { + query.Column(config.Columns...) + ensureTargetColumnSelected(query, relation.TargetColumn) + } if config.Where != nil { query.Where(config.Where) } @@ -544,6 +552,20 @@ func setRelationValue(parent reflect.Value, relationName string, relationType sc } } +func ensureTargetColumnSelected(q *SelectQuery, targetColumn *schema.ColumnDef) { + if len(q.cols) == 0 { + return + } + + for _, col := range q.cols { + if ref, ok := col.(schema.ColumnReference); ok && ref.ColumnDef().Name == targetColumn.Name && ref.ColumnDef().Table.Name == targetColumn.Table.Name { + return + } + } + + q.Column(schema.Ref(targetColumn)) +} + func dereferenceModelValue(value reflect.Value) reflect.Value { current := value for current.IsValid() && current.Kind() == reflect.Pointer { diff --git a/pkg/rain/sqlite_integration_test.go b/pkg/rain/sqlite_integration_test.go index 1a9386a..6fdccc6 100644 --- a/pkg/rain/sqlite_integration_test.go +++ b/pkg/rain/sqlite_integration_test.go @@ -1901,3 +1901,89 @@ func TestSQLiteIntegrationUpsertFilters(t *testing.T) { t.Fatalf("expected Name to be Custom Updated, got %q", row.Name) } } + +func TestSQLiteIntegrationSelectiveRelationColumns(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openSQLiteTestDB(t) + fixture := defineSQLiteRichTables() + createSQLiteRichSchema(t, ctx, db, fixture) + seeded := seedSQLiteRichFixture(t, ctx, db, fixture) + + type postSubset struct { + ID int64 `db:"id"` + Title string `db:"title"` + UserID int64 `db:"user_id"` + } + + type userWithPostSubset struct { + ID int64 `db:"id"` + Email string `db:"email"` + Posts []postSubset `rain:"relation:posts"` + } + + var results []userWithPostSubset + err := db.Select(). + Table(fixture.users). + Where(fixture.users.ID.Eq(seeded.AliceID)). + Relation("posts", rain.RelationConfig{ + Columns: []schema.Expression{fixture.posts.ID, fixture.posts.Title}, + OrderBy: []schema.OrderExpr{fixture.posts.ID.Asc()}, + }). + Scan(ctx, &results) + if err != nil { + t.Fatalf("scan with selective relation columns failed: %v", err) + } + + if len(results) != 1 { + t.Fatalf("expected 1 user, got %d", len(results)) + } + + alice := results[0] + if len(alice.Posts) != 2 { + t.Fatalf("expected 2 posts for Alice, got %d", len(alice.Posts)) + } + + // Verify that the selective columns were loaded. + // Since we're using a subset struct, only those fields are present. + // The mapping relies on UserID, which we didn't explicitly request in RelationConfig.Columns, + // but the ORM should have included it automatically. + if alice.Posts[0].Title == "" { + t.Fatalf("expected post title to be populated") + } +} + +func TestSQLiteIntegrationFirst(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := openSQLiteTestDB(t) + fixture := defineSQLiteRichTables() + createSQLiteRichSchema(t, ctx, db, fixture) + _ = seedSQLiteRichFixture(t, ctx, db, fixture) + + var user sqliteRichAuthorRow + err := db.Select(). + Table(fixture.users). + OrderBy(fixture.users.ID.Asc()). + First(ctx, &user) + if err != nil { + t.Fatalf("First failed: %v", err) + } + + if user.Email != "alice@example.com" { + t.Fatalf("expected alice@example.com, got %q", user.Email) + } + + // Test ErrNoRows + var none sqliteRichAuthorRow + err = db.Select(). + Table(fixture.users). + Where(fixture.users.ID.Eq(int64(99999))). + First(ctx, &none) + + if !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("expected sql.ErrNoRows, got %v", err) + } +}