Skip to content
10 changes: 4 additions & 6 deletions internal/api/v1beta1connect/audit_record.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ func (h *ConnectHandler) ListAuditRecords(ctx context.Context, request *connect.
case errors.Is(err, auditrecord.ErrNotFound):
return nil, connect.NewError(connect.CodeNotFound, err)
default:
errorLogger.LogUnexpectedError(ctx, request, "ListAuditRecords", err)
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("ListAuditRecords: %w", err))
}
}
pbRecords := make([]*frontierv1beta1.AuditRecord, 0)
Expand Down Expand Up @@ -191,8 +190,7 @@ func (h *ConnectHandler) ExportAuditRecords(ctx context.Context, request *connec
case errors.Is(err, auditrecord.ErrNotFound):
return connect.NewError(connect.CodeNotFound, err)
default:
errorLogger.LogUnexpectedError(ctx, request, "ExportAuditRecords", err)
return connect.NewError(connect.CodeInternal, ErrInternalServerError)
return connect.NewError(connect.CodeInternal, fmt.Errorf("ExportAuditRecords: %w", err))
}
}
// Stream the data using io.Reader
Expand Down Expand Up @@ -273,15 +271,15 @@ func streamReaderInChunks(reader io.Reader, contentType string, stream *connect.
break
}
if err != nil {
return connect.NewError(connect.CodeInternal, ErrInternalServerError)
return connect.NewError(connect.CodeInternal, fmt.Errorf("streamReaderInChunks: %w", err))
}
if n > 0 {
msg := &httpbody.HttpBody{
ContentType: contentType,
Data: buffer[:n],
}
if sendErr := stream.Send(msg); sendErr != nil {
return connect.NewError(connect.CodeInternal, ErrInternalServerError)
return connect.NewError(connect.CodeInternal, fmt.Errorf("streamReaderInChunks: %w", sendErr))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/api/v1beta1connect/audit_record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1449,7 +1449,7 @@ func TestHandler_ListAuditRecords(t *testing.T) {
},
}),
want: nil,
wantErr: connect.NewError(connect.CodeInternal, ErrInternalServerError),
wantErr: connect.NewError(connect.CodeInternal, errors.New("database connection failed")),
},
{
name: "should return invalid argument error for bad input",
Expand Down
53 changes: 12 additions & 41 deletions internal/api/v1beta1connect/authenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ import (
)

func (h *ConnectHandler) Authenticate(ctx context.Context, request *connect.Request[frontierv1beta1.AuthenticateRequest]) (*connect.Response[frontierv1beta1.AuthenticateResponse], error) {
errorLogger := NewErrorLogger()

returnToURL := h.authnService.SanitizeReturnToURL(request.Msg.GetReturnTo())
callbackURL := h.authnService.SanitizeCallbackURL(request.Msg.GetCallbackUrl())

Expand All @@ -41,8 +39,7 @@ func (h *ConnectHandler) Authenticate(ctx context.Context, request *connect.Requ
}
return resp, nil
} else if err != nil && !errors.Is(err, frontiersession.ErrNoSession) {
errorLogger.LogUnexpectedError(ctx, request, "Authenticate", err)
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("Authenticate: %w", err))
}

if (request.Msg.GetStrategyName() == authenticate.MailLinkAuthMethod.String() || request.Msg.GetStrategyName() == authenticate.MailOTPAuthMethod.String()) && !isValidEmail(request.Msg.GetEmail()) {
Expand All @@ -57,10 +54,7 @@ func (h *ConnectHandler) Authenticate(ctx context.Context, request *connect.Requ
Email: request.Msg.GetEmail(),
})
if err != nil {
errorLogger.LogUnexpectedError(ctx, request, "Authenticate", err,
"strategy", request.Msg.GetStrategyName(),
"email", request.Msg.GetEmail())
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("Authenticate: strategy=%s email=%s: %w", request.Msg.GetStrategyName(), request.Msg.GetEmail(), err))
}

