diff --git a/pkg/rain/query_select.go b/pkg/rain/query_select.go index 0318ccf..8072645 100644 --- a/pkg/rain/query_select.go +++ b/pkg/rain/query_select.go @@ -184,6 +184,7 @@ func (q *SelectQuery) WithRelations(names ...string) *SelectQuery { type RelationConfig struct { Where schema.Predicate OrderBy []schema.OrderExpr + Columns []schema.Expression } // Relation configures filters and ordering for a named relation. @@ -629,6 +630,12 @@ func (q *SelectQuery) writeJoins(ctx *compileContext) error { return nil } +// First executes the SELECT query with an implicit LIMIT 1 and scans the first row into dest. +// Returns sql.ErrNoRows if no result is found. +func (q *SelectQuery) First(ctx context.Context, dest any) error { + return q.clone().Limit(1).Scan(ctx, dest) +} + // Scan executes the SELECT query and scans results into dest. func (q *SelectQuery) Scan(ctx context.Context, dest any) error { if q.runner == nil { diff --git a/pkg/rain/rain.go b/pkg/rain/rain.go index 93db036..2535093 100644 --- a/pkg/rain/rain.go +++ b/pkg/rain/rain.go @@ -279,12 +279,17 @@ func (db *DB) QueryRow(ctx context.Context, query string, args ...any) *sql.Row // Begin starts a new transaction. func (db *DB) Begin(ctx context.Context) (*Tx, error) { + return db.BeginTx(ctx, nil) +} + +// BeginTx starts a new transaction with the provided options. +func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { primary := db.primaryHandle() if primary.db == nil { return nil, ErrNoConnection } - tx, err := primary.db.BeginTx(ctx, nil) + tx, err := primary.db.BeginTx(ctx, opts) if err != nil { return nil, err } @@ -294,7 +299,12 @@ func (db *DB) Begin(ctx context.Context) (*Tx, error) { // RunInTx executes fn in a transaction, rolling back on error and committing on success. func (db *DB) RunInTx(ctx context.Context, fn func(*Tx) error) error { - tx, err := db.Begin(ctx) + return db.RunInTxOpts(ctx, nil, fn) +} + +// RunInTxOpts executes fn in a transaction with the provided options, rolling back on error and committing on success. +func (db *DB) RunInTxOpts(ctx context.Context, opts *sql.TxOptions, fn func(*Tx) error) error { + tx, err := db.BeginTx(ctx, opts) if err != nil { return err } diff --git a/pkg/rain/relation_loading.go b/pkg/rain/relation_loading.go index 0ab7978..d34d277 100644 --- a/pkg/rain/relation_loading.go +++ b/pkg/rain/relation_loading.go @@ -295,6 +295,7 @@ func (q *SelectQuery) loadRelatedManyToManyRows( batchDest := reflect.New(reflect.SliceOf(relatedElemType)) targetQuery := &SelectQuery{runner: q.runner, dialect: q.dialect, table: tableDefSource{table: relation.TargetTable}} + ensureTargetColumnSelected(targetQuery, config.Columns, relation.TargetColumn) if config.Where != nil { targetQuery.Where(config.Where) } @@ -376,6 +377,7 @@ func (q *SelectQuery) loadRelatedRows( end := min(start+relationBatchSize, len(sourceKeys)) batchDest := reflect.New(reflect.SliceOf(relatedElemType)) query := &SelectQuery{runner: q.runner, dialect: q.dialect, table: tableDefSource{table: relation.TargetTable}} + ensureTargetColumnSelected(query, config.Columns, relation.TargetColumn) if config.Where != nil { query.Where(config.Where) } @@ -544,6 +546,26 @@ func setRelationValue(parent reflect.Value, relationName string, relationType sc } } +func ensureTargetColumnSelected(query *SelectQuery, columns []schema.Expression, targetCol *schema.ColumnDef) { + if len(columns) == 0 { + return + } + + query.Column(columns...) + + // Ensure mapping target column is selected. + found := false + for _, col := range columns { + if cr, ok := col.(schema.ColumnReference); ok && cr.ColumnDef() == targetCol { + found = true + break + } + } + if !found { + query.Column(schema.Ref(targetCol)) + } +} + func dereferenceModelValue(value reflect.Value) reflect.Value { current := value for current.IsValid() && current.Kind() == reflect.Pointer {