Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion specs/discovery-and-jwks-simplification.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ public Set<String> keyIds(); // unmodifiable, JWKS-endpoint or
public void refresh();
public int consecutiveFailures();
public Instant lastFailedRefresh();
public Instant lastRefreshAttempt();
public Instant lastSuccessfulRefresh();
public Instant nextDueAt();
@Override public void close();
Expand All @@ -185,7 +186,7 @@ For `JWKS.of(...)`:
- `resolve` and `get` work identically to the remote case.
- `refresh()` is a no-op (returns normally — the snapshot is already complete).
- `consecutiveFailures()` returns 0.
- `lastFailedRefresh()`, `lastSuccessfulRefresh()`, `nextDueAt()` return null.
- `lastFailedRefresh()`, `lastRefreshAttempt()`, `lastSuccessfulRefresh()`, `nextDueAt()` return null.
- `close()` is a no-op (no scheduler, no inflight worker, no thread to interrupt).
- `JWKS.of()` with no keys (or `of(List.of())`) is permitted, returns a non-null instance, and is not rejected at construction. `keys()` and `keyIds()` return empty collections; `get(kid)` returns null for any input; `resolve(header)` raises `MissingVerifierException` for any header. This is the same behavior as a remote-backed `JWKS` whose snapshot happens to be empty.

Expand Down
39 changes: 25 additions & 14 deletions specs/jwks-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ public final class JWKSource implements VerifierResolver, AutoCloseable {
// Observability (lock-free reads off the current snapshot)
public Instant lastSuccessfulRefresh(); // null if no successful refresh yet
public Instant lastFailedRefresh(); // null if no failure since the last success
public Instant lastRefreshAttempt(); // null if no attempt yet; advances on success and failure
public int consecutiveFailures(); // 0 on the success path
public Instant nextDueAt(); // earliest time at which a refresh is allowed to start
public Instant nextDueAt(); // scheduler's next-eligible refresh time; on-miss uses lastRefreshAttempt + minRefreshInterval
public Set<String> currentKids(); // unmodifiable snapshot of kids in the cache at call time
}
```
Expand Down Expand Up @@ -123,7 +124,7 @@ public enum CacheControlPolicy {
|---|---|---|
| `scheduledRefresh` | `false` | Most callers want lazy-warm + miss-driven refresh; opt in to a background thread. |
| `refreshInterval` | `60 minutes` | Matches the `max-age` most IdPs publish on JWKS responses; conservative wrt rotation. |
| `refreshOnMiss` | `true` | Unknown `kid` should trigger a fetch. The combination of singleflight + `nextDueAt` bounds amplification. |
| `refreshOnMiss` | `true` | Unknown `kid` should trigger a fetch. Singleflight + the `minRefreshInterval` on-miss debounce bound amplification. |
| `refreshTimeout` | `2 seconds` | Bounds blocking on `resolve()` during a miss. Long enough for healthy networks, short enough to fail fast on a wedged IdP. |
| `minRefreshInterval` | `30 seconds` | Floor for both the scheduler tick rate and the on-miss debounce. Hard cap on amplification under attack. |
| `cacheControlPolicy` | `CLAMP` | Honor IdP-published `max-age` when sane, but never refresh more often than `minRefreshInterval` and never wait longer than `refreshInterval` between refreshes. |
Expand Down Expand Up @@ -197,18 +198,19 @@ The cache state is a single immutable snapshot:
record Snapshot(
Map<String, Verifier> byKid,
Instant fetchedAt, // time of the snapshot's last successful fetch (Instant.EPOCH if never)
Instant nextDueAt, // earliest time at which a refresh is allowed to start
Instant nextDueAt, // earliest time at which the scheduler may start a refresh
int consecutiveFailures, // 0 on the success path
Instant lastFailedRefresh // null if no recorded failure since the last success
Instant lastFailedRefresh, // null if no recorded failure since the last success
Instant lastAttemptAt // time of the snapshot's last refresh attempt, success or failure (Instant.EPOCH if never)
) {}
```

It is held in `AtomicReference<Snapshot>`. Reads (`resolve()`, `currentKids()`, observability getters) load the reference and read fields off the snapshot — there are no locks on the read path.

`build()` performs a synchronous initial load, bounded by `refreshTimeout`:

- **On success:** snapshot installed with `consecutiveFailures=0`, `fetchedAt=now`, `nextDueAt` per §2.4, `lastFailedRefresh=null`.
- **On failure:** snapshot installed with `byKid=emptyMap`, `consecutiveFailures=1`, `fetchedAt=Instant.EPOCH`, `lastFailedRefresh=now`, `nextDueAt` per the failure path in §2.7. Failure is logged at `error`. `build()` returns normally; `lastSuccessfulRefresh()` returns `null`.
- **On success:** snapshot installed with `consecutiveFailures=0`, `fetchedAt=now`, `lastAttemptAt=now`, `nextDueAt` per §2.4, `lastFailedRefresh=null`.
- **On failure:** snapshot installed with `byKid=emptyMap`, `consecutiveFailures=1`, `fetchedAt=Instant.EPOCH`, `lastFailedRefresh=now`, `lastAttemptAt=now`, `nextDueAt` per the failure path in §2.7. Failure is logged at `error`. `build()` returns normally; `lastSuccessfulRefresh()` returns `null`.

Operators wanting fail-fast on initial load check `lastSuccessfulRefresh() == null` after `build()` and act accordingly. The library does not throw from `build()` on a network failure, by design — it preserves the same "availability over freshness" stance as the runtime failure path (§2.7), so a brief IdP outage at boot does not make the application unstartable.

Expand All @@ -221,13 +223,13 @@ Operators wanting fail-fast on initial load check `lastSuccessfulRefresh() == nu
if !v.canVerify(header.alg()): return null
return v
4. if !refreshOnMiss: return null
5. if now < snapshot.nextDueAt: return null // bounded by minRefreshInterval-derived window
5. if now < snapshot.lastAttemptAt + minRefreshInterval: return null // on-miss debounce
6. fresh = singleflight.refresh() // blocks up to refreshTimeout
7. v = fresh.byKid.get(header.kid())
8. apply step 3's canVerify check; return v or null
```

Step 5 is the DoS gate: even if 10,000 concurrent decoders all see the same unknown `kid`, only the first one past the `nextDueAt` window starts a fetch; the rest see `nextDueAt > now` and return `null` immediately.
Step 5 is the DoS gate: even if 10,000 concurrent decoders all see the same unknown `kid`, only the first one past the `minRefreshInterval` debounce starts a fetch; the rest return `null` immediately. The debounce is intentionally distinct from the scheduler's `nextDueAt` (§2.4) — see §2.4.1 for why.

If step 6's await elapses at `refreshTimeout` before the in-flight refresh completes, the in-flight fetch continues asynchronously; the await returns the current `ref.get()` (the pre-refresh snapshot), the on-miss path returns `null`, and a later decode benefits from the eventually-installed snapshot. The timeout is not a refresh failure; see §2.7.4.

Expand All @@ -249,21 +251,30 @@ Synchronous, blocking, singleflight-coalesced. If a refresh is already in flight

The snapshot is updated per §2.7 (prior keys preserved, `consecutiveFailures` incremented, `nextDueAt` advanced) before the exception leaves the method, *except* for `TIMEOUT` — which does not signal a refresh failure (see §2.7.4). Operators can dispatch on `e.reason()` (e.g., escalate `NON_2XX` to a health probe, swallow `TIMEOUT` quietly) without inspecting the cause chain.

`refresh()` ignores `nextDueAt`. The gate exists to defend against amplification on the on-miss / scheduler paths, not to throttle deliberate operator action.
`refresh()` ignores both `nextDueAt` and the on-miss debounce. Those gates exist to defend against amplification on the scheduler and on-miss paths respectively, not to throttle deliberate operator action.

If the source has been closed, `refresh()` is a no-op and logs at `debug`.

### 2.4 The `nextDueAt` watermark

`nextDueAt` is the unified "when is the next refresh allowed to start" signal. It is consulted by both the scheduler tick (§2.5) and the on-miss path (§2.2), so the two paths cannot fight each other.
`nextDueAt` is the scheduler's "when is the next refresh allowed to start" signal. It is consulted only by the scheduler tick (§2.5). The on-miss path uses a separate debounce (§2.4.1).