// set location header for redirect to start auth?
Expand All @@ -76,16 +70,12 @@ func (h *ConnectHandler) Authenticate(ctx context.Context, request *connect.Requ
if request.Msg.GetStrategyName() == authenticate.PassKeyAuthMethod.String() {
userCredentils := &structpb.Value{}
if err = userCredentils.UnmarshalJSON(response.StateConfig["options"].([]byte)); err != nil {
errorLogger.LogUnexpectedError(ctx, request, "Authenticate", err,
"strategy", authenticate.PassKeyAuthMethod.String())
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("Authenticate: strategy=%s: %w", authenticate.PassKeyAuthMethod.String(), err))
}
typeValue, ok := response.Flow.Metadata["passkey_type"].(string)
if !ok {
err = fmt.Errorf("passkey_type metadata is not a string")
errorLogger.LogUnexpectedError(ctx, request, "Authenticate", err,
"strategy", authenticate.PassKeyAuthMethod.String())
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("Authenticate: strategy=%s: %w", authenticate.PassKeyAuthMethod.String(), err))
}
stringValue := &structpb.Value{
Kind: &structpb.Value_StringValue{
Expand Down Expand Up @@ -123,10 +113,7 @@ func (h *ConnectHandler) AuthCallback(ctx context.Context, request *connect.Requ
"state", request.Msg.GetState())
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
errorLogger.LogUnexpectedError(ctx, request, "AuthCallback", err,
"strategy", request.Msg.GetStrategyName(),
"state", request.Msg.GetState())
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("AuthCallback: strategy=%s state=%s: %w", request.Msg.GetStrategyName(), request.Msg.GetState(), err))
}

// Extract session metadata from request headers
Expand All @@ -135,9 +122,7 @@ func (h *ConnectHandler) AuthCallback(ctx context.Context, request *connect.Requ
// registration/login complete, build a session
session, err := h.sessionService.Create(ctx, response.User.ID, sessionMetadata)
if err != nil {
errorLogger.LogUnexpectedError(ctx, request, "AuthCallback", err,
"user_id", response.User.ID)
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("AuthCallback: user_id=%s: %w", response.User.ID, err))
}
// create response and set headers
resp := connect.NewResponse(&frontierv1beta1.AuthCallbackResponse{})
Expand Down Expand Up @@ -182,9 +167,7 @@ func (h *ConnectHandler) AuthToken(ctx context.Context, request *connect.Request
authenticate.ClientCredentialsClientAssertion,
authenticate.JWTGrantClientAssertion)
if err != nil {
errorLogger.LogUnexpectedError(ctx, request, "AuthToken", err,
"grant_type", request.Msg.GetGrantType())
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("AuthToken: grant_type=%s: %w", request.Msg.GetGrantType(), err))
}

if principal.Type == schema.ServiceUserPrincipal {
Expand All @@ -197,16 +180,13 @@ func (h *ConnectHandler) AuthToken(ctx context.Context, request *connect.Request
if errors.Is(err, organization.ErrDisabled) {
return nil, connect.NewError(connect.CodeFailedPrecondition, ErrOrgDisabled)
}
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("AuthToken.orgService.Get: org_id=%s service_user_id=%s: %w", orgId, principal.ServiceUser.ID, err))
}
}

token, err := h.getAccessToken(ctx, principal, request.Header().Values(consts.ProjectRequestKey), request)
if err != nil {
errorLogger.LogUnexpectedError(ctx, request, "AuthToken", err,
"principal_id", principal.ID,
"principal_type", principal.Type)
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("AuthToken: principal_id=%s principal_type=%s: %w", principal.ID, principal.Type, err))
}

