diff --git a/main.go b/main.go index 90d871e..aa98f7b 100644 --- a/main.go +++ b/main.go @@ -39,6 +39,7 @@ import ( "golang.org/x/oauth2" "golang.org/x/oauth2/google" "google.golang.org/api/idtoken" + "google.golang.org/api/impersonate" ) type contextKey string @@ -66,6 +67,8 @@ var ( "handle http2 requests (allows grpc calls)") flagAuthorizationHeader = flag.String("authorization-header", "X-Serverless-Authorization", "header to provide the bearer token") + flagImpersonateServiceAccount = flag.String("impersonate-service-account", "", + "impersonate the specified service account to obtain ID tokens") ) func main() { @@ -130,7 +133,7 @@ func realMain(ctx context.Context) error { // Get the best token source. Cloud Run expects the audience parameter to be // the URL of the service. - tokenSource, err := findTokenSource(ctx, *flagToken, audience) + tokenSource, err := findTokenSource(ctx, *flagToken, audience, *flagImpersonateServiceAccount) if err != nil { return fmt.Errorf("failed to find token source: %w", err) } @@ -299,15 +302,28 @@ func createServer(bind *url.URL, proxy *httputil.ReverseProxy, enableHttp2 bool) // findTokenSource fetches the reusable/cached oauth2 token source. If rawToken // is provided, that token is used as a static value and the audience parameter -// is ignored. Othwerise, this attempts to get the renewable token from the +// is ignored. Otherwise, this attempts to get the renewable token from the // environment (via Application Default Credentials). -func findTokenSource(ctx context.Context, rawToken, audience string) (oauth2.TokenSource, error) { +func findTokenSource(ctx context.Context, rawToken, audience, impersonateServiceAccount string) (oauth2.TokenSource, error) { // Prefer supplied value, usually from the flag. if rawToken != "" { token := &oauth2.Token{AccessToken: rawToken} return oauth2.StaticTokenSource(token), nil } + // If impersonation is requested, use the impersonate package to get ID tokens. + if impersonateServiceAccount != "" { + tokenSource, err := impersonate.IDTokenSource(ctx, impersonate.IDTokenConfig{ + TargetPrincipal: impersonateServiceAccount, + Audience: audience, + IncludeEmail: true, + }) + if err != nil { + return nil, fmt.Errorf("failed to create impersonated token source: %w", err) + } + return tokenSource, nil + } + // Try to use the idtoken package, which will use the metadata service. // However, the idtoken package does not work with gcloud's ADC, so we need to // handle that case by falling back to default ADC search. However, the @@ -316,7 +332,7 @@ func findTokenSource(ctx context.Context, rawToken, audience string) (oauth2.Tok tokenSource, err := idtoken.NewTokenSource(ctx, audience) if err != nil { // Return any unexpected error. - if !strings.Contains(err.Error(), "credential must be service_account") { + if !strings.Contains(err.Error(), "unsupported credentials type") { return nil, fmt.Errorf("failed to get idtoken source: %w", err) } diff --git a/main_test.go b/main_test.go index 94b1717..688b35f 100644 --- a/main_test.go +++ b/main_test.go @@ -67,7 +67,7 @@ func TestBuildProxy(t *testing.T) { Host: fmt.Sprintf("localhost:%d", testRandomPort(t)), } - src, err := findTokenSource(ctx, "mytoken", "aud") + src, err := findTokenSource(ctx, "mytoken", "aud", "") if err != nil { t.Fatal(err) } @@ -80,7 +80,7 @@ func TestBuildProxy(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", nil) proxy.ServeHTTP(w, r) - if !called{ + if !called { t.Errorf("handler not called") } }) @@ -94,7 +94,7 @@ func TestFindTokenSource(t *testing.T) { t.Run("static", func(t *testing.T) { t.Parallel() - src, err := findTokenSource(ctx, "mytoken", "aud") + src, err := findTokenSource(ctx, "mytoken", "aud", "") if err != nil { t.Fatal(err) } @@ -184,7 +184,6 @@ func TestHttp2(t *testing.T) { called = true }) - srv := httptest.NewUnstartedServer(mux) srv.EnableHTTP2 = true srv.StartTLS() @@ -199,7 +198,7 @@ func TestHttp2(t *testing.T) { Host: fmt.Sprintf("localhost:%d", testRandomPort(t)), } - src, err := findTokenSource(ctx, "mytoken", "aud") + src, err := findTokenSource(ctx, "mytoken", "aud", "") if err != nil { t.Fatal(err) } @@ -214,7 +213,7 @@ func TestHttp2(t *testing.T) { proxy.ServeHTTP(w, r) - if !called{ + if !called { t.Errorf("handler not called") } defer srv.Close()