Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,5 @@
# Executables
*.exe

# Internal Test files
*_test.go

# internal tools
/tools
70 changes: 70 additions & 0 deletions pkg/connector/auth_recovery.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package connector

import (
"context"
"fmt"

"github.com/highesttt/matrix-line-messenger/pkg/line"
)

type lineCallDeps[T any] struct {
newClient func() *line.Client
recover func(context.Context) error
isAuthError func(error) bool
call func(*line.Client) (T, error)
}

func callLineWithRecovery[T any](ctx context.Context, client *line.Client, deps lineCallDeps[T]) (*line.Client, T, error) {
if client == nil {
client = deps.newClient()
}
res, err := deps.call(client)
if err == nil || !deps.isAuthError(err) {
return client, res, err
}

if errRecover := deps.recover(ctx); errRecover != nil {
var zero T
return client, zero, fmt.Errorf("failed to recover token after LINE auth error: %w", errRecover)
}

client = deps.newClient()
res, err = deps.call(client)
return client, res, err
}

func (lc *LineClient) isTokenError(err error) bool {
if line.IsNoUsableE2EEGroupKey(err) || line.IsNoUsableE2EEPublicKey(err) {
return false
}
return line.IsAuthError(err)
}

func (lc *LineClient) callLine(ctx context.Context, call func(*line.Client) error) (*line.Client, error) {
return lc.callLineUsing(ctx, nil, call)
}

func (lc *LineClient) callLineUsing(ctx context.Context, client *line.Client, call func(*line.Client) error) (*line.Client, error) {
client, _, err := callLineWithRecovery(ctx, client, lineCallDeps[struct{}]{
newClient: func() *line.Client { return lc.newClient() },
recover: lc.recoverToken,
isAuthError: lc.isTokenError,
call: func(client *line.Client) (struct{}, error) {
return struct{}{}, call(client)
},
})
return client, err
}

func callLineResult[T any](lc *LineClient, ctx context.Context, call func(*line.Client) (T, error)) (*line.Client, T, error) {
return callLineResultUsing(lc, ctx, nil, call)
}

func callLineResultUsing[T any](lc *LineClient, ctx context.Context, client *line.Client, call func(*line.Client) (T, error)) (*line.Client, T, error) {
return callLineWithRecovery(ctx, client, lineCallDeps[T]{
newClient: func() *line.Client { return lc.newClient() },
recover: lc.recoverToken,
isAuthError: lc.isTokenError,
call: call,
})
}
261 changes: 261 additions & 0 deletions pkg/connector/auth_recovery_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
package connector

import (
"context"
"errors"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/highesttt/matrix-line-messenger/pkg/line"
)

var (
errAuthRequired = errors.New(`API error 400: {"code":10051,"message":"RESPONSE_ERROR","data":{"name":"TalkException","code":119,"reason":"Access token refresh required"}}`)
errNotMember = errors.New(`API error 400: {"code":10051,"data":{"name":"TalkException","code":10,"reason":"not a member"}}`)
errNetwork = errors.New("request failed: dial tcp: i/o timeout")
)

func TestCallLineWithRecovery(t *testing.T) {
tests := []struct {
name string
callErrors []error
recoverErr error
wantCalls int
wantRecover int
wantErr error
wantErrPrefix string
}{
{
name: "success without recovery",
callErrors: []error{nil},
wantCalls: 1,
},
{
name: "non auth error is returned without recovery",
callErrors: []error{errNotMember},
wantCalls: 1,
wantRecover: 0,
wantErr: errNotMember,
},
{
name: "network error is returned without recovery",
callErrors: []error{errNetwork},
wantCalls: 1,
wantRecover: 0,
wantErr: errNetwork,
},
{
name: "auth error recovers and retries once",
callErrors: []error{errAuthRequired, nil},
wantCalls: 2,
wantRecover: 1,
},
{
name: "recovery failure is returned without retry",
callErrors: []error{errAuthRequired},
recoverErr: errors.New("refresh failed"),
wantCalls: 1,
wantRecover: 1,
wantErrPrefix: "failed to recover token after LINE auth error",
},
{
name: "retry auth error is not retried again",
callErrors: []error{errAuthRequired, errAuthRequired},
wantCalls: 2,
wantRecover: 1,
wantErr: errAuthRequired,
},
{
name: "retry non auth error is returned to caller",
callErrors: []error{errAuthRequired, errors.New("Extension does not support file upload")},
wantCalls: 2,
wantRecover: 1,
wantErrPrefix: "Extension does not support file upload",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var calls int
var recoveries int

_, _, err := callLineWithRecovery(context.Background(), nil, lineCallDeps[struct{}]{
newClient: func() *line.Client {
return line.NewClient("token")
},
recover: func(context.Context) error {
recoveries++
return tt.recoverErr
},
isAuthError: line.IsAuthError,
call: func(*line.Client) (struct{}, error) {
err := tt.callErrors[calls]
calls++
return struct{}{}, err
},
})

if calls != tt.wantCalls {
t.Fatalf("calls = %d, want %d", calls, tt.wantCalls)
}
if recoveries != tt.wantRecover {
t.Fatalf("recoveries = %d, want %d", recoveries, tt.wantRecover)
}
if tt.wantErr != nil && !errors.Is(err, tt.wantErr) {
t.Fatalf("err = %v, want %v", err, tt.wantErr)
}
if tt.wantErrPrefix != "" {
if err == nil || !strings.Contains(err.Error(), tt.wantErrPrefix) {
t.Fatalf("err = %v, want containing %q", err, tt.wantErrPrefix)
}
}
if tt.wantErr == nil && tt.wantErrPrefix == "" && err != nil {
t.Fatalf("unexpected err: %v", err)
}
})
}
}