resp := connect.NewResponse(&frontierv1beta1.AuthTokenResponse{
Expand Down Expand Up @@ -270,15 +250,11 @@ func (h *ConnectHandler) getAccessToken(ctx context.Context, principal authentic
}

func (h *ConnectHandler) AuthLogout(ctx context.Context, request *connect.Request[frontierv1beta1.AuthLogoutRequest]) (*connect.Response[frontierv1beta1.AuthLogoutResponse], error) {
errorLogger := NewErrorLogger()

// delete user session if exists
sessionID, err := h.getLoggedInSessionID(ctx)
if err == nil {
if err = h.sessionService.Delete(ctx, sessionID); err != nil {
errorLogger.LogUnexpectedError(ctx, request, "AuthLogout", err,
"session_id", sessionID.String())
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("AuthLogout: session_id=%s: %w", sessionID.String(), err))
}
}

Expand All @@ -298,7 +274,6 @@ func (h *ConnectHandler) getLoggedInSessionID(ctx context.Context) (uuid.UUID, e
}

func (h *ConnectHandler) GetLoggedInPrincipal(ctx context.Context, via ...authenticate.ClientAssertion) (authenticate.Principal, error) {
errorLogger := NewErrorLogger()
principal, err := h.authnService.GetPrincipal(ctx, via...)
if err != nil {
switch {
Expand All @@ -312,8 +287,7 @@ func (h *ConnectHandler) GetLoggedInPrincipal(ctx context.Context, via ...authen
errors.Is(err, patErrors.ErrDisabled):
return principal, connect.NewError(connect.CodeUnauthenticated, ErrUnauthenticated)
default:
errorLogger.LogUnexpectedError(ctx, nil, "GetLoggedInPrincipal", err)
return principal, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return principal, connect.NewError(connect.CodeInternal, fmt.Errorf("GetLoggedInPrincipal: %w", err))
}
}
return principal, nil
Expand All @@ -331,13 +305,10 @@ func (h *ConnectHandler) ListAuthStrategies(ctx context.Context, request *connec
}

func (h *ConnectHandler) GetJWKs(ctx context.Context, request *connect.Request[frontierv1beta1.GetJWKsRequest]) (*connect.Response[frontierv1beta1.GetJWKsResponse], error) {
errorLogger := NewErrorLogger()

keySet := h.authnService.JWKs(ctx)
jwks, err := toJSONWebKey(keySet)
if err != nil {
errorLogger.LogUnexpectedError(ctx, request, "GetJWKs", err)
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("GetJWKs: %w", err))
}
return connect.NewResponse(&frontierv1beta1.GetJWKsResponse{
Keys: jwks.Keys,
Expand Down
12 changes: 3 additions & 9 deletions internal/api/v1beta1connect/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func handleAuthErr(err error) error {
errors.Is(err, resource.ErrNotExist):
return connect.NewError(connect.CodeNotFound, ErrNotFound)
default:
return connect.NewError(connect.CodeInternal, ErrInternalServerError)
return connect.NewError(connect.CodeInternal, fmt.Errorf("handleAuthErr: %w", err))
}
}

Expand All @@ -112,19 +112,13 @@ func (h *ConnectHandler) IsSuperUser(ctx context.Context, request connect.AnyReq
return connect.NewError(connect.CodePermissionDenied, ErrUnauthorized)
case schema.UserPrincipal:
if ok, err := h.userService.IsSudo(ctx, currentUser.ID, schema.PlatformSudoPermission); err != nil {
errorLogger.LogUnexpectedError(ctx, request, "IsSuperUser", err,
"user_id", currentUser.ID,
"permission", schema.PlatformSudoPermission)
return connect.NewError(connect.CodeInternal, ErrInternalServerError)
return connect.NewError(connect.CodeInternal, fmt.Errorf("IsSuperUser: user_id=%s permission=%s: %w", currentUser.ID, schema.PlatformSudoPermission, err))
} else if ok {
return nil
}
case schema.ServiceUserPrincipal:
if ok, err := h.serviceUserService.IsSudo(ctx, currentUser.ID, schema.PlatformSudoPermission); err != nil {
errorLogger.LogUnexpectedError(ctx, request, "IsSuperUser", err,
"service_user_id", currentUser.ID,
"permission", schema.PlatformSudoPermission)
return connect.NewError(connect.CodeInternal, ErrInternalServerError)
return connect.NewError(connect.CodeInternal, fmt.Errorf("IsSuperUser: service_user_id=%s permission=%s: %w", currentUser.ID, schema.PlatformSudoPermission, err))
} else if ok {
return nil
}
Expand Down
29 changes: 6 additions & 23 deletions internal/api/v1beta1connect/billing_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ package v1beta1connect
import (
"context"
"errors"
"fmt"

"connectrpc.com/connect"
"github.com/raystack/frontier/billing/customer"
frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1"
)

func (h *ConnectHandler) CheckFeatureEntitlement(ctx context.Context, request *connect.Request[frontierv1beta1.CheckFeatureEntitlementRequest]) (*connect.Response[frontierv1beta1.CheckFeatureEntitlementResponse], error) {
errorLogger := NewErrorLogger()

// Always infer billing_id from org_id
cust, err := h.customerService.GetByOrgID(ctx, request.Msg.GetOrgId())
if err != nil {
Expand All @@ -21,18 +20,12 @@ func (h *ConnectHandler) CheckFeatureEntitlement(ctx context.Context, request *c
if errors.Is(err, customer.ErrInvalidUUID) || errors.Is(err, customer.ErrInvalidID) {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
errorLogger.LogServiceError(ctx, request, "CheckFeatureEntitlement.GetByOrgID", err,
"org_id", request.Msg.GetOrgId())
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("CheckFeatureEntitlement.GetByOrgID: org_id=%s: %w", request.Msg.GetOrgId(), err))
}

checkStatus, err := h.entitlementService.Check(ctx, cust.ID, request.Msg.GetFeature())
if err != nil {
errorLogger.LogServiceError(ctx, request, "CheckFeatureEntitlement", err,
"billing_id", cust.ID,
"org_id", request.Msg.GetOrgId(),
"feature", request.Msg.GetFeature())
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("CheckFeatureEntitlement: billing_id=%s org_id=%s feature=%s: %w", cust.ID, request.Msg.GetOrgId(), request.Msg.GetFeature(), err))
}

return connect.NewResponse(&frontierv1beta1.CheckFeatureEntitlementResponse{
Expand All @@ -41,15 +34,11 @@ func (h *ConnectHandler) CheckFeatureEntitlement(ctx context.Context, request *c
}

func (h *ConnectHandler) CheckCreditEntitlement(ctx context.Context, request *connect.Request[frontierv1beta1.CheckCreditEntitlementRequest]) (*connect.Response[frontierv1beta1.CheckCreditEntitlementResponse], error) {
errorLogger := NewErrorLogger()

customerList, err := h.customerService.List(ctx, customer.Filter{
OrgID: request.Msg.GetOrgId(),
})
if err != nil {
errorLogger.LogServiceError(ctx, request, "CheckCreditEntitlement.List", err,
"org_id", request.Msg.GetOrgId())
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("CheckCreditEntitlement.List: org_id=%s: %w", request.Msg.GetOrgId(), err))
}

if len(customerList) == 0 {
Expand All @@ -59,18 +48,12 @@ func (h *ConnectHandler) CheckCreditEntitlement(ctx context.Context, request *co
customer := customerList[0]
customerDetails, err := h.customerService.GetDetails(ctx, customer.ID)
if err != nil {
errorLogger.LogServiceError(ctx, request, "CheckCreditEntitlement.GetDetails", err,
"customer_id", customer.ID,
"org_id", request.Msg.GetOrgId())
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("CheckCreditEntitlement.GetDetails: customer_id=%s org_id=%s: %w", customer.ID, request.Msg.GetOrgId(), err))
}

creditBalance, err := h.creditService.GetBalance(ctx, customer.ID)
if err != nil {
errorLogger.LogServiceError(ctx, request, "CheckCreditEntitlement.GetBalance", err,
"customer_id", customer.ID,
"org_id", request.Msg.GetOrgId())
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
return nil, connect.NewError(connect.CodeInternal, fmt.Errorf("CheckCreditEntitlement.GetBalance: customer_id=%s org_id=%s: %w", customer.ID, request.Msg.GetOrgId(), err))
}

if creditBalance-request.Msg.GetAmount() >= customerDetails.CreditMin {
Expand Down
8 changes: 4 additions & 4 deletions internal/api/v1beta1connect/billing_check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestConnectHandler_CheckFeatureEntitlement(t *testing.T) {
es.EXPECT().Check(mock.Anything, "billing-123", "feature-abc").Return(false, errors.New("service error"))
},
want: nil,
wantErr: ErrInternalServerError,
wantErr: errors.New("service error"),
errCode: connect.CodeInternal,
},
{
Expand Down Expand Up @@ -152,7 +152,7 @@ func TestConnectHandler_CheckCreditEntitlement(t *testing.T) {
}).Return(nil, errors.New("service error"))
},
want: nil,
wantErr: ErrInternalServerError,
wantErr: errors.New("service error"),
errCode: connect.CodeInternal,
},
{
Expand Down Expand Up @@ -183,7 +183,7 @@ func TestConnectHandler_CheckCreditEntitlement(t *testing.T) {
cs.EXPECT().GetDetails(mock.Anything, "customer-123").Return(customer.Details{}, errors.New("service error"))
},
want: nil,
wantErr: ErrInternalServerError,
wantErr: errors.New("service error"),
errCode: connect.CodeInternal,
},
{
Expand All @@ -202,7 +202,7 @@ func TestConnectHandler_CheckCreditEntitlement(t *testing.T) {
crs.EXPECT().GetBalance(mock.Anything, "customer-123").Return(int64(0), errors.New("service error"))
},
want: nil,
wantErr: ErrInternalServerError,
wantErr: errors.New("service error"),
errCode: connect.CodeInternal,
},
{
Expand Down
Loading
Loading