Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions internal/datasource/wikidata/gowikidata.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,47 @@ type Gowikidata struct {
}

func (w Gowikidata) NewSearch(search string, language string) (SearchEntitiesRequest, error) {
return gowikidata.NewSearch(search, language)
req, err := gowikidata.NewSearch(search, language)
if err != nil {
return nil, err
}
return searchRequest{url: req.URL}, nil
}

func (w Gowikidata) NewGetEntities(ids []string) (GetEntitiesRequest, error) {
request, err := gowikidata.NewGetEntities(ids)
if err != nil {
return nil, err
}
return entitiesRequest{req: request}, nil
}

type searchRequest struct {
url string
}

return EntitiesRequest{request}, err
func (s searchRequest) Get() (*gowikidata.SearchEntitiesResponse, error) {
var response gowikidata.SearchEntitiesResponse
if err := apiHTTPClient.getJSON(s.url, &response); err != nil {
return nil, err
}
return &response, nil
}

type EntitiesRequest struct {
type entitiesRequest struct {
req *gowikidata.WikiDataGetEntitiesRequest
}

func (e EntitiesRequest) SetProps(props []string) {
func (e entitiesRequest) SetProps(props []string) {
e.req.SetProps(props)
}
func (e EntitiesRequest) SetLanguages(languages []string) {
func (e entitiesRequest) SetLanguages(languages []string) {
e.req.SetLanguages(languages)
}
func (e EntitiesRequest) Get() (*map[string]gowikidata.Entity, error) {
return e.req.Get()
func (e entitiesRequest) Get() (*map[string]gowikidata.Entity, error) {
var response gowikidata.GetEntitiesResponse
if err := apiHTTPClient.getJSON(e.req.URL, &response); err != nil {
return nil, err
}
return &response.Entities, nil
}
114 changes: 114 additions & 0 deletions internal/datasource/wikidata/httpclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package wikidata

import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strconv"
"sync"
"time"
)

const defaultRetryAfter = 60 * time.Second

var apiHTTPClient = newRateLimitedHTTPClient()

type rateLimitedHTTPClient struct {
mu sync.Mutex
resumeAt time.Time
client *http.Client
}

func newRateLimitedHTTPClient() *rateLimitedHTTPClient {
return &rateLimitedHTTPClient{
client: &http.Client{},
}
}

func (c *rateLimitedHTTPClient) getJSON(url string, dest any) error {
for {
c.waitUntil("")

req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return err
}
for key, value := range wikidataAPIHeaders() {
req.Header.Set(key, value)
}

resp, err := c.client.Do(req)
if err != nil {
return err
}

if resp.StatusCode == http.StatusTooManyRequests {
_ = resp.Body.Close()
c.waitUntil(resp.Header.Get("Retry-After"))
continue
}

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
_ = resp.Body.Close()
return fmt.Errorf("request failed with status code %d: %s", resp.StatusCode, body)
}

body, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
return err
}
return json.Unmarshal(body, dest)
}
}

// waitUntil blocks until the shared pause deadline has passed. When retryAfterHeader
// is non-empty, the deadline is extended using the Retry-After header value.
func (c *rateLimitedHTTPClient) waitUntil(retryAfterHeader string) {
c.mu.Lock()
var rateLimitWait time.Duration
if retryAfterHeader != "" {
rateLimitWait = parseRetryAfter(retryAfterHeader)
resumeAt := time.Now().Add(rateLimitWait)
if resumeAt.After(c.resumeAt) {
c.resumeAt = resumeAt
}
}
wait := time.Until(c.resumeAt)
c.mu.Unlock()
if retryAfterHeader != "" {
log.Printf("Wikidata rate limit reached, waiting %s before retrying", rateLimitWait)
}
if wait > 0 {
time.Sleep(wait)
}
}

func parseRetryAfter(header string) time.Duration {
if header == "" {
return defaultRetryAfter
}
if seconds, err := strconv.Atoi(header); err == nil {
if seconds <= 0 {
return defaultRetryAfter
}
return time.Duration(seconds) * time.Second
}
if retryTime, err := http.ParseTime(header); err == nil {
wait := time.Until(retryTime)
if wait <= 0 {
return defaultRetryAfter
}
return wait
}
return defaultRetryAfter
}

func wikidataAPIHeaders() map[string]string {
return map[string]string{
"User-Agent": "Mozilla/5.0 (compatible; coreander/1.0; +https://github.com/svera/coreander)",
}
}
93 changes: 93 additions & 0 deletions internal/datasource/wikidata/httpclient_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package wikidata

import (
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
)

func TestParseRetryAfter(t *testing.T) {
t.Run("seconds", func(t *testing.T) {
if got := parseRetryAfter("30"); got != 30*time.Second {
t.Fatalf("expected 30s, got %s", got)
}
})

t.Run("http date", func(t *testing.T) {
retryTime := time.Now().Add(45 * time.Second).UTC()
got := parseRetryAfter(retryTime.Format(http.TimeFormat))
if got < 44*time.Second || got > 46*time.Second {
t.Fatalf("expected about 45s, got %s", got)
}
})

t.Run("missing header uses default", func(t *testing.T) {
if got := parseRetryAfter(""); got != defaultRetryAfter {
t.Fatalf("expected default %s, got %s", defaultRetryAfter, got)
}
})
}

func TestRateLimitedHTTPClientWaitsForRetryAfter(t *testing.T) {
var requests atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if requests.Add(1) == 1 {
w.Header().Set("Retry-After", "1")
w.WriteHeader(http.StatusTooManyRequests)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"searchinfo":[],"search":[]}`))
}))
defer server.Close()

client := newRateLimitedHTTPClient()
client.client = server.Client()

start := time.Now()
var payload map[string]any
if err := client.getJSON(server.URL, &payload); err != nil {
t.Fatal(err)
}
if requests.Load() != 2 {
t.Fatalf("expected 2 requests, got %d", requests.Load())
}
if elapsed := time.Since(start); elapsed < time.Second {
t.Fatalf("expected to wait at least 1s after 429, took %s", elapsed)
}
}

func TestRateLimitedHTTPClientBlocksConcurrentRequests(t *testing.T) {
var requests atomic.Int32
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
count := requests.Add(1)
if count == 1 {
w.Header().Set("Retry-After", "1")
w.WriteHeader(http.StatusTooManyRequests)
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"entities":{}}`))
}))
defer server.Close()

client := newRateLimitedHTTPClient()
client.client = server.Client()

start := time.Now()
done := make(chan struct{}, 2)
for range 2 {
go func() {
var payload map[string]any
_ = client.getJSON(server.URL, &payload)
done <- struct{}{}
}()
}
<-done
<-done
if elapsed := time.Since(start); elapsed < time.Second {
t.Fatalf("expected concurrent requests to honor shared wait, took %s", elapsed)
}
}
4 changes: 4 additions & 0 deletions internal/datasource/wikidata/mockserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ func NewMockServer(t *testing.T, fixturePath string) *httptest.Server {
}
if queryValues.Get("action") == "wbgetentities" {
id := queryValues.Get("ids")
if strings.Contains(id, "|") {
parts := strings.Split(id, "|")
id = parts[0]
}
returnResponse(fmt.Sprintf("wbgetentities-%s", id), w, fixturePath)
return
}
Expand Down
Loading