func TestCallLineWithRecoveryReusesClientUntilRecovery(t *testing.T) {
ctx := context.Background()
initialClient := line.NewClient("initial")
refreshedClient := line.NewClient("refreshed")
var newClients int
var calls []string

client, _, err := callLineWithRecovery(ctx, initialClient, lineCallDeps[struct{}]{
newClient: func() *line.Client {
newClients++
return refreshedClient
},
recover: func(context.Context) error {
return nil
},
isAuthError: line.IsAuthError,
call: func(client *line.Client) (struct{}, error) {
calls = append(calls, client.AccessToken)
if len(calls) == 1 {
return struct{}{}, errAuthRequired
}
return struct{}{}, nil
},
})
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if client != refreshedClient {
t.Fatal("expected recovered client to be returned")
}
if newClients != 1 {
t.Fatalf("new clients = %d, want 1", newClients)
}
if len(calls) != 2 || calls[0] != "initial" || calls[1] != "refreshed" {
t.Fatalf("calls used clients %v, want [initial refreshed]", calls)
}
}

func TestCallLineWithRecoveryUsesProvidedClientWithoutRecreating(t *testing.T) {
ctx := context.Background()
initialClient := line.NewClient("initial")
var newClients int

client, _, err := callLineWithRecovery(ctx, initialClient, lineCallDeps[struct{}]{
newClient: func() *line.Client {
newClients++
return line.NewClient("unexpected")
},
recover: func(context.Context) error { return nil },
isAuthError: line.IsAuthError,
call: func(client *line.Client) (struct{}, error) {
if client.AccessToken != "initial" {
t.Fatalf("client token = %q, want initial", client.AccessToken)
}
return struct{}{}, nil
},
})
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if client != initialClient {
t.Fatal("expected provided client to be returned")
}
if newClients != 0 {
t.Fatalf("new clients = %d, want 0", newClients)
}
}

func TestLineClientIsTokenErrorExcludesE2EEErrors(t *testing.T) {
lc := &LineClient{}
if !lc.isTokenError(errAuthRequired) {
t.Fatal("expected auth-required error to be classified as token error")
}
if lc.isTokenError(line.ErrNoUsableE2EEGroupKey) {
t.Fatal("E2EE group key errors must not trigger token recovery")
}
if lc.isTokenError(line.ErrNoUsableE2EEPublicKey) {
t.Fatal("E2EE public key errors must not trigger token recovery")
}
}

func TestRunTokenRecoverySkipsRecentRecovery(t *testing.T) {
lc := &LineClient{recoverTime: time.Now()}
var calls int

err := lc.runTokenRecovery(context.Background(), func(context.Context) error {
calls++
return nil
})
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if calls != 0 {
t.Fatalf("recovery calls = %d, want 0", calls)
}
}

func TestRunTokenRecoverySerializesConcurrentRecovery(t *testing.T) {
var lc LineClient
var calls int32
started := make(chan struct{})
release := make(chan struct{})

recover := func(context.Context) error {
if atomic.AddInt32(&calls, 1) == 1 {
close(started)
<-release
}
return nil
}

var wg sync.WaitGroup
errs := make(chan error, 4)
for i := 0; i < 4; i++ {
wg.Add(1)
go func() {
defer wg.Done()
errs <- lc.runTokenRecovery(context.Background(), recover)
}()
}

select {
case <-started:
case <-time.After(time.Second):
t.Fatal("first recovery did not start")
}

close(release)
wg.Wait()
close(errs)

for err := range errs {
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("recovery calls = %d, want 1", got)
}
}
Loading