After a successful refresh:
- `nextDueAt = max(now + minRefreshInterval, now + chosenInterval)` where `chosenInterval` depends on `cacheControlPolicy` (see §2.6).
- `nextDueAt = now + chosenInterval` where `chosenInterval` depends on `cacheControlPolicy` (see §2.6). `chosenInterval` is itself clamped to `[minRefreshInterval, refreshInterval]`, so `nextDueAt` is always at least `now + minRefreshInterval`.

After a failed refresh:
- `nextDueAt = now + backoff(consecutiveFailures)` (see §2.7).

The on-miss path checks `now < nextDueAt` to decide whether to debounce. The scheduler tick checks the same condition before dispatching a refresh. There is no second cooldown variable.
### 2.4.1 The on-miss debounce

The on-miss path (§2.2) uses `lastAttemptAt + minRefreshInterval` as its debounce, independent of `nextDueAt`. `lastAttemptAt` is set to `now` on every refresh attempt — success or failure.

The two watermarks serve different purposes and cannot be unified:

- `nextDueAt` paces the scheduler against the IdP's published cache directive (often 5–60 minutes via `Cache-Control: max-age`). Unifying it with the on-miss debounce would block on-miss refresh for the full `chosenInterval` after a successful fetch, defeating the rotation use case: an IdP that rotates keys mid-interval would produce JWTs with unknown `kid`s that the source refuses to fetch keys for.
- The on-miss debounce caps amplification: 10,000 concurrent unknown-`kid` resolves are coalesced by singleflight, and subsequent waves are throttled to one fetch per `minRefreshInterval` (default 30s).

The two watermarks are independent. A successful refresh advances both; a refresh inside the `nextDueAt` window dispatched from the on-miss path will subsequently advance the scheduler's `nextDueAt` too (via the singleflight worker installing a fresh snapshot).

### 2.5 Scheduler tick

Expand Down Expand Up @@ -312,9 +323,9 @@ When a refresh raises (network failure, non-2xx response, parse failure, etc.):
#### 2.7.3 Caller-visible behavior during failure

- `resolve()` continues to return cached verifiers from the prior successful snapshot.
- Misses against unknown `kid`s return `null` immediately once `nextDueAt` is in the future.
- Misses against unknown `kid`s return `null` immediately while the on-miss debounce (`lastRefreshAttempt + minRefreshInterval`) is in the future.
- `lastSuccessfulRefresh()` does not advance; an integrator monitoring this can alert on staleness.
- `lastFailedRefresh()` advances to `now`; `consecutiveFailures()` increments.
- `lastFailedRefresh()` advances to `now`; `lastRefreshAttempt()` also advances to `now`; `consecutiveFailures()` increments.

There is no separate "circuit breaker open" state; the exponential-backoff `nextDueAt` *is* the circuit. After enough consecutive failures, `nextDueAt` settles at `now + refreshInterval` and the source effectively reverts to "try every full interval until something changes".

Expand Down
32 changes: 17 additions & 15 deletions src/main/java/org/lattejava/jwt/jwks/JWKS.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private JWKS(Builder b) {
this.source = b.source;
this.staticMode = false;
this.url = b.url();
this.ref.set(new Snapshot(List.of(), Map.of(), Map.of(), Instant.EPOCH, Instant.EPOCH, 0, null));
this.ref.set(new Snapshot(List.of(), Map.of(), Map.of(), Instant.EPOCH, Instant.EPOCH, 0, null, Instant.EPOCH));
CompletableFuture<Snapshot> initial = singleflightRefresh();
try {
initial.get(refreshTimeout.toMillis(), TimeUnit.MILLISECONDS);
Expand Down Expand Up @@ -134,7 +134,8 @@ private JWKS(List<JSONWebKey> staticKeys) {
Instant.EPOCH,
Instant.EPOCH,
0,
null));
null,
Instant.EPOCH));
}

// --- Public static methods ---
Expand Down Expand Up @@ -321,10 +322,6 @@ private static JWKSFetchException classifyFetchFailure(String msg, Throwable cau
return new JWKSFetchException(JWKSFetchException.Reason.PARSE, msg, cause);
}

private static Duration maxOf(Duration a, Duration b) {
return a.compareTo(b) >= 0 ? a : b;
}

