From 4361989d36a099488f19c67898b3801363a67c4c Mon Sep 17 00:00:00 2001 From: Minh Vu Date: Sat, 13 Jun 2026 17:41:02 +0200 Subject: [PATCH] fix(server): include nodes and creds in request dedupe Signed-off-by: Minh Vu --- pkg/server/http_server.go | 2 + pkg/server/http_server_test.go | 55 ++++++++++ pkg/server/trailing_delay_queue.go | 3 +- pkg/server/trailing_delay_queue_test.go | 73 +++++++++++++ pkg/topology/request.go | 134 +++++++++++++++++++++++- pkg/topology/request_test.go | 105 +++++++++++++++++++ 6 files changed, 367 insertions(+), 5 deletions(-) diff --git a/pkg/server/http_server.go b/pkg/server/http_server.go index d16000dc..d47cce91 100644 --- a/pkg/server/http_server.go +++ b/pkg/server/http_server.go @@ -204,6 +204,8 @@ func readRequest(w http.ResponseWriter, r *http.Request) *topology.Request { } } + tr.Provider.Creds = checkCredentials(tr.Provider.Creds, srv.cfg.Credentials) + klog.Info(tr.String()) if err = validate(tr); err != nil { diff --git a/pkg/server/http_server_test.go b/pkg/server/http_server_test.go index cdecca90..b62ff107 100644 --- a/pkg/server/http_server_test.go +++ b/pkg/server/http_server_test.go @@ -621,6 +621,61 @@ func TestReadRequest(t *testing.T) { } } +func TestReadRequestAppliesEffectiveCredentials(t *testing.T) { + configCreds := map[string]any{"token": "config-token"} + payloadCreds := map[string]any{"token": "payload-token"} + + srv = &HttpServer{ + cfg: &config.Config{ + Provider: "test", + Engine: "slurm", + Credentials: configCreds, + }, + } + + testCases := []struct { + name string + payload string + expected map[string]any + }{ + { + name: "Test readRequest with config credentials", + payload: fmt.Sprintf(simpleSlurmPayload, "test"), + expected: configCreds, + }, + { + name: "Test readRequest with payload credentials", + payload: `{ + "provider": { + "name": "test", + "creds": { + "token": "payload-token" + } + }, + "engine": { + "name": "slurm" + } + }`, + expected: payloadCreds, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := &http.Request{ + Method: http.MethodPost, + Body: io.NopCloser(bytes.NewBuffer([]byte(tc.payload))), + } + + w := httptest.NewRecorder() + req := readRequest(w, r) + + require.NotNil(t, req) + require.Equal(t, tc.expected, req.Provider.Creds) + }) + } +} + func readInvalidRequest(t *testing.T, payload, msg string) { r := &http.Request{ Method: http.MethodPost, diff --git a/pkg/server/trailing_delay_queue.go b/pkg/server/trailing_delay_queue.go index f8f74376..11fea2f8 100644 --- a/pkg/server/trailing_delay_queue.go +++ b/pkg/server/trailing_delay_queue.go @@ -122,7 +122,8 @@ func (q *TrailingDelayQueue) Get(hash string) *Completion { defer q.mutex.Unlock() if res, ok := q.store.Get(hash); ok { - return res.(*Completion) + completion := *(res.(*Completion)) + return &completion } return &Completion{ diff --git a/pkg/server/trailing_delay_queue_test.go b/pkg/server/trailing_delay_queue_test.go index 9d72248a..40b263fc 100644 --- a/pkg/server/trailing_delay_queue_test.go +++ b/pkg/server/trailing_delay_queue_test.go @@ -88,6 +88,79 @@ func TestVaryingPayload(t *testing.T) { queue.Shutdown() } +func TestVaryingPayloadByNodesAndCredentials(t *testing.T) { + processItem := func(item any) (any, *httperr.Error) { + return item, nil + } + + queue := NewTrailingDelayQueue(processItem, 10*time.Millisecond) + defer queue.Shutdown() + + requests := []*topology.Request{ + { + Provider: topology.Provider{ + Name: "test", + Creds: map[string]any{"token": "a"}, + }, + Engine: topology.Engine{Name: "slurm"}, + Nodes: []topology.ComputeInstances{ + { + Region: "region", + Instances: map[string]string{"instance-1": "node-1"}, + }, + }, + }, + { + Provider: topology.Provider{ + Name: "test", + Creds: map[string]any{"token": "a"}, + }, + Engine: topology.Engine{Name: "slurm"}, + Nodes: []topology.ComputeInstances{ + { + Region: "region", + Instances: map[string]string{"instance-2": "node-2"}, + }, + }, + }, + { + Provider: topology.Provider{ + Name: "test", + Creds: map[string]any{"token": "b"}, + }, + Engine: topology.Engine{Name: "slurm"}, + Nodes: []topology.ComputeInstances{ + { + Region: "region", + Instances: map[string]string{"instance-1": "node-1"}, + }, + }, + }, + } + + submissions := make([]string, 0, len(requests)) + for _, request := range requests { + uid, err := queue.Submit(request) + require.NoError(t, err) + submissions = append(submissions, uid) + } + + for i := 1; i < len(submissions); i++ { + require.NotEqual(t, submissions[i], submissions[i-1]) + } + require.NotEqual(t, submissions[0], submissions[2]) + + require.Eventually(t, func() bool { + for i, uid := range submissions { + res := queue.Get(uid) + if res.Status != http.StatusOK || res.Ret != requests[i] { + return false + } + } + return true + }, time.Second, 10*time.Millisecond) +} + func TestLRU(t *testing.T) { cache, _ := lru.New(3) diff --git a/pkg/topology/request.go b/pkg/topology/request.go index 1c89ea69..1e172153 100644 --- a/pkg/topology/request.go +++ b/pkg/topology/request.go @@ -17,11 +17,22 @@ package topology import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "hash/fnv" "sort" "strings" + "sync" +) + +var ( + credentialHashKey []byte + credentialHashKeyErr error + credentialHashKeyOnce sync.Once ) type Request struct { @@ -46,6 +57,28 @@ type ComputeInstances struct { Instances map[string]string `json:"instances"` // : map } +type requestHashData struct { + Provider providerHashData `json:"provider"` + Engine Engine `json:"engine"` + Nodes []computeInstancesHash `json:"nodes,omitempty"` +} + +type providerHashData struct { + Name string `json:"name"` + Params map[string]any `json:"params,omitempty"` + CredentialDigest string `json:"credentialDigest,omitempty"` +} + +type computeInstancesHash struct { + Region string `json:"region"` + Instances []instanceHash `json:"instances"` +} + +type instanceHash struct { + ID string `json:"id"` + Node string `json:"node"` +} + func NewRequest(prv Provider, eng Engine) *Request { return &Request{ Provider: prv, @@ -141,19 +174,112 @@ func GetNodeNameMap(cis []ComputeInstances) map[string]bool { } func (p *Request) Hash() (string, error) { - dataToHash := Request{ - Provider: Provider{ - Name: p.Provider.Name, - Params: p.Provider.Params, + credentialDigest, err := getCredentialDigest(p.Provider.Creds) + if err != nil { + return "", err + } + + dataToHash := requestHashData{ + Provider: providerHashData{ + Name: p.Provider.Name, + Params: p.Provider.Params, + CredentialDigest: credentialDigest, }, Engine: Engine{ Name: p.Engine.Name, Params: p.Engine.Params, }, + Nodes: canonicalComputeInstances(p.Nodes), } return GetHash(dataToHash) } +func canonicalComputeInstances(nodes []ComputeInstances) []computeInstancesHash { + if len(nodes) == 0 { + return nil + } + + canonical := make([]computeInstancesHash, 0, len(nodes)) + for _, nodeGroup := range nodes { + instances := make([]instanceHash, 0, len(nodeGroup.Instances)) + for id, node := range nodeGroup.Instances { + instances = append(instances, instanceHash{ + ID: id, + Node: node, + }) + } + sort.Slice(instances, func(i, j int) bool { + if instances[i].ID != instances[j].ID { + return instances[i].ID < instances[j].ID + } + return instances[i].Node < instances[j].Node + }) + + canonical = append(canonical, computeInstancesHash{ + Region: nodeGroup.Region, + Instances: instances, + }) + } + + sort.Slice(canonical, func(i, j int) bool { + if canonical[i].Region != canonical[j].Region { + return canonical[i].Region < canonical[j].Region + } + if len(canonical[i].Instances) != len(canonical[j].Instances) { + return len(canonical[i].Instances) < len(canonical[j].Instances) + } + for idx := range canonical[i].Instances { + if canonical[i].Instances[idx].ID != canonical[j].Instances[idx].ID { + return canonical[i].Instances[idx].ID < canonical[j].Instances[idx].ID + } + if canonical[i].Instances[idx].Node != canonical[j].Instances[idx].Node { + return canonical[i].Instances[idx].Node < canonical[j].Instances[idx].Node + } + } + return false + }) + + return canonical +} + +func getCredentialDigest(creds map[string]any) (string, error) { + if len(creds) == 0 { + return "", nil + } + + data, err := json.Marshal(creds) + if err != nil { + return "", fmt.Errorf("failed to marshal credentials for hashing: %v", err) + } + + key, err := getCredentialHashKey() + if err != nil { + return "", err + } + + mac := hmac.New(sha256.New, key) + _, _ = mac.Write(data) + return hex.EncodeToString(mac.Sum(nil)), nil +} + +func getCredentialHashKey() ([]byte, error) { + credentialHashKeyOnce.Do(func() { + credentialHashKey, credentialHashKeyErr = newCredentialHashKey() + }) + if credentialHashKeyErr != nil { + return nil, credentialHashKeyErr + } + return credentialHashKey, nil +} + +func newCredentialHashKey() ([]byte, error) { + key := make([]byte, sha256.Size) + if _, err := rand.Read(key); err != nil { + return nil, fmt.Errorf("failed to generate credential hash key: %v", err) + } + return key, nil +} + func GetHash(obj any) (string, error) { data, err := json.Marshal(obj) if err != nil { diff --git a/pkg/topology/request_test.go b/pkg/topology/request_test.go index bc15807d..bc56cdb2 100644 --- a/pkg/topology/request_test.go +++ b/pkg/topology/request_test.go @@ -177,3 +177,108 @@ func TestGetNodeNames(t *testing.T) { require.ElementsMatch(t, nodeList, GetNodeNameList(cis)) require.Equal(t, nodeMap, GetNodeNameMap(cis)) } + +func TestRequestHashIncludesCanonicalNodes(t *testing.T) { + newRequest := func(nodes []ComputeInstances) *Request { + return &Request{ + Provider: Provider{ + Name: "test", + Params: map[string]any{"providerParam": "value"}, + }, + Engine: Engine{ + Name: "slurm", + Params: map[string]any{"engineParam": "value"}, + }, + Nodes: nodes, + } + } + + nodes := []ComputeInstances{ + { + Region: "region-b", + Instances: map[string]string{"instance-2": "node-2", "instance-1": "node-1"}, + }, + { + Region: "region-a", + Instances: map[string]string{"instance-3": "node-3"}, + }, + } + reorderedNodes := []ComputeInstances{ + { + Region: "region-a", + Instances: map[string]string{"instance-3": "node-3"}, + }, + { + Region: "region-b", + Instances: map[string]string{"instance-1": "node-1", "instance-2": "node-2"}, + }, + } + differentNodes := []ComputeInstances{ + { + Region: "region-b", + Instances: map[string]string{"instance-2": "node-20", "instance-1": "node-1"}, + }, + { + Region: "region-a", + Instances: map[string]string{"instance-3": "node-3"}, + }, + } + emptyNodeGroup := []ComputeInstances{{Region: "region-empty"}} + + hash, err := newRequest(nodes).Hash() + require.NoError(t, err) + reorderedHash, err := newRequest(reorderedNodes).Hash() + require.NoError(t, err) + differentHash, err := newRequest(differentNodes).Hash() + require.NoError(t, err) + noNodesHash, err := newRequest(nil).Hash() + require.NoError(t, err) + emptyNodeGroupHash, err := newRequest(emptyNodeGroup).Hash() + require.NoError(t, err) + + require.Equal(t, hash, reorderedHash) + require.NotEqual(t, hash, differentHash) + require.NotEqual(t, noNodesHash, emptyNodeGroupHash) +} + +func TestRequestHashIncludesCredentialDigest(t *testing.T) { + newRequest := func(creds map[string]any) *Request { + return &Request{ + Provider: Provider{ + Name: "aws", + Creds: creds, + Params: map[string]any{"trimTiers": 1}, + }, + Engine: Engine{ + Name: "slurm", + Params: map[string]any{"plugin": TopologyBlock}, + }, + } + } + + creds := map[string]any{ + "accessKeyId": "id", + "secretAccessKey": "secret", + } + reorderedCreds := map[string]any{ + "secretAccessKey": "secret", + "accessKeyId": "id", + } + differentCreds := map[string]any{ + "accessKeyId": "id", + "secretAccessKey": "other-secret", + } + + hash, err := newRequest(creds).Hash() + require.NoError(t, err) + reorderedHash, err := newRequest(reorderedCreds).Hash() + require.NoError(t, err) + differentHash, err := newRequest(differentCreds).Hash() + require.NoError(t, err) + noCredsHash, err := newRequest(nil).Hash() + require.NoError(t, err) + + require.Equal(t, hash, reorderedHash) + require.NotEqual(t, hash, differentHash) + require.NotEqual(t, hash, noCredsHash) +}