Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/idtoken"
"google.golang.org/api/impersonate"
)

type contextKey string
Expand Down Expand Up @@ -66,6 +67,8 @@ var (
"handle http2 requests (allows grpc calls)")
flagAuthorizationHeader = flag.String("authorization-header", "X-Serverless-Authorization",
"header to provide the bearer token")
flagImpersonateServiceAccount = flag.String("impersonate-service-account", "",
"impersonate the specified service account to obtain ID tokens")
)

func main() {
Expand Down Expand Up @@ -130,7 +133,7 @@ func realMain(ctx context.Context) error {

// Get the best token source. Cloud Run expects the audience parameter to be
// the URL of the service.
tokenSource, err := findTokenSource(ctx, *flagToken, audience)
tokenSource, err := findTokenSource(ctx, *flagToken, audience, *flagImpersonateServiceAccount)
if err != nil {
return fmt.Errorf("failed to find token source: %w", err)
}
Expand Down Expand Up @@ -299,15 +302,28 @@ func createServer(bind *url.URL, proxy *httputil.ReverseProxy, enableHttp2 bool)

// findTokenSource fetches the reusable/cached oauth2 token source. If rawToken
// is provided, that token is used as a static value and the audience parameter
// is ignored. Othwerise, this attempts to get the renewable token from the
// is ignored. Otherwise, this attempts to get the renewable token from the
// environment (via Application Default Credentials).
func findTokenSource(ctx context.Context, rawToken, audience string) (oauth2.TokenSource, error) {
func findTokenSource(ctx context.Context, rawToken, audience, impersonateServiceAccount string) (oauth2.TokenSource, error) {
// Prefer supplied value, usually from the flag.
if rawToken != "" {
token := &oauth2.Token{AccessToken: rawToken}
return oauth2.StaticTokenSource(token), nil
}

// If impersonation is requested, use the impersonate package to get ID tokens.
if impersonateServiceAccount != "" {
tokenSource, err := impersonate.IDTokenSource(ctx, impersonate.IDTokenConfig{
TargetPrincipal: impersonateServiceAccount,
Audience: audience,
IncludeEmail: true,
})
if err != nil {
return nil, fmt.Errorf("failed to create impersonated token source: %w", err)
}
return tokenSource, nil
}

// Try to use the idtoken package, which will use the metadata service.
// However, the idtoken package does not work with gcloud's ADC, so we need to
// handle that case by falling back to default ADC search. However, the
Expand All @@ -316,7 +332,7 @@ func findTokenSource(ctx context.Context, rawToken, audience string) (oauth2.Tok
tokenSource, err := idtoken.NewTokenSource(ctx, audience)
if err != nil {
// Return any unexpected error.
if !strings.Contains(err.Error(), "credential must be service_account") {
if !strings.Contains(err.Error(), "unsupported credentials type") {
return nil, fmt.Errorf("failed to get idtoken source: %w", err)
}

Expand Down
11 changes: 5 additions & 6 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestBuildProxy(t *testing.T) {
Host: fmt.Sprintf("localhost:%d", testRandomPort(t)),
}

src, err := findTokenSource(ctx, "mytoken", "aud")
src, err := findTokenSource(ctx, "mytoken", "aud", "")
if err != nil {
t.Fatal(err)
}
Expand All @@ -80,7 +80,7 @@ func TestBuildProxy(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
proxy.ServeHTTP(w, r)
if !called{
if !called {
t.Errorf("handler not called")
}
})
Expand All @@ -94,7 +94,7 @@ func TestFindTokenSource(t *testing.T) {
t.Run("static", func(t *testing.T) {
t.Parallel()

src, err := findTokenSource(ctx, "mytoken", "aud")
src, err := findTokenSource(ctx, "mytoken", "aud", "")
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -184,7 +184,6 @@ func TestHttp2(t *testing.T) {
called = true
})


srv := httptest.NewUnstartedServer(mux)
srv.EnableHTTP2 = true
srv.StartTLS()
Expand All @@ -199,7 +198,7 @@ func TestHttp2(t *testing.T) {
Host: fmt.Sprintf("localhost:%d", testRandomPort(t)),
}

src, err := findTokenSource(ctx, "mytoken", "aud")
src, err := findTokenSource(ctx, "mytoken", "aud", "")
if err != nil {
t.Fatal(err)
}
Expand All @@ -214,7 +213,7 @@ func TestHttp2(t *testing.T) {

proxy.ServeHTTP(w, r)

if !called{
if !called {
t.Errorf("handler not called")
}
defer srv.Close()
Expand Down