private static List<JSONWebKey> parseJWKSResponseKeys(HttpURLConnection conn, InputStream is, FetchLimits limits) {
Map<String, Object> map = HardenedJSON.parse(is, limits);
Object keys = map.get("keys");
Expand Down Expand Up @@ -397,6 +394,12 @@ public Instant lastFailedRefresh() {
return ref.get().lastFailedRefresh();
}

public Instant lastRefreshAttempt() {
if (staticMode) return null;
Snapshot s = ref.get();
return s.lastAttemptAt().equals(Instant.EPOCH) ? null : s.lastAttemptAt();
}

public Instant lastSuccessfulRefresh() {
if (staticMode) return null;
Snapshot s = ref.get();
Expand Down Expand Up @@ -461,7 +464,7 @@ public Verifier resolve(Header header) {
if (!refreshOnMiss) return null;

Instant now = Instant.now(clock);
if (now.isBefore(snapshot.nextDueAt())) return null;
if (now.isBefore(snapshot.lastAttemptAt().plus(minRefreshInterval))) return null;

CompletableFuture<Snapshot> fut = singleflightRefresh();
try {
Expand All @@ -485,8 +488,7 @@ public Verifier resolve(Header header) {

/**
* Returns the {@link Duration} to use for {@code nextDueAt}. Honors the server's {@code Cache-Control: max-age} when
* {@link CacheControlPolicy#CLAMP} is configured, clamped into {@code [minRefreshInterval, refreshInterval]}; the
* caller applies the {@code minRefreshInterval} floor again as a final guard.
* {@link CacheControlPolicy#CLAMP} is configured, clamped into {@code [minRefreshInterval, refreshInterval]}.
*/
private Duration chosenInterval(JWKSResponse resp) {
if (cacheControlPolicy == CacheControlPolicy.IGNORE) return refreshInterval;
Expand Down Expand Up @@ -585,15 +587,14 @@ private Snapshot doRefreshOrThrow(Snapshot prev) {
throw new JWKSFetchException(JWKSFetchException.Reason.EMPTY_RESULT,
"JWKS refresh produced no usable keys after JWK conversion");
}
Duration chosen = chosenInterval(resp);
Instant nextDue = now.plus(maxOf(minRefreshInterval, chosen));
Instant nextDue = now.plus(chosenInterval(resp));
if (logger.isInfoEnabled()) {
logger.info("JWKS refresh succeeded; kids=[" + byKid.keySet() + "]");
}
List<JSONWebKey> allKeysSnapshot = Collections.unmodifiableList(new ArrayList<>(allKeys));
Map<String, Verifier> byKidSnapshot = Collections.unmodifiableMap(new LinkedHashMap<>(byKid));
Map<String, JSONWebKey> jwkByKidSnapshot = Collections.unmodifiableMap(new LinkedHashMap<>(jwkByKid));
return new Snapshot(allKeysSnapshot, byKidSnapshot, jwkByKidSnapshot, now, nextDue, 0, null);
return new Snapshot(allKeysSnapshot, byKidSnapshot, jwkByKidSnapshot, now, nextDue, 0, null, now);
}

/**
Expand Down Expand Up @@ -629,7 +630,7 @@ private Snapshot failureSnapshot(Snapshot prev, Instant now, Throwable cause) {
}
}
}
return new Snapshot(allKeys, byKid, jwkByKid, fetchedAt, nextDue, next, now);
return new Snapshot(allKeys, byKid, jwkByKid, fetchedAt, nextDue, next, now, now);
}

private JWKSResponse fetchFromSource() {
Expand Down Expand Up @@ -821,7 +822,7 @@ public JWKS build() {
JWKS jwks = new JWKS(this);
if (failFast && jwks.initialFetchFailure != null) {
Throwable f = jwks.initialFetchFailure;
if (jwks.scheduler != null) jwks.scheduler.shutdownNow();
jwks.close();
if (f instanceof JWKSFetchException jfe) throw jfe;
if (f instanceof OpenIDConnectException oce) throw oce;
throw new JWKSFetchException(JWKSFetchException.Reason.PARSE, "Initial JWKS fetch failed", f);
Expand Down Expand Up @@ -906,6 +907,7 @@ record Snapshot(
Instant fetchedAt,
Instant nextDueAt,
int consecutiveFailures,
Instant lastFailedRefresh) {
Instant lastFailedRefresh,
Instant lastAttemptAt) {
}
}
Loading
Loading